added AdminCompanyViewSet for users with role SITE_ADMIN

This commit is contained in:
Sam
2021-03-08 13:02:57 +00:00
parent 642688b98d
commit 3b28238a62
4 changed files with 182 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
from rest_framework import routers from rest_framework import routers
from core.views import CustomUserViewSet from core.views import CustomUserViewSet
from companies.views import CompanyViewSet, MyCompanyViewSet from companies.views import CompanyViewSet, MyCompanyViewSet, AdminCompanyViewSet
from products.views import ProductViewSet, MyProductsViewSet, AdminProductsViewSet from products.views import ProductViewSet, MyProductsViewSet, AdminProductsViewSet
from history.views import HistorySyncViewSet from history.views import HistorySyncViewSet
from stats.views import StatsLogViewSet from stats.views import StatsLogViewSet
@@ -14,9 +14,10 @@ router = routers.DefaultRouter()
router.register('users', CustomUserViewSet, basename='users') router.register('users', CustomUserViewSet, basename='users')
router.register('companies', CompanyViewSet, basename='company') router.register('companies', CompanyViewSet, basename='company')
router.register('my_company', MyCompanyViewSet, basename='my-company') router.register('my_company', MyCompanyViewSet, basename='my-company')
router.register('admin_companies', AdminCompanyViewSet, basename='admin-companies')
router.register('products', ProductViewSet, basename='product') router.register('products', ProductViewSet, basename='product')
router.register('my_products', MyProductsViewSet, basename='my-products') router.register('my_products', MyProductsViewSet, basename='my-products')
router.register('admin_products', AdminProductsViewSet, basename='admin-product') router.register('admin_products', AdminProductsViewSet, basename='admin-products')
router.register('history', HistorySyncViewSet, basename='history') router.register('history', HistorySyncViewSet, basename='history')
router.register('stats', StatsLogViewSet, basename='stats') router.register('stats', StatsLogViewSet, basename='stats')

View File

