diff --git a/companies/serializers.py b/companies/serializers.py index 231aea9..7f6ed4c 100644 --- a/companies/serializers.py +++ b/companies/serializers.py @@ -2,7 +2,7 @@ from rest_framework import serializers from companies.models import Company from drf_extra_fields.geo_fields import PointField - +from tagulous.serializers.json import Serializer as TagSerializer from utils.tag_serializers import TagListSerializerField, TaggitSerializer class CompanySerializer(TaggitSerializer, serializers.ModelSerializer): diff --git a/companies/tests.py b/companies/tests.py index fe81bfb..a247b4b 100644 --- a/companies/tests.py +++ b/companies/tests.py @@ -29,8 +29,8 @@ class CompanyViewSetTest(APITestCase): self.password = ''.join(random.choices(string.ascii_uppercase, k = 10)) self.user = CustomUserFactory(email="test@mail.com", password=self.password, is_active=True) - # user not authenticated - def test_not_logged_user_cannot_create_instance(self): + # anonymous user + def test_anon_user_cannot_create_instance(self): """Not logged-in user cannot create new instance """ instances = [self.factory() for n in range(random.randint(1,5))] @@ -40,7 +40,7 @@ class CompanyViewSetTest(APITestCase): # Assert access is forbidden self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_not_logged_user_cannot_modify_existing_instance(self): + def test_anon_user_cannot_modify_existing_instance(self): """Not logged-in user cannot modify existing instance """ # Create instance @@ -53,7 +53,7 @@ class CompanyViewSetTest(APITestCase): # Assert forbidden code self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_not_logged_user_cannot_delete_existing_instance(self): + def test_anon_user_cannot_delete_existing_instance(self): """Not logged-in user cannot delete existing instance """ # Create instances @@ -67,7 +67,7 @@ class CompanyViewSetTest(APITestCase): # Assert instance still exists on db self.assertTrue(self.model.objects.get(id=instance.pk)) - def test_not_logged_user_can_list_instance(self): + def test_anon_user_can_list_instance(self): """Not logged-in user can list instance """ # Request list @@ -76,6 +76,28 @@ class CompanyViewSetTest(APITestCase): # Assert access is forbidden self.assertEqual(response.status_code, status.HTTP_200_OK) + def test_anon_user_can_filter_tags(self): + # create instances + expected_instance = [ + self.factory(tags='ropa'), + self.factory(tags='tejidos, ropa') + ] + unexpected_instance = [ + self.factory(tags="zapatos, azules"), + self.factory(tags="xxl") + ] + # prepare url + url = f"{self.endpoint}?tags=ropa" + + # Request list + response = self.client.get(url) + payload = response.json() + + # Assert access is granted + self.assertEqual(response.status_code, status.HTTP_200_OK) + # Assert number of instnaces in response + self.assertEquals(len(expected_instance), len(payload)) + # authenticated user def test_logged_user_can_list_instance(self): """Regular logged-in user can list instance @@ -241,6 +263,7 @@ class CompanyViewSetTest(APITestCase): self.assertFalse(self.model.objects.filter(id=instance.pk).exists()) + class MyCompanyViewTest(APITestCase): """CompanyViewset tests """ @@ -259,14 +282,23 @@ class MyCompanyViewTest(APITestCase): self.user.save() def test_auth_user_gets_data(self): + # create instance + user_instances = [ + self.factory(creator=self.user), + self.factory(creator=self.user), + ] + # Authenticate token = get_tokens_for_user(self.user) self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") # Query endpoint response = self.client.get(self.endpoint) + payload = response.json() + # Assert forbidden code self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEquals(len(user_instances), len(payload)) def test_anon_user_cannot_access(self): # send in request diff --git a/companies/views.py b/companies/views.py index a659776..a2114ae 100644 --- a/companies/views.py +++ b/companies/views.py @@ -16,7 +16,7 @@ from ipware import get_client_ip from stats.models import StatsLog from companies.models import Company from companies.serializers import CompanySerializer - +from utils.tag_filters import CompanyTagFilter from back_latienda.permissions import IsCreator @@ -24,6 +24,7 @@ class CompanyViewSet(viewsets.ModelViewSet): queryset = Company.objects.all() serializer_class = CompanySerializer permission_classes = [IsAuthenticatedOrReadOnly, IsCreator] + filterset_class = CompanyTagFilter def perform_create(self, serializer): serializer.save(creator=self.request.user) @@ -124,6 +125,6 @@ class CompanyViewSet(viewsets.ModelViewSet): @api_view(['GET',]) @permission_classes([IsAuthenticated,]) def my_company(request): - qs = Company.objects.filter(creator=request.user).first() - company_serializer = CompanySerializer(qs) + qs = Company.objects.filter(creator=request.user) + company_serializer = CompanySerializer(qs, many=True) return Response(data=company_serializer.data) diff --git a/core/tests.py b/core/tests.py index a4273c3..10312d3 100644 --- a/core/tests.py +++ b/core/tests.py @@ -521,8 +521,11 @@ class MyUserViewTest(APITestCase): # Query endpoint response = self.client.get(self.endpoint) + payload = response.json() + # Assert forbidden code self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEquals(self.email, payload['email']) def test_anon_user_cannot_access(self): # send in request diff --git a/core/views.py b/core/views.py index de3955e..a882eab 100644 --- a/core/views.py +++ b/core/views.py @@ -128,7 +128,7 @@ def create_company_user(request): } try: user = models.CustomUser.objects.create(email=user_data['email']) - except IntegrityError as e: + except IntegrityError as e: return Response({"errors": {"details": str(e)}}, status=status.HTTP_409_CONFLICT) try: @@ -159,9 +159,13 @@ def create_company_user(request): @api_view(['GET',]) @permission_classes([IsAuthenticated,]) def my_user(request): - qs = User.objects.filter(email=request.user.email) - user_serializer = core_serializers.CustomUserReadSerializer(qs, many=True) - return Response(data=user_serializer.data) + try: + instance = User.objects.get(email=request.user.email) + user_serializer = core_serializers.CustomUserReadSerializer(instance) + return Response(data=user_serializer.data) + except Exception as e: + return Response({'error': {str(type(e))}}, status=500) + @api_view(['POST',]) diff --git a/products/tests.py b/products/tests.py index 9f9d471..46497f3 100644 --- a/products/tests.py +++ b/products/tests.py @@ -500,14 +500,23 @@ class MyProductsViewTest(APITestCase): self.user.save() def test_auth_user_gets_data(self): + # create instance + user_instances = [ + self.factory(creator=self.user), + self.factory(creator=self.user), + ] + # Authenticate token = get_tokens_for_user(self.user) self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") # Query endpoint response = self.client.get(self.endpoint) + payload = response.json() + # Assert forbidden code self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEquals(len(user_instances), len(payload)) def test_anon_user_cannot_access(self): # send in request diff --git a/products/views.py b/products/views.py index 790f785..45d765a 100644 --- a/products/views.py +++ b/products/views.py @@ -56,7 +56,7 @@ class ProductViewSet(viewsets.ModelViewSet): @permission_classes([IsAuthenticated,]) def my_products(request): qs = Product.objects.filter(creator=request.user) - product_serializer = ProductSerializer(qs) + product_serializer = ProductSerializer(qs, many=True) return Response(data=product_serializer.data) diff --git a/utils/tag_filters.py b/utils/tag_filters.py index 7394ac0..3cc8502 100644 --- a/utils/tag_filters.py +++ b/utils/tag_filters.py @@ -1,7 +1,23 @@ import django_filters + +from companies.models import Company from products.models import Product +class CompanyTagFilter(django_filters.FilterSet): + + tags = django_filters.CharFilter(method='tag_filter') + + class Meta: + model = Company + fields = ['tags', 'city',] + + def tag_filter(self, queryset, name, value): + return queryset.filter(**{ + name: value, + }) + + class ProductTagFilter(django_filters.FilterSet): tags = django_filters.CharFilter(method='tag_filter') @@ -16,3 +32,4 @@ class ProductTagFilter(django_filters.FilterSet): return queryset.filter(**{ name: value, }) +