diff --git a/companies/views.py b/companies/views.py index 010c4b4..62c4767 100644 --- a/companies/views.py +++ b/companies/views.py @@ -171,7 +171,7 @@ class MyCompanyViewSet(viewsets.ModelViewSet): permission_classes = [IsAuthenticated] def get_queryset(self): - return self.model.objects.filter(creator=self.request.user) + return self.model.objects.filter(company=self.request.user.company) def perform_create(self, serializer): serializer.save(creator=self.request.user) diff --git a/products/tests.py b/products/tests.py index 194fee3..d9db3ea 100644 --- a/products/tests.py +++ b/products/tests.py @@ -1004,6 +1004,18 @@ class MyProductsViewTest(APITestCase): # check response self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + def test_auth_user_without_company(self): + # Authenticate + token = get_tokens_for_user(self.user) + self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token['access']}") + + # Query endpoint + response = self.client.get(self.endpoint) + payload = response.json() + # Assert forbidden code + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEquals([], payload) + class AdminProductViewSetTest(APITestCase):