Commit 8313c77f by Alex Gaynor

A tiny bit of cleanup. Thanks Carl.

parent 7ff1d161
...@@ -7,4 +7,5 @@ fakeempire <adam@fakeempire.com> ...@@ -7,4 +7,5 @@ fakeempire <adam@fakeempire.com>
Ben Firshman <ben@firshman.co.uk> Ben Firshman <ben@firshman.co.uk>
Alex Gaynor <alex.gaynor@gmail.com> Alex Gaynor <alex.gaynor@gmail.com>
Rob Hudson <rob@cogit8.org> Rob Hudson <rob@cogit8.org>
Carl Meyer
Frank Wiles Frank Wiles
from collections import defaultdict from collections import defaultdict
import django import django
from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.generic import GenericForeignKey from django.contrib.contenttypes.generic import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models from django.db import models
from django.db.models.related import RelatedObject
from django.db.models.fields.related import ManyToManyRel from django.db.models.fields.related import ManyToManyRel
from django.db.models.related import RelatedObject
from django.db.models.query_utils import QueryWrapper from django.db.models.query_utils import QueryWrapper
from taggit.forms import TagField from taggit.forms import TagField
...@@ -102,14 +102,12 @@ class TaggableManager(object): ...@@ -102,14 +102,12 @@ class TaggableManager(object):
def m2m_reverse_name(self): def m2m_reverse_name(self):
if self.use_gfk: if self.use_gfk:
return "id" return "id"
else: return self.through._meta.get_field('content_object').rel.to._meta.pk.column
return self.through._meta.get_field('content_object').rel.to._meta.pk.column
def m2m_column_name(self): def m2m_column_name(self):
if self.use_gfk: if self.use_gfk:
return self.through._meta.virtual_fields[0].fk_field return self.through._meta.virtual_fields[0].fk_field
else: return self.through._meta.get_field('content_object').column
return self.through._meta.get_field('content_object').column
def db_type(self, connection=None): def db_type(self, connection=None):
return None return None
...@@ -142,8 +140,7 @@ class _TaggableManager(models.Manager): ...@@ -142,8 +140,7 @@ class _TaggableManager(models.Manager):
for tag in tags: for tag in tags:
if not isinstance(tag, Tag): if not isinstance(tag, Tag):
tag, _ = Tag.objects.get_or_create(name=tag) tag, _ = Tag.objects.get_or_create(name=tag)
self.through.objects.get_or_create(**dict(self._lookup_kwargs(), self.through.objects.get_or_create(tag=tag, **self._lookup_kwargs())
tag=tag))
@require_instance_manager @require_instance_manager
def set(self, *tags): def set(self, *tags):
...@@ -166,17 +163,22 @@ class _TaggableManager(models.Manager): ...@@ -166,17 +163,22 @@ class _TaggableManager(models.Manager):
@require_instance_manager @require_instance_manager
def similar_objects(self): def similar_objects(self):
qs = self.through.objects.values(*self._lookup_kwargs().keys()) lookup_kwargs = self._lookup_kwargs()
qs = self.through.objects.values(*lookup_kwargs.keys())
qs = qs.annotate(n=models.Count('pk')) qs = qs.annotate(n=models.Count('pk'))
qs = qs.exclude(**self._lookup_kwargs()) qs = qs.exclude(**lookup_kwargs)
qs = qs.filter(tag__in=self.all()) qs = qs.filter(tag__in=self.all())
qs = qs.order_by('-n') qs = qs.order_by('-n')
if 'content_object' in self._lookup_kwargs(): # TODO: This all feels like a giant hack... giant.
if len(lookup_kwargs) == 1:
using_gfk = False using_gfk = False
items = dict([(o.pk, o) for o in # Can this just be select_related? I think so.
self.through._meta.get_field('content_object').rel.to.objects.filter( items = self.through._meta.get_field_by_name(
pk__in=[r['content_object'] for r in qs])]) lookup_kwargs.keys()[0]
)[0].rel.to._default_manager.in_bulk(
[r["content_object"] for r in qs]
)
else: else:
using_gfk = True using_gfk = True
preload = defaultdict(set) preload = defaultdict(set)
...@@ -186,11 +188,12 @@ class _TaggableManager(models.Manager): ...@@ -186,11 +188,12 @@ class _TaggableManager(models.Manager):
items = {} items = {}
for ct, obj_ids in preload.iteritems(): for ct, obj_ids in preload.iteritems():
ct = ContentType.objects.get_for_id(ct) ct = ContentType.objects.get_for_id(ct)
items[ct.pk] = dict((o.pk, o) for o in items[ct.pk] = ct.model_class()._default_manager.in_bulk(obj_ids)
ct.model_class()._default_manager.filter(pk__in=obj_ids))
results = [] results = []
for result in qs: for result in qs:
# TODO: Consolidate this into dicts keyed by a tuple of the
# (content_type, object_id) instead of the nesting.
if using_gfk: if using_gfk:
obj = items[result["content_type"]][result["object_id"]] obj = items[result["content_type"]][result["object_id"]]
else: else:
...@@ -204,6 +207,7 @@ def _get_subclasses(model): ...@@ -204,6 +207,7 @@ def _get_subclasses(model):
subclasses = [model] subclasses = [model]
for f in model._meta.get_all_field_names(): for f in model._meta.get_all_field_names():
field = model._meta.get_field_by_name(f)[0] field = model._meta.get_field_by_name(f)[0]
if isinstance(field, RelatedObject) and getattr(field.field.rel, "parent_link", None): if (isinstance(field, RelatedObject) and
getattr(field.field.rel, "parent_link", None)):
subclasses.extend(_get_subclasses(field.model)) subclasses.extend(_get_subclasses(field.model))
return subclasses return subclasses
...@@ -40,18 +40,23 @@ class TaggedItemBase(models.Model): ...@@ -40,18 +40,23 @@ class TaggedItemBase(models.Model):
@classmethod @classmethod
def tag_relname(cls): def tag_relname(cls):
return cls._meta.get_field('tag').rel.related_name return cls._meta.get_field_by_name('tag')[0].rel.related_name
@classmethod @classmethod
def lookup_kwargs(cls, instance): def lookup_kwargs(cls, instance):
return {'content_object': instance} return {
'content_object': instance
}
@classmethod @classmethod
def tags_for(cls, model, instance=None): def tags_for(cls, model, instance=None):
if instance is not None: if instance is not None:
return Tag.objects.filter(**{'%s__content_object' % cls.tag_relname(): instance}) return Tag.objects.filter(**{
else: '%s__content_object' % cls.tag_relname(): instance
return Tag.objects.filter(**{'%s__content_object__isnull' % cls.tag_relname(): False}).distinct() })
return Tag.objects.filter(**{
'%s__content_object__isnull' % cls.tag_relname(): False
}).distinct()
class TaggedItem(TaggedItemBase): class TaggedItem(TaggedItemBase):
...@@ -61,15 +66,20 @@ class TaggedItem(TaggedItemBase): ...@@ -61,15 +66,20 @@ class TaggedItem(TaggedItemBase):
@classmethod @classmethod
def lookup_kwargs(cls, instance): def lookup_kwargs(cls, instance):
return {'object_id': instance.pk, return {
'content_type': ContentType.objects.get_for_model(instance)} 'object_id': instance.pk,
'content_type': ContentType.objects.get_for_model(instance)
}
@classmethod @classmethod
def tags_for(cls, model, instance=None): def tags_for(cls, model, instance=None):
ct = ContentType.objects.get_for_model(model) ct = ContentType.objects.get_for_model(model)
if instance is not None: if instance is not None:
return Tag.objects.filter(**{'%s__object_id' % cls.tag_relname(): instance.pk, return Tag.objects.filter(**{
'%s__content_type' % cls.tag_relname(): ct}) '%s__object_id' % cls.tag_relname(): instance.pk,
else: '%s__content_type' % cls.tag_relname(): ct
return Tag.objects.filter(**{'%s__content_type' % cls.tag_relname(): ct}).distinct() })
return Tag.objects.filter(**{
'%s__content_type' % cls.tag_relname(): ct
}).distinct()
...@@ -50,7 +50,11 @@ class TaggableManagerTestCase(BaseTaggingTest): ...@@ -50,7 +50,11 @@ class TaggableManagerTestCase(BaseTaggingTest):
self.assert_tags_equal(apple.tags.all(), ['green', 'red']) self.assert_tags_equal(apple.tags.all(), ['green', 'red'])
self.assert_tags_equal(self.food_model.tags.all(), ['green', 'red']) self.assert_tags_equal(self.food_model.tags.all(), ['green', 'red'])
self.assert_tags_equal(self.food_model.tags.most_common(), ['green', 'red'], sort=False) self.assert_tags_equal(
self.food_model.tags.most_common(),
['green', 'red'],
sort=False
)
apple.tags.remove('green') apple.tags.remove('green')
self.assert_tags_equal(apple.tags.all(), ['red']) self.assert_tags_equal(apple.tags.all(), ['red'])
...@@ -85,14 +89,23 @@ class TaggableManagerTestCase(BaseTaggingTest): ...@@ -85,14 +89,23 @@ class TaggableManagerTestCase(BaseTaggingTest):
pear = self.food_model.objects.create(name="pear") pear = self.food_model.objects.create(name="pear")
pear.tags.add("green") pear.tags.add("green")
self.assertEqual(list(self.food_model.objects.filter(tags__in=["red"])), [apple]) self.assertEqual(
self.assertEqual(list(self.food_model.objects.filter(tags__in=["green"])), [apple, pear]) list(self.food_model.objects.filter(tags__in=["red"])),
[apple]
)
self.assertEqual(
list(self.food_model.objects.filter(tags__in=["green"])),
[apple, pear]
)
kitty = self.pet_model.objects.create(name="kitty") kitty = self.pet_model.objects.create(name="kitty")
kitty.tags.add("fuzzy", "red") kitty.tags.add("fuzzy", "red")
dog = self.pet_model.objects.create(name="dog") dog = self.pet_model.objects.create(name="dog")
dog.tags.add("woof", "red") dog.tags.add("woof", "red")
self.assertEqual(list(self.food_model.objects.filter(tags__in=["red"]).distinct()), [apple]) self.assertEqual(
list(self.food_model.objects.filter(tags__in=["red"]).distinct()),
[apple]
)
tag = Tag.objects.get(name="woof") tag = Tag.objects.get(name="woof")
self.assertEqual(list(self.pet_model.objects.filter(tags__in=[tag])), [dog]) self.assertEqual(list(self.pet_model.objects.filter(tags__in=[tag])), [dog])
...@@ -130,7 +143,8 @@ class TaggableManagerDirectTestCase(TaggableManagerTestCase): ...@@ -130,7 +143,8 @@ class TaggableManagerDirectTestCase(TaggableManagerTestCase):
food_model = DirectFood food_model = DirectFood
pet_model = DirectPet pet_model = DirectPet
housepet_model = DirectHousePet housepet_model = DirectHousePet
class TaggableFormTestCase(BaseTaggingTest): class TaggableFormTestCase(BaseTaggingTest):
form_class = FoodForm form_class = FoodForm
food_model = Food food_model = Food
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment