import traceback from rest_framework import serializers from rest_framework.fields import CharField, ListField from rest_framework.serializers import raise_errors_on_nested_writes from rest_framework.utils import model_meta from tagulous.models.managers import TagRelatedManagerMixin from taggit_serializer.serializers import TagListSerializerField, TaggitSerializer class SingleTag(str): def __init__(self, *args, **kwargs): pass def __str__(self): return self class SingleTagSerializerField(serializers.Field): child = serializers.CharField() default_error_messages = { 'not_a_str': 'Expected a string but got type "{input_type}".', } order_by = None def __init__(self, **kwargs): super(SingleTagSerializerField, self).__init__(**kwargs) def to_internal_value(self, value): if isinstance(value, str): if not value: value = "" if not isinstance(value, str): self.fail('not_a_str', input_type=type(value).__name__) return value def to_representation(self, value): if not isinstance(value, SingleTag): if not isinstance(value, str): value = value.name value = SingleTag(value) return value class CustomTagSerializer(serializers.ModelSerializer): """ Differentiate between tags and single-tags """ def __init__(self, instance=None, data='', **kwargs): self.serializer_field_mapping[SingleTagSerializerField] = CharField self.serializer_field_mapping[TagListSerializerField] = ListField super(CustomTagSerializer, self).__init__(instance, data, **kwargs) def create(self, validated_data): to_be_tagged, validated_data = self._pop_tags(validated_data) # tag_object = super(CustomTagSerializer, self).create(validated_data) raise_errors_on_nested_writes('create', self, validated_data) ModelClass = self.Meta.model # Remove many-to-many relationships from validated_data. # They are not valid arguments to the default `.create()` method, # as they require that the instance has already been saved. info = model_meta.get_field_info(ModelClass) many_to_many = {} for field_name, relation_info in info.relations.items(): if relation_info.to_many and (field_name in validated_data): many_to_many[field_name] = validated_data.pop(field_name) try: instance = ModelClass._default_manager.create(**validated_data) except TypeError: tb = traceback.format_exc() msg = ( 'Got a `TypeError` when calling `%s.%s.create()`. ' 'This may be because you have a writable field on the ' 'serializer class that is not a valid argument to ' '`%s.%s.create()`. You may need to make the field ' 'read-only, or override the %s.create() method to handle ' 'this correctly.\nOriginal exception was:\n %s' % ( ModelClass.__name__, ModelClass._default_manager.name, ModelClass.__name__, ModelClass._default_manager.name, self.__class__.__name__, tb ) ) raise TypeError(msg) # Save many-to-many relationships after the instance is created. if many_to_many: for field_name, value in many_to_many.items(): field = getattr(instance, field_name) if field_name in ('tags', 'category', 'attributes'): for item in value: field.set(item) else: field.set(value) return self._save_tags(instance, to_be_tagged) def update(self, instance, validated_data): to_be_tagged, validated_data = self._pop_tags(validated_data) raise_errors_on_nested_writes('update', self, validated_data) info = model_meta.get_field_info(instance) # Simply set each attribute on the instance, and then save it. # Note that unlike `.create()` we don't need to treat many-to-many # relationships as being a special case. During updates we already # have an instance pk for the relationships to be associated with. m2m_fields = [] for attr, value in validated_data.items(): if attr in info.relations and info.relations[attr].to_many: m2m_fields.append((attr, value)) else: setattr(instance, attr, value) instance.save() # Note that many-to-many fields are set after updating instance. # Setting m2m fields triggers signals which could potentially change # updated instance and we do not want it to collide with .update() for attr, value in m2m_fields: field = getattr(instance, attr) if attr in ('tags', 'category', 'attributes'): for item in value: field.set(item) else: field.set(value) return self._save_tags(instance, to_be_tagged) def _save_tags(self, tag_object, tags): for key in tags.keys(): tag_values = tags.get(key) getattr(tag_object, key).set(*tag_values) return tag_object def _pop_tags(self, validated_data): to_be_tagged = {} for key in self.serializer_field_mapping.keys(): field = self.serializer_field_mapping[key] if isinstance(field, TagListSerializerField): if key in validated_data: to_be_tagged[key] = validated_data.pop(key) return (to_be_tagged, validated_data)