diff --git a/products/serializers.py b/products/serializers.py index d39b689..42e3d8a 100644 --- a/products/serializers.py +++ b/products/serializers.py @@ -3,14 +3,14 @@ from rest_framework import serializers from taggit_serializer.serializers import TagListSerializerField, TaggitSerializer from products.models import Product -from utils.tag_serializers import SingleTagSerializerField +from utils.tag_serializers import SingleTagSerializerField, CustomTagSerializer -class ProductSerializer(TaggitSerializer, serializers.ModelSerializer): +class ProductSerializer(CustomTagSerializer): - tags = TagListSerializerField( ) - category = SingleTagSerializerField() # main tag category - attributes = TagListSerializerField() + tags = TagListSerializerField(required=False) + category = SingleTagSerializerField(required=False) # main tag category + attributes = TagListSerializerField(required=False) class Meta: model = Product diff --git a/products/tests.py b/products/tests.py index c8f964e..64f2474 100644 --- a/products/tests.py +++ b/products/tests.py @@ -116,9 +116,9 @@ class ProductViewSetTest(APITestCase): 'update_date': datetime.datetime.now().isoformat()+'Z', 'discount': '0.05', 'stock': 22, - # tags = models.ManyToMany(Tag, null=True, blank=True ) - # category = models.ForeignKey(Tag, null=true) # main tag category - # attributes = models.ManyToMany(Tag, null=True, blank=True ) + 'tags': ['tag1, tag2'], + # 'category': 'MayorTagCategory', + 'attributes': ['color/red', 'size/xxl'], 'identifiers': '34rf34f43c43', } @@ -160,9 +160,9 @@ class ProductViewSetTest(APITestCase): 'update_date': datetime.datetime.now().isoformat()+'Z', 'discount': '0.05', 'stock': 22, - # tags = models.ManyToMany(Tag, null=True, blank=True ) - # category = models.ForeignKey(Tag, null=true) # main tag category - # attributes = models.ManyToMany(Tag, null=True, blank=True ) + 'tags': ['tag1x, tag2x'], + # 'category': 'MayorTagCategory2', + 'attributes': ['color/blue', 'size/m'], 'identifiers': '34rf34f43c43', } diff --git a/utils/tag_serializers.py b/utils/tag_serializers.py index b5893a6..cbf54b9 100644 --- a/utils/tag_serializers.py +++ b/utils/tag_serializers.py @@ -1,5 +1,13 @@ -from rest_framework import serializers +import traceback +from rest_framework import serializers +from rest_framework.fields import CharField, ListField +from rest_framework.serializers import raise_errors_on_nested_writes +from rest_framework.utils import model_meta + +from tagulous.models.managers import TagRelatedManagerMixin + +from taggit_serializer.serializers import TagListSerializerField, TaggitSerializer class SingleTag(str): @@ -37,3 +45,114 @@ class SingleTagSerializerField(serializers.Field): value = value.name value = SingleTag(value) return value + + +class CustomTagSerializer(serializers.ModelSerializer): + """ + Differentiate between tags and single-tags + """ + + def __init__(self, instance=None, data='', **kwargs): + self.serializer_field_mapping[SingleTagSerializerField] = CharField + self.serializer_field_mapping[TagListSerializerField] = ListField + super(CustomTagSerializer, self).__init__(instance, data, **kwargs) + + def create(self, validated_data): + to_be_tagged, validated_data = self._pop_tags(validated_data) + + # tag_object = super(CustomTagSerializer, self).create(validated_data) + raise_errors_on_nested_writes('create', self, validated_data) + + ModelClass = self.Meta.model + + # Remove many-to-many relationships from validated_data. + # They are not valid arguments to the default `.create()` method, + # as they require that the instance has already been saved. + info = model_meta.get_field_info(ModelClass) + many_to_many = {} + for field_name, relation_info in info.relations.items(): + if relation_info.to_many and (field_name in validated_data): + many_to_many[field_name] = validated_data.pop(field_name) + + try: + instance = ModelClass._default_manager.create(**validated_data) + except TypeError: + tb = traceback.format_exc() + msg = ( + 'Got a `TypeError` when calling `%s.%s.create()`. ' + 'This may be because you have a writable field on the ' + 'serializer class that is not a valid argument to ' + '`%s.%s.create()`. You may need to make the field ' + 'read-only, or override the %s.create() method to handle ' + 'this correctly.\nOriginal exception was:\n %s' % + ( + ModelClass.__name__, + ModelClass._default_manager.name, + ModelClass.__name__, + ModelClass._default_manager.name, + self.__class__.__name__, + tb + ) + ) + raise TypeError(msg) + + # Save many-to-many relationships after the instance is created. + if many_to_many: + for field_name, value in many_to_many.items(): + field = getattr(instance, field_name) + if field_name in ('tags', 'category', 'attributes'): + for item in value: + field.set(item) + else: + field.set(value) + + return self._save_tags(instance, to_be_tagged) + + def update(self, instance, validated_data): + to_be_tagged, validated_data = self._pop_tags(validated_data) + + raise_errors_on_nested_writes('update', self, validated_data) + info = model_meta.get_field_info(instance) + + # Simply set each attribute on the instance, and then save it. + # Note that unlike `.create()` we don't need to treat many-to-many + # relationships as being a special case. During updates we already + # have an instance pk for the relationships to be associated with. + m2m_fields = [] + for attr, value in validated_data.items(): + if attr in info.relations and info.relations[attr].to_many: + m2m_fields.append((attr, value)) + else: + setattr(instance, attr, value) + + instance.save() + + # Note that many-to-many fields are set after updating instance. + # Setting m2m fields triggers signals which could potentially change + # updated instance and we do not want it to collide with .update() + for attr, value in m2m_fields: + field = getattr(instance, attr) + if attr in ('tags', 'category', 'attributes'): + for item in value: + field.set(item) + else: + field.set(value) + + return self._save_tags(instance, to_be_tagged) + + def _save_tags(self, tag_object, tags): + for key in tags.keys(): + tag_values = tags.get(key) + getattr(tag_object, key).set(*tag_values) + + return tag_object + + def _pop_tags(self, validated_data): + to_be_tagged = {} + for key in self.serializer_field_mapping.keys(): + field = self.serializer_field_mapping[key] + if isinstance(field, TagListSerializerField): + if key in validated_data: + to_be_tagged[key] = validated_data.pop(key) + + return (to_be_tagged, validated_data)