diff --git a/products/tests.py b/products/tests.py index da69855..64f2474 100644 --- a/products/tests.py +++ b/products/tests.py @@ -128,7 +128,6 @@ class ProductViewSetTest(APITestCase): # Query endpoint response = self.client.post(self.endpoint, data=data, format='json') - import ipdb; ipdb.set_trace() # Assert endpoint returns created status self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -162,7 +161,7 @@ class ProductViewSetTest(APITestCase): 'discount': '0.05', 'stock': 22, 'tags': ['tag1x, tag2x'], - 'category': 'MayorTagCategory2', + # 'category': 'MayorTagCategory2', 'attributes': ['color/blue', 'size/m'], 'identifiers': '34rf34f43c43', } diff --git a/utils/tag_serializers.py b/utils/tag_serializers.py index 6a9a72d..cbf54b9 100644 --- a/utils/tag_serializers.py +++ b/utils/tag_serializers.py @@ -5,7 +5,7 @@ from rest_framework.fields import CharField, ListField from rest_framework.serializers import raise_errors_on_nested_writes from rest_framework.utils import model_meta -from tagulous.models.descriptors import FakeTagRelatedManager +from tagulous.models.managers import TagRelatedManagerMixin from taggit_serializer.serializers import TagListSerializerField, TaggitSerializer @@ -100,8 +100,7 @@ class CustomTagSerializer(serializers.ModelSerializer): if many_to_many: for field_name, value in many_to_many.items(): field = getattr(instance, field_name) - import ipdb; ipdb.set_trace() - if type(field) == "": + if field_name in ('tags', 'category', 'attributes'): for item in value: field.set(item) else: @@ -112,10 +111,34 @@ class CustomTagSerializer(serializers.ModelSerializer): def update(self, instance, validated_data): to_be_tagged, validated_data = self._pop_tags(validated_data) - tag_object = super(CustomTagSerializer, self).update( - instance, validated_data) + raise_errors_on_nested_writes('update', self, validated_data) + info = model_meta.get_field_info(instance) - return self._save_tags(tag_object, to_be_tagged) + # Simply set each attribute on the instance, and then save it. + # Note that unlike `.create()` we don't need to treat many-to-many + # relationships as being a special case. During updates we already + # have an instance pk for the relationships to be associated with. + m2m_fields = [] + for attr, value in validated_data.items(): + if attr in info.relations and info.relations[attr].to_many: + m2m_fields.append((attr, value)) + else: + setattr(instance, attr, value) + + instance.save() + + # Note that many-to-many fields are set after updating instance. + # Setting m2m fields triggers signals which could potentially change + # updated instance and we do not want it to collide with .update() + for attr, value in m2m_fields: + field = getattr(instance, attr) + if attr in ('tags', 'category', 'attributes'): + for item in value: + field.set(item) + else: + field.set(value) + + return self._save_tags(instance, to_be_tagged) def _save_tags(self, tag_object, tags): for key in tags.keys():