Commit 536b16d6 by Kálmán Viktor

Merge remote-tracking branch 'origin/master' into feature-template-wizard

Conflicts:
	circle/dashboard/static/dashboard/dashboard.css
parents 7af77c17 be367e47
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from hashlib import sha224 from hashlib import sha224
from itertools import chain, imap
from logging import getLogger from logging import getLogger
from time import time from time import time
...@@ -56,12 +57,52 @@ activity_context = contextmanager(activitycontextimpl) ...@@ -56,12 +57,52 @@ activity_context = contextmanager(activitycontextimpl)
activity_code_separator = '.' activity_code_separator = '.'
def has_prefix(activity_code, *prefixes):
"""Determine whether the activity code has the specified prefix.
E.g.: has_prefix('foo.bar.buz', 'foo.bar') == True
has_prefix('foo.bar.buz', 'foo', 'bar') == True
has_prefix('foo.bar.buz', 'foo.bar', 'buz') == True
has_prefix('foo.bar.buz', 'foo', 'bar', 'buz') == True
has_prefix('foo.bar.buz', 'foo', 'buz') == False
"""
equal = lambda a, b: a == b
act_code_parts = split_activity_code(activity_code)
prefixes = chain(*imap(split_activity_code, prefixes))
return all(imap(equal, act_code_parts, prefixes))
def has_suffix(activity_code, *suffixes):
"""Determine whether the activity code has the specified suffix.
E.g.: has_suffix('foo.bar.buz', 'bar.buz') == True
has_suffix('foo.bar.buz', 'bar', 'buz') == True
has_suffix('foo.bar.buz', 'foo.bar', 'buz') == True
has_suffix('foo.bar.buz', 'foo', 'bar', 'buz') == True
has_suffix('foo.bar.buz', 'foo', 'buz') == False
"""
equal = lambda a, b: a == b
act_code_parts = split_activity_code(activity_code)
suffixes = list(chain(*imap(split_activity_code, suffixes)))
return all(imap(equal, reversed(act_code_parts), reversed(suffixes)))
def join_activity_code(*args): def join_activity_code(*args):
"""Join the specified parts into an activity code. """Join the specified parts into an activity code.
:returns: Activity code string.
""" """
return activity_code_separator.join(args) return activity_code_separator.join(args)
def split_activity_code(activity_code):
"""Split the specified activity code into its parts.
:returns: A list of activity code parts.
"""
return activity_code.split(activity_code_separator)
class ActivityModel(TimeStampedModel): class ActivityModel(TimeStampedModel):
activity_code = CharField(max_length=100, verbose_name=_('activity code')) activity_code = CharField(max_length=100, verbose_name=_('activity code'))
parent = ForeignKey('self', blank=True, null=True, related_name='children') parent = ForeignKey('self', blank=True, null=True, related_name='children')
......
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
# You should have received a copy of the GNU General Public License along # You should have received a copy of the GNU General Public License along
# with CIRCLE. If not, see <http://www.gnu.org/licenses/>. # with CIRCLE. If not, see <http://www.gnu.org/licenses/>.
from inspect import getargspec
from logging import getLogger from logging import getLogger
from .models import activity_context from .models import activity_context, has_suffix
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
...@@ -31,6 +32,7 @@ class Operation(object): ...@@ -31,6 +32,7 @@ class Operation(object):
async_queue = 'localhost.man' async_queue = 'localhost.man'
required_perms = () required_perms = ()
do_not_call_in_templates = True do_not_call_in_templates = True
abortable = False
def __call__(self, **kwargs): def __call__(self, **kwargs):
return self.call(**kwargs) return self.call(**kwargs)
...@@ -46,23 +48,50 @@ class Operation(object): ...@@ -46,23 +48,50 @@ class Operation(object):
def __prelude(self, kwargs): def __prelude(self, kwargs):
"""This method contains the shared prelude of call and async. """This method contains the shared prelude of call and async.
""" """
skip_auth_check = kwargs.setdefault('system', False) defaults = {'parent_activity': None, 'system': False, 'user': None}
user = kwargs.setdefault('user', None)
parent_activity = kwargs.pop('parent_activity', None) allargs = dict(defaults, **kwargs) # all arguments
auxargs = allargs.copy() # auxiliary (i.e. only for _operation) args
# NOTE: consumed items should be removed from auxargs, and no new items
# should be added to it
skip_auth_check = auxargs.pop('system')
user = auxargs.pop('user')
parent_activity = auxargs.pop('parent_activity')
# check for unexpected keyword arguments
argspec = getargspec(self._operation)
if argspec.keywords is None: # _operation doesn't take ** args
unexpected_kwargs = set(auxargs) - set(argspec.args)
if unexpected_kwargs:
raise TypeError("Operation got unexpected keyword arguments: "
"%s" % ", ".join(unexpected_kwargs))
if not skip_auth_check: if not skip_auth_check:
self.check_auth(user) self.check_auth(user)
self.check_precond() self.check_precond()
return self.create_activity(parent=parent_activity, user=user)
def _exec_op(self, activity, user, **kwargs): activity = self.create_activity(parent=parent_activity, user=user)
return activity, allargs, auxargs
def _exec_op(self, allargs, auxargs):
"""Execute the operation inside the specified activity's context. """Execute the operation inside the specified activity's context.
""" """
with activity_context(activity, on_abort=self.on_abort, # compile arguments for _operation
argspec = getargspec(self._operation)
if argspec.keywords is not None: # _operation takes ** args
arguments = allargs.copy()
else: # _operation doesn't take ** args
arguments = {k: v for (k, v) in allargs.iteritems()
if k in argspec.args}
arguments.update(auxargs)
with activity_context(allargs['activity'], on_abort=self.on_abort,
on_commit=self.on_commit): on_commit=self.on_commit):
return self._operation(activity=activity, user=user, **kwargs) return self._operation(**arguments)
def _operation(self, activity, user, system, **kwargs): def _operation(self, **kwargs):
"""This method is the operation's particular implementation. """This method is the operation's particular implementation.
Deriving classes should implement this method. Deriving classes should implement this method.
...@@ -82,12 +111,10 @@ class Operation(object): ...@@ -82,12 +111,10 @@ class Operation(object):
logger.info("%s called asynchronously on %s with the following " logger.info("%s called asynchronously on %s with the following "
"parameters: %r", self.__class__.__name__, self.subject, "parameters: %r", self.__class__.__name__, self.subject,
kwargs) kwargs)
activity = self.__prelude(kwargs) activity, allargs, auxargs = self.__prelude(kwargs)
return self.async_operation.apply_async(args=(self.id, return self.async_operation.apply_async(
self.subject.pk, args=(self.id, self.subject.pk, activity.pk, allargs, auxargs, ),
activity.pk), queue=self.async_queue)
kwargs=kwargs,
queue=self.async_queue)
def call(self, **kwargs): def call(self, **kwargs):
"""Execute the operation (synchronously). """Execute the operation (synchronously).
...@@ -105,8 +132,9 @@ class Operation(object): ...@@ -105,8 +132,9 @@ class Operation(object):
logger.info("%s called (synchronously) on %s with the following " logger.info("%s called (synchronously) on %s with the following "
"parameters: %r", self.__class__.__name__, self.subject, "parameters: %r", self.__class__.__name__, self.subject,
kwargs) kwargs)
activity = self.__prelude(kwargs) activity, allargs, auxargs = self.__prelude(kwargs)
return self._exec_op(activity=activity, **kwargs) allargs['activity'] = activity
return self._exec_op(allargs, auxargs)
def check_precond(self): def check_precond(self):
pass pass
...@@ -160,6 +188,19 @@ class OperatedMixin(object): ...@@ -160,6 +188,19 @@ class OperatedMixin(object):
else: else:
yield op yield op
def get_operation_from_activity_code(self, activity_code):
"""Get an instance of the Operation corresponding to the specified
activity code.
:returns: A bound instance of an operation, or None if no matching
operation could be found.
"""
for op in getattr(self, operation_registry_name, {}).itervalues():
if has_suffix(activity_code, op.activity_code_suffix):
return op(self)
else:
return None
def register_operation(op_cls, op_id=None, target_cls=None): def register_operation(op_cls, op_id=None, target_cls=None):
"""Register the specified operation with the target class. """Register the specified operation with the target class.
......
...@@ -75,3 +75,41 @@ class OperationTestCase(TestCase): ...@@ -75,3 +75,41 @@ class OperationTestCase(TestCase):
patch.object(Operation, 'create_activity'), \ patch.object(Operation, 'create_activity'), \
patch.object(Operation, '_exec_op'): patch.object(Operation, '_exec_op'):
op.call(system=True) op.call(system=True)
def test_no_exception_for_more_arguments_when_operation_takes_kwargs(self):
class KwargOp(Operation):
activity_code_suffix = 'test'
id = 'test'
def _operation(self, **kwargs):
pass
op = KwargOp(MagicMock())
with patch.object(KwargOp, 'create_activity'), \
patch.object(KwargOp, '_exec_op'):
op.call(system=True, foo=42)
def test_exception_for_unexpected_arguments(self):
class TestOp(Operation):
activity_code_suffix = 'test'
id = 'test'
def _operation(self):
pass
op = TestOp(MagicMock())
with patch.object(TestOp, 'create_activity'), \
patch.object(TestOp, '_exec_op'):
self.assertRaises(TypeError, op.call, system=True, foo=42)
def test_exception_for_missing_arguments(self):
class TestOp(Operation):
activity_code_suffix = 'test'
id = 'test'
def _operation(self, foo):
pass
op = TestOp(MagicMock())
with patch.object(TestOp, 'create_activity'):
self.assertRaises(TypeError, op.call, system=True)
...@@ -23,6 +23,46 @@ html { ...@@ -23,6 +23,46 @@ html {
padding-right: 15px; padding-right: 15px;
} }
/* values for 45px tall navbar */
.navbar {
min-height: 45px;
}
.navbar-brand {
height: 45px;
padding: 12.5px 12.5px;
}
.navbar-toggle {
margin-top: 5.5px;
margin-bottom: 5.5px;
}
.navbar-form {
margin-top: 5.5px;
margin-bottom: 5.5px;
}
.navbar-btn {
margin-top: 5.5px;
margin-bottom: 5.5px;
}
.navbar-btn.btn-sm {
margin-top: 7.5px;
margin-bottom: 7.5px;
}
.navbar-btn.btn-xs {
margin-top: 11.5px;
margin-bottom: 11.5px;
}
.navbar-text {
margin-top: 12.5px;
margin-bottom: 12.5px;
}
/* --- */
/* Responsive: Portrait tablets and up */ /* Responsive: Portrait tablets and up */
@media screen and (min-width: 768px) { @media screen and (min-width: 768px) {
/* Let the jumbotron breathe */ /* Let the jumbotron breathe */
...@@ -33,6 +73,12 @@ html { ...@@ -33,6 +73,12 @@ html {
.body-content { .body-content {
padding: 0; padding: 0;
} }
.navbar-nav > li > a {
padding-top: 12.5px;
padding-bottom: 12.5px;
}
} }
.no-margin { .no-margin {
margin: 0!important; margin: 0!important;
...@@ -552,3 +598,12 @@ footer a, footer a:hover, footer a:visited { ...@@ -552,3 +598,12 @@ footer a, footer a:hover, footer a:visited {
#ops { #ops {
padding: 15px 0 15px 15px; padding: 15px 0 15px 15px;
} }
#vm-access-table th:last-child, #vm-access-table td:last-child,
#template-access-table th:last-child, #template-access-table td:last-child {
text-align: center;
}
#notifications-button {
margin: 0;
}
...@@ -183,6 +183,7 @@ $(function() { ...@@ -183,6 +183,7 @@ $(function() {
$("#vm-details-h1-name").hide(); $("#vm-details-h1-name").hide();
$("#vm-details-rename").css('display', 'inline'); $("#vm-details-rename").css('display', 'inline');
$("#vm-details-rename-name").focus(); $("#vm-details-rename-name").focus();
return false;
}); });
/* rename in home tab */ /* rename in home tab */
...@@ -190,6 +191,7 @@ $(function() { ...@@ -190,6 +191,7 @@ $(function() {
$(".vm-details-home-edit-name-click").hide(); $(".vm-details-home-edit-name-click").hide();
$("#vm-details-home-rename").show(); $("#vm-details-home-rename").show();
$("input", $("#vm-details-home-rename")).focus(); $("input", $("#vm-details-home-rename")).focus();
return false;
}); });
/* rename ajax */ /* rename ajax */
...@@ -219,6 +221,11 @@ $(function() { ...@@ -219,6 +221,11 @@ $(function() {
$(".vm-details-home-edit-description-click").click(function() { $(".vm-details-home-edit-description-click").click(function() {
$(".vm-details-home-edit-description-click").hide(); $(".vm-details-home-edit-description-click").hide();
$("#vm-details-home-description").show(); $("#vm-details-home-description").show();
var ta = $("#vm-details-home-description textarea");
var tmp = ta.val();
ta.val("");
ta.focus();
ta.val(tmp)
return false; return false;
}); });
......
...@@ -26,7 +26,9 @@ ...@@ -26,7 +26,9 @@
<body> <body>
<div class="navbar navbar-inverse navbar-fixed-top"> <div class="navbar navbar-inverse navbar-fixed-top">
<div class="navbar-header"> <div class="navbar-header">
<a class="navbar-brand" href="{% url "dashboard.index" %}">CIRCLE</a> <a class="navbar-brand" href="{% url "dashboard.index" %}" style="padding: 10px 15px;">
<img src="{{ STATIC_URL}}dashboard/img/logo.png" style="height: 25px;"/>
</a>
<button type="button" class="navbar-toggle" data-toggle="collapse" data-target=".navbar-collapse"> <button type="button" class="navbar-toggle" data-toggle="collapse" data-target=".navbar-collapse">
<span class="icon-bar"></span> <span class="icon-bar"></span>
<span class="icon-bar"></span> <span class="icon-bar"></span>
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
{% block content %} {% block content %}
<div class="row"> <div class="row">
<div class="col-md-8"> <div class="col-md-7">
<div class="panel panel-default"> <div class="panel panel-default">
<div class="panel-heading"> <div class="panel-heading">
<a class="pull-right btn btn-default btn-xs" href="{% url "dashboard.views.template-list" %}">{% trans "Back" %}</a> <a class="pull-right btn btn-default btn-xs" href="{% url "dashboard.views.template-list" %}">{% trans "Back" %}</a>
...@@ -23,33 +23,51 @@ ...@@ -23,33 +23,51 @@
</div> </div>
</div> </div>
<div class="col-md-4"> <div class="col-md-5">
<div class="panel panel-default"> <div class="panel panel-default">
<div class="panel-heading"> <div class="panel-heading">
<h4 class="no-margin"><i class="icon-group"></i> {% trans "Manage access" %}</h4> <h4 class="no-margin"><i class="icon-group"></i> {% trans "Manage access" %}</h4>
</div> </div>
<div class="panel-body"> <div class="panel-body">
<form action="{% url "dashboard.views.template-acl" pk=object.pk %}" method="post">{% csrf_token %} <form action="{% url "dashboard.views.template-acl" pk=object.pk %}" method="post">{% csrf_token %}
<table class="table table-striped table-with-form-fields"> <table class="table table-striped table-with-form-fields" id="template-access-table">
<thead><tr><th></th><th>{% trans "Who" %}</th><th>{% trans "What" %}</th><th></th></tr></thead> <thead>
<tr>
<th></th>
<th>{% trans "Who" %}</th>
<th>{% trans "What" %}</th>
<th><i class="icon-remove"></i></th>
</tr></thead>
<tbody> <tbody>
{% for i in acl.users %} {% for i in acl.users %}
<tr><td><i class="icon-user"></i></td><td>{{i.user}}</td> <tr>
<td><select class="form-control" name="perm-u-{{i.user.id}}"> <td><i class="icon-user"></i></td><td>{{i.user}}</td>
{% for id, name in acl.levels %} <td>
<option{%if id = i.level%} selected="selected"{%endif%} value="{{id}}">{{name}}</option> <select class="form-control" name="perm-u-{{i.user.id}}">
{% endfor %} {% for id, name in acl.levels %}
</select></td> <option{%if id = i.level%} selected="selected"{%endif%} value="{{id}}">{{name}}</option>
<td><a href="#" class="btn btn-link btn-xs"><i class="icon-remove"><span class="sr-only">{% trans "remove" %}</span></i></a></td></tr> {% endfor %}
</select>
</td>
<td>
<input type="checkbox" name="remove-u-{{i.user.id}}" title="{% trans "Remove" %}"/>
</td>
</tr>
{% endfor %} {% endfor %}
{% for i in acl.groups %} {% for i in acl.groups %}
<tr><td><i class="icon-group"></i></td><td>{{i.group}}</td> <tr>
<td><select class="form-control" name="perm-g-{{i.group.id}}"> <td><i class="icon-group"></i></td><td>{{i.group}}</td>
{% for id, name in acl.levels %} <td>
<option{%if id = i.level%} selected="selected"{%endif%} value="{{id}}">{{name}}</option> <select class="form-control" name="perm-g-{{i.group.id}}">
{% endfor %} {% for id, name in acl.levels %}
</select></td> <option{%if id = i.level%} selected="selected"{%endif%} value="{{id}}">{{name}}</option>
<td><a href="#" class="btn btn-link btn-xs"><i class="icon-remove"><span class="sr-only">{% trans "remove" %}</span></i></a></td></tr> {% endfor %}
</select>
</td>
<td>
<input type="checkbox" name="remove-g-{{i.group.id}}" title="{% trans "Remove" %}"/>
</td>
</tr>
{% endfor %} {% endfor %}
<tr><td><i class="icon-plus"></i></td> <tr><td><i class="icon-plus"></i></td>
<td><input type="text" class="form-control" name="perm-new-name" <td><input type="text" class="form-control" name="perm-new-name"
......
...@@ -9,6 +9,14 @@ ...@@ -9,6 +9,14 @@
{{ a.get_readable_name }}{% if user.is_superuser %}</a>{% endif %} {{ a.get_readable_name }}{% if user.is_superuser %}</a>{% endif %}
</strong> </strong>
{{ a.started|date:"Y-m-d H:i" }}{% if a.user %}, {{ a.user }}{% endif %} {{ a.started|date:"Y-m-d H:i" }}{% if a.user %}, {{ a.user }}{% endif %}
{% if a.is_abortable_for_user %}
<form action="{{ a.instance.get_absolute_url }}" method="POST" class="pull-right">
{% csrf_token %}
<input type="hidden" name="abort_operation"/>
<input type="hidden" name="activity" value="{{ a.pk }}"/>
<button class="btn btn-danger btn-xs"><i class="icon-bolt"></i> {% trans "Abort" %}</button>
</form>
{% endif %}
{% if a.children.count > 0 %} {% if a.children.count > 0 %}
<div class="sub-timeline"> <div class="sub-timeline">
{% for s in a.children.all %} {% for s in a.children.all %}
......
...@@ -15,26 +15,43 @@ ...@@ -15,26 +15,43 @@
</p> </p>
<h3>{% trans "Permissions"|capfirst %}</h3> <h3>{% trans "Permissions"|capfirst %}</h3>
<form action="{{acl.url}}" method="post">{% csrf_token %} <form action="{{acl.url}}" method="post">{% csrf_token %}
<table class="table table-striped table-with-form-fields"> <table class="table table-striped table-with-form-fields" id="vm-access-table">
<thead><tr><th></th><th>{% trans "Who" %}</th><th>{% trans "What" %}</th><th></th></tr></thead> <thead><tr>
<th></th>
<th>{% trans "Who" %}</th>
<th>{% trans "What" %}</th>
<th>{% trans "Remove" %}</th>
</tr></thead>
<tbody> <tbody>
{% for i in acl.users %} {% for i in acl.users %}
<tr><td><i class="icon-user"></i></td><td>{{i.user}}</td> <tr>
<td><select class="form-control" name="perm-u-{{i.user.id}}"> <td><i class="icon-user"></i></td>
{% for id, name in acl.levels %} <td>{{i.user}}</td>
<option{%if id = i.level%} selected="selected"{%endif%} value="{{id}}">{{name}}</option> <td>
{% endfor %} <select class="form-control" name="perm-u-{{i.user.id}}">
</select></td> {% for id, name in acl.levels %}
<td><a href="#" class="btn btn-link btn-xs"><i class="icon-remove"><span class="sr-only">{% trans "remove" %}</span></i></a></td></tr> <option{%if id = i.level%} selected="selected"{%endif%} value="{{id}}">{{name}}</option>
{% endfor %}
</select>
</td>
<td>
<input type="checkbox" name="remove-u-{{i.user.id}}"/>
</td>
</tr>
{% endfor %} {% endfor %}
{% for i in acl.groups %} {% for i in acl.groups %}
<tr><td><i class="icon-group"></i></td><td>{{i.group}}</td> <tr>
<td><select class="form-control" name="perm-g-{{i.group.id}}"> <td><i class="icon-group"></i></td><td>{{i.group}}</td>
{% for id, name in acl.levels %} <td>
<option{%if id = i.level%} selected="selected"{%endif%} value="{{id}}">{{name}}</option> <select class="form-control" name="perm-g-{{i.group.id}}">
{% endfor %} {% for id, name in acl.levels %}
<option{%if id = i.level%} selected="selected"{%endif%} value="{{id}}">{{name}}</option>
{% endfor %}
</select></td> </select></td>
<td><a href="#" class="btn btn-link btn-xs"><i class="icon-remove"><span class="sr-only">{% trans "remove" %}</span></i></a></td></tr> <td>
<input type="checkbox" name="remove-g-{{i.group.id}}"/>
</td>
</tr>
{% endfor %} {% endfor %}
<tr><td><i class="icon-plus"></i></td> <tr><td><i class="icon-plus"></i></td>
<td><input type="text" class="form-control" name="perm-new-name" <td><input type="text" class="form-control" name="perm-new-name"
......
...@@ -1161,3 +1161,121 @@ class ProfileViewTest(LoginMixin, TestCase): ...@@ -1161,3 +1161,121 @@ class ProfileViewTest(LoginMixin, TestCase):
self.assertIsNotNone(authenticate(username="user1", self.assertIsNotNone(authenticate(username="user1",
password="password")) password="password"))
self.assertIsNone(authenticate(username="user1", password="asd")) self.assertIsNone(authenticate(username="user1", password="asd"))
class AclViewTest(LoginMixin, TestCase):
fixtures = ['test-vm-fixture.json', 'node.json']
def setUp(self):
Instance.get_remote_queue_name = Mock(return_value='test')
self.u1 = User.objects.create(username='user1')
self.u1.set_password('password')
self.u1.save()
self.u2 = User.objects.create(username='user2', is_staff=True)
self.u2.set_password('password')
self.u2.save()
self.us = User.objects.create(username='superuser', is_superuser=True)
self.us.set_password('password')
self.us.save()
self.ut = User.objects.get(username="test")
self.g1 = Group.objects.create(name='group1')
self.g1.user_set.add(self.u1)
self.g1.user_set.add(self.u2)
self.g1.save()
settings["default_vlangroup"] = 'public'
VlanGroup.objects.create(name='public')
def tearDown(self):
super(AclViewTest, self).tearDown()
self.u1.delete()
self.u2.delete()
self.us.delete()
self.g1.delete()
def test_permitted_instance_access_revoke(self):
c = Client()
# this is from the fixtures
self.login(c, "test", "test")
inst = Instance.objects.get(id=1)
inst.set_level(self.u1, "user")
resp = c.post("/dashboard/vm/1/acl/", {
'remove-u-%d' % self.u1.pk: "",
'perm-new-name': "",
'perm-new': "",
})
self.assertFalse((self.u1, "user") in inst.get_users_with_level())
self.assertEqual(resp.status_code, 302)
def test_unpermitted_instance_access_revoke(self):
c = Client()
self.login(c, self.u2)
inst = Instance.objects.get(id=1)
inst.set_level(self.u1, "user")
resp = c.post("/dashboard/vm/1/acl/", {
'remove-u-%d' % self.u1.pk: "",
'perm-new-name': "",
'perm-new': "",
})
self.assertTrue((self.u1, "user") in inst.get_users_with_level())
self.assertEqual(resp.status_code, 403)
def test_instance_original_owner_access_revoke(self):
c = Client()
self.login(c, self.u1)
inst = Instance.objects.get(id=1)
inst.set_level(self.u1, "owner")
inst.set_level(self.ut, "owner")
resp = c.post("/dashboard/vm/1/acl/", {
'remove-u-%d' % self.ut.pk: "",
'perm-new-name': "",
'perm-new': "",
})
self.assertEqual(self.ut, Instance.objects.get(id=1).owner)
self.assertTrue((self.ut, "owner") in inst.get_users_with_level())
self.assertEqual(resp.status_code, 302)
def test_permitted_template_access_revoke(self):
c = Client()
# this is from the fixtures
self.login(c, "test", "test")
tmpl = InstanceTemplate.objects.get(id=1)
tmpl.set_level(self.u1, "user")
resp = c.post("/dashboard/template/1/acl/", {
'remove-u-%d' % self.u1.pk: "",
'perm-new-name': "",
'perm-new': "",
})
self.assertFalse((self.u1, "user") in tmpl.get_users_with_level())
self.assertEqual(resp.status_code, 302)
def test_unpermitted_template_access_revoke(self):
c = Client()
self.login(c, self.u2)
tmpl = InstanceTemplate.objects.get(id=1)
tmpl.set_level(self.u1, "user")
resp = c.post("/dashboard/template/1/acl/", {
'remove-u-%d' % self.u1.pk: "",
'perm-new-name': "",
'perm-new': "",
})
self.assertTrue((self.u1, "user") in tmpl.get_users_with_level())
self.assertEqual(resp.status_code, 403)
def test_template_original_owner_access_revoke(self):
c = Client()
self.login(c, self.u1)
tmpl = InstanceTemplate.objects.get(id=1)
tmpl.set_level(self.u1, "owner")
tmpl.set_level(self.ut, "owner")
resp = c.post("/dashboard/template/1/acl/", {
'remove-u-%d' % self.ut.pk: "",
'perm-new-name': "",
'perm-new': "",
})
self.assertEqual(self.ut, InstanceTemplate.objects.get(id=1).owner)
self.assertTrue((self.ut, "owner") in tmpl.get_users_with_level())
self.assertEqual(resp.status_code, 302)
...@@ -224,11 +224,7 @@ class VmDetailView(CheckedDetailView): ...@@ -224,11 +224,7 @@ class VmDetailView(CheckedDetailView):
}) })
# activity data # activity data
context['activities'] = ( context['activities'] = self.object.get_activities(self.request.user)
InstanceActivity.objects.filter(
instance=self.object, parent=None).
order_by('-started').
select_related('user').prefetch_related('children'))
context['vlans'] = Vlan.get_objects_with_level( context['vlans'] = Vlan.get_objects_with_level(
'user', self.request.user 'user', self.request.user
...@@ -260,6 +256,7 @@ class VmDetailView(CheckedDetailView): ...@@ -260,6 +256,7 @@ class VmDetailView(CheckedDetailView):
'to_remove': self.__remove_tag, 'to_remove': self.__remove_tag,
'port': self.__add_port, 'port': self.__add_port,
'new_network_vlan': self.__new_network, 'new_network_vlan': self.__new_network,
'abort_operation': self.__abort_operation,
} }
for k, v in options.iteritems(): for k, v in options.iteritems():
if request.POST.get(k) is not None: if request.POST.get(k) is not None:
...@@ -445,6 +442,16 @@ class VmDetailView(CheckedDetailView): ...@@ -445,6 +442,16 @@ class VmDetailView(CheckedDetailView):
return redirect("%s#network" % reverse_lazy( return redirect("%s#network" % reverse_lazy(
"dashboard.views.detail", kwargs={'pk': self.object.pk})) "dashboard.views.detail", kwargs={'pk': self.object.pk}))
def __abort_operation(self, request):
self.object = self.get_object()
activity = get_object_or_404(InstanceActivity,
pk=request.POST.get("activity"))
if not activity.is_abortable_for(request.user):
raise PermissionDenied()
activity.abort()
return redirect("%s#activity" % self.object.get_absolute_url())
class OperationView(DetailView): class OperationView(DetailView):
...@@ -714,8 +721,9 @@ class AclUpdateView(LoginRequiredMixin, View, SingleObjectMixin): ...@@ -714,8 +721,9 @@ class AclUpdateView(LoginRequiredMixin, View, SingleObjectMixin):
unicode(instance), unicode(request.user)) unicode(instance), unicode(request.user))
raise PermissionDenied() raise PermissionDenied()
self.set_levels(request, instance) self.set_levels(request, instance)
self.remove_levels(request, instance)
self.add_levels(request, instance) self.add_levels(request, instance)
return redirect(instance) return redirect("%s#access" % instance.get_absolute_url())
def set_levels(self, request, instance): def set_levels(self, request, instance):
for key, value in request.POST.items(): for key, value in request.POST.items():
...@@ -732,6 +740,24 @@ class AclUpdateView(LoginRequiredMixin, View, SingleObjectMixin): ...@@ -732,6 +740,24 @@ class AclUpdateView(LoginRequiredMixin, View, SingleObjectMixin):
unicode(entity), unicode(instance), unicode(entity), unicode(instance),
value, unicode(request.user)) value, unicode(request.user))
def remove_levels(self, request, instance):
for key, value in request.POST.items():
if key.startswith("remove"):
typ = key[7:8] # len("remove-")
id = key[9:] # len("remove-x-")
entity = {'u': User, 'g': Group}[typ].objects.get(id=id)
if getattr(instance, "owner", None) == entity:
logger.info("Tried to remove owner from %s by %s.",
unicode(instance), unicode(request.user))
msg = _("The original owner cannot be removed, however "
"you can transfer ownership!")
messages.warning(request, msg)
continue
instance.set_level(entity, None)
logger.info("Revoked %s's access to %s by %s.",
unicode(entity), unicode(instance),
unicode(request.user))
def add_levels(self, request, instance): def add_levels(self, request, instance):
name = request.POST['perm-new-name'] name = request.POST['perm-new-name']
value = request.POST['perm-new'] value = request.POST['perm-new']
...@@ -772,6 +798,7 @@ class TemplateAclUpdateView(AclUpdateView): ...@@ -772,6 +798,7 @@ class TemplateAclUpdateView(AclUpdateView):
else: else:
self.set_levels(request, template) self.set_levels(request, template)
self.add_levels(request, template) self.add_levels(request, template)
self.remove_levels(request, template)
post_for_disk = request.POST.copy() post_for_disk = request.POST.copy()
post_for_disk['perm-new'] = 'user' post_for_disk['perm-new'] = 'user'
...@@ -779,8 +806,7 @@ class TemplateAclUpdateView(AclUpdateView): ...@@ -779,8 +806,7 @@ class TemplateAclUpdateView(AclUpdateView):
for d in template.disks.all(): for d in template.disks.all():
self.add_levels(request, d) self.add_levels(request, d)
return redirect(reverse("dashboard.views.template-detail", return redirect(template)
kwargs=self.kwargs))
class GroupAclUpdateView(AclUpdateView): class GroupAclUpdateView(AclUpdateView):
...@@ -1791,9 +1817,7 @@ def vm_activity(request, pk): ...@@ -1791,9 +1817,7 @@ def vm_activity(request, pk):
if only_status == "false": # instance activity if only_status == "false": # instance activity
context = { context = {
'instance': instance, 'instance': instance,
'activities': InstanceActivity.objects.filter( 'activities': instance.get_activities(request.user),
instance=instance, parent=None
).order_by('-started').select_related(),
'ops': get_operations(instance, request.user), 'ops': get_operations(instance, request.user),
} }
...@@ -2398,10 +2422,8 @@ class InstanceActivityDetail(SuperuserRequiredMixin, DetailView): ...@@ -2398,10 +2422,8 @@ class InstanceActivityDetail(SuperuserRequiredMixin, DetailView):
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
ctx = super(InstanceActivityDetail, self).get_context_data(**kwargs) ctx = super(InstanceActivityDetail, self).get_context_data(**kwargs)
ctx['activities'] = ( ctx['activities'] = self.object.instance.get_activities(
self.object.instance.activity_log.filter(parent=None). self.request.user)
order_by('-started').select_related('user').
prefetch_related('children'))
return ctx return ctx
......
...@@ -24,6 +24,7 @@ import logging ...@@ -24,6 +24,7 @@ import logging
from os.path import join from os.path import join
import uuid import uuid
from celery.signals import worker_ready
from django.db.models import (Model, CharField, DateTimeField, from django.db.models import (Model, CharField, DateTimeField,
ForeignKey) ForeignKey)
from django.utils import timezone from django.utils import timezone
...@@ -631,3 +632,12 @@ def disk_activity(code_suffix, disk, task_uuid=None, user=None, ...@@ -631,3 +632,12 @@ def disk_activity(code_suffix, disk, task_uuid=None, user=None,
on_abort=None, on_commit=None): on_abort=None, on_commit=None):
act = DiskActivity.create(code_suffix, disk, task_uuid, user) act = DiskActivity.create(code_suffix, disk, task_uuid, user)
return activitycontextimpl(act, on_abort=on_abort, on_commit=on_commit) return activitycontextimpl(act, on_abort=on_abort, on_commit=on_commit)
@worker_ready.connect()
def cleanup(conf=None, **kwargs):
# TODO check if other manager workers are running
for i in DiskActivity.objects.filter(finished__isnull=True):
i.finish(False, "Manager is restarted, activity is cleaned up. "
"You can try again now.")
logger.error('Forced finishing stale activity %s', i)
...@@ -20,15 +20,19 @@ from contextlib import contextmanager ...@@ -20,15 +20,19 @@ from contextlib import contextmanager
from logging import getLogger from logging import getLogger
from celery.signals import worker_ready from celery.signals import worker_ready
from celery.contrib.abortable import AbortableAsyncResult
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.db.models import CharField, ForeignKey from django.db.models import CharField, ForeignKey
from django.utils import timezone from django.utils import timezone
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from common.models import ( from common.models import (
ActivityModel, activitycontextimpl, join_activity_code, ActivityModel, activitycontextimpl, join_activity_code, split_activity_code
) )
from manager.mancelery import celery
logger = getLogger(__name__) logger = getLogger(__name__)
...@@ -66,19 +70,8 @@ class InstanceActivity(ActivityModel): ...@@ -66,19 +70,8 @@ class InstanceActivity(ActivityModel):
return '{}({})'.format(self.activity_code, return '{}({})'.format(self.activity_code,
self.instance) self.instance)
def get_absolute_url(self): def abort(self):
return reverse('dashboard.views.vm-activity', args=[self.pk]) AbortableAsyncResult(self.task_uuid, backend=celery.backend).abort()
def get_readable_name(self):
return self.activity_code.split('.')[-1].replace('_', ' ').capitalize()
def get_status_id(self):
if self.succeeded is None:
return 'wait'
elif self.succeeded:
return 'success'
else:
return 'failed'
@classmethod @classmethod
def create(cls, code_suffix, instance, task_uuid=None, user=None, def create(cls, code_suffix, instance, task_uuid=None, user=None,
...@@ -108,6 +101,51 @@ class InstanceActivity(ActivityModel): ...@@ -108,6 +101,51 @@ class InstanceActivity(ActivityModel):
act.save() act.save()
return act return act
def get_absolute_url(self):
return reverse('dashboard.views.vm-activity', args=[self.pk])
def get_readable_name(self):
activity_code_last_suffix = split_activity_code(self.activity_code)[-1]
return activity_code_last_suffix.replace('_', ' ').capitalize()
def get_status_id(self):
if self.succeeded is None:
return 'wait'
elif self.succeeded:
return 'success'
else:
return 'failed'
@property
def is_abortable(self):
"""Can the activity be aborted?
:returns: True if the activity can be aborted; otherwise, False.
"""
op = self.instance.get_operation_from_activity_code(self.activity_code)
return self.task_uuid and op and op.abortable and not self.finished
def is_abortable_for(self, user):
"""Can the given user abort the activity?
"""
return self.is_abortable and (
user.is_superuser or user in (self.instance.owner, self.user))
@property
def is_aborted(self):
"""Has the activity been aborted?
:returns: True if the activity has been aborted; otherwise, False.
"""
return self.task_uuid and AbortableAsyncResult(self.task_uuid
).is_aborted()
def save(self, *args, **kwargs):
ret = super(InstanceActivity, self).save(*args, **kwargs)
self.instance._update_status()
return ret
@contextmanager @contextmanager
def sub_activity(self, code_suffix, on_abort=None, on_commit=None, def sub_activity(self, code_suffix, on_abort=None, on_commit=None,
task_uuid=None, concurrency_check=True): task_uuid=None, concurrency_check=True):
...@@ -116,11 +154,6 @@ class InstanceActivity(ActivityModel): ...@@ -116,11 +154,6 @@ class InstanceActivity(ActivityModel):
act = self.create_sub(code_suffix, task_uuid, concurrency_check) act = self.create_sub(code_suffix, task_uuid, concurrency_check)
return activitycontextimpl(act, on_abort=on_abort, on_commit=on_commit) return activitycontextimpl(act, on_abort=on_abort, on_commit=on_commit)
def save(self, *args, **kwargs):
ret = super(InstanceActivity, self).save(*args, **kwargs)
self.instance._update_status()
return ret
@contextmanager @contextmanager
def instance_activity(code_suffix, instance, on_abort=None, on_commit=None, def instance_activity(code_suffix, instance, on_abort=None, on_commit=None,
......
...@@ -17,11 +17,14 @@ ...@@ -17,11 +17,14 @@
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
from datetime import timedelta from datetime import timedelta
from functools import partial
from importlib import import_module from importlib import import_module
from logging import getLogger from logging import getLogger
from string import ascii_lowercase from string import ascii_lowercase
from warnings import warn from warnings import warn
from celery.exceptions import TimeoutError
from celery.contrib.abortable import AbortableAsyncResult
import django.conf import django.conf
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core import signing from django.core import signing
...@@ -834,13 +837,20 @@ class Instance(AclBase, VirtualMachineDescModel, StatusModel, OperatedMixin, ...@@ -834,13 +837,20 @@ class Instance(AclBase, VirtualMachineDescModel, StatusModel, OperatedMixin,
queue=queue_name queue=queue_name
).get(timeout=timeout) ).get(timeout=timeout)
def shutdown_vm(self, timeout=120): def shutdown_vm(self, task=None, step=5):
queue_name = self.get_remote_queue_name('vm') queue_name = self.get_remote_queue_name('vm')
logger.debug("RPC Shutdown at queue: %s, for vm: %s.", queue_name, logger.debug("RPC Shutdown at queue: %s, for vm: %s.", queue_name,
self.vm_name) self.vm_name)
return vm_tasks.shutdown.apply_async(kwargs={'name': self.vm_name}, remote = vm_tasks.shutdown.apply_async(kwargs={'name': self.vm_name},
queue=queue_name queue=queue_name)
).get(timeout=timeout)
while True:
try:
return remote.get(timeout=step)
except TimeoutError:
if task is not None and task.is_aborted():
AbortableAsyncResult(remote.id).abort()
raise Exception("Shutdown aborted by user.")
def suspend_vm(self, timeout=60): def suspend_vm(self, timeout=60):
queue_name = self.get_remote_queue_name('vm') queue_name = self.get_remote_queue_name('vm')
...@@ -893,3 +903,13 @@ class Instance(AclBase, VirtualMachineDescModel, StatusModel, OperatedMixin, ...@@ -893,3 +903,13 @@ class Instance(AclBase, VirtualMachineDescModel, StatusModel, OperatedMixin,
'PENDING': 'icon-rocket', 'PENDING': 'icon-rocket',
'DESTROYED': 'icon-trash', 'DESTROYED': 'icon-trash',
'MIGRATING': 'icon-truck'}.get(self.status, 'icon-question-sign') 'MIGRATING': 'icon-truck'}.get(self.status, 'icon-question-sign')
def get_activities(self, user=None):
acts = (self.activity_log.filter(parent=None).
order_by('-started').
select_related('user').prefetch_related('children'))
if user is not None:
for i in acts:
i.is_abortable_for_user = partial(i.is_abortable_for,
user=user)
return acts
...@@ -26,7 +26,9 @@ from django.utils.translation import ugettext_lazy as _ ...@@ -26,7 +26,9 @@ from django.utils.translation import ugettext_lazy as _
from celery.exceptions import TimeLimitExceeded from celery.exceptions import TimeLimitExceeded
from common.operations import Operation, register_operation from common.operations import Operation, register_operation
from .tasks.local_tasks import async_instance_operation, async_node_operation from .tasks.local_tasks import (
abortable_async_instance_operation, abortable_async_node_operation,
)
from .models import ( from .models import (
Instance, InstanceActivity, InstanceTemplate, Interface, Node, Instance, InstanceActivity, InstanceTemplate, Interface, Node,
NodeActivity, NodeActivity,
...@@ -38,7 +40,7 @@ logger = getLogger(__name__) ...@@ -38,7 +40,7 @@ logger = getLogger(__name__)
class InstanceOperation(Operation): class InstanceOperation(Operation):
acl_level = 'owner' acl_level = 'owner'
async_operation = async_instance_operation async_operation = abortable_async_instance_operation
host_cls = Instance host_cls = Instance
def __init__(self, instance): def __init__(self, instance):
...@@ -126,7 +128,7 @@ class DeployOperation(InstanceOperation): ...@@ -126,7 +128,7 @@ class DeployOperation(InstanceOperation):
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'RUNNING' activity.resultant_state = 'RUNNING'
def _operation(self, activity, user, system, timeout=15): def _operation(self, activity, timeout=15):
# Allocate VNC port and host node # Allocate VNC port and host node
self.instance.allocate_vnc_port() self.instance.allocate_vnc_port()
self.instance.allocate_node() self.instance.allocate_node()
...@@ -162,7 +164,7 @@ class DestroyOperation(InstanceOperation): ...@@ -162,7 +164,7 @@ class DestroyOperation(InstanceOperation):
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'DESTROYED' activity.resultant_state = 'DESTROYED'
def _operation(self, activity, user, system): def _operation(self, activity):
if self.instance.node: if self.instance.node:
# Destroy networks # Destroy networks
with activity.sub_activity('destroying_net'): with activity.sub_activity('destroying_net'):
...@@ -200,7 +202,7 @@ class MigrateOperation(InstanceOperation): ...@@ -200,7 +202,7 @@ class MigrateOperation(InstanceOperation):
name = _("migrate") name = _("migrate")
description = _("Live migrate running VM to another node.") description = _("Live migrate running VM to another node.")
def _operation(self, activity, user, system, to_node=None, timeout=120): def _operation(self, activity, to_node=None, timeout=120):
if not to_node: if not to_node:
with activity.sub_activity('scheduling') as sa: with activity.sub_activity('scheduling') as sa:
to_node = self.instance.select_node() to_node = self.instance.select_node()
...@@ -230,7 +232,7 @@ class RebootOperation(InstanceOperation): ...@@ -230,7 +232,7 @@ class RebootOperation(InstanceOperation):
name = _("reboot") name = _("reboot")
description = _("Reboot virtual machine with Ctrl+Alt+Del signal.") description = _("Reboot virtual machine with Ctrl+Alt+Del signal.")
def _operation(self, activity, user, system, timeout=5): def _operation(self, timeout=5):
self.instance.reboot_vm(timeout=timeout) self.instance.reboot_vm(timeout=timeout)
...@@ -280,7 +282,7 @@ class ResetOperation(InstanceOperation): ...@@ -280,7 +282,7 @@ class ResetOperation(InstanceOperation):
name = _("reset") name = _("reset")
description = _("Reset virtual machine (reset button).") description = _("Reset virtual machine (reset button).")
def _operation(self, activity, user, system, timeout=5): def _operation(self, timeout=5):
self.instance.reset_vm(timeout=timeout) self.instance.reset_vm(timeout=timeout)
register_operation(ResetOperation) register_operation(ResetOperation)
...@@ -295,6 +297,7 @@ class SaveAsTemplateOperation(InstanceOperation): ...@@ -295,6 +297,7 @@ class SaveAsTemplateOperation(InstanceOperation):
Template can be shared with groups and users. Template can be shared with groups and users.
Users can instantiate Virtual Machines from Templates. Users can instantiate Virtual Machines from Templates.
""") """)
abortable = True
@staticmethod @staticmethod
def _rename(name): def _rename(name):
...@@ -307,11 +310,11 @@ class SaveAsTemplateOperation(InstanceOperation): ...@@ -307,11 +310,11 @@ class SaveAsTemplateOperation(InstanceOperation):
return "%s v%d" % (name, v) return "%s v%d" % (name, v)
def _operation(self, activity, user, system, timeout=300, name=None, def _operation(self, activity, user, system, timeout=300, name=None,
with_shutdown=True, **kwargs): with_shutdown=True, task=None, **kwargs):
if with_shutdown: if with_shutdown:
try: try:
ShutdownOperation(self.instance).call(parent_activity=activity, ShutdownOperation(self.instance).call(parent_activity=activity,
user=user) user=user, task=task)
except Instance.WrongStateError: except Instance.WrongStateError:
pass pass
...@@ -370,23 +373,18 @@ class ShutdownOperation(InstanceOperation): ...@@ -370,23 +373,18 @@ class ShutdownOperation(InstanceOperation):
id = 'shutdown' id = 'shutdown'
name = _("shutdown") name = _("shutdown")
description = _("Shutdown virtual machine with ACPI signal.") description = _("Shutdown virtual machine with ACPI signal.")
abortable = True
def check_precond(self): def check_precond(self):
super(ShutdownOperation, self).check_precond() super(ShutdownOperation, self).check_precond()
if self.instance.status not in ['RUNNING']: if self.instance.status not in ['RUNNING']:
raise self.instance.WrongStateError(self.instance) raise self.instance.WrongStateError(self.instance)
def on_abort(self, activity, error):
if isinstance(error, TimeLimitExceeded):
activity.resultant_state = None
else:
activity.resultant_state = 'ERROR'
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'STOPPED' activity.resultant_state = 'STOPPED'
def _operation(self, activity, user, system, timeout=120): def _operation(self, task=None):
self.instance.shutdown_vm(timeout=timeout) self.instance.shutdown_vm(task=task)
self.instance.yield_node() self.instance.yield_node()
self.instance.yield_vnc_port() self.instance.yield_vnc_port()
...@@ -403,7 +401,7 @@ class ShutOffOperation(InstanceOperation): ...@@ -403,7 +401,7 @@ class ShutOffOperation(InstanceOperation):
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'STOPPED' activity.resultant_state = 'STOPPED'
def _operation(self, activity, user, system): def _operation(self, activity):
# Shutdown networks # Shutdown networks
with activity.sub_activity('shutdown_net'): with activity.sub_activity('shutdown_net'):
self.instance.shutdown_net() self.instance.shutdown_net()
...@@ -440,7 +438,7 @@ class SleepOperation(InstanceOperation): ...@@ -440,7 +438,7 @@ class SleepOperation(InstanceOperation):
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'SUSPENDED' activity.resultant_state = 'SUSPENDED'
def _operation(self, activity, user, system, timeout=60): def _operation(self, activity, timeout=60):
# Destroy networks # Destroy networks
with activity.sub_activity('shutdown_net'): with activity.sub_activity('shutdown_net'):
self.instance.shutdown_net() self.instance.shutdown_net()
...@@ -476,7 +474,7 @@ class WakeUpOperation(InstanceOperation): ...@@ -476,7 +474,7 @@ class WakeUpOperation(InstanceOperation):
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'RUNNING' activity.resultant_state = 'RUNNING'
def _operation(self, activity, user, system, timeout=60): def _operation(self, activity, timeout=60):
# Schedule vm # Schedule vm
self.instance.allocate_vnc_port() self.instance.allocate_vnc_port()
self.instance.allocate_node() self.instance.allocate_node()
...@@ -497,7 +495,7 @@ register_operation(WakeUpOperation) ...@@ -497,7 +495,7 @@ register_operation(WakeUpOperation)
class NodeOperation(Operation): class NodeOperation(Operation):
async_operation = async_node_operation async_operation = abortable_async_node_operation
host_cls = Node host_cls = Node
def __init__(self, node): def __init__(self, node):
...@@ -527,7 +525,7 @@ class FlushOperation(NodeOperation): ...@@ -527,7 +525,7 @@ class FlushOperation(NodeOperation):
name = _("flush") name = _("flush")
description = _("Disable node and move all instances to other ones.") description = _("Disable node and move all instances to other ones.")
def _operation(self, activity, user, system): def _operation(self, activity, user):
self.node.disable(user, activity) self.node.disable(user, activity)
for i in self.node.instance_set.all(): for i in self.node.instance_set.all():
with activity.sub_activity('migrate_instance_%d' % i.pk): with activity.sub_activity('migrate_instance_%d' % i.pk):
......
...@@ -15,32 +15,41 @@ ...@@ -15,32 +15,41 @@
# You should have received a copy of the GNU General Public License along # You should have received a copy of the GNU General Public License along
# with CIRCLE. If not, see <http://www.gnu.org/licenses/>. # with CIRCLE. If not, see <http://www.gnu.org/licenses/>.
from celery.contrib.abortable import AbortableTask
from manager.mancelery import celery from manager.mancelery import celery
@celery.task @celery.task(base=AbortableTask, bind=True)
def async_instance_operation(operation_id, instance_pk, activity_pk, **kwargs): def abortable_async_instance_operation(task, operation_id, instance_pk,
activity_pk, allargs, auxargs):
from vm.models import Instance, InstanceActivity from vm.models import Instance, InstanceActivity
instance = Instance.objects.get(pk=instance_pk) instance = Instance.objects.get(pk=instance_pk)
operation = getattr(instance, operation_id) operation = getattr(instance, operation_id)
activity = InstanceActivity.objects.get(pk=activity_pk) activity = InstanceActivity.objects.get(pk=activity_pk)
# save async task UUID to activity # save async task UUID to activity
activity.task_uuid = async_instance_operation.request.id activity.task_uuid = task.request.id
activity.save() activity.save()
return operation._exec_op(activity=activity, **kwargs) allargs['activity'] = activity
allargs['task'] = task
return operation._exec_op(allargs, auxargs)
@celery.task
def async_node_operation(operation_id, node_pk, activity_pk, **kwargs): @celery.task(base=AbortableTask, bind=True)
def abortable_async_node_operation(task, operation_id, node_pk, activity_pk,
allargs, auxargs):
from vm.models import Node, NodeActivity from vm.models import Node, NodeActivity
node = Node.objects.get(pk=node_pk) node = Node.objects.get(pk=node_pk)
operation = getattr(node, operation_id) operation = getattr(node, operation_id)
activity = NodeActivity.objects.get(pk=activity_pk) activity = NodeActivity.objects.get(pk=activity_pk)
# save async task UUID to activity # save async task UUID to activity
activity.task_uuid = async_node_operation.request.id activity.task_uuid = task.request.id
activity.save() activity.save()
return operation._exec_op(activity=activity, **kwargs) allargs['activity'] = activity
allargs['task'] = task
return operation._exec_op(allargs, auxargs)
...@@ -19,6 +19,7 @@ from datetime import datetime ...@@ -19,6 +19,7 @@ from datetime import datetime
from mock import Mock, MagicMock, patch, call from mock import Mock, MagicMock, patch, call
import types import types
from celery.contrib.abortable import AbortableAsyncResult
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.test import TestCase from django.test import TestCase
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
...@@ -231,6 +232,92 @@ class InstanceActivityTestCase(TestCase): ...@@ -231,6 +232,92 @@ class InstanceActivityTestCase(TestCase):
raise AssertionError("'create_sub' method checked for " raise AssertionError("'create_sub' method checked for "
"concurrent activities.") "concurrent activities.")
def test_is_abortable(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertTrue(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_not_associated_with_task(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid=None)
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_finished(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=True, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_operation_not_abortable(self):
get_op = MagicMock(return_value=MagicMock(abortable=False))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_no_matching_operation(self):
get_op = MagicMock(return_value=None)
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_aborted_when_not_associated_with_task(self):
iaobj = MagicMock(task_uuid=None)
self.assertFalse(InstanceActivity.is_aborted.fget(iaobj))
def test_is_aborted_when_associated_task_is_aborted(self):
expected = object()
iaobj = MagicMock(task_uuid='test')
with patch.object(AbortableAsyncResult, 'is_aborted',
return_value=expected):
self.assertEquals(expected,
InstanceActivity.is_aborted.fget(iaobj))
def test_is_abortable_for_activity_owner_if_not_abortable(self):
iaobj = MagicMock(spec=InstanceActivity, is_abortable=False,
user=MagicMock(spec=User, is_superuser=False))
self.assertFalse(InstanceActivity.is_abortable_for(iaobj, iaobj.user))
def test_is_abortable_for_instance_owner(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op,
owner=MagicMock(spec=User, is_superuser=False))
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test',
user=MagicMock(spec=User, is_superuser=False))
self.assertTrue(
InstanceActivity.is_abortable_for(iaobj, iaobj.instance.owner))
def test_is_abortable_for_activity_owner(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test',
user=MagicMock(spec=User, is_superuser=False))
self.assertTrue(InstanceActivity.is_abortable_for(iaobj, iaobj.user))
def test_not_abortable_for_foreign(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable_for(
iaobj, MagicMock(spec=User, is_superuser=False)))
def test_is_abortable_for_superuser(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
su = MagicMock(spec=User, is_superuser=True)
self.assertTrue(InstanceActivity.is_abortable_for(iaobj, su))
def test_disable_enabled(self): def test_disable_enabled(self):
node = MagicMock(spec=Node, enabled=True) node = MagicMock(spec=Node, enabled=True)
with patch('vm.models.node.node_activity') as nac: with patch('vm.models.node.node_activity') as nac:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment