Commit 4f2e47f8 by idan

Added support for prefetch_related on tags fields

parent 45952169
from __future__ import unicode_literals from __future__ import unicode_literals
from operator import attrgetter
from django import VERSION from django import VERSION
from django.contrib.contenttypes.generic import GenericRelation from django.contrib.contenttypes.generic import GenericRelation
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import models from django.db import models, router
from django.db.models.fields import Field from django.db.models.fields import Field
from django.db.models.fields.related import ManyToManyRel, RelatedField, add_lazy_relation from django.db.models.fields.related import ManyToManyRel, RelatedField, add_lazy_relation
from django.db.models.related import RelatedObject from django.db.models.related import RelatedObject
...@@ -82,7 +83,7 @@ class TaggableManager(RelatedField, Field): ...@@ -82,7 +83,7 @@ class TaggableManager(RelatedField, Field):
raise ValueError("%s objects need to have a primary key value " raise ValueError("%s objects need to have a primary key value "
"before you can access their tags." % model.__name__) "before you can access their tags." % model.__name__)
manager = _TaggableManager( manager = _TaggableManager(
through=self.through, model=model, instance=instance through=self.through, model=model, instance=instance, prefetch_cache_name = self.name
) )
return manager return manager
...@@ -314,14 +315,42 @@ class TaggableManager(RelatedField, Field): ...@@ -314,14 +315,42 @@ class TaggableManager(RelatedField, Field):
class _TaggableManager(models.Manager): class _TaggableManager(models.Manager):
def __init__(self, through, model, instance): def __init__(self, through, model, instance, prefetch_cache_name):
self.through = through self.through = through
self.model = model self.model = model
self.instance = instance self.instance = instance
self.prefetch_cache_name = prefetch_cache_name
self._db = None
def is_cached(self, instance):
return self.prefetch_cache_name in instance._prefetched_objects_cache
def get_query_set(self): def get_query_set(self):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
except (AttributeError, KeyError):
return self.through.tags_for(self.model, self.instance) return self.through.tags_for(self.model, self.instance)
def get_prefetch_query_set(self, instances):
instance = instances[0]
from django.db import connections
db = self._db or router.db_for_read(instance.__class__, instance=instance)
fk = self.through._meta.get_field('object_id' if issubclass(self.through, GenericTaggedItemBase) else 'content_object')
query = {
'%s__%s__in' % (self.through.tag_relname(), fk.name) : set(obj._get_pk_val() for obj in instances)
}
join_table = self.through._meta.db_table
source_col = fk.column
connection = connections[db]
qn = connection.ops.quote_name
qs = self.get_query_set().using(db)._next_is_sticky().filter(**query).extra(select = { '_prefetch_related_val' : '%s.%s' % (qn(join_table), qn(source_col))})
return (qs,
attrgetter('_prefetch_related_val'),
attrgetter(instance._meta.pk.name),
False,
self.prefetch_cache_name)
# Django 1.6 renamed this # Django 1.6 renamed this
get_queryset = get_query_set get_queryset = get_query_set
......
...@@ -341,6 +341,20 @@ class TaggableManagerTestCase(BaseTaggingTestCase): ...@@ -341,6 +341,20 @@ class TaggableManagerTestCase(BaseTaggingTestCase):
apple = self.food_model.objects.create(name="apple") apple = self.food_model.objects.create(name="apple")
serializers.serialize("json", (apple,)) serializers.serialize("json", (apple,))
def test_prefetch_related(self):
apple = self.food_model.objects.create(name="apple")
apple.tags.add('1', '2')
orange = self.food_model.objects.create(name="orange")
orange.tags.add('2', '4')
with self.assertNumQueries(2):
l = list(self.food_model.objects.prefetch_related('tags').all())
with self.assertNumQueries(0):
foods = {f.name : set(t.name for t in f.tags.all()) for f in l}
self.assertEqual(foods, {
u'orange': {'2', '4'},
u'apple': {'1', '2'}
})
class TaggableManagerDirectTestCase(TaggableManagerTestCase): class TaggableManagerDirectTestCase(TaggableManagerTestCase):
food_model = DirectFood food_model = DirectFood
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment