Commit 7ee9fd3a by Carl Meyer

fix last failing test; all previous tests pass

parent a16c126b
...@@ -27,6 +27,7 @@ class TaggableRel(ManyToManyRel): ...@@ -27,6 +27,7 @@ class TaggableRel(ManyToManyRel):
class TaggableManager(object): class TaggableManager(object):
def __init__(self, verbose_name="Tags", through=None): def __init__(self, verbose_name="Tags", through=None):
self.use_gfk = through is None
self.through = through or TaggedItem self.through = through or TaggedItem
self.rel = TaggableRel(to=self.through) self.rel = TaggableRel(to=self.through)
self.verbose_name = verbose_name self.verbose_name = verbose_name
...@@ -100,16 +101,16 @@ class TaggableManager(object): ...@@ -100,16 +101,16 @@ class TaggableManager(object):
return None return None
def m2m_reverse_name(self): def m2m_reverse_name(self):
try: if self.use_gfk:
return self.through._meta.get_field('content_object').rel.to._meta.pk.column
except FieldDoesNotExist:
return "id" return "id"
else:
return self.through._meta.get_field('content_object').rel.to._meta.pk.column
def m2m_column_name(self): def m2m_column_name(self):
try: if self.use_gfk:
return self.through._meta.get_field('content_object').column
except FieldDoesNotExist :
return self.through._meta.virtual_fields[0].fk_field return self.through._meta.virtual_fields[0].fk_field
else:
return self.through._meta.get_field('content_object').column
def db_type(self, connection=None): def db_type(self, connection=None):
return None return None
...@@ -117,6 +118,16 @@ class TaggableManager(object): ...@@ -117,6 +118,16 @@ class TaggableManager(object):
def m2m_db_table(self): def m2m_db_table(self):
return self.through._meta.db_table return self.through._meta.db_table
def extra_filters(self, pieces, pos, negate):
if negate or not self.use_gfk:
return []
prefix = "__".join(pieces[:pos+1])
cts = map(ContentType.objects.get_for_model, _get_subclasses(self.model))
if len(cts) == 1:
return [("%s__content_type" % prefix, cts[0])]
return [("%s__content_type__in" % prefix, cts)]
return self.through._meta.db_table
class _TaggableManager(models.Manager): class _TaggableManager(models.Manager):
def __init__(self, through=None): def __init__(self, through=None):
...@@ -183,3 +194,12 @@ class _TaggableManager(models.Manager): ...@@ -183,3 +194,12 @@ class _TaggableManager(models.Manager):
obj.similar_tags = result["n"] obj.similar_tags = result["n"]
results.append(obj) results.append(obj)
return results return results
def _get_subclasses(model):
subclasses = [model]
for f in model._meta.get_all_field_names():
field = model._meta.get_field_by_name(f)[0]
if isinstance(field, RelatedObject) and getattr(field.field.rel, "parent_link", None):
subclasses.extend(_get_subclasses(field.model))
return subclasses
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment