diff --git a/companies/serializers.py b/companies/serializers.py index 231aea9..7f6ed4c 100644 --- a/companies/serializers.py +++ b/companies/serializers.py @@ -2,7 +2,7 @@ from rest_framework import serializers from companies.models import Company from drf_extra_fields.geo_fields import PointField - +from tagulous.serializers.json import Serializer as TagSerializer from utils.tag_serializers import TagListSerializerField, TaggitSerializer class CompanySerializer(TaggitSerializer, serializers.ModelSerializer): diff --git a/companies/tests.py b/companies/tests.py index fe81bfb..51596ac 100644 --- a/companies/tests.py +++ b/companies/tests.py @@ -259,14 +259,23 @@ class MyCompanyViewTest(APITestCase): self.user.save() def test_auth_user_gets_data(self): + # create instance + user_instances = [ + self.factory(creator=self.user), + self.factory(creator=self.user), + ] + # Authenticate token = get_tokens_for_user(self.user) self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") # Query endpoint response = self.client.get(self.endpoint) + payload = response.json() + # Assert forbidden code self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEquals(len(user_instances), len(payload)) def test_anon_user_cannot_access(self): # send in request diff --git a/companies/views.py b/companies/views.py index 845e8a8..2460b0e 100644 --- a/companies/views.py +++ b/companies/views.py @@ -125,5 +125,5 @@ class CompanyViewSet(viewsets.ModelViewSet): @permission_classes([IsAuthenticated,]) def my_company(request): qs = Company.objects.filter(creator=request.user) - company_serializer = CompanySerializer(qs) + company_serializer = CompanySerializer(qs, many=True) return Response(data=company_serializer.data) diff --git a/core/tests.py b/core/tests.py index a4273c3..10312d3 100644 --- a/core/tests.py +++ b/core/tests.py @@ -521,8 +521,11 @@ class MyUserViewTest(APITestCase): # Query endpoint response = self.client.get(self.endpoint) + payload = response.json() + # Assert forbidden code self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEquals(self.email, payload['email']) def test_anon_user_cannot_access(self): # send in request diff --git a/core/views.py b/core/views.py index de3955e..a882eab 100644 --- a/core/views.py +++ b/core/views.py @@ -128,7 +128,7 @@ def create_company_user(request): } try: user = models.CustomUser.objects.create(email=user_data['email']) - except IntegrityError as e: + except IntegrityError as e: return Response({"errors": {"details": str(e)}}, status=status.HTTP_409_CONFLICT) try: @@ -159,9 +159,13 @@ def create_company_user(request): @api_view(['GET',]) @permission_classes([IsAuthenticated,]) def my_user(request): - qs = User.objects.filter(email=request.user.email) - user_serializer = core_serializers.CustomUserReadSerializer(qs, many=True) - return Response(data=user_serializer.data) + try: + instance = User.objects.get(email=request.user.email) + user_serializer = core_serializers.CustomUserReadSerializer(instance) + return Response(data=user_serializer.data) + except Exception as e: + return Response({'error': {str(type(e))}}, status=500) + @api_view(['POST',]) diff --git a/products/tests.py b/products/tests.py index 9f9d471..46497f3 100644 --- a/products/tests.py +++ b/products/tests.py @@ -500,14 +500,23 @@ class MyProductsViewTest(APITestCase): self.user.save() def test_auth_user_gets_data(self): + # create instance + user_instances = [ + self.factory(creator=self.user), + self.factory(creator=self.user), + ] + # Authenticate token = get_tokens_for_user(self.user) self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") # Query endpoint response = self.client.get(self.endpoint) + payload = response.json() + # Assert forbidden code self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEquals(len(user_instances), len(payload)) def test_anon_user_cannot_access(self): # send in request diff --git a/products/views.py b/products/views.py index 790f785..45d765a 100644 --- a/products/views.py +++ b/products/views.py @@ -56,7 +56,7 @@ class ProductViewSet(viewsets.ModelViewSet): @permission_classes([IsAuthenticated,]) def my_products(request): qs = Product.objects.filter(creator=request.user) - product_serializer = ProductSerializer(qs) + product_serializer = ProductSerializer(qs, many=True) return Response(data=product_serializer.data)