# -*- coding: utf-8 -*- import json # Third party import six from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers class TagList(list): def __init__(self, *args, **kwargs): pretty_print = kwargs.pop("pretty_print", True) list.__init__(self, *args, **kwargs) self.pretty_print = pretty_print def __add__(self, rhs): return TagList(list.__add__(self, rhs)) def __getitem__(self, item): result = list.__getitem__(self, item) try: return TagList(result) except TypeError: return result def __str__(self): if self.pretty_print: return json.dumps( self, sort_keys=True, indent=4, separators=(',', ': ')) else: return json.dumps(self) class SingleTag(str): def __init__(self, *args, **kwargs): pass def __str__(self): return json.dumps(self) class SingleTagSerializerField(serializers.Field): child = serializers.CharField() default_error_messages = { 'invalid_json': _('Invalid json str. A tag list submitted in string' ' form must be valid json.'), 'not_a_str': _('Expected a string but got type "{input_type}".') } order_by = None def __init__(self, **kwargs): super(SingleTagSerializerField, self).__init__(**kwargs) def to_internal_value(self, value): if isinstance(value, six.string_types): if not value: value = "" if not isinstance(value, str): self.fail('not_a_str', input_type=type(value).__name__) return value def to_representation(self, value): if not isinstance(value, SingleTag): if not isinstance(value, str): value = value.name value = SingleTag(value) return value class TagListSerializerField(serializers.Field): child = serializers.CharField() default_error_messages = { 'not_a_list': _( 'Expected a list of items but got type "{input_type}".'), 'invalid_json': _('Invalid json list. A tag list submitted in string' ' form must be valid json.'), 'not_a_str': _('All list items must be of string type.') } order_by = None def __init__(self, **kwargs): pretty_print = kwargs.pop("pretty_print", True) style = kwargs.pop("style", {}) kwargs["style"] = {'base_template': 'textarea.html'} kwargs["style"].update(style) super(TagListSerializerField, self).__init__(**kwargs) self.pretty_print = pretty_print def to_internal_value(self, value): if isinstance(value, six.string_types): if not value: value = "[]" try: value = json.loads(value) except ValueError: self.fail('invalid_json') if not isinstance(value, list): self.fail('not_a_list', input_type=type(value).__name__) for s in value: if not isinstance(s, six.string_types): self.fail('not_a_str') self.child.run_validation(s) return value def to_representation(self, value): if not isinstance(value, TagList): if not isinstance(value, list): if self.order_by: tags = value.all().order_by(*self.order_by) else: tags = value.all() value = [tag.name for tag in tags] value = TagList(value, pretty_print=self.pretty_print) return value class TaggitSerializer(serializers.Serializer): def create(self, validated_data): to_be_tagged, validated_data = self._pop_tags(validated_data) tag_object = super(TaggitSerializer, self).create(validated_data) return self._save_tags(tag_object, to_be_tagged) def update(self, instance, validated_data): to_be_tagged, validated_data = self._pop_tags(validated_data) tag_object = super(TaggitSerializer, self).update( instance, validated_data) return self._save_tags(tag_object, to_be_tagged) def _save_tags(self, tag_object, tags): for key in tags.keys(): tag_values = tags.get(key) getattr(tag_object, key).set(*tag_values) return tag_object def _pop_tags(self, validated_data): to_be_tagged = {} for key in self.fields.keys(): field = self.fields[key] if isinstance(field, TagListSerializerField): if key in validated_data: to_be_tagged[key] = validated_data.pop(key) return (to_be_tagged, validated_data)