diff --git a/back_latienda/routers.py b/back_latienda/routers.py index ef0d46c..bad6b0c 100644 --- a/back_latienda/routers.py +++ b/back_latienda/routers.py @@ -1,7 +1,7 @@ from rest_framework import routers from core.views import CustomUserViewSet -from companies.views import CompanyViewSet, MyCompanyViewSet +from companies.views import CompanyViewSet, MyCompanyViewSet, AdminCompanyViewSet from products.views import ProductViewSet, MyProductsViewSet, AdminProductsViewSet from history.views import HistorySyncViewSet from stats.views import StatsLogViewSet @@ -14,9 +14,10 @@ router = routers.DefaultRouter() router.register('users', CustomUserViewSet, basename='users') router.register('companies', CompanyViewSet, basename='company') router.register('my_company', MyCompanyViewSet, basename='my-company') +router.register('admin_companies', AdminCompanyViewSet, basename='admin-companies') router.register('products', ProductViewSet, basename='product') router.register('my_products', MyProductsViewSet, basename='my-products') -router.register('admin_products', AdminProductsViewSet, basename='admin-product') +router.register('admin_products', AdminProductsViewSet, basename='admin-products') router.register('history', HistorySyncViewSet, basename='history') router.register('stats', StatsLogViewSet, basename='stats') diff --git a/companies/tests.py b/companies/tests.py index 167cfb1..1bd5c51 100644 --- a/companies/tests.py +++ b/companies/tests.py @@ -7,7 +7,7 @@ from django.test import TestCase from rest_framework.test import APITestCase from rest_framework import status -from companies.factories import ValidatedCompanyFactory +from companies.factories import ValidatedCompanyFactory, CompanyFactory from companies.models import Company from core.factories import CustomUserFactory @@ -407,3 +407,168 @@ class RandomCompanySampleTest(APITestCase): self.assertEquals(size, len(payload)) # test IDs not correlative (eventually it could be, because it's random) self.assertTrue(payload[0]['id'] != (payload[1]['id'] + 1)) + + +class AdminCompanyViewSetTest(APITestCase): + + def setUp(self): + """Tests setup + """ + self.endpoint = '/api/v1/admin_companies/' + self.factory = CompanyFactory + self.model = Company + # create user + self.email = f"user@mail.com" + self.password = ''.join(random.choices(string.ascii_uppercase, k = 10)) + self.user = CustomUserFactory(email=self.email, is_active=True) + self.user.set_password(self.password) + self.user.save() + + def test_anon_user_cannot_access(self): + instance = self.factory() + url = f"{self.endpoint}{instance.id}/" + # GET + response = self.client.get(self.endpoint) + # check response + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + # POST + response = self.client.post(self.endpoint, data={}) + # check response + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + # PUT + response = self.client.get(url, data={}) + # check response + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + # delete + response = self.client.get(url) + # check response + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + def test_auth_user_cannot_access(self): + # Authenticate + token = get_tokens_for_user(self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") + + instance = self.factory() + url = f"{self.endpoint}{instance.id}/" + # GET + response = self.client.get(self.endpoint) + # check response + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + # POST + response = self.client.post(self.endpoint, data={}) + # check response + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + # PUT + response = self.client.get(url, data={}) + # check response + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + # delete + response = self.client.get(url) + # check response + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_admin_user_can_list(self): + # make user site amdin + self.user.role = 'SITE_ADMIN' + self.user.save() + + # Authenticate + token = get_tokens_for_user(self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") + + # create instances + instance = [self.factory() for i in range(random.randint(1,5))] + # query endpoint + response = self.client.get(self.endpoint) + + # assertions + self.assertEquals(response.status_code, 200) + payload = response.json() + self.assertEquals(len(instance), len(payload)) + + def test_admin_user_can_get_details(self): + # make user site amdin + self.user.role = 'SITE_ADMIN' + self.user.save() + + # Authenticate + token = get_tokens_for_user(self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") + + # create instances + instance = self.factory() + url = f"{self.endpoint}{instance.id}/" + # query endpoint + response = self.client.get(url) + + # assertions + self.assertEquals(response.status_code, 200) + payload = response.json() + self.assertEquals(instance.id, payload['id']) + + def test_admin_can_create_instance(self): + # make user site amdin + self.user.role = 'SITE_ADMIN' + self.user.save() + + # Authenticate + token = get_tokens_for_user(self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") + + # create instances + data = { + 'short_name': 'test_compnay short _name', + } + # query endpoint + response = self.client.post(self.endpoint, data=data) + + # assertions + self.assertEquals(response.status_code, 201) + payload = response.json() + self.assertEquals(data['short_name'], payload['short_name']) + + def test_admin_can_update_instance(self): + # make user site amdin + self.user.role = 'SITE_ADMIN' + self.user.save() + + # Authenticate + token = get_tokens_for_user(self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") + + # create instance + instance = self.factory() + url = f"{self.endpoint}{instance.id}/" + + # data + data = { + 'short_name': 'test_compnay short _name', + } + # query endpoint + response = self.client.put(url, data=data) + + # assertions + self.assertEquals(response.status_code, 200) + payload = response.json() + self.assertEquals(data['short_name'], payload['short_name']) + + def test_admin_can_delete_instance(self): + # make user site amdin + self.user.role = 'SITE_ADMIN' + self.user.save() + + # Authenticate + token = get_tokens_for_user(self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") + + # create instance + instance = self.factory() + url = f"{self.endpoint}{instance.id}/" + + # query endpoint + response = self.client.delete(url) + + # assertions + self.assertEquals(response.status_code, 204) + diff --git a/companies/views.py b/companies/views.py index eb1cc75..5133d43 100644 --- a/companies/views.py +++ b/companies/views.py @@ -18,13 +18,11 @@ from stats.models import StatsLog from companies.models import Company from companies.serializers import CompanySerializer from utils.tag_filters import CompanyTagFilter -from back_latienda.permissions import IsCreator +from back_latienda.permissions import IsCreator, IsSiteAdmin from utils import woocommerce - - class CompanyViewSet(viewsets.ModelViewSet): queryset = Company.objects.filter(is_validated=True).order_by('-created') serializer_class = CompanySerializer @@ -169,6 +167,17 @@ class MyCompanyViewSet(viewsets.ModelViewSet): serializer.save(creator=self.request.user) +class AdminCompanyViewSet(viewsets.ModelViewSet): + """ Allows user with role 'SITE_ADMIN' to access all company instances + """ + queryset = Company.objects.all() + serializer_class = CompanySerializer + permission_classes = [IsSiteAdmin] + + def perform_create(self, serializer): + serializer.save(creator=self.request.user) + + @api_view(['GET',]) @permission_classes([IsAuthenticatedOrReadOnly,]) def random_company_sample(request): diff --git a/products/tests.py b/products/tests.py index d15e5c6..e1c0d7c 100644 --- a/products/tests.py +++ b/products/tests.py @@ -979,7 +979,7 @@ class MyProductsViewTest(APITestCase): self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) -class AdminProductViewSet(APITestCase): +class AdminProductViewSetTest(APITestCase): def setUp(self): """Tests setup