diff --git a/products/tests.py b/products/tests.py index 8b497ff..f503677 100644 --- a/products/tests.py +++ b/products/tests.py @@ -195,19 +195,23 @@ class ProductViewSetTest(APITestCase): self.assertEquals(len(expected_instance), len(payload)) def test_anon_can_get_related_products(self): + tag = 'cosa' + company = CompanyFactory() # Create instances instance = self.factory() # make our user the creator instance.creator = self.user instance.save() - url = f"{self.endpoint}{instance.id}/related/" + instances = [self.factory(tags=tag, company=company) for i in range(10)] + + url = f"{self.endpoint}{instances[0].id}/related/" response = self.client.get(url) self.assertEquals(response.status_code, 200) payload= response.json() - self.assertTrue(len(payload) <= 6) + self.assertTrue(len(payload) <= 10) # authenticated user def test_auth_user_can_paginate_instances(self): diff --git a/products/utils.py b/products/utils.py index 80c1010..8c938dc 100644 --- a/products/utils.py +++ b/products/utils.py @@ -85,14 +85,42 @@ def extract_search_filters(result_set): return filter_dict -def get_related_products(description, tags, attributes, category): - products_qs = Product.objects.filter( - Q(description=description) | - Q(tags__in=tags) | - Q(attributes__in=attributes) | - Q(category=category) - )[:10] - return products_qs +def get_related_products(product): + """Make different db searches until you get 10 instances to return + """ + total_results = [] + + # search by category + category_qs = Product.objects.filter(category=product.category)[:10] + # add to results + for item in category_qs: + total_results.append(item) + + # check size + if len(total_results) < 10: + # search by tags + tags_qs = Product.objects.filter(tags__in=product.tags.all())[:10] + # add to results + for item in tags_qs: + total_results.append(item) + + # check size + if len(total_results) < 10: + # search by coop + coop_qs = Product.objects.filter(company=product.company)[:10] + # add to results + for item in coop_qs: + total_results.append(item) + + # check size + if len(total_results) < 10: + # search by latest + latest_qs = Product.objects.order_by('-created')[:10] + # add to results + for item in coop_qs: + total_results.append(item) + + return total_results[:10] def ranked_product_search(keyword, shipping_cost=None, discount=None, category=None, tags=None, price_min=None,price_max=None): diff --git a/products/views.py b/products/views.py index 360b935..8103f67 100644 --- a/products/views.py +++ b/products/views.py @@ -59,7 +59,7 @@ class ProductViewSet(viewsets.ModelViewSet): """ # TODO: find the most similar products product = self.get_object() - qs = get_related_products(product.description, product.tags.all(), product.attributes.all(), product.category) + qs = get_related_products(product) serializer = self.serializer_class(qs, many=True) return Response(data=serializer.data)