diff --git a/companies/serializers.py b/companies/serializers.py index 9ff7fa8..b0ca1ce 100644 --- a/companies/serializers.py +++ b/companies/serializers.py @@ -1,7 +1,8 @@ from rest_framework import serializers -from companies.models import Company -from utils.tag_serializers import TagListSerializerField, TaggitSerializer +from taggit_serializer.serializers import TagListSerializerField, TaggitSerializer + +from companies.models import Company class CompanySerializer(TaggitSerializer, serializers.ModelSerializer): diff --git a/products/serializers.py b/products/serializers.py index e7d2fcf..ddbc05b 100644 --- a/products/serializers.py +++ b/products/serializers.py @@ -1,8 +1,10 @@ from rest_framework import serializers -from products.models import Product +from taggit_serializer.serializers import TagListSerializerField, TaggitSerializer -from utils.tag_serializers import TagListSerializerField, TaggitSerializer, SingleTagSerializerField +from utils.tag_serializers import SingleTagSerializerField + +from products.models import Product class ProductSerializer(TaggitSerializer, serializers.ModelSerializer): diff --git a/utils/tag_serializers.py b/utils/tag_serializers.py index 9f74c72..70b0277 100644 --- a/utils/tag_serializers.py +++ b/utils/tag_serializers.py @@ -7,30 +7,6 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers -class TagList(list): - def __init__(self, *args, **kwargs): - pretty_print = kwargs.pop("pretty_print", True) - list.__init__(self, *args, **kwargs) - self.pretty_print = pretty_print - - def __add__(self, rhs): - return TagList(list.__add__(self, rhs)) - - def __getitem__(self, item): - result = list.__getitem__(self, item) - try: - return TagList(result) - except TypeError: - return result - - def __str__(self): - if self.pretty_print: - return json.dumps( - self, sort_keys=True, indent=4, separators=(',', ': ')) - else: - return json.dumps(self) - - class SingleTag(str): def __init__(self, *args, **kwargs): @@ -69,92 +45,3 @@ class SingleTagSerializerField(serializers.Field): value = SingleTag(value) return value - -class TagListSerializerField(serializers.Field): - child = serializers.CharField() - default_error_messages = { - 'not_a_list': _( - 'Expected a list of items but got type "{input_type}".'), - 'invalid_json': _('Invalid json list. A tag list submitted in string' - ' form must be valid json.'), - 'not_a_str': _('All list items must be of string type.') - } - order_by = None - - def __init__(self, **kwargs): - pretty_print = kwargs.pop("pretty_print", True) - - style = kwargs.pop("style", {}) - kwargs["style"] = {'base_template': 'textarea.html'} - kwargs["style"].update(style) - - super(TagListSerializerField, self).__init__(**kwargs) - - self.pretty_print = pretty_print - - def to_internal_value(self, value): - if isinstance(value, six.string_types): - if not value: - value = "[]" - try: - value = json.loads(value) - except ValueError: - self.fail('invalid_json') - - if not isinstance(value, list): - self.fail('not_a_list', input_type=type(value).__name__) - - for s in value: - if not isinstance(s, six.string_types): - self.fail('not_a_str') - - self.child.run_validation(s) - - return value - - def to_representation(self, value): - if not isinstance(value, TagList): - if not isinstance(value, list): - if self.order_by: - tags = value.all().order_by(*self.order_by) - else: - tags = value.all() - value = [tag.name for tag in tags] - value = TagList(value, pretty_print=self.pretty_print) - - return value - - -class TaggitSerializer(serializers.Serializer): - def create(self, validated_data): - to_be_tagged, validated_data = self._pop_tags(validated_data) - - tag_object = super(TaggitSerializer, self).create(validated_data) - - return self._save_tags(tag_object, to_be_tagged) - - def update(self, instance, validated_data): - to_be_tagged, validated_data = self._pop_tags(validated_data) - - tag_object = super(TaggitSerializer, self).update( - instance, validated_data) - - return self._save_tags(tag_object, to_be_tagged) - - def _save_tags(self, tag_object, tags): - for key in tags.keys(): - tag_values = tags.get(key) - getattr(tag_object, key).set(*tag_values) - - return tag_object - - def _pop_tags(self, validated_data): - to_be_tagged = {} - - for key in self.fields.keys(): - field = self.fields[key] - if isinstance(field, TagListSerializerField): - if key in validated_data: - to_be_tagged[key] = validated_data.pop(key) - - return (to_be_tagged, validated_data)