@@ -7,7 +7,7 @@ from django.test import TestCase
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from rest_framework import status from rest_framework import status
from companies.factories import ValidatedCompanyFactory from companies.factories import ValidatedCompanyFactory, CompanyFactory
from companies.models import Company from companies.models import Company
from core.factories import CustomUserFactory from core.factories import CustomUserFactory
@@ -407,3 +407,168 @@ class RandomCompanySampleTest(APITestCase):
self.assertEquals(size, len(payload)) self.assertEquals(size, len(payload))
# test IDs not correlative (eventually it could be, because it's random) # test IDs not correlative (eventually it could be, because it's random)
self.assertTrue(payload[0]['id'] != (payload[1]['id'] + 1)) self.assertTrue(payload[0]['id'] != (payload[1]['id'] + 1))
class AdminCompanyViewSetTest(APITestCase):
def setUp(self):
"""Tests setup
"""
self.endpoint = '/api/v1/admin_companies/'
self.factory = CompanyFactory
self.model = Company
# create user
self.email = f"user@mail.com"
self.password = ''.join(random.choices(string.ascii_uppercase, k = 10))
self.user = CustomUserFactory(email=self.email, is_active=True)
self.user.set_password(self.password)
self.user.save()
def test_anon_user_cannot_access(self):
instance = self.factory()
url = f"{self.endpoint}{instance.id}/"
# GET
response = self.client.get(self.endpoint)
# check response
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# POST
response = self.client.post(self.endpoint, data={})
# check response
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# PUT
response = self.client.get(url, data={})
# check response
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# delete
response = self.client.get(url)
# check response
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_auth_user_cannot_access(self):
# Authenticate
token = get_tokens_for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
instance = self.factory()
url = f"{self.endpoint}{instance.id}/"
# GET
response = self.client.get(self.endpoint)
# check response
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
# POST
response = self.client.post(self.endpoint, data={})
# check response
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
# PUT
response = self.client.get(url, data={})
# check response
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
# delete
response = self.client.get(url)
# check response
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_admin_user_can_list(self):
# make user site amdin
self.user.role = 'SITE_ADMIN'
self.user.save()
# Authenticate
token = get_tokens_for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# create instances
instance = [self.factory() for i in range(random.randint(1,5))]
# query endpoint
response = self.client.get(self.endpoint)
# assertions
self.assertEquals(response.status_code, 200)
payload = response.json()
self.assertEquals(len(instance), len(payload))
def test_admin_user_can_get_details(self):
# make user site amdin
self.user.role = 'SITE_ADMIN'
self.user.save()
# Authenticate
token = get_tokens_for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# create instances
instance = self.factory()
url = f"{self.endpoint}{instance.id}/"
# query endpoint
response = self.client.get(url)
# assertions
self.assertEquals(response.status_code, 200)
payload = response.json()
self.assertEquals(instance.id, payload['id'])
def test_admin_can_create_instance(self):
# make user site amdin
self.user.role = 'SITE_ADMIN'
self.user.save()
# Authenticate
token = get_tokens_for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# create instances
data = {
'short_name': 'test_compnay short _name',
}
# query endpoint
response = self.client.post(self.endpoint, data=data)
# assertions
self.assertEquals(response.status_code, 201)
payload = response.json()
self.assertEquals(data['short_name'], payload['short_name'])
def test_admin_can_update_instance(self):
# make user site amdin
self.user.role = 'SITE_ADMIN'
self.user.save()
# Authenticate
token = get_tokens_for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# create instance
instance = self.factory()
url = f"{self.endpoint}{instance.id}/"
# data
data = {
'short_name': 'test_compnay short _name',
}
# query endpoint
response = self.client.put(url, data=data)
# assertions
self.assertEquals(response.status_code, 200)
payload = response.json()
self.assertEquals(data['short_name'], payload['short_name'])
def test_admin_can_delete_instance(self):
# make user site amdin
self.user.role = 'SITE_ADMIN'
self.user.save()
# Authenticate
token = get_tokens_for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# create instance
instance = self.factory()
url = f"{self.endpoint}{instance.id}/"
# query endpoint
response = self.client.delete(url)
# assertions
self.assertEquals(response.status_code, 204)

View File

@@ -18,13 +18,11 @@ from stats.models import StatsLog
from companies.models import Company from companies.models import Company
from companies.serializers import CompanySerializer from companies.serializers import CompanySerializer
from utils.tag_filters import CompanyTagFilter from utils.tag_filters import CompanyTagFilter
from back_latienda.permissions import IsCreator from back_latienda.permissions import IsCreator, IsSiteAdmin
from utils import woocommerce from utils import woocommerce
class CompanyViewSet(viewsets.ModelViewSet): class CompanyViewSet(viewsets.ModelViewSet):
queryset = Company.objects.filter(is_validated=True).order_by('-created') queryset = Company.objects.filter(is_validated=True).order_by('-created')
serializer_class = CompanySerializer serializer_class = CompanySerializer
@@ -169,6 +167,17 @@ class MyCompanyViewSet(viewsets.ModelViewSet):
serializer.save(creator=self.request.user) serializer.save(creator=self.request.user)
class AdminCompanyViewSet(viewsets.ModelViewSet):
""" Allows user with role 'SITE_ADMIN' to access all company instances
"""
queryset = Company.objects.all()
serializer_class = CompanySerializer
permission_classes = [IsSiteAdmin]
def perform_create(self, serializer):
serializer.save(creator=self.request.user)
@api_view(['GET',]) @api_view(['GET',])
@permission_classes([IsAuthenticatedOrReadOnly,]) @permission_classes([IsAuthenticatedOrReadOnly,])
def random_company_sample(request): def random_company_sample(request):

View File

@@ -979,7 +979,7 @@ class MyProductsViewTest(APITestCase):
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
class AdminProductViewSet(APITestCase): class AdminProductViewSetTest(APITestCase):
def setUp(self): def setUp(self):
"""Tests setup """Tests setup