switched my_company and my_products from method to class views

This commit is contained in:
Sam
2021-03-08 12:24:44 +00:00
parent ebd1b6744c
commit ec27937b85
7 changed files with 48 additions and 48 deletions

View File

@@ -35,6 +35,19 @@ class IsStaff(permissions.BasePermission):
return request.user.is_staff return request.user.is_staff
class IsSiteAdmin(permissions.BasePermission):
"""
Grant permission if request.user.role == 'SITE_ADMIN'
"""
admin_role = 'SITE_ADMIN'
def has_object_permission(self, request, view, obj):
return request.user.role == self.admin_role
def has_permission(self, request, view):
return request.user.role == self.admin_role
class ReadOnly(permissions.BasePermission): class ReadOnly(permissions.BasePermission):
def has_permission(self, request, view): def has_permission(self, request, view):

View File

@@ -1,8 +1,8 @@
from rest_framework import routers from rest_framework import routers
from core.views import CustomUserViewSet from core.views import CustomUserViewSet
from companies.views import CompanyViewSet from companies.views import CompanyViewSet, MyCompanyViewSet
from products.views import ProductViewSet from products.views import ProductViewSet, MyProductsViewSet
from history.views import HistorySyncViewSet from history.views import HistorySyncViewSet
from stats.views import StatsLogViewSet from stats.views import StatsLogViewSet
@@ -13,7 +13,9 @@ router = routers.DefaultRouter()
router.register('users', CustomUserViewSet, basename='users') router.register('users', CustomUserViewSet, basename='users')
router.register('companies', CompanyViewSet, basename='company') router.register('companies', CompanyViewSet, basename='company')
router.register('my_company', MyCompanyViewSet, basename='my-company')
router.register('products', ProductViewSet, basename='product') router.register('products', ProductViewSet, basename='product')
router.register('my_products', MyProductsViewSet, basename='my-products')
router.register('history', HistorySyncViewSet, basename='history') router.register('history', HistorySyncViewSet, basename='history')
router.register('stats', StatsLogViewSet, basename='stats') router.register('stats', StatsLogViewSet, basename='stats')

View File

@@ -39,9 +39,9 @@ urlpatterns = [
path('api/v1/search_products/', product_views.product_search, name='product-search'), path('api/v1/search_products/', product_views.product_search, name='product-search'),
path('api/v1/create_company_user/', core_views.create_company_user, name='create-company-user'), path('api/v1/create_company_user/', core_views.create_company_user, name='create-company-user'),
path('api/v1/my_user/', core_views.my_user, name='my-user'), path('api/v1/my_user/', core_views.my_user, name='my-user'),
path('api/v1/my_company/', company_views.my_company , name='my-company'), # path('api/v1/my_company/', company_views.my_company , name='my-company'),
path('api/v1/companies/sample/', company_views.random_company_sample , name='company-sample'), path('api/v1/companies/sample/', company_views.random_company_sample , name='company-sample'),
path('api/v1/my_products/', product_views.my_products, name='my-products'), # path('api/v1/my_products/', product_views.my_products, name='my-products'),
path('api/v1/stats/me/', stat_views.track_user, name='user-tracker'), path('api/v1/stats/me/', stat_views.track_user, name='user-tracker'),
path('api/v1/autocomplete/category-tag/', product_views.CategoryTagAutocomplete.as_view(), name='category-autocomplete'), path('api/v1/autocomplete/category-tag/', product_views.CategoryTagAutocomplete.as_view(), name='category-autocomplete'),
path('api/v1/', include(router.urls)), path('api/v1/', include(router.urls)),

View File

@@ -311,6 +311,9 @@ class MyCompanyViewTest(APITestCase):
self.user.set_password(self.password) self.user.set_password(self.password)
self.user.save() self.user.save()
def tearDown(self):
self.model.objects.all().delete()
def test_auth_user_gets_data(self): def test_auth_user_gets_data(self):
# create instance # create instance
user_instances = [self.factory(creator=self.user) for i in range(5)] user_instances = [self.factory(creator=self.user) for i in range(5)]
@@ -346,7 +349,7 @@ class MyCompanyViewTest(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
# assert only 2 instances in response # assert only 2 instances in response
payload = response.json() payload = response.json()
self.assertEquals(2, len(payload)) self.assertEquals(2, len(payload['results']))
def test_anon_user_cannot_access(self): def test_anon_user_cannot_access(self):
# send in request # send in request

View File

@@ -23,6 +23,8 @@ from back_latienda.permissions import IsCreator
from utils import woocommerce from utils import woocommerce
class CompanyViewSet(viewsets.ModelViewSet): class CompanyViewSet(viewsets.ModelViewSet):
queryset = Company.objects.filter(is_validated=True).order_by('-created') queryset = Company.objects.filter(is_validated=True).order_by('-created')
serializer_class = CompanySerializer serializer_class = CompanySerializer
@@ -155,23 +157,16 @@ class CompanyViewSet(viewsets.ModelViewSet):
return Response(message) return Response(message)
@api_view(['GET',]) class MyCompanyViewSet(viewsets.ModelViewSet):
@permission_classes([IsAuthenticated,]) model = Company
def my_company(request): serializer_class = CompanySerializer
limit = request.GET.get('limit') permission_classes = [IsAuthenticated]
offset = request.GET.get('offset')
qs = Company.objects.filter(creator=request.user) def get_queryset(self):
company_serializer = CompanySerializer(qs, many=True) return self.model.objects.filter(creator=self.request.user)
data = company_serializer.data
# RESULTS PAGINATION def perform_create(self, serializer):
if limit is not None and offset is not None: serializer.save(creator=self.request.user)
limit = int(limit)
offset = int(offset)
data = data[offset:(limit+offset)]
elif limit is not None:
limit = int(limit)
data = data[:limit]
return Response(data=data)
@api_view(['GET',]) @api_view(['GET',])

View File

@@ -946,11 +946,9 @@ class MyProductsViewTest(APITestCase):
# Query endpoint # Query endpoint
response = self.client.get(self.endpoint) response = self.client.get(self.endpoint)
payload = response.json() payload = response.json()
# Assert forbidden code # Assert forbidden code
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEquals(payload['count'], len(payload['results'])) self.assertEquals(len(user_instances), len(payload))
self.assertEquals(len(user_instances), payload['count'])
def test_auth_user_can_paginate_instances(self): def test_auth_user_can_paginate_instances(self):
"""authenticated user can paginate instances """authenticated user can paginate instances
@@ -970,8 +968,8 @@ class MyProductsViewTest(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
# assert only 2 instances in response # assert only 2 instances in response
payload = response.json() payload = response.json()
self.assertEquals(payload['count'], len(payload['results'])) self.assertEquals(payload['count'], self.model.objects.count())
self.assertEquals(2, payload['count']) self.assertEquals(2, len(payload['results']))
def test_anon_user_cannot_access(self): def test_anon_user_cannot_access(self):
# send in request # send in request

View File

@@ -52,27 +52,16 @@ class ProductViewSet(viewsets.ModelViewSet):
return Response(data=[]) return Response(data=[])
@api_view(['GET',]) class MyProductsViewSet(viewsets.ModelViewSet):
@permission_classes([IsAuthenticated,]) model = Product
def my_products(request): serializer_class = ProductSerializer
limit = request.GET.get('limit') permission_classes = [IsAuthenticated]
offset = request.GET.get('offset')
qs = Product.objects.filter(creator=request.user) def get_queryset(self):
product_serializer = ProductSerializer(qs, many=True) return self.model.objects.filter(creator=self.request.user)
data = product_serializer.data
# RESULTS PAGINATION def perform_create(self, serializer):
if limit is not None and offset is not None: serializer.save(creator=self.request.user)
limit = int(limit)
offset = int(offset)
data = data[offset:(limit+offset)]
elif limit is not None:
limit = int(limit)
data = data[:limit]
# prepare response payload
payload = {}
payload['results'] = data
payload['count'] = len(payload['results'])
return Response(data=payload)
@api_view(['POST',]) @api_view(['POST',])