Commit e364191b by Alex Gaynor

Optimize TaggableManager.add, it should do fewer queries now.

parent 9229a50a
...@@ -126,9 +126,21 @@ class _TaggableManager(models.Manager): ...@@ -126,9 +126,21 @@ class _TaggableManager(models.Manager):
@require_instance_manager @require_instance_manager
def add(self, *tags): def add(self, *tags):
for tag in tags: str_tags = set([
if not isinstance(tag, self.through.tag_model()): t
tag, _ = self.through.tag_model().objects.get_or_create(name=tag) for t in tags
if not isinstance(t, self.through.tag_model())
])
tag_objs = set(tags) - str_tags
existing = self.through.tag_model().objects.filter(
name__in=str_tags
)
tag_objs.update(existing)
for new_tag in str_tags - set(t.name for t in existing):
tag_objs.add(self.through.tag_model().objects.create(name=new_tag))
for tag in tag_objs:
self.through.objects.get_or_create(tag=tag, **self._lookup_kwargs()) self.through.objects.get_or_create(tag=tag, **self._lookup_kwargs())
@require_instance_manager @require_instance_manager
......
from unittest import TestCase as UnitTestCase from unittest import TestCase as UnitTestCase
from django.conf import settings
from django.db import connection
from django.test import TestCase, TransactionTestCase from django.test import TestCase, TransactionTestCase
from taggit.models import Tag, TaggedItem from taggit.models import Tag, TaggedItem
...@@ -19,6 +21,19 @@ class BaseTaggingTest(object): ...@@ -19,6 +21,19 @@ class BaseTaggingTest(object):
got.sort() got.sort()
tags.sort() tags.sort()
self.assertEqual(got, tags) self.assertEqual(got, tags)
def assert_num_queries(self, n, f, *args, **kwargs):
original_DEBUG = settings.DEBUG
settings.DEBUG = True
current = len(connection.queries)
try:
f(*args, **kwargs)
self.assertEqual(
len(connection.queries) - current,
n,
)
finally:
settings.DEBUG = original_DEBUG
class BaseTaggingTestCase(TestCase, BaseTaggingTest): class BaseTaggingTestCase(TestCase, BaseTaggingTest):
pass pass
...@@ -96,6 +111,20 @@ class TaggableManagerTestCase(BaseTaggingTestCase): ...@@ -96,6 +111,20 @@ class TaggableManagerTestCase(BaseTaggingTestCase):
apple.delete() apple.delete()
self.assert_tags_equal(self.food_model.tags.all(), ["green"]) self.assert_tags_equal(self.food_model.tags.all(), ["green"])
def test_add_queries(self):
apple = self.food_model.objects.create(name="apple")
# 1 query to see which tags exist
# + 3 queries to create the tags.
# + 6 queries to create the intermediary things (including SELECTs, to
# make sure we don't double create.
self.assert_num_queries(10, apple.tags.add, "red", "delicious", "green")
pear = self.food_model.objects.create(name="pear")
# 1 query to see which tags exist
# + 4 queries to create the intermeidary things (including SELECTs, to
# make sure we dont't double create.
self.assert_num_queries(5, pear.tags.add, "green", "delicious")
def test_require_pk(self): def test_require_pk(self):
food_instance = self.food_model() food_instance = self.food_model()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment