diff --git a/back_latienda/permissions.py b/back_latienda/permissions.py index 992f43e..9bb10d0 100644 --- a/back_latienda/permissions.py +++ b/back_latienda/permissions.py @@ -35,6 +35,19 @@ class IsStaff(permissions.BasePermission): return request.user.is_staff +class IsSiteAdmin(permissions.BasePermission): + """ + Grant permission if request.user.role == 'SITE_ADMIN' + """ + + admin_role = 'SITE_ADMIN' + + def has_object_permission(self, request, view, obj): + return request.user.role == self.admin_role + + def has_permission(self, request, view): + return request.user.role == self.admin_role + class ReadOnly(permissions.BasePermission): def has_permission(self, request, view): diff --git a/back_latienda/routers.py b/back_latienda/routers.py index f0c9c26..b734217 100644 --- a/back_latienda/routers.py +++ b/back_latienda/routers.py @@ -1,8 +1,8 @@ from rest_framework import routers from core.views import CustomUserViewSet -from companies.views import CompanyViewSet -from products.views import ProductViewSet +from companies.views import CompanyViewSet, MyCompanyViewSet +from products.views import ProductViewSet, MyProductsViewSet from history.views import HistorySyncViewSet from stats.views import StatsLogViewSet @@ -13,7 +13,9 @@ router = routers.DefaultRouter() router.register('users', CustomUserViewSet, basename='users') router.register('companies', CompanyViewSet, basename='company') +router.register('my_company', MyCompanyViewSet, basename='my-company') router.register('products', ProductViewSet, basename='product') +router.register('my_products', MyProductsViewSet, basename='my-products') router.register('history', HistorySyncViewSet, basename='history') router.register('stats', StatsLogViewSet, basename='stats') diff --git a/back_latienda/urls.py b/back_latienda/urls.py index b60dd0c..99f4276 100644 --- a/back_latienda/urls.py +++ b/back_latienda/urls.py @@ -39,9 +39,9 @@ urlpatterns = [ path('api/v1/search_products/', product_views.product_search, name='product-search'), path('api/v1/create_company_user/', core_views.create_company_user, name='create-company-user'), path('api/v1/my_user/', core_views.my_user, name='my-user'), - path('api/v1/my_company/', company_views.my_company , name='my-company'), + # path('api/v1/my_company/', company_views.my_company , name='my-company'), path('api/v1/companies/sample/', company_views.random_company_sample , name='company-sample'), - path('api/v1/my_products/', product_views.my_products, name='my-products'), + # path('api/v1/my_products/', product_views.my_products, name='my-products'), path('api/v1/stats/me/', stat_views.track_user, name='user-tracker'), path('api/v1/autocomplete/category-tag/', product_views.CategoryTagAutocomplete.as_view(), name='category-autocomplete'), path('api/v1/', include(router.urls)), diff --git a/companies/tests.py b/companies/tests.py index 07189da..167cfb1 100644 --- a/companies/tests.py +++ b/companies/tests.py @@ -311,6 +311,9 @@ class MyCompanyViewTest(APITestCase): self.user.set_password(self.password) self.user.save() + def tearDown(self): + self.model.objects.all().delete() + def test_auth_user_gets_data(self): # create instance user_instances = [self.factory(creator=self.user) for i in range(5)] @@ -346,7 +349,7 @@ class MyCompanyViewTest(APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) # assert only 2 instances in response payload = response.json() - self.assertEquals(2, len(payload)) + self.assertEquals(2, len(payload['results'])) def test_anon_user_cannot_access(self): # send in request diff --git a/companies/views.py b/companies/views.py index c92ab46..eb1cc75 100644 --- a/companies/views.py +++ b/companies/views.py @@ -23,6 +23,8 @@ from back_latienda.permissions import IsCreator from utils import woocommerce + + class CompanyViewSet(viewsets.ModelViewSet): queryset = Company.objects.filter(is_validated=True).order_by('-created') serializer_class = CompanySerializer @@ -155,23 +157,16 @@ class CompanyViewSet(viewsets.ModelViewSet): return Response(message) -@api_view(['GET',]) -@permission_classes([IsAuthenticated,]) -def my_company(request): - limit = request.GET.get('limit') - offset = request.GET.get('offset') - qs = Company.objects.filter(creator=request.user) - company_serializer = CompanySerializer(qs, many=True) - data = company_serializer.data - # RESULTS PAGINATION - if limit is not None and offset is not None: - limit = int(limit) - offset = int(offset) - data = data[offset:(limit+offset)] - elif limit is not None: - limit = int(limit) - data = data[:limit] - return Response(data=data) +class MyCompanyViewSet(viewsets.ModelViewSet): + model = Company + serializer_class = CompanySerializer + permission_classes = [IsAuthenticated] + + def get_queryset(self): + return self.model.objects.filter(creator=self.request.user) + + def perform_create(self, serializer): + serializer.save(creator=self.request.user) @api_view(['GET',]) diff --git a/products/tests.py b/products/tests.py index e654ea3..7aa43de 100644 --- a/products/tests.py +++ b/products/tests.py @@ -946,11 +946,9 @@ class MyProductsViewTest(APITestCase): # Query endpoint response = self.client.get(self.endpoint) payload = response.json() - # Assert forbidden code self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEquals(payload['count'], len(payload['results'])) - self.assertEquals(len(user_instances), payload['count']) + self.assertEquals(len(user_instances), len(payload)) def test_auth_user_can_paginate_instances(self): """authenticated user can paginate instances @@ -970,8 +968,8 @@ class MyProductsViewTest(APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) # assert only 2 instances in response payload = response.json() - self.assertEquals(payload['count'], len(payload['results'])) - self.assertEquals(2, payload['count']) + self.assertEquals(payload['count'], self.model.objects.count()) + self.assertEquals(2, len(payload['results'])) def test_anon_user_cannot_access(self): # send in request diff --git a/products/views.py b/products/views.py index 0febf05..84c2340 100644 --- a/products/views.py +++ b/products/views.py @@ -52,27 +52,16 @@ class ProductViewSet(viewsets.ModelViewSet): return Response(data=[]) -@api_view(['GET',]) -@permission_classes([IsAuthenticated,]) -def my_products(request): - limit = request.GET.get('limit') - offset = request.GET.get('offset') - qs = Product.objects.filter(creator=request.user) - product_serializer = ProductSerializer(qs, many=True) - data = product_serializer.data - # RESULTS PAGINATION - if limit is not None and offset is not None: - limit = int(limit) - offset = int(offset) - data = data[offset:(limit+offset)] - elif limit is not None: - limit = int(limit) - data = data[:limit] - # prepare response payload - payload = {} - payload['results'] = data - payload['count'] = len(payload['results']) - return Response(data=payload) +class MyProductsViewSet(viewsets.ModelViewSet): + model = Product + serializer_class = ProductSerializer + permission_classes = [IsAuthenticated] + + def get_queryset(self): + return self.model.objects.filter(creator=self.request.user) + + def perform_create(self, serializer): + serializer.save(creator=self.request.user) @api_view(['POST',])