Commit d513882d by Alex Gaynor

Preload related tags.

parent 56e1227f
from collections import defaultdict
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import models from django.db import models
from django.db.models.fields.related import ManyToManyRel from django.db.models.fields.related import ManyToManyRel
...@@ -28,7 +30,7 @@ class TaggableManager(object): ...@@ -28,7 +30,7 @@ class TaggableManager(object):
self.choices = None self.choices = None
self.creation_counter = models.Field.creation_counter self.creation_counter = models.Field.creation_counter
models.Field.creation_counter += 1 models.Field.creation_counter += 1
def __get__(self, instance, type): def __get__(self, instance, type):
manager = _TaggableManager() manager = _TaggableManager()
manager.model = type manager.model = type
...@@ -37,7 +39,7 @@ class TaggableManager(object): ...@@ -37,7 +39,7 @@ class TaggableManager(object):
else: else:
manager.object_id = instance.pk manager.object_id = instance.pk
return manager return manager
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
self.name = self.column = name self.name = self.column = name
self.model = cls self.model = cls
...@@ -46,7 +48,7 @@ class TaggableManager(object): ...@@ -46,7 +48,7 @@ class TaggableManager(object):
def save_form_data(self, instance, value): def save_form_data(self, instance, value):
getattr(instance, self.name).set(*value) getattr(instance, self.name).set(*value)
def get_db_prep_lookup(self, lookup_type, value): def get_db_prep_lookup(self, lookup_type, value):
if lookup_type != "in": if lookup_type != "in":
raise ValueError("You can't do lookups other than \"in\" on Tags") raise ValueError("You can't do lookups other than \"in\" on Tags")
...@@ -63,7 +65,7 @@ class TaggableManager(object): ...@@ -63,7 +65,7 @@ class TaggableManager(object):
raise ValueError("You can't combine Tag objects and strings. '%s' was provided." % value) raise ValueError("You can't combine Tag objects and strings. '%s' was provided." % value)
sql, params = qs.values_list("pk", flat=True).query.as_sql() sql, params = qs.values_list("pk", flat=True).query.as_sql()
return QueryWrapper(("(%s)" % sql), params) return QueryWrapper(("(%s)" % sql), params)
def formfield(self, form_class=TagField, **kwargs): def formfield(self, form_class=TagField, **kwargs):
defaults = { defaults = {
"label": "Tags", "label": "Tags",
...@@ -71,25 +73,25 @@ class TaggableManager(object): ...@@ -71,25 +73,25 @@ class TaggableManager(object):
} }
defaults.update(kwargs) defaults.update(kwargs)
return form_class(**kwargs) return form_class(**kwargs)
def value_from_object(self, instance): def value_from_object(self, instance):
return ", ".join(map(unicode, getattr(instance, self.name).all())) return ", ".join(map(unicode, getattr(instance, self.name).all()))
def related_query_name(self): def related_query_name(self):
return None return None
def m2m_reverse_name(self): def m2m_reverse_name(self):
return "id" return "id"
def m2m_column_name(self): def m2m_column_name(self):
return "object_id" return "object_id"
def db_type(self): def db_type(self):
return None return None
def m2m_db_table(self): def m2m_db_table(self):
return self.rel.to._meta.db_table return self.rel.to._meta.db_table
def extra_filters(self, pieces, pos, negate): def extra_filters(self, pieces, pos, negate):
if negate: if negate:
return [] return []
...@@ -101,45 +103,62 @@ class _TaggableManager(models.Manager): ...@@ -101,45 +103,62 @@ class _TaggableManager(models.Manager):
def get_query_set(self): def get_query_set(self):
ct = ContentType.objects.get_for_model(self.model) ct = ContentType.objects.get_for_model(self.model)
if self.object_id is not None: if self.object_id is not None:
return Tag.objects.filter(items__object_id=self.object_id, return Tag.objects.filter(items__object_id=self.object_id,
items__content_type=ct) items__content_type=ct)
else: else:
return Tag.objects.filter(items__content_type=ct).distinct() return Tag.objects.filter(items__content_type=ct).distinct()
@require_instance_manager @require_instance_manager
def add(self, *tags): def add(self, *tags):
for tag in tags: for tag in tags:
if not isinstance(tag, Tag): if not isinstance(tag, Tag):
tag, _ = Tag.objects.get_or_create(name=tag) tag, _ = Tag.objects.get_or_create(name=tag)
TaggedItem.objects.create(object_id=self.object_id, TaggedItem.objects.create(object_id=self.object_id,
content_type=ContentType.objects.get_for_model(self.model), tag=tag) content_type=ContentType.objects.get_for_model(self.model), tag=tag)
@require_instance_manager @require_instance_manager
def set(self, *tags): def set(self, *tags):
self.clear() self.clear()
self.add(*tags) self.add(*tags)
@require_instance_manager @require_instance_manager
def remove(self, *tags): def remove(self, *tags):
TaggedItem.objects.filter(object_id=self.object_id, TaggedItem.objects.filter(object_id=self.object_id,
content_type=ContentType.objects.get_for_model(self.model)).filter( content_type=ContentType.objects.get_for_model(self.model)).filter(
tag__name__in=tags).delete() tag__name__in=tags).delete()
@require_instance_manager @require_instance_manager
def clear(self): def clear(self):
TaggedItem.objects.filter(object_id=self.object_id, TaggedItem.objects.filter(object_id=self.object_id,
content_type=ContentType.objects.get_for_model(self.model)).delete() content_type=ContentType.objects.get_for_model(self.model)).delete()
def most_common(self): def most_common(self):
return self.get_query_set().annotate( return self.get_query_set().annotate(
num_times=models.Count('items') num_times=models.Count('items')
).order_by('-num_times') ).order_by('-num_times')
@require_instance_manager @require_instance_manager
def similar_objects(self): def similar_objects(self):
return TaggedItem.objects.values('object_id', 'content_type') \ qs = TaggedItem.objects.values('object_id', 'content_type') \
.annotate(models.Count('pk')) \ .annotate(n=models.Count('pk')) \
.exclude(object_id=self.object_id) \ .exclude(object_id=self.object_id) \
.filter(tag__in=self.all()) \ .filter(tag__in=self.all()) \
.order_by('-pk__count') .order_by('-n')
preload = defaultdict(set)
for result in qs:
preload[result["content_type"]].add(result["object_id"])
items = {}
for ct, obj_ids in preload.iteritems():
ct = ContentType.objects.get_for_id(ct)
items[ct.pk] = dict((o.pk, o) for o in
ct.model_class()._default_manager.filter(pk__in=obj_ids)
)
results = []
for result in qs:
obj = items[result["content_type"]][result["object_id"]]
obj.similar_tags = result["n"]
results.append(obj)
return results
...@@ -19,22 +19,22 @@ class AddTagTestCase(BaseTaggingTest): ...@@ -19,22 +19,22 @@ class AddTagTestCase(BaseTaggingTest):
apple = Food.objects.create(name="apple") apple = Food.objects.create(name="apple")
self.assertEqual(list(apple.tags.all()), []) self.assertEqual(list(apple.tags.all()), [])
self.assertEqual(list(Food.tags.all()), []) self.assertEqual(list(Food.tags.all()), [])
apple.tags.add('green') apple.tags.add('green')
self.assert_tags_equal(apple.tags.all(), ['green']) self.assert_tags_equal(apple.tags.all(), ['green'])
self.assert_tags_equal(Food.tags.all(), ['green']) self.assert_tags_equal(Food.tags.all(), ['green'])
pear = Food.objects.create(name="pear") pear = Food.objects.create(name="pear")
pear.tags.add('green') pear.tags.add('green')
self.assert_tags_equal(pear.tags.all(), ['green']) self.assert_tags_equal(pear.tags.all(), ['green'])
self.assert_tags_equal(Food.tags.all(), ['green']) self.assert_tags_equal(Food.tags.all(), ['green'])
apple.tags.add('red') apple.tags.add('red')
self.assert_tags_equal(apple.tags.all(), ['green', 'red']) self.assert_tags_equal(apple.tags.all(), ['green', 'red'])
self.assert_tags_equal(Food.tags.all(), ['green', 'red']) self.assert_tags_equal(Food.tags.all(), ['green', 'red'])
self.assert_tags_equal(Food.tags.most_common(), ['green', 'red'], sort=False) self.assert_tags_equal(Food.tags.most_common(), ['green', 'red'], sort=False)
apple.tags.remove('green') apple.tags.remove('green')
self.assert_tags_equal(apple.tags.all(), ['red']) self.assert_tags_equal(apple.tags.all(), ['red'])
self.assert_tags_equal(Food.tags.all(), ['green', 'red']) self.assert_tags_equal(Food.tags.all(), ['green', 'red'])
...@@ -49,10 +49,10 @@ class LookupByTagTestCase(BaseTaggingTest): ...@@ -49,10 +49,10 @@ class LookupByTagTestCase(BaseTaggingTest):
apple.tags.add("red", "green") apple.tags.add("red", "green")
pear = Food.objects.create(name="pear") pear = Food.objects.create(name="pear")
pear.tags.add("green") pear.tags.add("green")
self.assertEqual(list(Food.objects.filter(tags__in=["red"])), [apple]) self.assertEqual(list(Food.objects.filter(tags__in=["red"])), [apple])
self.assertEqual(list(Food.objects.filter(tags__in=["green"])), [apple, pear]) self.assertEqual(list(Food.objects.filter(tags__in=["green"])), [apple, pear])
kitty = Pet.objects.create(name="kitty") kitty = Pet.objects.create(name="kitty")
kitty.tags.add("fuzzy", "red") kitty.tags.add("fuzzy", "red")
dog = Pet.objects.create(name="dog") dog = Pet.objects.create(name="dog")
...@@ -66,12 +66,12 @@ class LookupByTagTestCase(BaseTaggingTest): ...@@ -66,12 +66,12 @@ class LookupByTagTestCase(BaseTaggingTest):
class TaggableFormTestCase(BaseTaggingTest): class TaggableFormTestCase(BaseTaggingTest):
def test_form(self): def test_form(self):
self.assertEqual(FoodForm.base_fields.keys(), ['name', 'tags']) self.assertEqual(FoodForm.base_fields.keys(), ['name', 'tags'])
f = FoodForm({'name': 'apple', 'tags': 'green, red, yummy'}) f = FoodForm({'name': 'apple', 'tags': 'green, red, yummy'})
f.save() f.save()
apple = Food.objects.get(name='apple') apple = Food.objects.get(name='apple')
self.assert_tags_equal(apple.tags.all(), ['green', 'red', 'yummy']) self.assert_tags_equal(apple.tags.all(), ['green', 'red', 'yummy'])
f = FoodForm({'name': 'apple', 'tags': 'green, red, yummy, delicious'}, instance=apple) f = FoodForm({'name': 'apple', 'tags': 'green, red, yummy, delicious'}, instance=apple)
f.save() f.save()
apple = Food.objects.get(name='apple') apple = Food.objects.get(name='apple')
...@@ -90,6 +90,6 @@ class SimilarityByTagTestCase(BaseTaggingTest): ...@@ -90,6 +90,6 @@ class SimilarityByTagTestCase(BaseTaggingTest):
watermelon = Food.objects.create(name="watermelon") watermelon = Food.objects.create(name="watermelon")
watermelon.tags.add("green", "juicy", "large", "sweet") watermelon.tags.add("green", "juicy", "large", "sweet")
self.assertEqual(apple.tags.similar_objects(), similar_objs = apple.tags.similar_objects()
[{'pk__count': 3, 'content_type': 13, 'object_id': 6}, {'pk__count': 2, 'content_type': 13, 'object_id': 7}]) self.assertEqual(similar_objs, [pear, watermelon])
self.assertEqual(map(lambda x: x.similar_tags, similar_objs), [3, 2])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment