fixes to the my_ views, and their tests

This commit is contained in:
Sam
2021-02-12 12:03:16 +00:00
parent 0a61cea599
commit c4faa89a99
7 changed files with 32 additions and 7 deletions

View File

@@ -2,7 +2,7 @@ from rest_framework import serializers
from companies.models import Company from companies.models import Company
from drf_extra_fields.geo_fields import PointField from drf_extra_fields.geo_fields import PointField
from tagulous.serializers.json import Serializer as TagSerializer
from utils.tag_serializers import TagListSerializerField, TaggitSerializer from utils.tag_serializers import TagListSerializerField, TaggitSerializer
class CompanySerializer(TaggitSerializer, serializers.ModelSerializer): class CompanySerializer(TaggitSerializer, serializers.ModelSerializer):

View File

@@ -259,14 +259,23 @@ class MyCompanyViewTest(APITestCase):
self.user.save() self.user.save()
def test_auth_user_gets_data(self): def test_auth_user_gets_data(self):
# create instance
user_instances = [
self.factory(creator=self.user),
self.factory(creator=self.user),
]
# Authenticate # Authenticate
token = get_tokens_for_user(self.user) token = get_tokens_for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# Query endpoint # Query endpoint
response = self.client.get(self.endpoint) response = self.client.get(self.endpoint)
payload = response.json()
# Assert forbidden code # Assert forbidden code
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEquals(len(user_instances), len(payload))
def test_anon_user_cannot_access(self): def test_anon_user_cannot_access(self):
# send in request # send in request

View File

@@ -125,5 +125,5 @@ class CompanyViewSet(viewsets.ModelViewSet):
@permission_classes([IsAuthenticated,]) @permission_classes([IsAuthenticated,])
def my_company(request): def my_company(request):
qs = Company.objects.filter(creator=request.user) qs = Company.objects.filter(creator=request.user)
company_serializer = CompanySerializer(qs) company_serializer = CompanySerializer(qs, many=True)
return Response(data=company_serializer.data) return Response(data=company_serializer.data)

View File

@@ -521,8 +521,11 @@ class MyUserViewTest(APITestCase):
# Query endpoint # Query endpoint
response = self.client.get(self.endpoint) response = self.client.get(self.endpoint)
payload = response.json()
# Assert forbidden code # Assert forbidden code
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEquals(self.email, payload['email'])
def test_anon_user_cannot_access(self): def test_anon_user_cannot_access(self):
# send in request # send in request

View File

@@ -128,7 +128,7 @@ def create_company_user(request):
} }
try: try:
user = models.CustomUser.objects.create(email=user_data['email']) user = models.CustomUser.objects.create(email=user_data['email'])
except IntegrityError as e: except IntegrityError as e:
return Response({"errors": {"details": str(e)}}, status=status.HTTP_409_CONFLICT) return Response({"errors": {"details": str(e)}}, status=status.HTTP_409_CONFLICT)
try: try:
@@ -159,9 +159,13 @@ def create_company_user(request):
@api_view(['GET',]) @api_view(['GET',])
@permission_classes([IsAuthenticated,]) @permission_classes([IsAuthenticated,])
def my_user(request): def my_user(request):
qs = User.objects.filter(email=request.user.email) try:
user_serializer = core_serializers.CustomUserReadSerializer(qs, many=True) instance = User.objects.get(email=request.user.email)
return Response(data=user_serializer.data) user_serializer = core_serializers.CustomUserReadSerializer(instance)
return Response(data=user_serializer.data)
except Exception as e:
return Response({'error': {str(type(e))}}, status=500)
@api_view(['POST',]) @api_view(['POST',])

View File

@@ -500,14 +500,23 @@ class MyProductsViewTest(APITestCase):
self.user.save() self.user.save()
def test_auth_user_gets_data(self): def test_auth_user_gets_data(self):
# create instance
user_instances = [
self.factory(creator=self.user),
self.factory(creator=self.user),
]
# Authenticate # Authenticate
token = get_tokens_for_user(self.user) token = get_tokens_for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# Query endpoint # Query endpoint
response = self.client.get(self.endpoint) response = self.client.get(self.endpoint)
payload = response.json()
# Assert forbidden code # Assert forbidden code
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEquals(len(user_instances), len(payload))
def test_anon_user_cannot_access(self): def test_anon_user_cannot_access(self):
# send in request # send in request

View File

@@ -56,7 +56,7 @@ class ProductViewSet(viewsets.ModelViewSet):
@permission_classes([IsAuthenticated,]) @permission_classes([IsAuthenticated,])
def my_products(request): def my_products(request):
qs = Product.objects.filter(creator=request.user) qs = Product.objects.filter(creator=request.user)
product_serializer = ProductSerializer(qs) product_serializer = ProductSerializer(qs, many=True)
return Response(data=product_serializer.data) return Response(data=product_serializer.data)