diff --git a/products/tests.py b/products/tests.py index 07339f2..cbdff57 100644 --- a/products/tests.py +++ b/products/tests.py @@ -98,8 +98,15 @@ class ProductViewSetTest(APITestCase): def test_anon_user_can_filter_tags(self): # create instances - self.factory(name='product1', tags="zapatos, verdes") - self.factory(name='product2', tags="rojos") + expected_instance = [ + self.factory(name='product1', tags="zapatos, rojos"), + self.factory(name='product2', tags="rojos") + ] + unexpected_instance = [ + self.factory(name='sadfdsa', tags="zapatos, azules"), + self.factory(name='qwerw', tags="xxl") + ] + url = f"{self.endpoint}?tags=rojos" @@ -107,12 +114,10 @@ class ProductViewSetTest(APITestCase): response = self.client.get(url) payload = response.json() - # import ipdb; ipdb.set_trace() - # Assert access is granted self.assertEqual(response.status_code, status.HTTP_200_OK) # Assert number of instnaces in response - self.assertEquals(1, len(payload)) + self.assertEquals(len(expected_instance), len(payload)) # authenticated user diff --git a/products/views.py b/products/views.py index 2b82e88..6521908 100644 --- a/products/views.py +++ b/products/views.py @@ -14,7 +14,6 @@ from rest_framework import viewsets from rest_framework.response import Response from rest_framework.permissions import IsAuthenticatedOrReadOnly, IsAdminUser, IsAuthenticated from rest_framework.decorators import api_view, permission_classes, action -import django_filters.rest_framework import requests @@ -41,7 +40,7 @@ class ProductViewSet(viewsets.ModelViewSet): queryset = Product.objects.all() serializer_class = ProductSerializer permission_classes = [IsAuthenticatedOrReadOnly, IsCreator] - filter_backends = ProductTagFilter + filterset_class = ProductTagFilter filterset_fields = ['name', 'tags'] def perform_create(self, serializer): diff --git a/utils/tag_filters.py b/utils/tag_filters.py index f1f7abd..997dbf9 100644 --- a/utils/tag_filters.py +++ b/utils/tag_filters.py @@ -3,9 +3,15 @@ from products.models import Product class ProductTagFilter(django_filters.FilterSet): - tags = django_filters.CharFilter(field_name='tags.name', lookup_expr='iexact') + tags = django_filters.CharFilter(method='tag_filter') class Meta: model = Product - fields = ['name',] + fields = ['name', 'tags'] + + def tag_filter(self, queryset, name, value): + if name == 'tags': + return queryset.filter(tags=value) + return [] +