Commit be367e47 by Őry Máté

Merge branch 'feature-abortable-operations-rebased'

Feature Abortable Operations
Make possible the implementation of abortable operations. Implement it for
shutdown.
parents 681af7af e7abb4a5
......@@ -18,6 +18,7 @@
from collections import deque
from contextlib import contextmanager
from hashlib import sha224
from itertools import chain, imap
from logging import getLogger
from time import time
......@@ -56,12 +57,52 @@ activity_context = contextmanager(activitycontextimpl)
activity_code_separator = '.'
def has_prefix(activity_code, *prefixes):
"""Determine whether the activity code has the specified prefix.
E.g.: has_prefix('foo.bar.buz', 'foo.bar') == True
has_prefix('foo.bar.buz', 'foo', 'bar') == True
has_prefix('foo.bar.buz', 'foo.bar', 'buz') == True
has_prefix('foo.bar.buz', 'foo', 'bar', 'buz') == True
has_prefix('foo.bar.buz', 'foo', 'buz') == False
"""
equal = lambda a, b: a == b
act_code_parts = split_activity_code(activity_code)
prefixes = chain(*imap(split_activity_code, prefixes))
return all(imap(equal, act_code_parts, prefixes))
def has_suffix(activity_code, *suffixes):
"""Determine whether the activity code has the specified suffix.
E.g.: has_suffix('foo.bar.buz', 'bar.buz') == True
has_suffix('foo.bar.buz', 'bar', 'buz') == True
has_suffix('foo.bar.buz', 'foo.bar', 'buz') == True
has_suffix('foo.bar.buz', 'foo', 'bar', 'buz') == True
has_suffix('foo.bar.buz', 'foo', 'buz') == False
"""
equal = lambda a, b: a == b
act_code_parts = split_activity_code(activity_code)
suffixes = list(chain(*imap(split_activity_code, suffixes)))
return all(imap(equal, reversed(act_code_parts), reversed(suffixes)))
def join_activity_code(*args):
"""Join the specified parts into an activity code.
:returns: Activity code string.
"""
return activity_code_separator.join(args)
def split_activity_code(activity_code):
"""Split the specified activity code into its parts.
:returns: A list of activity code parts.
"""
return activity_code.split(activity_code_separator)
class ActivityModel(TimeStampedModel):
activity_code = CharField(max_length=100, verbose_name=_('activity code'))
parent = ForeignKey('self', blank=True, null=True, related_name='children')
......
......@@ -15,9 +15,10 @@
# You should have received a copy of the GNU General Public License along
# with CIRCLE. If not, see <http://www.gnu.org/licenses/>.
from inspect import getargspec
from logging import getLogger
from .models import activity_context
from .models import activity_context, has_suffix
from django.core.exceptions import PermissionDenied
......@@ -31,6 +32,7 @@ class Operation(object):
async_queue = 'localhost.man'
required_perms = ()
do_not_call_in_templates = True
abortable = False
def __call__(self, **kwargs):
return self.call(**kwargs)
......@@ -46,23 +48,50 @@ class Operation(object):
def __prelude(self, kwargs):
"""This method contains the shared prelude of call and async.
"""
skip_auth_check = kwargs.setdefault('system', False)
user = kwargs.setdefault('user', None)
parent_activity = kwargs.pop('parent_activity', None)
defaults = {'parent_activity': None, 'system': False, 'user': None}
allargs = dict(defaults, **kwargs) # all arguments
auxargs = allargs.copy() # auxiliary (i.e. only for _operation) args
# NOTE: consumed items should be removed from auxargs, and no new items
# should be added to it
skip_auth_check = auxargs.pop('system')
user = auxargs.pop('user')
parent_activity = auxargs.pop('parent_activity')
# check for unexpected keyword arguments
argspec = getargspec(self._operation)
if argspec.keywords is None: # _operation doesn't take ** args
unexpected_kwargs = set(auxargs) - set(argspec.args)
if unexpected_kwargs:
raise TypeError("Operation got unexpected keyword arguments: "
"%s" % ", ".join(unexpected_kwargs))
if not skip_auth_check:
self.check_auth(user)
self.check_precond()
return self.create_activity(parent=parent_activity, user=user)
def _exec_op(self, activity, user, **kwargs):
activity = self.create_activity(parent=parent_activity, user=user)
return activity, allargs, auxargs
def _exec_op(self, allargs, auxargs):
"""Execute the operation inside the specified activity's context.
"""
with activity_context(activity, on_abort=self.on_abort,
# compile arguments for _operation
argspec = getargspec(self._operation)
if argspec.keywords is not None: # _operation takes ** args
arguments = allargs.copy()
else: # _operation doesn't take ** args
arguments = {k: v for (k, v) in allargs.iteritems()
if k in argspec.args}
arguments.update(auxargs)
with activity_context(allargs['activity'], on_abort=self.on_abort,
on_commit=self.on_commit):
return self._operation(activity=activity, user=user, **kwargs)
return self._operation(**arguments)
def _operation(self, activity, user, system, **kwargs):
def _operation(self, **kwargs):
"""This method is the operation's particular implementation.
Deriving classes should implement this method.
......@@ -82,12 +111,10 @@ class Operation(object):
logger.info("%s called asynchronously on %s with the following "
"parameters: %r", self.__class__.__name__, self.subject,
kwargs)
activity = self.__prelude(kwargs)
return self.async_operation.apply_async(args=(self.id,
self.subject.pk,
activity.pk),
kwargs=kwargs,
queue=self.async_queue)
activity, allargs, auxargs = self.__prelude(kwargs)
return self.async_operation.apply_async(
args=(self.id, self.subject.pk, activity.pk, allargs, auxargs, ),
queue=self.async_queue)
def call(self, **kwargs):
"""Execute the operation (synchronously).
......@@ -105,8 +132,9 @@ class Operation(object):
logger.info("%s called (synchronously) on %s with the following "
"parameters: %r", self.__class__.__name__, self.subject,
kwargs)
activity = self.__prelude(kwargs)
return self._exec_op(activity=activity, **kwargs)
activity, allargs, auxargs = self.__prelude(kwargs)
allargs['activity'] = activity
return self._exec_op(allargs, auxargs)
def check_precond(self):
pass
......@@ -160,6 +188,19 @@ class OperatedMixin(object):
else:
yield op
def get_operation_from_activity_code(self, activity_code):
"""Get an instance of the Operation corresponding to the specified
activity code.
:returns: A bound instance of an operation, or None if no matching
operation could be found.
"""
for op in getattr(self, operation_registry_name, {}).itervalues():
if has_suffix(activity_code, op.activity_code_suffix):
return op(self)
else:
return None
def register_operation(op_cls, op_id=None, target_cls=None):
"""Register the specified operation with the target class.
......
......@@ -75,3 +75,41 @@ class OperationTestCase(TestCase):
patch.object(Operation, 'create_activity'), \
patch.object(Operation, '_exec_op'):
op.call(system=True)
def test_no_exception_for_more_arguments_when_operation_takes_kwargs(self):
class KwargOp(Operation):
activity_code_suffix = 'test'
id = 'test'
def _operation(self, **kwargs):
pass
op = KwargOp(MagicMock())
with patch.object(KwargOp, 'create_activity'), \
patch.object(KwargOp, '_exec_op'):
op.call(system=True, foo=42)
def test_exception_for_unexpected_arguments(self):
class TestOp(Operation):
activity_code_suffix = 'test'
id = 'test'
def _operation(self):
pass
op = TestOp(MagicMock())
with patch.object(TestOp, 'create_activity'), \
patch.object(TestOp, '_exec_op'):
self.assertRaises(TypeError, op.call, system=True, foo=42)
def test_exception_for_missing_arguments(self):
class TestOp(Operation):
activity_code_suffix = 'test'
id = 'test'
def _operation(self, foo):
pass
op = TestOp(MagicMock())
with patch.object(TestOp, 'create_activity'):
self.assertRaises(TypeError, op.call, system=True)
......@@ -9,6 +9,14 @@
{{ a.get_readable_name }}{% if user.is_superuser %}</a>{% endif %}
</strong>
{{ a.started|date:"Y-m-d H:i" }}{% if a.user %}, {{ a.user }}{% endif %}
{% if a.is_abortable_for_user %}
<form action="{{ a.instance.get_absolute_url }}" method="POST" class="pull-right">
{% csrf_token %}
<input type="hidden" name="abort_operation"/>
<input type="hidden" name="activity" value="{{ a.pk }}"/>
<button class="btn btn-danger btn-xs"><i class="icon-bolt"></i> {% trans "Abort" %}</button>
</form>
{% endif %}
{% if a.children.count > 0 %}
<div class="sub-timeline">
{% for s in a.children.all %}
......
......@@ -224,11 +224,7 @@ class VmDetailView(CheckedDetailView):
})
# activity data
context['activities'] = (
InstanceActivity.objects.filter(
instance=self.object, parent=None).
order_by('-started').
select_related('user').prefetch_related('children'))
context['activities'] = self.object.get_activities(self.request.user)
context['vlans'] = Vlan.get_objects_with_level(
'user', self.request.user
......@@ -260,6 +256,7 @@ class VmDetailView(CheckedDetailView):
'to_remove': self.__remove_tag,
'port': self.__add_port,
'new_network_vlan': self.__new_network,
'abort_operation': self.__abort_operation,
}
for k, v in options.iteritems():
if request.POST.get(k) is not None:
......@@ -445,6 +442,16 @@ class VmDetailView(CheckedDetailView):
return redirect("%s#network" % reverse_lazy(
"dashboard.views.detail", kwargs={'pk': self.object.pk}))
def __abort_operation(self, request):
self.object = self.get_object()
activity = get_object_or_404(InstanceActivity,
pk=request.POST.get("activity"))
if not activity.is_abortable_for(request.user):
raise PermissionDenied()
activity.abort()
return redirect("%s#activity" % self.object.get_absolute_url())
class OperationView(DetailView):
......@@ -1736,9 +1743,7 @@ def vm_activity(request, pk):
if only_status == "false": # instance activity
context = {
'instance': instance,
'activities': InstanceActivity.objects.filter(
instance=instance, parent=None
).order_by('-started').select_related(),
'activities': instance.get_activities(request.user),
'ops': get_operations(instance, request.user),
}
......@@ -2343,10 +2348,8 @@ class InstanceActivityDetail(SuperuserRequiredMixin, DetailView):
def get_context_data(self, **kwargs):
ctx = super(InstanceActivityDetail, self).get_context_data(**kwargs)
ctx['activities'] = (
self.object.instance.activity_log.filter(parent=None).
order_by('-started').select_related('user').
prefetch_related('children'))
ctx['activities'] = self.object.instance.get_activities(
self.request.user)
return ctx
......
......@@ -20,15 +20,19 @@ from contextlib import contextmanager
from logging import getLogger
from celery.signals import worker_ready
from celery.contrib.abortable import AbortableAsyncResult
from django.core.urlresolvers import reverse
from django.db.models import CharField, ForeignKey
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _
from common.models import (
ActivityModel, activitycontextimpl, join_activity_code,
ActivityModel, activitycontextimpl, join_activity_code, split_activity_code
)
from manager.mancelery import celery
logger = getLogger(__name__)
......@@ -66,19 +70,8 @@ class InstanceActivity(ActivityModel):
return '{}({})'.format(self.activity_code,
self.instance)
def get_absolute_url(self):
return reverse('dashboard.views.vm-activity', args=[self.pk])
def get_readable_name(self):
return self.activity_code.split('.')[-1].replace('_', ' ').capitalize()
def get_status_id(self):
if self.succeeded is None:
return 'wait'
elif self.succeeded:
return 'success'
else:
return 'failed'
def abort(self):
AbortableAsyncResult(self.task_uuid, backend=celery.backend).abort()
@classmethod
def create(cls, code_suffix, instance, task_uuid=None, user=None,
......@@ -108,6 +101,51 @@ class InstanceActivity(ActivityModel):
act.save()
return act
def get_absolute_url(self):
return reverse('dashboard.views.vm-activity', args=[self.pk])
def get_readable_name(self):
activity_code_last_suffix = split_activity_code(self.activity_code)[-1]
return activity_code_last_suffix.replace('_', ' ').capitalize()
def get_status_id(self):
if self.succeeded is None:
return 'wait'
elif self.succeeded:
return 'success'
else:
return 'failed'
@property
def is_abortable(self):
"""Can the activity be aborted?
:returns: True if the activity can be aborted; otherwise, False.
"""
op = self.instance.get_operation_from_activity_code(self.activity_code)
return self.task_uuid and op and op.abortable and not self.finished
def is_abortable_for(self, user):
"""Can the given user abort the activity?
"""
return self.is_abortable and (
user.is_superuser or user in (self.instance.owner, self.user))
@property
def is_aborted(self):
"""Has the activity been aborted?
:returns: True if the activity has been aborted; otherwise, False.
"""
return self.task_uuid and AbortableAsyncResult(self.task_uuid
).is_aborted()
def save(self, *args, **kwargs):
ret = super(InstanceActivity, self).save(*args, **kwargs)
self.instance._update_status()
return ret
@contextmanager
def sub_activity(self, code_suffix, on_abort=None, on_commit=None,
task_uuid=None, concurrency_check=True):
......@@ -116,11 +154,6 @@ class InstanceActivity(ActivityModel):
act = self.create_sub(code_suffix, task_uuid, concurrency_check)
return activitycontextimpl(act, on_abort=on_abort, on_commit=on_commit)
def save(self, *args, **kwargs):
ret = super(InstanceActivity, self).save(*args, **kwargs)
self.instance._update_status()
return ret
@contextmanager
def instance_activity(code_suffix, instance, on_abort=None, on_commit=None,
......
......@@ -17,11 +17,14 @@
from __future__ import absolute_import, unicode_literals
from datetime import timedelta
from functools import partial
from importlib import import_module
from logging import getLogger
from string import ascii_lowercase
from warnings import warn
from celery.exceptions import TimeoutError
from celery.contrib.abortable import AbortableAsyncResult
import django.conf
from django.contrib.auth.models import User
from django.core import signing
......@@ -832,13 +835,20 @@ class Instance(AclBase, VirtualMachineDescModel, StatusModel, OperatedMixin,
queue=queue_name
).get(timeout=timeout)
def shutdown_vm(self, timeout=120):
def shutdown_vm(self, task=None, step=5):
queue_name = self.get_remote_queue_name('vm')
logger.debug("RPC Shutdown at queue: %s, for vm: %s.", queue_name,
self.vm_name)
return vm_tasks.shutdown.apply_async(kwargs={'name': self.vm_name},
queue=queue_name
).get(timeout=timeout)
remote = vm_tasks.shutdown.apply_async(kwargs={'name': self.vm_name},
queue=queue_name)
while True:
try:
return remote.get(timeout=step)
except TimeoutError:
if task is not None and task.is_aborted():
AbortableAsyncResult(remote.id).abort()
raise Exception("Shutdown aborted by user.")
def suspend_vm(self, timeout=60):
queue_name = self.get_remote_queue_name('vm')
......@@ -891,3 +901,13 @@ class Instance(AclBase, VirtualMachineDescModel, StatusModel, OperatedMixin,
'PENDING': 'icon-rocket',
'DESTROYED': 'icon-trash',
'MIGRATING': 'icon-truck'}.get(self.status, 'icon-question-sign')
def get_activities(self, user=None):
acts = (self.activity_log.filter(parent=None).
order_by('-started').
select_related('user').prefetch_related('children'))
if user is not None:
for i in acts:
i.is_abortable_for_user = partial(i.is_abortable_for,
user=user)
return acts
......@@ -26,7 +26,9 @@ from django.utils.translation import ugettext_lazy as _
from celery.exceptions import TimeLimitExceeded
from common.operations import Operation, register_operation
from .tasks.local_tasks import async_instance_operation, async_node_operation
from .tasks.local_tasks import (
abortable_async_instance_operation, abortable_async_node_operation,
)
from .models import (
Instance, InstanceActivity, InstanceTemplate, Interface, Node,
NodeActivity,
......@@ -38,7 +40,7 @@ logger = getLogger(__name__)
class InstanceOperation(Operation):
acl_level = 'owner'
async_operation = async_instance_operation
async_operation = abortable_async_instance_operation
host_cls = Instance
def __init__(self, instance):
......@@ -126,7 +128,7 @@ class DeployOperation(InstanceOperation):
def on_commit(self, activity):
activity.resultant_state = 'RUNNING'
def _operation(self, activity, user, system, timeout=15):
def _operation(self, activity, timeout=15):
# Allocate VNC port and host node
self.instance.allocate_vnc_port()
self.instance.allocate_node()
......@@ -162,7 +164,7 @@ class DestroyOperation(InstanceOperation):
def on_commit(self, activity):
activity.resultant_state = 'DESTROYED'
def _operation(self, activity, user, system):
def _operation(self, activity):
if self.instance.node:
# Destroy networks
with activity.sub_activity('destroying_net'):
......@@ -200,7 +202,7 @@ class MigrateOperation(InstanceOperation):
name = _("migrate")
description = _("Live migrate running VM to another node.")
def _operation(self, activity, user, system, to_node=None, timeout=120):
def _operation(self, activity, to_node=None, timeout=120):
if not to_node:
with activity.sub_activity('scheduling') as sa:
to_node = self.instance.select_node()
......@@ -230,7 +232,7 @@ class RebootOperation(InstanceOperation):
name = _("reboot")
description = _("Reboot virtual machine with Ctrl+Alt+Del signal.")
def _operation(self, activity, user, system, timeout=5):
def _operation(self, timeout=5):
self.instance.reboot_vm(timeout=timeout)
......@@ -280,7 +282,7 @@ class ResetOperation(InstanceOperation):
name = _("reset")
description = _("Reset virtual machine (reset button).")
def _operation(self, activity, user, system, timeout=5):
def _operation(self, timeout=5):
self.instance.reset_vm(timeout=timeout)
register_operation(ResetOperation)
......@@ -295,6 +297,7 @@ class SaveAsTemplateOperation(InstanceOperation):
Template can be shared with groups and users.
Users can instantiate Virtual Machines from Templates.
""")
abortable = True
@staticmethod
def _rename(name):
......@@ -307,11 +310,11 @@ class SaveAsTemplateOperation(InstanceOperation):
return "%s v%d" % (name, v)
def _operation(self, activity, user, system, timeout=300, name=None,
with_shutdown=True, **kwargs):
with_shutdown=True, task=None, **kwargs):
if with_shutdown:
try:
ShutdownOperation(self.instance).call(parent_activity=activity,
user=user)
user=user, task=task)
except Instance.WrongStateError:
pass
......@@ -370,23 +373,18 @@ class ShutdownOperation(InstanceOperation):
id = 'shutdown'
name = _("shutdown")
description = _("Shutdown virtual machine with ACPI signal.")
abortable = True
def check_precond(self):
super(ShutdownOperation, self).check_precond()
if self.instance.status not in ['RUNNING']:
raise self.instance.WrongStateError(self.instance)
def on_abort(self, activity, error):
if isinstance(error, TimeLimitExceeded):
activity.resultant_state = None
else:
activity.resultant_state = 'ERROR'
def on_commit(self, activity):
activity.resultant_state = 'STOPPED'
def _operation(self, activity, user, system, timeout=120):
self.instance.shutdown_vm(timeout=timeout)
def _operation(self, task=None):
self.instance.shutdown_vm(task=task)
self.instance.yield_node()
self.instance.yield_vnc_port()
......@@ -403,7 +401,7 @@ class ShutOffOperation(InstanceOperation):
def on_commit(self, activity):
activity.resultant_state = 'STOPPED'
def _operation(self, activity, user, system):
def _operation(self, activity):
# Shutdown networks
with activity.sub_activity('shutdown_net'):
self.instance.shutdown_net()
......@@ -440,7 +438,7 @@ class SleepOperation(InstanceOperation):
def on_commit(self, activity):
activity.resultant_state = 'SUSPENDED'
def _operation(self, activity, user, system, timeout=60):
def _operation(self, activity, timeout=60):
# Destroy networks
with activity.sub_activity('shutdown_net'):
self.instance.shutdown_net()
......@@ -476,7 +474,7 @@ class WakeUpOperation(InstanceOperation):
def on_commit(self, activity):
activity.resultant_state = 'RUNNING'
def _operation(self, activity, user, system, timeout=60):
def _operation(self, activity, timeout=60):
# Schedule vm
self.instance.allocate_vnc_port()
self.instance.allocate_node()
......@@ -497,7 +495,7 @@ register_operation(WakeUpOperation)
class NodeOperation(Operation):
async_operation = async_node_operation
async_operation = abortable_async_node_operation
host_cls = Node
def __init__(self, node):
......@@ -527,7 +525,7 @@ class FlushOperation(NodeOperation):
name = _("flush")
description = _("Disable node and move all instances to other ones.")
def _operation(self, activity, user, system):
def _operation(self, activity, user):
self.node.disable(user, activity)
for i in self.node.instance_set.all():
with activity.sub_activity('migrate_instance_%d' % i.pk):
......
......@@ -15,32 +15,41 @@
# You should have received a copy of the GNU General Public License along
# with CIRCLE. If not, see <http://www.gnu.org/licenses/>.
from celery.contrib.abortable import AbortableTask
from manager.mancelery import celery
@celery.task
def async_instance_operation(operation_id, instance_pk, activity_pk, **kwargs):
@celery.task(base=AbortableTask, bind=True)
def abortable_async_instance_operation(task, operation_id, instance_pk,
activity_pk, allargs, auxargs):
from vm.models import Instance, InstanceActivity
instance = Instance.objects.get(pk=instance_pk)
operation = getattr(instance, operation_id)
activity = InstanceActivity.objects.get(pk=activity_pk)
# save async task UUID to activity
activity.task_uuid = async_instance_operation.request.id
activity.task_uuid = task.request.id
activity.save()
return operation._exec_op(activity=activity, **kwargs)
allargs['activity'] = activity
allargs['task'] = task
return operation._exec_op(allargs, auxargs)
@celery.task
def async_node_operation(operation_id, node_pk, activity_pk, **kwargs):
@celery.task(base=AbortableTask, bind=True)
def abortable_async_node_operation(task, operation_id, node_pk, activity_pk,
allargs, auxargs):
from vm.models import Node, NodeActivity
node = Node.objects.get(pk=node_pk)
operation = getattr(node, operation_id)
activity = NodeActivity.objects.get(pk=activity_pk)
# save async task UUID to activity
activity.task_uuid = async_node_operation.request.id
activity.task_uuid = task.request.id
activity.save()
return operation._exec_op(activity=activity, **kwargs)
allargs['activity'] = activity
allargs['task'] = task
return operation._exec_op(allargs, auxargs)
......@@ -19,6 +19,7 @@ from datetime import datetime
from mock import Mock, MagicMock, patch, call
import types
from celery.contrib.abortable import AbortableAsyncResult
from django.contrib.auth.models import User
from django.test import TestCase
from django.utils.translation import ugettext_lazy as _
......@@ -231,6 +232,92 @@ class InstanceActivityTestCase(TestCase):
raise AssertionError("'create_sub' method checked for "
"concurrent activities.")
def test_is_abortable(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertTrue(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_not_associated_with_task(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid=None)
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_finished(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=True, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_operation_not_abortable(self):
get_op = MagicMock(return_value=MagicMock(abortable=False))
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_abortable_when_no_matching_operation(self):
get_op = MagicMock(return_value=None)
instance = MagicMock(get_operation_from_activity_code=get_op)
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test')
self.assertFalse(InstanceActivity.is_abortable.fget(iaobj))
def test_not_aborted_when_not_associated_with_task(self):
iaobj = MagicMock(task_uuid=None)
self.assertFalse(InstanceActivity.is_aborted.fget(iaobj))
def test_is_aborted_when_associated_task_is_aborted(self):
expected = object()
iaobj = MagicMock(task_uuid='test')
with patch.object(AbortableAsyncResult, 'is_aborted',
return_value=expected):
self.assertEquals(expected,
InstanceActivity.is_aborted.fget(iaobj))
def test_is_abortable_for_activity_owner_if_not_abortable(self):
iaobj = MagicMock(spec=InstanceActivity, is_abortable=False,
user=MagicMock(spec=User, is_superuser=False))
self.assertFalse(InstanceActivity.is_abortable_for(iaobj, iaobj.user))
def test_is_abortable_for_instance_owner(self):
get_op = MagicMock(return_value=MagicMock(abortable=True))
instance = MagicMock(get_operation_from_activity_code=get_op,
owner=MagicMock(spec=User, is_superuser=False))
iaobj = MagicMock(spec=InstanceActivity, activity_code='test',
finished=False, instance=instance, task_uuid='test',