switched my_company and my_products from method to class views

This commit is contained in:
Sam
2021-03-08 12:24:44 +00:00
parent ebd1b6744c
commit ec27937b85
7 changed files with 48 additions and 48 deletions

View File

@@ -35,6 +35,19 @@ class IsStaff(permissions.BasePermission):
return request.user.is_staff
class IsSiteAdmin(permissions.BasePermission):
"""
Grant permission if request.user.role == 'SITE_ADMIN'
"""
admin_role = 'SITE_ADMIN'
def has_object_permission(self, request, view, obj):
return request.user.role == self.admin_role
def has_permission(self, request, view):
return request.user.role == self.admin_role
class ReadOnly(permissions.BasePermission):
def has_permission(self, request, view):

View File

@@ -1,8 +1,8 @@
from rest_framework import routers
from core.views import CustomUserViewSet
from companies.views import CompanyViewSet
from products.views import ProductViewSet
from companies.views import CompanyViewSet, MyCompanyViewSet
from products.views import ProductViewSet, MyProductsViewSet
from history.views import HistorySyncViewSet
from stats.views import StatsLogViewSet
@@ -13,7 +13,9 @@ router = routers.DefaultRouter()
router.register('users', CustomUserViewSet, basename='users')
router.register('companies', CompanyViewSet, basename='company')
router.register('my_company', MyCompanyViewSet, basename='my-company')
router.register('products', ProductViewSet, basename='product')
router.register('my_products', MyProductsViewSet, basename='my-products')
router.register('history', HistorySyncViewSet, basename='history')
router.register('stats', StatsLogViewSet, basename='stats')

View File

@@ -39,9 +39,9 @@ urlpatterns = [
path('api/v1/search_products/', product_views.product_search, name='product-search'),
path('api/v1/create_company_user/', core_views.create_company_user, name='create-company-user'),
path('api/v1/my_user/', core_views.my_user, name='my-user'),
path('api/v1/my_company/', company_views.my_company , name='my-company'),
# path('api/v1/my_company/', company_views.my_company , name='my-company'),
path('api/v1/companies/sample/', company_views.random_company_sample , name='company-sample'),
path('api/v1/my_products/', product_views.my_products, name='my-products'),
# path('api/v1/my_products/', product_views.my_products, name='my-products'),
path('api/v1/stats/me/', stat_views.track_user, name='user-tracker'),
path('api/v1/autocomplete/category-tag/', product_views.CategoryTagAutocomplete.as_view(), name='category-autocomplete'),
path('api/v1/', include(router.urls)),

View File

@@ -311,6 +311,9 @@ class MyCompanyViewTest(APITestCase):
self.user.set_password(self.password)
self.user.save()
def tearDown(self):
self.model.objects.all().delete()
def test_auth_user_gets_data(self):
# create instance
user_instances = [self.factory(creator=self.user) for i in range(5)]
@@ -346,7 +349,7 @@ class MyCompanyViewTest(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
# assert only 2 instances in response
payload = response.json()
self.assertEquals(2, len(payload))
self.assertEquals(2, len(payload['results']))
def test_anon_user_cannot_access(self):
# send in request

View File

@@ -23,6 +23,8 @@ from back_latienda.permissions import IsCreator
from utils import woocommerce
class CompanyViewSet(viewsets.ModelViewSet):
queryset = Company.objects.filter(is_validated=True).order_by('-created')
serializer_class = CompanySerializer
@@ -155,23 +157,16 @@ class CompanyViewSet(viewsets.ModelViewSet):
return Response(message)
@api_view(['GET',])
@permission_classes([IsAuthenticated,])
def my_company(request):
limit = request.GET.get('limit')
offset = request.GET.get('offset')
qs = Company.objects.filter(creator=request.user)
company_serializer = CompanySerializer(qs, many=True)
data = company_serializer.data
# RESULTS PAGINATION
if limit is not None and offset is not None:
limit = int(limit)
offset = int(offset)
data = data[offset:(limit+offset)]
elif limit is not None:
limit = int(limit)
data = data[:limit]
return Response(data=data)
class MyCompanyViewSet(viewsets.ModelViewSet):
model = Company
serializer_class = CompanySerializer
permission_classes = [IsAuthenticated]
def get_queryset(self):
return self.model.objects.filter(creator=self.request.user)
def perform_create(self, serializer):
serializer.save(creator=self.request.user)
@api_view(['GET',])

View File

@@ -946,11 +946,9 @@ class MyProductsViewTest(APITestCase):
# Query endpoint
response = self.client.get(self.endpoint)
payload = response.json()
# Assert forbidden code
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEquals(payload['count'], len(payload['results']))
self.assertEquals(len(user_instances), payload['count'])
self.assertEquals(len(user_instances), len(payload))
def test_auth_user_can_paginate_instances(self):
"""authenticated user can paginate instances
@@ -970,8 +968,8 @@ class MyProductsViewTest(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
# assert only 2 instances in response
payload = response.json()
self.assertEquals(payload['count'], len(payload['results']))
self.assertEquals(2, payload['count'])
self.assertEquals(payload['count'], self.model.objects.count())
self.assertEquals(2, len(payload['results']))
def test_anon_user_cannot_access(self):
# send in request

View File

@@ -52,27 +52,16 @@ class ProductViewSet(viewsets.ModelViewSet):
return Response(data=[])
@api_view(['GET',])
@permission_classes([IsAuthenticated,])
def my_products(request):
limit = request.GET.get('limit')
offset = request.GET.get('offset')
qs = Product.objects.filter(creator=request.user)
product_serializer = ProductSerializer(qs, many=True)
data = product_serializer.data
# RESULTS PAGINATION
if limit is not None and offset is not None:
limit = int(limit)
offset = int(offset)
data = data[offset:(limit+offset)]
elif limit is not None:
limit = int(limit)
data = data[:limit]
# prepare response payload
payload = {}
payload['results'] = data
payload['count'] = len(payload['results'])
return Response(data=payload)
class MyProductsViewSet(viewsets.ModelViewSet):
model = Product
serializer_class = ProductSerializer
permission_classes = [IsAuthenticated]
def get_queryset(self):
return self.model.objects.filter(creator=self.request.user)
def perform_create(self, serializer):
serializer.save(creator=self.request.user)
@api_view(['POST',])