diff --git a/back_latienda/routers.py b/back_latienda/routers.py index bad6b0c..83385ca 100644 --- a/back_latienda/routers.py +++ b/back_latienda/routers.py @@ -1,7 +1,7 @@ from rest_framework import routers from core.views import CustomUserViewSet -from companies.views import CompanyViewSet, MyCompanyViewSet, AdminCompanyViewSet +from companies.views import CompanyViewSet, AdminCompanyViewSet from products.views import ProductViewSet, MyProductsViewSet, AdminProductsViewSet from history.views import HistorySyncViewSet from stats.views import StatsLogViewSet @@ -13,7 +13,6 @@ router = routers.DefaultRouter() router.register('users', CustomUserViewSet, basename='users') router.register('companies', CompanyViewSet, basename='company') -router.register('my_company', MyCompanyViewSet, basename='my-company') router.register('admin_companies', AdminCompanyViewSet, basename='admin-companies') router.register('products', ProductViewSet, basename='product') router.register('my_products', MyProductsViewSet, basename='my-products') diff --git a/back_latienda/urls.py b/back_latienda/urls.py index 91431ad..c6743e3 100644 --- a/back_latienda/urls.py +++ b/back_latienda/urls.py @@ -39,6 +39,7 @@ urlpatterns = [ path('api/v1/search_products/', product_views.product_search, name='product-search'), path('api/v1/create_company_user/', core_views.create_company_user, name='create-company-user'), path('api/v1/my_user/', core_views.my_user, name='my-user'), + path('api/v1/my_company/', company_views.my_company, name='my-company'), path('api/v1/companies/sample/', company_views.random_company_sample , name='company-sample'), path('api/v1/purchase_email/', product_views.purchase_email, name='purchase-email'), path('api/v1/stats/me/', stat_views.track_user, name='user-tracker'), diff --git a/companies/tests.py b/companies/tests.py index 8473b02..1ad084f 100644 --- a/companies/tests.py +++ b/companies/tests.py @@ -312,12 +312,11 @@ class MyCompanyViewTest(APITestCase): self.user.set_password(self.password) self.user.save() - def tearDown(self): - self.model.objects.all().delete() - def test_auth_user_gets_data(self): # create instance - user_instances = [self.factory(creator=self.user) for i in range(5)] + company = CompanyFactory() + self.user.company = company + self.user.save() # Authenticate token = get_tokens_for_user(self.user) @@ -325,32 +324,10 @@ class MyCompanyViewTest(APITestCase): # Query endpoint response = self.client.get(self.endpoint) - payload = response.json() - # Assert forbidden code self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEquals(len(user_instances), len(payload)) - - def test_auth_user_can_paginate_instances(self): - """authenticated user can paginate instances - """ - - # Authenticate - token = get_tokens_for_user(self.user) - self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") - - # create instances - instances = [self.factory(creator=self.user) for n in range(12)] - - # Request list - url = f"{self.endpoint}?limit=5&offset=10" - response = self.client.get(url) - - # Assert access is allowed - self.assertEqual(response.status_code, status.HTTP_200_OK) - # assert only 2 instances in response payload = response.json() - self.assertEquals(2, len(payload['results'])) + self.assertEquals(payload['company']['id'], company.id) def test_anon_user_cannot_access(self): # send in request diff --git a/companies/views.py b/companies/views.py index 5133d43..010c4b4 100644 --- a/companies/views.py +++ b/companies/views.py @@ -155,6 +155,16 @@ class CompanyViewSet(viewsets.ModelViewSet): return Response(message) +@api_view(['GET']) +@permission_classes([IsAuthenticated]) +def my_company(request): + if request.user.company: + serializer = CompanySerializer(request.user.company) + return Response({'company': serializer.data}) + else: + return Response(status=status.HTTP_406_NOT_ACCEPTABLE) + +''' class MyCompanyViewSet(viewsets.ModelViewSet): model = Company serializer_class = CompanySerializer @@ -165,7 +175,7 @@ class MyCompanyViewSet(viewsets.ModelViewSet): def perform_create(self, serializer): serializer.save(creator=self.request.user) - +''' class AdminCompanyViewSet(viewsets.ModelViewSet): """ Allows user with role 'SITE_ADMIN' to access all company instances diff --git a/core/serializers.py b/core/serializers.py index e71c71f..6a2039e 100644 --- a/core/serializers.py +++ b/core/serializers.py @@ -23,7 +23,7 @@ class CustomUserWriteSerializer(serializers.ModelSerializer): class Meta: model = models.CustomUser - fields = ('email', 'full_name', 'role', 'password', 'provider') + fields = ('email', 'full_name', 'role', 'password', 'provider', 'notify') class CreatorSerializer(serializers.ModelSerializer): diff --git a/products/tests.py b/products/tests.py index f503677..194fee3 100644 --- a/products/tests.py +++ b/products/tests.py @@ -952,9 +952,13 @@ class MyProductsViewTest(APITestCase): def test_auth_user_gets_data(self): # create instance + company = CompanyFactory() + self.user.company = company + self.user.save() + user_instances = [ - self.factory(creator=self.user), - self.factory(creator=self.user), + self.factory(company=company), + self.factory(company=company), ] # Authenticate @@ -976,7 +980,11 @@ class MyProductsViewTest(APITestCase): self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") # create instances - instances = [self.factory(creator=self.user) for n in range(12)] + company = CompanyFactory() + self.user.company = company + self.user.save() + + instances = [self.factory(company=company) for n in range(12)] # Request list url = f"{self.endpoint}?limit=5&offset=10" diff --git a/products/views.py b/products/views.py index 228f6e1..4e132ae 100644 --- a/products/views.py +++ b/products/views.py @@ -72,10 +72,7 @@ class MyProductsViewSet(viewsets.ModelViewSet): permission_classes = [IsAuthenticated] def get_queryset(self): - return self.model.objects.filter(creator=self.request.user).order_by('-created') - - def perform_create(self, serializer): - serializer.save(creator=self.request.user) + return self.model.objects.filter(company=self.request.user.company).order_by('-created') class AdminProductsViewSet(viewsets.ModelViewSet):