Commit be367e47 by Őry Máté

Merge branch 'feature-abortable-operations-rebased'

Feature Abortable Operations
Make possible the implementation of abortable operations. Implement it for
shutdown.
parents 681af7af e7abb4a5
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from hashlib import sha224 from hashlib import sha224
from itertools import chain, imap
from logging import getLogger from logging import getLogger
from time import time from time import time
...@@ -56,12 +57,52 @@ activity_context = contextmanager(activitycontextimpl) ...@@ -56,12 +57,52 @@ activity_context = contextmanager(activitycontextimpl)
activity_code_separator = '.' activity_code_separator = '.'
def has_prefix(activity_code, *prefixes):
"""Determine whether the activity code has the specified prefix.
E.g.: has_prefix('foo.bar.buz', 'foo.bar') == True
has_prefix('foo.bar.buz', 'foo', 'bar') == True
has_prefix('foo.bar.buz', 'foo.bar', 'buz') == True
has_prefix('foo.bar.buz', 'foo', 'bar', 'buz') == True
has_prefix('foo.bar.buz', 'foo', 'buz') == False
"""
equal = lambda a, b: a == b
act_code_parts = split_activity_code(activity_code)
prefixes = chain(*imap(split_activity_code, prefixes))
return all(imap(equal, act_code_parts, prefixes))
def has_suffix(activity_code, *suffixes):
"""Determine whether the activity code has the specified suffix.
E.g.: has_suffix('foo.bar.buz', 'bar.buz') == True
has_suffix('foo.bar.buz', 'bar', 'buz') == True
has_suffix('foo.bar.buz', 'foo.bar', 'buz') == True
has_suffix('foo.bar.buz', 'foo', 'bar', 'buz') == True
has_suffix('foo.bar.buz', 'foo', 'buz') == False
"""
equal = lambda a, b: a == b
act_code_parts = split_activity_code(activity_code)
suffixes = list(chain(*imap(split_activity_code, suffixes)))
return all(imap(equal, reversed(act_code_parts), reversed(suffixes)))
def join_activity_code(*args): def join_activity_code(*args):
"""Join the specified parts into an activity code. """Join the specified parts into an activity code.
:returns: Activity code string.
""" """
return activity_code_separator.join(args) return activity_code_separator.join(args)
def split_activity_code(activity_code):
"""Split the specified activity code into its parts.
:returns: A list of activity code parts.
"""
return activity_code.split(activity_code_separator)
class ActivityModel(TimeStampedModel): class ActivityModel(TimeStampedModel):
activity_code = CharField(max_length=100, verbose_name=_('activity code')) activity_code = CharField(max_length=100, verbose_name=_('activity code'))
parent = ForeignKey('self', blank=True, null=True, related_name='children') parent = ForeignKey('self', blank=True, null=True, related_name='children')
......
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
# You should have received a copy of the GNU General Public License along # You should have received a copy of the GNU General Public License along
# with CIRCLE. If not, see <http://www.gnu.org/licenses/>. # with CIRCLE. If not, see <http://www.gnu.org/licenses/>.
from inspect import getargspec
from logging import getLogger from logging import getLogger
from .models import activity_context from .models import activity_context, has_suffix
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
...@@ -31,6 +32,7 @@ class Operation(object): ...@@ -31,6 +32,7 @@ class Operation(object):
async_queue = 'localhost.man' async_queue = 'localhost.man'
required_perms = () required_perms = ()
do_not_call_in_templates = True do_not_call_in_templates = True
abortable = False
def __call__(self, **kwargs): def __call__(self, **kwargs):
return self.call(**kwargs) return self.call(**kwargs)
...@@ -46,23 +48,50 @@ class Operation(object): ...@@ -46,23 +48,50 @@ class Operation(object):
def __prelude(self, kwargs): def __prelude(self, kwargs):
"""This method contains the shared prelude of call and async. """This method contains the shared prelude of call and async.
""" """
skip_auth_check = kwargs.setdefault('system', False) defaults = {'parent_activity': None, 'system': False, 'user': None}
user = kwargs.setdefault('user', None)
parent_activity = kwargs.pop('parent_activity', None) allargs = dict(defaults, **kwargs) # all arguments
auxargs = allargs.copy() # auxiliary (i.e. only for _operation) args
# NOTE: consumed items should be removed from auxargs, and no new items
# should be added to it
skip_auth_check = auxargs.pop('system')
user = auxargs.pop('user')
parent_activity = auxargs.pop('parent_activity')
# check for unexpected keyword arguments
argspec = getargspec(self._operation)
if argspec.keywords is None: # _operation doesn't take ** args
unexpected_kwargs = set(auxargs) - set(argspec.args)
if unexpected_kwargs:
raise TypeError("Operation got unexpected keyword arguments: "
"%s" % ", ".join(unexpected_kwargs))
if not skip_auth_check: if not skip_auth_check:
self.check_auth(user) self.check_auth(user)
self.check_precond() self.check_precond()
return self.create_activity(parent=parent_activity, user=user)
def _exec_op(self, activity, user, **kwargs): activity = self.create_activity(parent=parent_activity, user=user)
return activity, allargs, auxargs
def _exec_op(self, allargs, auxargs):
"""Execute the operation inside the specified activity's context. """Execute the operation inside the specified activity's context.
""" """
with activity_context(activity, on_abort=self.on_abort, # compile arguments for _operation
argspec = getargspec(self._operation)
if argspec.keywords is not None: # _operation takes ** args
arguments = allargs.copy()
else: # _operation doesn't take ** args
arguments = {k: v for (k, v) in allargs.iteritems()
if k in argspec.args}
arguments.update(auxargs)
with activity_context(allargs['activity'], on_abort=self.on_abort,
on_commit=self.on_commit): on_commit=self.on_commit):
return self._operation(activity=activity, user=user, **kwargs) return self._operation(**arguments)
def _operation(self, activity, user, system, **kwargs): def _operation(self, **kwargs):
"""This method is the operation's particular implementation. """This method is the operation's particular implementation.
Deriving classes should implement this method. Deriving classes should implement this method.
...@@ -82,11 +111,9 @@ class Operation(object): ...@@ -82,11 +111,9 @@ class Operation(object):
logger.info("%s called asynchronously on %s with the following " logger.info("%s called asynchronously on %s with the following "
"parameters: %r", self.__class__.__name__, self.subject, "parameters: %r", self.__class__.__name__, self.subject,
kwargs) kwargs)
activity = self.__prelude(kwargs) activity, allargs, auxargs = self.__prelude(kwargs)
return self.async_operation.apply_async(args=(self.id, return self.async_operation.apply_async(
self.subject.pk, args=(self.id, self.subject.pk, activity.pk, allargs, auxargs, ),
activity.pk),
kwargs=kwargs,
queue=self.async_queue) queue=self.async_queue)
def call(self, **kwargs): def call(self, **kwargs):
...@@ -105,8 +132,9 @@ class Operation(object): ...@@ -105,8 +132,9 @@ class Operation(object):
logger.info("%s called (synchronously) on %s with the following " logger.info("%s called (synchronously) on %s with the following "
"parameters: %r", self.__class__.__name__, self.subject, "parameters: %r", self.__class__.__name__, self.subject,
kwargs) kwargs)
activity = self.__prelude(kwargs) activity, allargs, auxargs = self.__prelude(kwargs)
return self._exec_op(activity=activity, **kwargs) allargs['activity'] = activity
return self._exec_op(allargs, auxargs)
def check_precond(self): def check_precond(self):
pass pass
...@@ -160,6 +188,19 @@ class OperatedMixin(object): ...@@ -160,6 +188,19 @@ class OperatedMixin(object):
else: else:
yield op yield op
def get_operation_from_activity_code(self, activity_code):
"""Get an instance of the Operation corresponding to the specified
activity code.
:returns: A bound instance of an operation, or None if no matching
operation could be found.
"""
for op in getattr(self, operation_registry_name, {}).itervalues():
if has_suffix(activity_code, op.activity_code_suffix):
return op(self)
else:
return None
def register_operation(op_cls, op_id=None, target_cls=None): def register_operation(op_cls, op_id=None, target_cls=None):
"""Register the specified operation with the target class. """Register the specified operation with the target class.
......
...@@ -75,3 +75,41 @@ class OperationTestCase(TestCase): ...@@ -75,3 +75,41 @@ class OperationTestCase(TestCase):
patch.object(Operation, 'create_activity'), \ patch.object(Operation, 'create_activity'), \
patch.object(Operation, '_exec_op'): patch.object(Operation, '_exec_op'):
op.call(system=True) op.call(system=True)
def test_no_exception_for_more_arguments_when_operation_takes_kwargs(self):
class KwargOp(Operation):
activity_code_suffix = 'test'
id = 'test'
def _operation(self, **kwargs):
pass
op = KwargOp(MagicMock())
with patch.object(KwargOp, 'create_activity'), \
patch.object(KwargOp, '_exec_op'):
op.call(system=True, foo=42)
def test_exception_for_unexpected_arguments(self):
class TestOp(Operation):
activity_code_suffix = 'test'
id = 'test'
def _operation(self):
pass
op = TestOp(MagicMock())
with patch.object(TestOp, 'create_activity'), \
patch.object(TestOp, '_exec_op'):
self.assertRaises(TypeError, op.call, system=True, foo=42)
def test_exception_for_missing_arguments(self):
class TestOp(Operation):
activity_code_suffix = 'test'
id = 'test'
def _operation(self, foo):
pass
op = TestOp(MagicMock())
with patch.object(TestOp, 'create_activity'):
self.assertRaises(TypeError, op.call, system=True)
...@@ -9,6 +9,14 @@ ...@@ -9,6 +9,14 @@
{{ a.get_readable_name }}{% if user.is_superuser %}</a>{% endif %} {{ a.get_readable_name }}{% if user.is_superuser %}</a>{% endif %}
</strong> </strong>
{{ a.started|date:"Y-m-d H:i" }}{% if a.user %}, {{ a.user }}{% endif %} {{ a.started|date:"Y-m-d H:i" }}{% if a.user %}, {{ a.user }}{% endif %}
{% if a.is_abortable_for_user %}
<form action="{{ a.instance.get_absolute_url }}" method="POST" class="pull-right">
{% csrf_token %}
<input type="hidden" name="abort_operation"/>
<input type="hidden" name="activity" value="{{ a.pk }}"/>
<button class="btn btn-danger btn-xs"><i class="icon-bolt"></i> {% trans "Abort" %}</button>
</form>
{% endif %}
{% if a.children.count > 0 %} {% if a.children.count > 0 %}
<div class="sub-timeline"> <div class="sub-timeline">
{% for s in a.children.all %} {% for s in a.children.all %}
......
...@@ -224,11 +224,7 @@ class VmDetailView(CheckedDetailView): ...@@ -224,11 +224,7 @@ class VmDetailView(CheckedDetailView):
}) })
# activity data # activity data
context['activities'] = ( context['activities'] = self.object.get_activities(self.request.user)
InstanceActivity.objects.filter(
instance=self.object, parent=None).
order_by('-started').
select_related('user').prefetch_related('children'))
context['vlans'] = Vlan.get_objects_with_level( context['vlans'] = Vlan.get_objects_with_level(
'user', self.request.user 'user', self.request.user
...@@ -260,6 +256,7 @@ class VmDetailView(CheckedDetailView): ...@@ -260,6 +256,7 @@ class VmDetailView(CheckedDetailView):
'to_remove': self.__remove_tag, 'to_remove': self.__remove_tag,
'port': self.__add_port, 'port': self.__add_port,
'new_network_vlan': self.__new_network, 'new_network_vlan': self.__new_network,
'abort_operation': self.__abort_operation,
} }
for k, v in options.iteritems(): for k, v in options.iteritems():
if request.POST.get(k) is not None: if request.POST.get(k) is not None:
...@@ -445,6 +442,16 @@ class VmDetailView(CheckedDetailView): ...@@ -445,6 +442,16 @@ class VmDetailView(CheckedDetailView):
return redirect("%s#network" % reverse_lazy( return redirect("%s#network" % reverse_lazy(
"dashboard.views.detail", kwargs={'pk': self.object.pk})) "dashboard.views.detail", kwargs={'pk': self.object.pk}))
def __abort_operation(self, request):
self.object = self.get_object()
activity = get_object_or_404(InstanceActivity,
pk=request.POST.get("activity"))
if not activity.is_abortable_for(request.user):
raise PermissionDenied()
activity.abort()
return redirect("%s#activity" % self.object.get_absolute_url())
class OperationView(DetailView): class OperationView(DetailView):
...@@ -1736,9 +1743,7 @@ def vm_activity(request, pk): ...@@ -1736,9 +1743,7 @@ def vm_activity(request, pk):
if only_status == "false": # instance activity if only_status == "false": # instance activity
context = { context = {
'instance': instance, 'instance': instance,
'activities': InstanceActivity.objects.filter( 'activities': instance.get_activities(request.user),
instance=instance, parent=None
).order_by('-started').select_related(),
'ops': get_operations(instance, request.user), 'ops': get_operations(instance, request.user),
} }
...@@ -2343,10 +2348,8 @@ class InstanceActivityDetail(SuperuserRequiredMixin, DetailView): ...@@ -2343,10 +2348,8 @@ class InstanceActivityDetail(SuperuserRequiredMixin, DetailView):
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
ctx = super(InstanceActivityDetail, self).get_context_data(**kwargs) ctx = super(InstanceActivityDetail, self).get_context_data(**kwargs)
ctx['activities'] = ( ctx['activities'] = self.object.instance.get_activities(
self.object.instance.activity_log.filter(parent=None). self.request.user)
order_by('-started').select_related('user').
prefetch_related('children'))
return ctx return ctx
......
...@@ -20,15 +20,19 @@ from contextlib import contextmanager ...@@ -20,15 +20,19 @@ from contextlib import contextmanager
from logging import getLogger from logging import getLogger
from celery.signals import worker_ready from celery.signals import worker_ready
from celery.contrib.abortable import AbortableAsyncResult
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.db.models import CharField, ForeignKey from django.db.models import CharField, ForeignKey
from django.utils import timezone from django.utils import timezone
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from common.models import ( from common.models import (
ActivityModel, activitycontextimpl, join_activity_code, ActivityModel, activitycontextimpl, join_activity_code, split_activity_code
) )
from manager.mancelery import celery
logger = getLogger(__name__) logger = getLogger(__name__)
...@@ -66,19 +70,8 @@ class InstanceActivity(ActivityModel): ...@@ -66,19 +70,8 @@ class InstanceActivity(ActivityModel):
return '{}({})'.format(self.activity_code, return '{}({})'.format(self.activity_code,
self.instance) self.instance)
def get_absolute_url(self): def abort(self):
return reverse('dashboard.views.vm-activity', args=[self.pk]) AbortableAsyncResult(self.task_uuid, backend=celery.backend).abort()
def get_readable_name(self):
return self.activity_code.split('.')[-1].replace('_', ' ').capitalize()
def get_status_id(self):
if self.succeeded is None:
return 'wait'
elif self.succeeded:
return 'success'
else:
return 'failed'
@classmethod @classmethod
def create(cls, code_suffix, instance, task_uuid=None, user=None, def create(cls, code_suffix, instance, task_uuid=None, user=None,
...@@ -108,6 +101,51 @@ class InstanceActivity(ActivityModel): ...@@ -108,6 +101,51 @@ class InstanceActivity(ActivityModel):
act.save() act.save()
return act return act
def get_absolute_url(self):
return reverse('dashboard.views.vm-activity', args=[self.pk])
def get_readable_name(self):
activity_code_last_suffix = split_activity_code(self.activity_code)[-1]
return activity_code_last_suffix.replace('_', ' ').capitalize()
def get_status_id(self):
if self.succeeded is None:
return 'wait'
elif self.succeeded:
return 'success'
else:
return 'failed'
@property
def is_abortable(self):
"""Can the activity be aborted?
:returns: True if the activity can be aborted; otherwise, False.
"""
op = self.instance.get_operation_from_activity_code(self.activity_code)
return self.task_uuid and op and op.abortable and not self.finished
def is_abortable_for(self, user):
"""Can the given user abort the activity?
"""
return self.is_abortable and (
user.is_superuser or user in (self.instance.owner, self.user))
@property
def is_aborted(self):
"""Has the activity been aborted?
:returns: True if the activity has been aborted; otherwise, False.
"""
return self.task_uuid and AbortableAsyncResult(self.task_uuid
).is_aborted()
def save(self, *args, **kwargs):
ret = super(InstanceActivity, self).save(*args, **kwargs)
self.instance._update_status()
return ret
@contextmanager @contextmanager
def sub_activity(self, code_suffix, on_abort=None, on_commit=None, def sub_activity(self, code_suffix, on_abort=None, on_commit=None,
task_uuid=None, concurrency_check=True): task_uuid=None, concurrency_check=True):
...@@ -116,11 +154,6 @@ class InstanceActivity(ActivityModel): ...@@ -116,11 +154,6 @@ class InstanceActivity(ActivityModel):
act = self.create_sub(code_suffix, task_uuid, concurrency_check) act = self.create_sub(code_suffix, task_uuid, concurrency_check)
return activitycontextimpl(act, on_abort=on_abort, on_commit=on_commit) return activitycontextimpl(act, on_abort=on_abort, on_commit=on_commit)
def save(self, *args, **kwargs):
ret = super(InstanceActivity, self).save(*args, **kwargs)
self.instance._update_status()
return ret
@contextmanager @contextmanager
def instance_activity(code_suffix, instance, on_abort=None, on_commit=None, def instance_activity(code_suffix, instance, on_abort=None, on_commit=None,
......
...@@ -17,11 +17,14 @@ ...@@ -17,11 +17,14 @@
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
from datetime import timedelta from datetime import timedelta
from functools import partial
from importlib import import_module from importlib import import_module
from logging import getLogger from logging import getLogger
from string import ascii_lowercase from string import ascii_lowercase
from warnings import warn from warnings import warn
from celery.exceptions import TimeoutError
from celery.contrib.abortable import AbortableAsyncResult
import django.conf import django.conf
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core import signing from django.core import signing
...@@ -832,13 +835,20 @@ class Instance(AclBase, VirtualMachineDescModel, StatusModel, OperatedMixin, ...@@ -832,13 +835,20 @@ class Instance(AclBase, VirtualMachineDescModel, StatusModel, OperatedMixin,
queue=queue_name queue=queue_name
).get(timeout=timeout) ).get(timeout=timeout)
def shutdown_vm(self, timeout=120): def shutdown_vm(self, task=None, step=5):
queue_name = self.get_remote_queue_name('vm') queue_name = self.get_remote_queue_name('vm')
logger.debug("RPC Shutdown at queue: %s, for vm: %s.", queue_name, logger.debug("RPC Shutdown at queue: %s, for vm: %s.", queue_name,
self.vm_name) self.vm_name)
return vm_tasks.shutdown.apply_async(kwargs={'name': self.vm_name}, remote = vm_tasks.shutdown.apply_async(kwargs={'name': self.vm_name},
queue=queue_name queue=queue_name)
).get(timeout=timeout)
while True:
try:
return remote.get(timeout=step)
except TimeoutError:
if task is not None and task.is_aborted():
AbortableAsyncResult(remote.id).abort()
raise Exception("Shutdown aborted by user.")
def suspend_vm(self, timeout=60): def suspend_vm(self, timeout=60):
queue_name = self.get_remote_queue_name('vm') queue_name = self.get_remote_queue_name('vm')
...@@ -891,3 +901,13 @@ class Instance(AclBase, VirtualMachineDescModel, StatusModel, OperatedMixin, ...@@ -891,3 +901,13 @@ class Instance(AclBase, VirtualMachineDescModel, StatusModel, OperatedMixin,
'PENDING': 'icon-rocket', 'PENDING': 'icon-rocket',
'DESTROYED': 'icon-trash', 'DESTROYED': 'icon-trash',
'MIGRATING': 'icon-truck'}.get(self.status, 'icon-question-sign') 'MIGRATING': 'icon-truck'}.get(self.status, 'icon-question-sign')
def get_activities(self, user=None):
acts = (self.activity_log.filter(parent=None).
order_by('-started').
select_related('user').prefetch_related('children'))
if user is not None:
for i in acts:
i.is_abortable_for_user = partial(i.is_abortable_for,
user=user)
return acts
...@@ -26,7 +26,9 @@ from django.utils.translation import ugettext_lazy as _ ...@@ -26,7 +26,9 @@ from django.utils.translation import ugettext_lazy as _
from celery.exceptions import TimeLimitExceeded from celery.exceptions import TimeLimitExceeded
from common.operations import Operation, register_operation from common.operations import Operation, register_operation
from .tasks.local_tasks import async_instance_operation, async_node_operation from .tasks.local_tasks import (
abortable_async_instance_operation, abortable_async_node_operation,
)
from .models import ( from .models import (
Instance, InstanceActivity, InstanceTemplate, Interface, Node, Instance, InstanceActivity, InstanceTemplate, Interface, Node,
NodeActivity, NodeActivity,
...@@ -38,7 +40,7 @@ logger = getLogger(__name__) ...@@ -38,7 +40,7 @@ logger = getLogger(__name__)
class InstanceOperation(Operation): class InstanceOperation(Operation):
acl_level = 'owner' acl_level = 'owner'
async_operation = async_instance_operation async_operation = abortable_async_instance_operation
host_cls = Instance host_cls = Instance
def __init__(self, instance): def __init__(self, instance):
...@@ -126,7 +128,7 @@ class DeployOperation(InstanceOperation): ...@@ -126,7 +128,7 @@ class DeployOperation(InstanceOperation):
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'RUNNING' activity.resultant_state = 'RUNNING'
def _operation(self, activity, user, system, timeout=15): def _operation(self, activity, timeout=15):
# Allocate VNC port and host node # Allocate VNC port and host node
self.instance.allocate_vnc_port() self.instance.allocate_vnc_port()
self.instance.allocate_node() self.instance.allocate_node()
...@@ -162,7 +164,7 @@ class DestroyOperation(InstanceOperation): ...@@ -162,7 +164,7 @@ class DestroyOperation(InstanceOperation):
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'DESTROYED' activity.resultant_state = 'DESTROYED'
def _operation(self, activity, user, system): def _operation(self, activity):
if self.instance.node: if self.instance.node:
# Destroy networks # Destroy networks
with activity.sub_activity('destroying_net'): with activity.sub_activity('destroying_net'):
...@@ -200,7 +202,7 @@ class MigrateOperation(InstanceOperation): ...@@ -200,7 +202,7 @@ class MigrateOperation(InstanceOperation):
name = _("migrate") name = _("migrate")
description = _("Live migrate running VM to another node.") description = _("Live migrate running VM to another node.")
def _operation(self, activity, user, system, to_node=None, timeout=120): def _operation(self, activity, to_node=None, timeout=120):
if not to_node: if not to_node:
with activity.sub_activity('scheduling') as sa: with activity.sub_activity('scheduling') as sa:
to_node = self.instance.select_node() to_node = self.instance.select_node()
...@@ -230,7 +232,7 @@ class RebootOperation(InstanceOperation): ...@@ -230,7 +232,7 @@ class RebootOperation(InstanceOperation):
name = _("reboot") name = _("reboot")
description = _("Reboot virtual machine with Ctrl+Alt+Del signal.") description = _("Reboot virtual machine with Ctrl+Alt+Del signal.")
def _operation(self, activity, user, system, timeout=5): def _operation(self, timeout=5):
self.instance.reboot_vm(timeout=timeout) self.instance.reboot_vm(timeout=timeout)
...@@ -280,7 +282,7 @@ class ResetOperation(InstanceOperation): ...@@ -280,7 +282,7 @@ class ResetOperation(InstanceOperation):
name = _("reset") name = _("reset")
description = _("Reset virtual machine (reset button).") description = _("Reset virtual machine (reset button).")
def _operation(self, activity, user, system, timeout=5): def _operation(self, timeout=5):
self.instance.reset_vm(timeout=timeout) self.instance.reset_vm(timeout=timeout)
register_operation(ResetOperation) register_operation(ResetOperation)
...@@ -295,6 +297,7 @@ class SaveAsTemplateOperation(InstanceOperation): ...@@ -295,6 +297,7 @@ class SaveAsTemplateOperation(InstanceOperation):
Template can be shared with groups and users. Template can be shared with groups and users.
Users can instantiate Virtual Machines from Templates. Users can instantiate Virtual Machines from Templates.
""") """)
abortable = True
@staticmethod @staticmethod
def _rename(name): def _rename(name):
...@@ -307,11 +310,11 @@ class SaveAsTemplateOperation(InstanceOperation): ...@@ -307,11 +310,11 @@ class SaveAsTemplateOperation(InstanceOperation):
return "%s v%d" % (name, v) return "%s v%d" % (name, v)
def _operation(self, activity, user, system, timeout=300, name=None, def _operation(self, activity, user, system, timeout=300, name=None,
with_shutdown=True, **kwargs): with_shutdown=True, task=None, **kwargs):
if with_shutdown: if with_shutdown:
try: try:
ShutdownOperation(self.instance).call(parent_activity=activity, ShutdownOperation(self.instance).call(parent_activity=activity,
user=user) user=user, task=task)
except Instance.WrongStateError: except Instance.WrongStateError:
pass pass
...@@ -370,23 +373,18 @@ class ShutdownOperation(InstanceOperation): ...@@ -370,23 +373,18 @@ class ShutdownOperation(InstanceOperation):
id = 'shutdown' id = 'shutdown'
name = _("shutdown") name = _("shutdown")
description = _("Shutdown virtual machine with ACPI signal.") description = _("Shutdown virtual machine with ACPI signal.")
abortable = True
def check_precond(self): def check_precond(self):
super(ShutdownOperation, self).check_precond() super(ShutdownOperation, self).check_precond()
if self.instance.status not in ['RUNNING']: if self.instance.status not in ['RUNNING']:
raise self.instance.WrongStateError(self.instance) raise self.instance.WrongStateError(self.instance)
def on_abort(self, activity, error):
if isinstance(error, TimeLimitExceeded):
activity.resultant_state = None
else:
activity.resultant_state = 'ERROR'
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'STOPPED' activity.resultant_state = 'STOPPED'
def _operation(self, activity, user, system, timeout=120): def _operation(self, task=None):
self.instance.shutdown_vm(timeout=timeout) self.instance.shutdown_vm(task=task)
self.instance.yield_node() self.instance.yield_node()
self.instance.yield_vnc_port() self.instance.yield_vnc_port()
...@@ -403,7 +401,7 @@ class ShutOffOperation(InstanceOperation): ...@@ -403,7 +401,7 @@ class ShutOffOperation(InstanceOperation):
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'STOPPED' activity.resultant_state = 'STOPPED'
def _operation(self, activity, user, system): def _operation(self, activity):
# Shutdown networks # Shutdown networks
with activity.sub_activity('shutdown_net'): with activity.sub_activity('shutdown_net'):
self.instance.shutdown_net() self.instance.shutdown_net()
...@@ -440,7 +438,7 @@ class SleepOperation(InstanceOperation): ...@@ -440,7 +438,7 @@ class SleepOperation(InstanceOperation):
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'SUSPENDED' activity.resultant_state = 'SUSPENDED'
def _operation(self, activity, user, system, timeout=60): def _operation(self, activity, timeout=60):
# Destroy networks # Destroy networks
with activity.sub_activity('shutdown_net'): with activity.sub_activity('shutdown_net'):
self.instance.shutdown_net() self.instance.shutdown_net()
...@@ -476,7 +474,7 @@ class WakeUpOperation(InstanceOperation): ...@@ -476,7 +474,7 @@ class WakeUpOperation(InstanceOperation):
def on_commit(self, activity): def on_commit(self, activity):
activity.resultant_state = 'RUNNING' activity.resultant_state = 'RUNNING'
def _operation(self, activity, user, system, timeout=60): def _operation(self, activity, timeout=60):
# Schedule vm # Schedule vm
self.instance.allocate_vnc_port() self.instance.allocate_vnc_port()
self.instance.allocate_node() self.instance.allocate_node()
...@@ -497,7 +495,7 @@ register_operation(WakeUpOperation) ...@@ -497,7 +495,7 @@ register_operation(WakeUpOperation)
class NodeOperation(Operation): class NodeOperation(Operation):
async_operation = async_node_operation async_operation = abortable_async_node_operation
host_cls = Node host_cls = Node
def __init__(self, node): def __init__(self, node):
...@@ -527,7 +525,7 @@ class FlushOperation(NodeOperation): ...@@ -527,7 +525,7 @@ class FlushOperation(NodeOperation):
name = _("flush") name = _("flush")
description = _("Disable node and move all instances to other ones.") description = _("Disable node and move all instances to other ones.")
def _operation(self, activity, user, system): def _operation(self, activity, user):
self.node.disable(user, activity) self.node.disable(user, activity)
for i in self.node.instance_set.all(): for i in self.node.instance_set.all():
with activity.sub_activity('migrate_instance_%d' % i.pk): with activity.sub_activity('migrate_instance_%d' % i.pk):
......
...@@ -15,32 +15,41 @@ ...@@ -15,32 +15,41 @@
# You should have received a copy of the GNU General Public License along # You should have received a copy of the GNU General Public License along
# with CIRCLE. If not, see <http://www.gnu.org/licenses/>. # with CIRCLE. If not, see <http://www.gnu.org/licenses/>.
from celery.contrib.abortable import AbortableTask
from manager.mancelery import celery from manager.mancelery import celery
@celery.task @celery.task(base=AbortableTask, bind=True)
def async_instance_operation(operation_id, instance_pk, activity_pk, **kwargs): def abortable_async_instance_operation(task, operation_id, instance_pk,
activity_pk, allargs, auxargs):
from vm.models import Instance, InstanceActivity from vm.models import Instance, InstanceActivity
instance = Instance.objects.get(pk=instance_pk) instance = Instance.objects.get(pk=instance_pk)
operation = getattr(instance, operation_id) operation = getattr(instance, operation_id)
activity = InstanceActivity.objects.get(pk=activity_pk) activity = InstanceActivity.objects.get(pk=activity_pk)
# save async task UUID to activity # save async task UUID to activity
activity.task_uuid = async_instance_operation.request.id activity.task_uuid = task.request.id
activity.save() activity.save()
return operation._exec_op(activity=activity, **kwargs) allargs['activity'] = activity
allargs['task'] = task
return operation._exec_op(allargs, auxargs)
@celery.task
def async_node_operation(operation_id, node_pk, activity_pk, **kwargs): @celery.task(base=AbortableTask, bind=True)
def abortable_async_node_operation(task, operation_id, node_pk, activity_pk,
allargs, auxargs):
from vm.models import Node, NodeActivity from vm.models import Node, NodeActivity
node = Node.objects.get(pk=node_pk) node = Node.objects.get(pk=node_pk)
operation = getattr(node, operation_id) operation = getattr(node, operation_id)
activity = NodeActivity.objects.get(pk=activity_pk) activity = NodeActivity.objects.get(pk=activity_pk)
# save async task UUID to activity # save async task UUID to activity
activity.task_uuid = async_node_operation.request.id activity.task_uuid = task.request.id
activity.save() activity.save()
return operation._exec_op(activity=activity, **kwargs) allargs['activity'] = activity
allargs['task'] = task
return operation._exec_op(allargs, auxargs)
...@@ -19,6 +19,7 @@ from datetime import datetime ...@@ -19,6 +19,7 @@ from datetime import datetime
from mock import Mock, MagicMock, patch, call from mock import Mock, MagicMock, patch, call
import types import types
from celery.contrib.abortable import AbortableAsyncResult
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.test import TestCase from django.test import TestCase
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
...@@ -231,6 +232,92 @@ class InstanceActivityTestCase(TestCase): ...@@ -231,6 +232,92 @@ class InstanceActivityTestCase(TestCase):
raise AssertionError("'create_sub' method checked for " raise AssertionError("'create_sub' method checked for "
"concurrent activities.") "concurrent activities.")
def test_is_abortable(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertTrue(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_not_associated_with_task(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid=None)
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_finished(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=True, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_operation_not_abortable(self):
get_op = MagicMock(return_value=MagicMock(abortable=False))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_no_matching_operation(self):
get_op = MagicMock(return_value=None)
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_aborted_when_not_associated_with_task(self):
iaobj = MagicMock(task_uuid=None)
self.assertFalse(InstanceActivity.is_aborted.fget(iaobj))
def test_is_aborted_when_associated_task_is_aborted(self):
expected = object()
iaobj = MagicMock(task_uuid='test')
with patch.object(AbortableAsyncResult, 'is_aborted',
return_value=expected):
self.assertEquals(expected,
InstanceActivity.is_aborted.fget(iaobj))
def test_is_abortable_for_activity_owner_if_not_abortable(self):
iaobj = MagicMock(spec=InstanceActivity, is_abortable=False,
user=MagicMock(spec=User, is_superuser=False))
self.assertFalse(InstanceActivity.is_abortable_for(iaobj, iaobj.user))
def test_is_abortable_for_instance_owner(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op,
owner=MagicMock(spec=User, is_superuser=False))
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test',
user=MagicMock(spec=User, is_superuser=False))
self.assertTrue(
InstanceActivity.is_abortable_for(iaobj, iaobj.instance.owner))
def test_is_abortable_for_activity_owner(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test',
user=MagicMock(spec=User, is_superuser=False))
self.assertTrue(InstanceActivity.is_abortable_for(iaobj, iaobj.user))
def test_not_abortable_for_foreign(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable_for(
iaobj, MagicMock(spec=User, is_superuser=False)))
def test_is_abortable_for_superuser(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
su = MagicMock(spec=User, is_superuser=True)
self.assertTrue(InstanceActivity.is_abortable_for(iaobj, su))
def test_disable_enabled(self): def test_disable_enabled(self):
node = MagicMock(spec=Node, enabled=True) node = MagicMock(spec=Node, enabled=True)
with patch('vm.models.node.node_activity') as nac: with patch('vm.models.node.node_activity') as nac:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment