bunch of stuff

This commit is contained in:
Sam
2021-03-11 13:51:14 +00:00
parent 5fe3883fcd
commit 4cf22fd969
7 changed files with 30 additions and 38 deletions

View File

@@ -1,7 +1,7 @@
from rest_framework import routers
from core.views import CustomUserViewSet
from companies.views import CompanyViewSet, MyCompanyViewSet, AdminCompanyViewSet
from companies.views import CompanyViewSet, AdminCompanyViewSet
from products.views import ProductViewSet, MyProductsViewSet, AdminProductsViewSet
from history.views import HistorySyncViewSet
from stats.views import StatsLogViewSet
@@ -13,7 +13,6 @@ router = routers.DefaultRouter()
router.register('users', CustomUserViewSet, basename='users')
router.register('companies', CompanyViewSet, basename='company')
router.register('my_company', MyCompanyViewSet, basename='my-company')
router.register('admin_companies', AdminCompanyViewSet, basename='admin-companies')
router.register('products', ProductViewSet, basename='product')
router.register('my_products', MyProductsViewSet, basename='my-products')

View File

@@ -39,6 +39,7 @@ urlpatterns = [
path('api/v1/search_products/', product_views.product_search, name='product-search'),
path('api/v1/create_company_user/', core_views.create_company_user, name='create-company-user'),
path('api/v1/my_user/', core_views.my_user, name='my-user'),
path('api/v1/my_company/', company_views.my_company, name='my-company'),
path('api/v1/companies/sample/', company_views.random_company_sample , name='company-sample'),
path('api/v1/purchase_email/', product_views.purchase_email, name='purchase-email'),
path('api/v1/stats/me/', stat_views.track_user, name='user-tracker'),

View File

@@ -312,12 +312,11 @@ class MyCompanyViewTest(APITestCase):
self.user.set_password(self.password)
self.user.save()
def tearDown(self):
self.model.objects.all().delete()
def test_auth_user_gets_data(self):
# create instance
user_instances = [self.factory(creator=self.user) for i in range(5)]
company = CompanyFactory()
self.user.company = company
self.user.save()
# Authenticate
token = get_tokens_for_user(self.user)
@@ -325,32 +324,10 @@ class MyCompanyViewTest(APITestCase):
# Query endpoint
response = self.client.get(self.endpoint)
payload = response.json()
# Assert forbidden code
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEquals(len(user_instances), len(payload))
def test_auth_user_can_paginate_instances(self):
"""authenticated user can paginate instances
"""
# Authenticate
token = get_tokens_for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# create instances
instances = [self.factory(creator=self.user) for n in range(12)]
# Request list
url = f"{self.endpoint}?limit=5&offset=10"
response = self.client.get(url)
# Assert access is allowed
self.assertEqual(response.status_code, status.HTTP_200_OK)
# assert only 2 instances in response
payload = response.json()
self.assertEquals(2, len(payload['results']))
self.assertEquals(payload['company']['id'], company.id)
def test_anon_user_cannot_access(self):
# send in request

View File

@@ -155,6 +155,16 @@ class CompanyViewSet(viewsets.ModelViewSet):
return Response(message)
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def my_company(request):
if request.user.company:
serializer = CompanySerializer(request.user.company)
return Response({'company': serializer.data})
else:
return Response(status=status.HTTP_406_NOT_ACCEPTABLE)
'''
class MyCompanyViewSet(viewsets.ModelViewSet):
model = Company
serializer_class = CompanySerializer
@@ -165,7 +175,7 @@ class MyCompanyViewSet(viewsets.ModelViewSet):
def perform_create(self, serializer):
serializer.save(creator=self.request.user)
'''
class AdminCompanyViewSet(viewsets.ModelViewSet):
""" Allows user with role 'SITE_ADMIN' to access all company instances

View File

@@ -23,7 +23,7 @@ class CustomUserWriteSerializer(serializers.ModelSerializer):
class Meta:
model = models.CustomUser
fields = ('email', 'full_name', 'role', 'password', 'provider')
fields = ('email', 'full_name', 'role', 'password', 'provider', 'notify')
class CreatorSerializer(serializers.ModelSerializer):

View File

@@ -952,9 +952,13 @@ class MyProductsViewTest(APITestCase):
def test_auth_user_gets_data(self):
# create instance
company = CompanyFactory()
self.user.company = company
self.user.save()
user_instances = [
self.factory(creator=self.user),
self.factory(creator=self.user),
self.factory(company=company),
self.factory(company=company),
]
# Authenticate
@@ -976,7 +980,11 @@ class MyProductsViewTest(APITestCase):
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}")
# create instances
instances = [self.factory(creator=self.user) for n in range(12)]
company = CompanyFactory()
self.user.company = company
self.user.save()
instances = [self.factory(company=company) for n in range(12)]
# Request list
url = f"{self.endpoint}?limit=5&offset=10"

View File

@@ -72,10 +72,7 @@ class MyProductsViewSet(viewsets.ModelViewSet):
permission_classes = [IsAuthenticated]
def get_queryset(self):
return self.model.objects.filter(creator=self.request.user).order_by('-created')
def perform_create(self, serializer):
serializer.save(creator=self.request.user)
return self.model.objects.filter(company=self.request.user.company).order_by('-created')
class AdminProductsViewSet(viewsets.ModelViewSet):