bunch of stuff

This commit is contained in:
Sam
2021-03-11 13:51:14 +00:00
parent 5fe3883fcd
commit 4cf22fd969
7 changed files with 30 additions and 38 deletions

View File

@@ -1,7 +1,7 @@
from rest_framework import routers from rest_framework import routers
from core.views import CustomUserViewSet from core.views import CustomUserViewSet
from companies.views import CompanyViewSet, MyCompanyViewSet, AdminCompanyViewSet from companies.views import CompanyViewSet, AdminCompanyViewSet
from products.views import ProductViewSet, MyProductsViewSet, AdminProductsViewSet from products.views import ProductViewSet, MyProductsViewSet, AdminProductsViewSet
from history.views import HistorySyncViewSet from history.views import HistorySyncViewSet
from stats.views import StatsLogViewSet from stats.views import StatsLogViewSet
@@ -13,7 +13,6 @@ router = routers.DefaultRouter()
router.register('users', CustomUserViewSet, basename='users') router.register('users', CustomUserViewSet, basename='users')
router.register('companies', CompanyViewSet, basename='company') router.register('companies', CompanyViewSet, basename='company')
router.register('my_company', MyCompanyViewSet, basename='my-company')
router.register('admin_companies', AdminCompanyViewSet, basename='admin-companies') router.register('admin_companies', AdminCompanyViewSet, basename='admin-companies')
router.register('products', ProductViewSet, basename='product') router.register('products', ProductViewSet, basename='product')
router.register('my_products', MyProductsViewSet, basename='my-products') router.register('my_products', MyProductsViewSet, basename='my-products')

View File

@@ -39,6 +39,7 @@ urlpatterns = [
path('api/v1/search_products/', product_views.product_search, name='product-search'), path('api/v1/search_products/', product_views.product_search, name='product-search'),
path('api/v1/create_company_user/', core_views.create_company_user, name='create-company-user'), path('api/v1/create_company_user/', core_views.create_company_user, name='create-company-user'),
path('api/v1/my_user/', core_views.my_user, name='my-user'), path('api/v1/my_user/', core_views.my_user, name='my-user'),
path('api/v1/my_company/', company_views.my_company, name='my-company'),
path('api/v1/companies/sample/', company_views.random_company_sample , name='company-sample'), path('api/v1/companies/sample/', company_views.random_company_sample , name='company-sample'),
path('api/v1/purchase_email/', product_views.purchase_email, name='purchase-email'), path('api/v1/purchase_email/', product_views.purchase_email, name='purchase-email'),
path('api/v1/stats/me/', stat_views.track_user, name='user-tracker'), path('api/v1/stats/me/', stat_views.track_user, name='user-tracker'),

View File

@@ -312,12 +312,11 @@ class MyCompanyViewTest(APITestCase):
self.user.set_password(self.password) self.user.set_password(self.password)
self.user.save() self.user.save()
def tearDown(self):
self.model.objects.all().delete()
def test_auth_user_gets_data(self): def test_auth_user_gets_data(self):
# create instance # create instance
user_instances = [self.factory(creator=self.user) for i in range(5)] company = CompanyFactory()
self.user.company = company
self.user.save()
# Authenticate # Authenticate
token = get_tokens_for_user(self.user) token = get_tokens_for_user(self.user)
@@ -325,32 +324,10 @@ class MyCompanyViewTest(APITestCase):
# Query endpoint # Query endpoint
response = self.client.get(self.endpoint) response = self.client.get(self.endpoint)
payload = response.json()
# Assert forbidden code # Assert forbidden code
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEquals(len(user_instances), len(payload))
def test_auth_user_can_paginate_instances(self):
"""authenticated user can paginate instances
"""
# Authenticate
token = get_tokens_for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# create instances
instances = [self.factory(creator=self.user) for n in range(12)]
# Request list
url = f"{self.endpoint}?limit=5&offset=10"
response = self.client.get(url)
# Assert access is allowed
self.assertEqual(response.status_code, status.HTTP_200_OK)
# assert only 2 instances in response
payload = response.json() payload = response.json()
self.assertEquals(2, len(payload['results'])) self.assertEquals(payload['company']['id'], company.id)
def test_anon_user_cannot_access(self): def test_anon_user_cannot_access(self):
# send in request # send in request

View File

@@ -155,6 +155,16 @@ class CompanyViewSet(viewsets.ModelViewSet):
return Response(message) return Response(message)
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def my_company(request):
if request.user.company:
serializer = CompanySerializer(request.user.company)
return Response({'company': serializer.data})
else:
return Response(status=status.HTTP_406_NOT_ACCEPTABLE)
'''
class MyCompanyViewSet(viewsets.ModelViewSet): class MyCompanyViewSet(viewsets.ModelViewSet):
model = Company model = Company
serializer_class = CompanySerializer serializer_class = CompanySerializer
@@ -165,7 +175,7 @@ class MyCompanyViewSet(viewsets.ModelViewSet):
def perform_create(self, serializer): def perform_create(self, serializer):
serializer.save(creator=self.request.user) serializer.save(creator=self.request.user)
'''
class AdminCompanyViewSet(viewsets.ModelViewSet): class AdminCompanyViewSet(viewsets.ModelViewSet):
""" Allows user with role 'SITE_ADMIN' to access all company instances """ Allows user with role 'SITE_ADMIN' to access all company instances

View File

@@ -23,7 +23,7 @@ class CustomUserWriteSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = models.CustomUser model = models.CustomUser
fields = ('email', 'full_name', 'role', 'password', 'provider') fields = ('email', 'full_name', 'role', 'password', 'provider', 'notify')
class CreatorSerializer(serializers.ModelSerializer): class CreatorSerializer(serializers.ModelSerializer):

View File

@@ -952,9 +952,13 @@ class MyProductsViewTest(APITestCase):
def test_auth_user_gets_data(self): def test_auth_user_gets_data(self):
# create instance # create instance
company = CompanyFactory()
self.user.company = company
self.user.save()
user_instances = [ user_instances = [
self.factory(creator=self.user), self.factory(company=company),
self.factory(creator=self.user), self.factory(company=company),
] ]
# Authenticate # Authenticate
@@ -976,7 +980,11 @@ class MyProductsViewTest(APITestCase):
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# create instances # create instances
instances = [self.factory(creator=self.user) for n in range(12)] company = CompanyFactory()
self.user.company = company
self.user.save()
instances = [self.factory(company=company) for n in range(12)]
# Request list # Request list
url = f"{self.endpoint}?limit=5&offset=10" url = f"{self.endpoint}?limit=5&offset=10"

View File

@@ -72,10 +72,7 @@ class MyProductsViewSet(viewsets.ModelViewSet):
permission_classes = [IsAuthenticated] permission_classes = [IsAuthenticated]
def get_queryset(self): def get_queryset(self):
return self.model.objects.filter(creator=self.request.user).order_by('-created') return self.model.objects.filter(company=self.request.user.company).order_by('-created')
def perform_create(self, serializer):
serializer.save(creator=self.request.user)
class AdminProductsViewSet(viewsets.ModelViewSet): class AdminProductsViewSet(viewsets.ModelViewSet):