Commit a38ee9d8 by Florian Apolloner

Merge pull request #213 from nicholasserra/nicholasserra-pluggable-manager

Allow pluggable manager for TaggableManager.
parents 1682de06 a00fbc0a
......@@ -69,6 +69,16 @@ playing around with the API.
>>> apple.tags.slugs()
[u'green-and-juicy', u'red']
.. hint::
You can subclass ``_TaggableManager`` (note the underscore) to add
methods or functionality. ``TaggableManager`` takes an optional
manager keyword argument for your custom class, like this::
class Food(models.Model):
# ... fields here
tags = TaggableManager(manager=_CustomTaggableManager)
Filtering
~~~~~~~~~
......
......@@ -72,21 +72,166 @@ class ExtraJoinRestriction(object):
return self.__class__(self.alias, self.col, self.content_types[:])
class _TaggableManager(models.Manager):
def __init__(self, through, model, instance, prefetch_cache_name):
self.through = through
self.model = model
self.instance = instance
self.prefetch_cache_name = prefetch_cache_name
self._db = None
def is_cached(self, instance):
return self.prefetch_cache_name in instance._prefetched_objects_cache
def get_queryset(self):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
except (AttributeError, KeyError):
return self.through.tags_for(self.model, self.instance)
def get_prefetch_queryset(self, instances, queryset=None):
if queryset is not None:
raise ValueError("Custom queryset can't be used for this lookup.")
instance = instances[0]
from django.db import connections
db = self._db or router.db_for_read(instance.__class__, instance=instance)
fieldname = ('object_id' if issubclass(self.through, GenericTaggedItemBase)
else 'content_object')
fk = self.through._meta.get_field(fieldname)
query = {
'%s__%s__in' % (self.through.tag_relname(), fk.name) :
set(obj._get_pk_val() for obj in instances)
}
join_table = self.through._meta.db_table
source_col = fk.column
connection = connections[db]
qn = connection.ops.quote_name
qs = self.get_queryset().using(db)._next_is_sticky().filter(**query).extra(
select = {
'_prefetch_related_val' : '%s.%s' % (qn(join_table), qn(source_col))
}
)
return (qs,
attrgetter('_prefetch_related_val'),
attrgetter(instance._meta.pk.name),
False,
self.prefetch_cache_name)
# Django < 1.6 uses the previous name of query_set
get_query_set = get_queryset
get_prefetch_query_set = get_prefetch_queryset
def _lookup_kwargs(self):
return self.through.lookup_kwargs(self.instance)
@require_instance_manager
def add(self, *tags):
str_tags = set([
t
for t in tags
if not isinstance(t, self.through.tag_model())
])
tag_objs = set(tags) - str_tags
# If str_tags has 0 elements Django actually optimizes that to not do a
# query. Malcolm is very smart.
existing = self.through.tag_model().objects.filter(
name__in=str_tags
)
tag_objs.update(existing)
for new_tag in str_tags - set(t.name for t in existing):
tag_objs.add(self.through.tag_model().objects.create(name=new_tag))
for tag in tag_objs:
self.through.objects.get_or_create(tag=tag, **self._lookup_kwargs())
@require_instance_manager
def names(self):
return self.get_queryset().values_list('name', flat=True)
@require_instance_manager
def slugs(self):
return self.get_queryset().values_list('slug', flat=True)
@require_instance_manager
def set(self, *tags):
self.clear()
self.add(*tags)
@require_instance_manager
def remove(self, *tags):
self.through.objects.filter(**self._lookup_kwargs()).filter(
tag__name__in=tags).delete()
@require_instance_manager
def clear(self):
self.through.objects.filter(**self._lookup_kwargs()).delete()
def most_common(self):
return self.get_queryset().annotate(
num_times=models.Count(self.through.tag_relname())
).order_by('-num_times')
@require_instance_manager
def similar_objects(self):
lookup_kwargs = self._lookup_kwargs()
lookup_keys = sorted(lookup_kwargs)
qs = self.through.objects.values(*six.iterkeys(lookup_kwargs))
qs = qs.annotate(n=models.Count('pk'))
qs = qs.exclude(**lookup_kwargs)
qs = qs.filter(tag__in=self.all())
qs = qs.order_by('-n')
# TODO: This all feels like a bit of a hack.
items = {}
if len(lookup_keys) == 1:
# Can we do this without a second query by using a select_related()
# somehow?
f = self.through._meta.get_field_by_name(lookup_keys[0])[0]
objs = f.rel.to._default_manager.filter(**{
"%s__in" % f.rel.field_name: [r["content_object"] for r in qs]
})
for obj in objs:
items[(getattr(obj, f.rel.field_name),)] = obj
else:
preload = {}
for result in qs:
preload.setdefault(result['content_type'], set())
preload[result["content_type"]].add(result["object_id"])
for ct, obj_ids in preload.items():
ct = ContentType.objects.get_for_id(ct)
for obj in ct.model_class()._default_manager.filter(pk__in=obj_ids):
items[(ct.pk, obj.pk)] = obj
results = []
for result in qs:
obj = items[
tuple(result[k] for k in lookup_keys)
]
obj.similar_tags = result["n"]
results.append(obj)
return results
class TaggableManager(RelatedField, Field):
_related_name_counter = 0
def __init__(self, verbose_name=_("Tags"), help_text=_("A comma-separated list of tags."),
through=None, blank=False, related_name=None):
through=None, blank=False, related_name=None, manager=_TaggableManager):
Field.__init__(self, verbose_name=verbose_name, help_text=help_text, blank=blank, null=True, serialize=False)
self.through = through or TaggedItem
self.rel = TaggableRel(self, related_name, self.through)
self.swappable = False
self.manager = manager
def __get__(self, instance, model):
if instance is not None and instance.pk is None:
raise ValueError("%s objects need to have a primary key value "
"before you can access their tags." % model.__name__)
manager = _TaggableManager(
manager = self.manager(
through=self.through,
model=model,
instance=instance,
......@@ -288,150 +433,6 @@ class TaggableManager(RelatedField, Field):
return [self.related_fields[0][1]]
class _TaggableManager(models.Manager):
def __init__(self, through, model, instance, prefetch_cache_name):
self.through = through
self.model = model
self.instance = instance
self.prefetch_cache_name = prefetch_cache_name
self._db = None
def is_cached(self, instance):
return self.prefetch_cache_name in instance._prefetched_objects_cache
def get_queryset(self):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
except (AttributeError, KeyError):
return self.through.tags_for(self.model, self.instance)
def get_prefetch_queryset(self, instances, queryset=None):
if queryset is not None:
raise ValueError("Custom queryset can't be used for this lookup.")
instance = instances[0]
from django.db import connections
db = self._db or router.db_for_read(instance.__class__, instance=instance)
fieldname = ('object_id' if issubclass(self.through, GenericTaggedItemBase)
else 'content_object')
fk = self.through._meta.get_field(fieldname)
query = {
'%s__%s__in' % (self.through.tag_relname(), fk.name) :
set(obj._get_pk_val() for obj in instances)
}
join_table = self.through._meta.db_table
source_col = fk.column
connection = connections[db]
qn = connection.ops.quote_name
qs = self.get_queryset().using(db)._next_is_sticky().filter(**query).extra(
select = {
'_prefetch_related_val' : '%s.%s' % (qn(join_table), qn(source_col))
}
)
return (qs,
attrgetter('_prefetch_related_val'),
attrgetter(instance._meta.pk.name),
False,
self.prefetch_cache_name)
# Django < 1.6 uses the previous name of query_set
get_query_set = get_queryset
get_prefetch_query_set = get_prefetch_queryset
def _lookup_kwargs(self):
return self.through.lookup_kwargs(self.instance)
@require_instance_manager
def add(self, *tags):
str_tags = set([
t
for t in tags
if not isinstance(t, self.through.tag_model())
])
tag_objs = set(tags) - str_tags
# If str_tags has 0 elements Django actually optimizes that to not do a
# query. Malcolm is very smart.
existing = self.through.tag_model().objects.filter(
name__in=str_tags
)
tag_objs.update(existing)
for new_tag in str_tags - set(t.name for t in existing):
tag_objs.add(self.through.tag_model().objects.create(name=new_tag))
for tag in tag_objs:
self.through.objects.get_or_create(tag=tag, **self._lookup_kwargs())
@require_instance_manager
def names(self):
return self.get_queryset().values_list('name', flat=True)
@require_instance_manager
def slugs(self):
return self.get_queryset().values_list('slug', flat=True)
@require_instance_manager
def set(self, *tags):
self.clear()
self.add(*tags)
@require_instance_manager
def remove(self, *tags):
self.through.objects.filter(**self._lookup_kwargs()).filter(
tag__name__in=tags).delete()
@require_instance_manager
def clear(self):
self.through.objects.filter(**self._lookup_kwargs()).delete()
def most_common(self):
return self.get_queryset().annotate(
num_times=models.Count(self.through.tag_relname())
).order_by('-num_times')
@require_instance_manager
def similar_objects(self):
lookup_kwargs = self._lookup_kwargs()
lookup_keys = sorted(lookup_kwargs)
qs = self.through.objects.values(*six.iterkeys(lookup_kwargs))
qs = qs.annotate(n=models.Count('pk'))
qs = qs.exclude(**lookup_kwargs)
qs = qs.filter(tag__in=self.all())
qs = qs.order_by('-n')
# TODO: This all feels like a bit of a hack.
items = {}
if len(lookup_keys) == 1:
# Can we do this without a second query by using a select_related()
# somehow?
f = self.through._meta.get_field_by_name(lookup_keys[0])[0]
objs = f.rel.to._default_manager.filter(**{
"%s__in" % f.rel.field_name: [r["content_object"] for r in qs]
})
for obj in objs:
items[(getattr(obj, f.rel.field_name),)] = obj
else:
preload = {}
for result in qs:
preload.setdefault(result['content_type'], set())
preload[result["content_type"]].add(result["object_id"])
for ct, obj_ids in preload.items():
ct = ContentType.objects.get_for_id(ct)
for obj in ct.model_class()._default_manager.filter(pk__in=obj_ids):
items[(ct.pk, obj.pk)] = obj
results = []
for result in qs:
obj = items[
tuple(result[k] for k in lookup_keys)
]
obj.similar_tags = result["n"]
results.append(obj)
return results
def _get_subclasses(model):
subclasses = [model]
for f in model._meta.get_all_field_names():
......
......@@ -184,3 +184,11 @@ class Article(models.Model):
title = models.CharField(max_length=100)
tags = TaggableManager(through=ArticleTaggedItem)
class CustomManager(models.Model):
class Foo(object):
def __init__(*args, **kwargs):
pass
tags = TaggableManager(manager=Foo)
......@@ -13,14 +13,14 @@ from django.utils.encoding import force_text
from django.contrib.contenttypes.models import ContentType
from taggit.managers import TaggableManager, _model_name
from taggit.managers import TaggableManager, _TaggableManager, _model_name
from taggit.models import Tag, TaggedItem
from .forms import (FoodForm, DirectFoodForm, CustomPKFoodForm,
OfficialFoodForm)
from .models import (Food, Pet, HousePet, DirectFood, DirectPet,
DirectHousePet, TaggedPet, CustomPKFood, CustomPKPet, CustomPKHousePet,
TaggedCustomPKPet, OfficialFood, OfficialPet, OfficialHousePet,
OfficialThroughModel, OfficialTag, Photo, Movie, Article)
OfficialThroughModel, OfficialTag, Photo, Movie, Article, CustomManager)
from taggit.utils import parse_tags, edit_string_for_tags
......@@ -355,7 +355,6 @@ class TaggableManagerTestCase(BaseTaggingTestCase):
'apple': set(['1', '2'])
})
class TaggableManagerDirectTestCase(TaggableManagerTestCase):
food_model = DirectFood
pet_model = DirectPet
......@@ -391,6 +390,16 @@ class TaggableManagerOfficialTestCase(TaggableManagerTestCase):
self.assertEqual(apple, self.food_model.objects.get(tags__official=False))
class TaggableManagerInitializationTestCase(TaggableManagerTestCase):
"""Make sure manager override defaults and sets correctly."""
food_model = Food
custom_manager_model = CustomManager
def test_default_manager(self):
self.assertEqual(self.food_model.tags.__class__, _TaggableManager)
def test_custom_manager(self):
self.assertEqual(self.custom_manager_model.tags.__class__, CustomManager.Foo)
class TaggableFormTestCase(BaseTaggingTestCase):
form_class = FoodForm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment