Commit a16c126b by Carl Meyer

abstract base: all previous tests but one passing; no new tests yet

parent 9443861c
...@@ -2,8 +2,10 @@ from collections import defaultdict ...@@ -2,8 +2,10 @@ from collections import defaultdict
import django import django
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.generic import GenericForeignKey
from django.db import models from django.db import models
from django.db.models.related import RelatedObject from django.db.models.related import RelatedObject
from django.db.models.fields import FieldDoesNotExist
from django.db.models.fields.related import ManyToManyRel from django.db.models.fields.related import ManyToManyRel
from django.db.models.query_utils import QueryWrapper from django.db.models.query_utils import QueryWrapper
...@@ -12,9 +14,10 @@ from taggit.models import Tag, TaggedItem ...@@ -12,9 +14,10 @@ from taggit.models import Tag, TaggedItem
from taggit.utils import require_instance_manager from taggit.utils import require_instance_manager
class TaggableRel(ManyToManyRel): class TaggableRel(ManyToManyRel):
def __init__(self): def __init__(self, to):
self.to = TaggedItem self.to = to
self.related_name = None self.related_name = None
self.limit_choices_to = {} self.limit_choices_to = {}
self.symmetrical = True self.symmetrical = True
...@@ -23,8 +26,9 @@ class TaggableRel(ManyToManyRel): ...@@ -23,8 +26,9 @@ class TaggableRel(ManyToManyRel):
class TaggableManager(object): class TaggableManager(object):
def __init__(self, verbose_name="Tags"): def __init__(self, verbose_name="Tags", through=None):
self.rel = TaggableRel() self.through = through or TaggedItem
self.rel = TaggableRel(to=self.through)
self.verbose_name = verbose_name self.verbose_name = verbose_name
self.editable = True self.editable = True
self.unique = False self.unique = False
...@@ -35,16 +39,13 @@ class TaggableManager(object): ...@@ -35,16 +39,13 @@ class TaggableManager(object):
self.creation_counter = models.Field.creation_counter self.creation_counter = models.Field.creation_counter
models.Field.creation_counter += 1 models.Field.creation_counter += 1
def __get__(self, instance, type): def __get__(self, instance, model):
manager = _TaggableManager() manager = _TaggableManager(through=self.through)
manager.model = type manager.model = model
if instance is None: if instance is not None and instance.pk is None:
manager.object_id = None
elif instance.pk is None:
raise ValueError("%s objects need to have a primary key value " raise ValueError("%s objects need to have a primary key value "
"before you can access their tags." % type.__name__) "before you can access their tags." % model.__name__)
else: manager.instance = instance
manager.object_id = instance.pk
return manager return manager
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
...@@ -60,9 +61,9 @@ class TaggableManager(object): ...@@ -60,9 +61,9 @@ class TaggableManager(object):
if lookup_type != "in": if lookup_type != "in":
raise ValueError("You can't do lookups other than \"in\" on Tags") raise ValueError("You can't do lookups other than \"in\" on Tags")
if all(isinstance(v, Tag) for v in value): if all(isinstance(v, Tag) for v in value):
qs = TaggedItem.objects.filter(tag__in=value) qs = self.through.objects.filter(tag__in=value)
elif all(isinstance(v, basestring) for v in value): elif all(isinstance(v, basestring) for v in value):
qs = TaggedItem.objects.filter(tag__name__in=value) qs = self.through.objects.filter(tag__name__in=value)
elif all(isinstance(v, (int, long)) for v in value): elif all(isinstance(v, (int, long)) for v in value):
# This one is really ackward, just don't do it. The ORM does it # This one is really ackward, just don't do it. The ORM does it
# for deletes, but no one else gets to. # for deletes, but no one else gets to.
...@@ -92,53 +93,48 @@ class TaggableManager(object): ...@@ -92,53 +93,48 @@ class TaggableManager(object):
def value_from_object(self, instance): def value_from_object(self, instance):
if instance.pk: if instance.pk:
return TaggedItem.objects.filter( return self.through.objects.filter(**self.through.lookup_kwargs(instance))
object_id=instance.pk, return self.through.objects.none()
content_type=ContentType.objects.get_for_model(instance)
)
return TaggedItem.objects.none()
def related_query_name(self): def related_query_name(self):
return None return None
def m2m_reverse_name(self): def m2m_reverse_name(self):
return "id" try:
return self.through._meta.get_field('content_object').rel.to._meta.pk.column
except FieldDoesNotExist:
return "id"
def m2m_column_name(self): def m2m_column_name(self):
return "object_id" try:
return self.through._meta.get_field('content_object').column
except FieldDoesNotExist :
return self.through._meta.virtual_fields[0].fk_field
def db_type(self, connection=None): def db_type(self, connection=None):
return None return None
def m2m_db_table(self): def m2m_db_table(self):
return self.rel.to._meta.db_table return self.through._meta.db_table
def extra_filters(self, pieces, pos, negate):
if negate:
return []
prefix = "__".join(pieces[:pos+1])
cts = map(ContentType.objects.get_for_model, _get_subclasses(self.model))
if len(cts) == 1:
return [("%s__content_type" % prefix, cts[0])]
return [("%s__content_type__in" % prefix, cts)]
class _TaggableManager(models.Manager): class _TaggableManager(models.Manager):
def __init__(self, through=None):
self.through = through or TaggedItem
def get_query_set(self): def get_query_set(self):
ct = ContentType.objects.get_for_model(self.model) return self.through.tags_for(self.model, self.instance)
if self.object_id is not None:
return Tag.objects.filter(items__object_id=self.object_id,
items__content_type=ct)
else:
return Tag.objects.filter(items__content_type=ct).distinct()
def lookup_kwargs(self):
return self.through.lookup_kwargs(self.instance)
@require_instance_manager @require_instance_manager
def add(self, *tags): def add(self, *tags):
for tag in tags: for tag in tags:
if not isinstance(tag, Tag): if not isinstance(tag, Tag):
tag, _ = Tag.objects.get_or_create(name=tag) tag, _ = Tag.objects.get_or_create(name=tag)
TaggedItem.objects.get_or_create(object_id=self.object_id, self.through.objects.get_or_create(**dict(self.lookup_kwargs(),
content_type=ContentType.objects.get_for_model(self.model), tag=tag) tag=tag))
@require_instance_manager @require_instance_manager
def set(self, *tags): def set(self, *tags):
...@@ -147,53 +143,43 @@ class _TaggableManager(models.Manager): ...@@ -147,53 +143,43 @@ class _TaggableManager(models.Manager):
@require_instance_manager @require_instance_manager
def remove(self, *tags): def remove(self, *tags):
TaggedItem.objects.filter(object_id=self.object_id, self.through.objects.filter(**self.lookup_kwargs()).filter(
content_type=ContentType.objects.get_for_model(self.model)).filter(
tag__name__in=tags).delete() tag__name__in=tags).delete()
@require_instance_manager @require_instance_manager
def clear(self): def clear(self):
TaggedItem.objects.filter(object_id=self.object_id, self.through.objects.filter(**self.lookup_kwargs()).delete()
content_type=ContentType.objects.get_for_model(self.model)).delete()
def most_common(self): def most_common(self):
return self.get_query_set().annotate( return self.get_query_set().annotate(
num_times=models.Count('items') num_times=models.Count(self.through.tag_relname())
).order_by('-num_times') ).order_by('-num_times')
@require_instance_manager @require_instance_manager
def similar_objects(self): def similar_objects(self):
qs = TaggedItem.objects.values('object_id', 'content_type') qs = self.through.objects.values(*self.lookup_kwargs().keys())
qs = qs.annotate(n=models.Count('pk')) qs = qs.annotate(n=models.Count('pk'))
qs = qs.exclude( qs = qs.exclude(**self.lookup_kwargs())
object_id=self.object_id,
content_type=ContentType.objects.get_for_model(self.model)
)
qs = qs.filter(tag__in=self.all()) qs = qs.filter(tag__in=self.all())
qs = qs.order_by('-n') qs = qs.order_by('-n')
preload = defaultdict(set) if not 'content_object' in self.lookup_kwargs():
for result in qs: preload = defaultdict(set)
preload[result["content_type"]].add(result["object_id"]) for result in qs:
preload[result["content_type"]].add(result["object_id"])
items = {} items = {}
for ct, obj_ids in preload.iteritems(): for ct, obj_ids in preload.iteritems():
ct = ContentType.objects.get_for_id(ct) ct = ContentType.objects.get_for_id(ct)
items[ct.pk] = dict((o.pk, o) for o in items[ct.pk] = dict((o.pk, o) for o in
ct.model_class()._default_manager.filter(pk__in=obj_ids) ct.model_class()._default_manager.filter(pk__in=obj_ids))
)
results = [] results = []
for result in qs: for result in qs:
obj = items[result["content_type"]][result["object_id"]] try:
obj = result['content_object']
except KeyError:
obj = items[result["content_type"]][result["object_id"]]
obj.similar_tags = result["n"] obj.similar_tags = result["n"]
results.append(obj) results.append(obj)
return results return results
def _get_subclasses(model):
subclasses = [model]
for f in model._meta.get_all_field_names():
field = model._meta.get_field_by_name(f)[0]
if isinstance(field, RelatedObject) and getattr(field.field.rel, "parent_link", None):
subclasses.extend(_get_subclasses(field.model))
return subclasses
...@@ -25,12 +25,47 @@ class Tag(models.Model): ...@@ -25,12 +25,47 @@ class Tag(models.Model):
return super(Tag, self).save(*args, **kwargs) return super(Tag, self).save(*args, **kwargs)
class TaggedItem(models.Model): class TaggedItemBase(models.Model):
tag = models.ForeignKey(Tag, related_name="%(app_label)s_%(class)s_items")
def __unicode__(self):
return "%s tagged with %s" % (self.content_object, self.tag)
class Meta:
abstract = True
@classmethod
def tag_relname(cls):
return cls._meta.get_field('tag').rel.related_name
@classmethod
def lookup_kwargs(cls, instance):
return {'content_object': instance}
@classmethod
def tags_for(cls, model, instance=None):
if instance is not None:
return Tag.objects.filter(**{'%s__content_object' % cls.tag_relname(): instance})
else:
return Tag.objects.filter(**{'%s__content_object__isnull' % cls.tag_relname(): False})
class TaggedItem(TaggedItemBase):
object_id = models.IntegerField() object_id = models.IntegerField()
content_type = models.ForeignKey(ContentType, related_name="tagged_items") content_type = models.ForeignKey(ContentType, related_name="tagged_items")
content_object = GenericForeignKey() content_object = GenericForeignKey()
tag = models.ForeignKey(Tag, related_name="items") @classmethod
def lookup_kwargs(cls, instance):
def __unicode__(self): return {'object_id': instance.pk,
return "%s tagged with %s" % (self.content_object, self.tag) 'content_type': ContentType.objects.get_for_model(instance)}
@classmethod
def tags_for(cls, model, instance=None):
ct = ContentType.objects.get_for_model(model)
if instance is not None:
return Tag.objects.filter(**{'%s__object_id' % cls.tag_relname(): instance.pk,
'%s__content_type' % cls.tag_relname(): ct})
else:
return Tag.objects.filter(**{'%s__content_type' % cls.tag_relname(): ct}).distinct()
...@@ -84,7 +84,7 @@ class LookupByTagTestCase(BaseTaggingTest): ...@@ -84,7 +84,7 @@ class LookupByTagTestCase(BaseTaggingTest):
apple.tags.add("red", "green") apple.tags.add("red", "green")
pear = Food.objects.create(name="pear") pear = Food.objects.create(name="pear")
pear.tags.add("green") pear.tags.add("green")
self.assertEqual(list(Food.objects.filter(tags__in=["red"])), [apple]) self.assertEqual(list(Food.objects.filter(tags__in=["red"])), [apple])
self.assertEqual(list(Food.objects.filter(tags__in=["green"])), [apple, pear]) self.assertEqual(list(Food.objects.filter(tags__in=["green"])), [apple, pear])
......
...@@ -8,7 +8,7 @@ def parse_tags(tags): ...@@ -8,7 +8,7 @@ def parse_tags(tags):
def require_instance_manager(func): def require_instance_manager(func):
@wraps(func) @wraps(func)
def inner(self, *args, **kwargs): def inner(self, *args, **kwargs):
if self.object_id is None: if self.instance is None:
raise TypeError("Can't call %s with a non-instance manager" % func.__name__) raise TypeError("Can't call %s with a non-instance manager" % func.__name__)
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
return inner return inner
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment