Commit 4f2e47f8 by idan

Added support for prefetch_related on tags fields

parent 45952169
from __future__ import unicode_literals
from operator import attrgetter
from django import VERSION
from django.contrib.contenttypes.generic import GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.db import models, router
from django.db.models.fields import Field
from django.db.models.fields.related import ManyToManyRel, RelatedField, add_lazy_relation
from django.db.models.related import RelatedObject
......@@ -82,7 +83,7 @@ class TaggableManager(RelatedField, Field):
raise ValueError("%s objects need to have a primary key value "
"before you can access their tags." % model.__name__)
manager = _TaggableManager(
through=self.through, model=model, instance=instance
through=self.through, model=model, instance=instance, prefetch_cache_name = self.name
)
return manager
......@@ -314,13 +315,41 @@ class TaggableManager(RelatedField, Field):
class _TaggableManager(models.Manager):
def __init__(self, through, model, instance):
def __init__(self, through, model, instance, prefetch_cache_name):
self.through = through
self.model = model
self.instance = instance
self.prefetch_cache_name = prefetch_cache_name
self._db = None
def is_cached(self, instance):
return self.prefetch_cache_name in instance._prefetched_objects_cache
def get_query_set(self):
return self.through.tags_for(self.model, self.instance)
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
except (AttributeError, KeyError):
return self.through.tags_for(self.model, self.instance)
def get_prefetch_query_set(self, instances):
instance = instances[0]
from django.db import connections
db = self._db or router.db_for_read(instance.__class__, instance=instance)
fk = self.through._meta.get_field('object_id' if issubclass(self.through, GenericTaggedItemBase) else 'content_object')
query = {
'%s__%s__in' % (self.through.tag_relname(), fk.name) : set(obj._get_pk_val() for obj in instances)
}
join_table = self.through._meta.db_table
source_col = fk.column
connection = connections[db]
qn = connection.ops.quote_name
qs = self.get_query_set().using(db)._next_is_sticky().filter(**query).extra(select = { '_prefetch_related_val' : '%s.%s' % (qn(join_table), qn(source_col))})
return (qs,
attrgetter('_prefetch_related_val'),
attrgetter(instance._meta.pk.name),
False,
self.prefetch_cache_name)
# Django 1.6 renamed this
get_queryset = get_query_set
......
......@@ -341,6 +341,20 @@ class TaggableManagerTestCase(BaseTaggingTestCase):
apple = self.food_model.objects.create(name="apple")
serializers.serialize("json", (apple,))
def test_prefetch_related(self):
apple = self.food_model.objects.create(name="apple")
apple.tags.add('1', '2')
orange = self.food_model.objects.create(name="orange")
orange.tags.add('2', '4')
with self.assertNumQueries(2):
l = list(self.food_model.objects.prefetch_related('tags').all())
with self.assertNumQueries(0):
foods = {f.name : set(t.name for t in f.tags.all()) for f in l}
self.assertEqual(foods, {
u'orange': {'2', '4'},
u'apple': {'1', '2'}
})
class TaggableManagerDirectTestCase(TaggableManagerTestCase):
food_model = DirectFood
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment