Commit 8b527a9e by Őry Máté

dashboard: use Operation.check_perms in MassOperationView

parent 245e39dd
...@@ -77,7 +77,9 @@ from .tables import ( ...@@ -77,7 +77,9 @@ from .tables import (
NodeListTable, TemplateListTable, LeaseListTable, NodeListTable, TemplateListTable, LeaseListTable,
GroupListTable, UserKeyListTable GroupListTable, UserKeyListTable
) )
from common.models import HumanReadableObject, HumanReadableException from common.models import (
HumanReadableObject, HumanReadableException, fetch_human_exception
)
from vm.models import ( from vm.models import (
Instance, instance_activity, InstanceActivity, InstanceTemplate, Interface, Instance, instance_activity, InstanceActivity, InstanceTemplate, Interface,
InterfaceTemplate, Lease, Node, NodeActivity, Trait, InterfaceTemplate, Lease, Node, NodeActivity, Trait,
...@@ -561,6 +563,10 @@ class OperationView(RedirectToLoginMixin, DetailView): ...@@ -561,6 +563,10 @@ class OperationView(RedirectToLoginMixin, DetailView):
setattr(self, '_opobj', getattr(self.get_object(), self.op)) setattr(self, '_opobj', getattr(self.get_object(), self.op))
return self._opobj return self._opobj
@classmethod
def get_operation_class(cls):
return cls.model.get_operation_class(cls.op)
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
ctx = super(OperationView, self).get_context_data(**kwargs) ctx = super(OperationView, self).get_context_data(**kwargs)
ctx['op'] = self.get_op() ctx['op'] = self.get_op()
...@@ -576,6 +582,10 @@ class OperationView(RedirectToLoginMixin, DetailView): ...@@ -576,6 +582,10 @@ class OperationView(RedirectToLoginMixin, DetailView):
logger.debug("OperationView.check_auth(%s)", unicode(self)) logger.debug("OperationView.check_auth(%s)", unicode(self))
self.get_op().check_auth(self.request.user) self.get_op().check_auth(self.request.user)
@classmethod
def check_perms(cls, user):
cls.get_operation_class().check_perms(user)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
self.check_auth() self.check_auth()
return super(OperationView, self).get(request, *args, **kwargs) return super(OperationView, self).get(request, *args, **kwargs)
...@@ -1013,6 +1023,9 @@ def get_operations(instance, user): ...@@ -1013,6 +1023,9 @@ def get_operations(instance, user):
class MassOperationView(OperationView): class MassOperationView(OperationView):
template_name = 'dashboard/mass-operate.html' template_name = 'dashboard/mass-operate.html'
def check_auth(self):
pass # OperationView.get calls this
@classmethod @classmethod
def get_urlname(cls): def get_urlname(cls):
return 'dashboard.vm.mass-op.%s' % cls.op return 'dashboard.vm.mass-op.%s' % cls.op
...@@ -1027,61 +1040,51 @@ class MassOperationView(OperationView): ...@@ -1027,61 +1040,51 @@ class MassOperationView(OperationView):
else: else:
return Instance._ops[self.op] return Instance._ops[self.op]
def dispatch(self, *args, **kwargs):
user = self.request.user
self.objects_of_user = Instance.get_objects_with_level("user", user)
return super(MassOperationView, self).dispatch(*args, **kwargs)
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
ctx = super(MassOperationView, self).get_context_data(**kwargs) ctx = super(MassOperationView, self).get_context_data(**kwargs)
instances = self.request.GET.getlist("vm") instances = self.get_object()
instances = Instance.objects.filter(pk__in=instances) ctx['instances'] = self._get_operable_instances(
ctx['instances'], ctx['vm_count'] = self._check_instances(
instances, self.request.user) instances, self.request.user)
ctx['vm_count'] = sum(1 for i in ctx['instances'] if not i.disabled)
return ctx return ctx
@classmethod def _call_operations(self, extra):
def check_auth(self, user=None): request = self.request
pass user = request.user
instances = self.get_object()
for i in instances:
try:
self.get_op(i).async(user=user, **extra)
except HumanReadableException as e:
e.send_message(request)
except Exception as e:
# pre-existing errors should have been catched when the
# confirmation dialog was constructed
messages.error(request, _(
"Failed to execute %(op)s operation on "
"instance %(instance)s.") % {"op": self.name,
"instance": i})
def get_object(self): def get_object(self):
return None vms = getattr(self.request, self.request.method).getlist("vm")
return Instance.objects.filter(pk__in=vms)
def _check_instances(self, instances, user): def _get_operable_instances(self, instances, user):
vms = []
ok_vm_count = 0
for i in instances: for i in instances:
try: try:
self._op_checks(i, user) op = self.get_op(i)
except HumanReadableException as e: op.check_auth(user)
setattr(i, "disabled", e.get_user_text()) op.check_precond()
except SuspiciousOperation: except Exception as e:
continue i.disabled = fetch_human_exception(e)
except PermissionDenied:
setattr(i, "disabled", _("Permission denied"))
except Exception:
raise
else: else:
ok_vm_count += 1 i.disabled = False
vms.append(i) return instances
return vms, ok_vm_count
def post(self, request, extra=None, *args, **kwargs): def post(self, request, extra=None, *args, **kwargs):
if extra is None: if extra is None:
extra = {} extra = {}
user = self.request.user self._call_operations(extra)
vms = request.POST.getlist("vm")
instances = Instance.objects.filter(pk__in=vms)
for i in instances:
try:
op = self._op_checks(i, user)
op.async(user=user, **extra)
except HumanReadableException as e:
e.send_message(request)
except Exception as e:
pass
if request.is_ajax(): if request.is_ajax():
store = messages.get_messages(request) store = messages.get_messages(request)
store.used = True store.used = True
...@@ -1092,14 +1095,6 @@ class MassOperationView(OperationView): ...@@ -1092,14 +1095,6 @@ class MassOperationView(OperationView):
else: else:
return redirect(reverse("dashboard.views.vm-list")) return redirect(reverse("dashboard.views.vm-list"))
def _op_checks(self, instance, user):
if instance not in self.objects_of_user:
raise SuspiciousOperation()
op = self.get_op(instance)
op.check_auth(user)
op.check_precond()
return op
@classmethod @classmethod
def factory(cls, vm_op, extra_bases=(), **kwargs): def factory(cls, vm_op, extra_bases=(), **kwargs):
return type(str(cls.__name__ + vm_op.op), return type(str(cls.__name__ + vm_op.op),
...@@ -1724,7 +1719,7 @@ class VmList(LoginRequiredMixin, FilterMixin, ListView): ...@@ -1724,7 +1719,7 @@ class VmList(LoginRequiredMixin, FilterMixin, ListView):
context['ops'] = [] context['ops'] = []
for k, v in vm_mass_ops.iteritems(): for k, v in vm_mass_ops.iteritems():
try: try:
v.check_auth(user=self.request.user) v.check_perms(user=self.request.user)
except PermissionDenied: except PermissionDenied:
pass pass
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment