diff --git a/README.md b/README.md index bcceed1..7c2ba1d 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ By default a `LimitOffsetPagination` pagination is enabled Examples: `http://127.0.0.1:8000/api/v1/products/?limit=10&offset=0` -The response data has the following keyspayload: +The response data has the following keys: ``` dict_keys(['count', 'next', 'previous', 'results']) ``` @@ -294,4 +294,6 @@ To create a dataset of fake companies and products: `python manage.py addtestdata` -Creates 10 Companies, with 100 products each. +Creates 10 Companies, with 10 products each. + +WARNING: the script deletes existing instances of both Company and Product diff --git a/back_latienda/settings/base.py b/back_latienda/settings/base.py index afc4c78..ccb5e4b 100644 --- a/back_latienda/settings/base.py +++ b/back_latienda/settings/base.py @@ -42,6 +42,7 @@ INSTALLED_APPS = [ 'django.contrib.messages', 'django.contrib.staticfiles', 'django.contrib.gis', + 'django.contrib.postgres', # 3rd party 'rest_framework', diff --git a/core/management/commands/addtestdata.py b/core/management/commands/addtestdata.py index 7d98462..35d7bf9 100644 --- a/core/management/commands/addtestdata.py +++ b/core/management/commands/addtestdata.py @@ -70,7 +70,10 @@ class Command(BaseCommand): # TODO: apply automatic tags from tag list # TODO: write image to S3 storage # create instance - product = ProductFactory(name=name, description=description) + product = ProductFactory( + company=company, + name=name, + description=description) # get image response = requests.get(self.logo_url, stream=True) diff --git a/products/serializers.py b/products/serializers.py index c2fa24c..6bd7497 100644 --- a/products/serializers.py +++ b/products/serializers.py @@ -18,12 +18,13 @@ class ProductSerializer(TaggitSerializer, serializers.ModelSerializer): exclude = ['created', 'updated', 'creator'] -class ProductSearchSerializer(TaggitSerializer, serializers.ModelSerializer): +class SearchResultSerializer(TaggitSerializer, serializers.ModelSerializer): tags = TagListSerializerField(required=False) category = SingleTagSerializerField(required=False) # main tag category attributes = TagListSerializerField(required=False) company = CompanySerializer(read_only=True) + rank = serializers.FloatField() class Meta: model = Product diff --git a/products/tests.py b/products/tests.py index a71db66..46f57c7 100644 --- a/products/tests.py +++ b/products/tests.py @@ -13,6 +13,7 @@ from rest_framework import status from companies.factories import CompanyFactory from products.factories import ProductFactory from products.models import Product +from products.utils import find_related_products_v3 from core.factories import CustomUserFactory from core.utils import get_tokens_for_user @@ -461,7 +462,7 @@ class ProductSearchTest(TestCase): def test_anon_user_can_search(self): expected_instances = [ - self.factory(tags="lunares/blancos",description="zapatos verdes"), + self.factory(tags="lunares/rojos", category='zapatos', description="zapatos verdes"), self.factory(tags="colores/rojos, tono/brillante"), self.factory(tags="lunares/azules", description="zapatos rojos"), self.factory(tags="lunares/rojos", description="zapatos"), @@ -477,15 +478,51 @@ class ProductSearchTest(TestCase): url = f"{self.endpoint}?query_string={query_string}" # send in request response = self.client.get(url) - payload = response.json() + # check response self.assertEqual(response.status_code, 200) + # load response data + payload = response.json() # check for object creation self.assertEquals(len(payload['products']), len(expected_instances)) + # check ids + for i in range(len(payload['products'])): + self.assertTrue(payload['products'][i]['id'] == expected_instances[i].id) + # check results ordered by rank + current = 1 + for i in range(len(payload['products'])): + self.assertTrue(payload['products'][i]['rank'] <= current ) + current = payload['products'][i]['rank'] # check for filters self.assertNotEquals([], payload['filters']['singles']) self.assertTrue(len(payload['filters']) >= 2 ) + def test_anon_user_can_paginate_search(self): + expected_instances = [ + self.factory(tags="lunares/rojos", category='zapatos', description="zapatos verdes"), + self.factory(tags="colores/rojos, tono/brillante"), + self.factory(tags="lunares/azules", description="zapatos rojos"), + self.factory(tags="lunares/rojos", description="zapatos"), + self.factory(attributes='"zapatos de campo", tono/oscuro'), + ] + unexpected_instances = [ + self.factory(description="chanclas"), + self.factory(tags="azules"), + ] + + query_string = quote("zapatos rojos") + limit = 2 + + url = f"{self.endpoint}?query_string={query_string}&limit=2" + # send in request + response = self.client.get(url) + + # check response + self.assertEqual(response.status_code, 200) + # load response data + payload = response.json() + # check for object creation + self.assertEquals(len(payload['products']), limit) class MyProductsViewTest(APITestCase): """my_products tests @@ -529,3 +566,46 @@ class MyProductsViewTest(APITestCase): # check response self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + +class FindRelatedProductsTest(APITestCase): + + def setUp(self): + """Tests setup + """ + self.factory = ProductFactory + self.model = Product + # clear table + self.model.objects.all().delete() + + def test_v3_find_by_tags(self): + # create tagged product + tag = 'cool' + expected_instances = [ + self.factory(tags=tag), + self.factory(tags=f'{tag} hat'), + self.factory(tags=f'temperatures/{tag}'), + self.factory(tags=f'temperatures/{tag}, body/hot'), + self.factory(tags=f'temperatures/{tag}, hats/{tag}'), + # multiple hits + self.factory(tags=tag, attributes=tag), + self.factory(tags=tag, attributes=tag, category=tag), + self.factory(tags=tag, attributes=tag, category=tag, name=tag), + self.factory(tags=tag, attributes=tag, category=tag, name=tag, description=tag), + ] + + unexpected_instances = [ + self.factory(tags="notcool"), # shouldn't catch it + self.factory(tags="azules"), + ] + + # searh for it + results = find_related_products_v3(tag) + + # assert result + self.assertTrue(len(results) == len(expected_instances)) + + + + + diff --git a/products/utils.py b/products/utils.py index 7a08a70..2b6f8e0 100644 --- a/products/utils.py +++ b/products/utils.py @@ -1,5 +1,11 @@ import logging +from django.db.models import Q +from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector, TrigramSimilarity + +from products.models import Product + + def extract_search_filters(result_set): """ @@ -49,3 +55,85 @@ def extract_search_filters(result_set): except Exception as e: logging.error(f'Extacting filters for {item}') return filter_dict + + +def find_related_products_v1(keyword): + """ + Classical approach to the search + + Using Q objects + + """ + # search in tags + tags = Product.tags.tag_model.objects.filter(name__icontains=keyword) + # search in category + categories = Product.category.tag_model.objects.filter(name__icontains=keyword) + # search in attributes + attributes = Product.attributes.tag_model.objects.filter(name__icontains=keyword) + # unified tag search + products_qs = Product.objects.filter( + Q(name__icontains=keyword)| + Q(description__icontains=keyword)| + Q(tags__in=tags)| + Q(category__in=categories)| + Q(attributes__in=attributes) + ) + return products_qs + + +def find_related_products_v5(keyword): + """ + Single query solution, using Q objects + """ + products_qs = Product.objects.filter( + Q(name__icontains=keyword)| + Q(description__icontains=keyword)| + Q(tags__label__icontains=keyword)| + Q(category__name__icontains=keyword)| + Q(attributes__label__icontains=keyword) + ) + return set(products_qs) + + +def find_related_products_v2(keyword): + """ + More advanced: using search vectors + """ + fields=('name', 'description', 'tags__label', 'attributes__label', 'category__name') + vector = SearchVector(*fields) + products_qs = Product.objects.annotate( + search=vector + ).filter(search=keyword) + return set(products_qs) + + +def find_related_products_v3(keyword): + """ + Ranked product search + + SearchVectors for the fields + SearchQuery for the value + SearchRank for relevancy scoring and ranking + """ + vector = SearchVector('name') + SearchVector('description') + SearchVector('tags__label') + SearchVector('attributes__label') + SearchVector('category__name') + query = SearchQuery(keyword) + + products_qs = Product.objects.annotate( + rank=SearchRank(vector, query) + ).filter(rank__gt=0.05) # removed order_by because its lost in casting + + return set(products_qs) + + +def find_related_products_v4(keyword): + """ + Similarity-ranked search using trigrams + Not working + """ + # fields=('name', 'description', 'tags__label', 'attributes__label', 'category__name') + + products_qs = Product.objects.annotate( + similarity=TrigramSimilarity('name', keyword), + ).order_by('-similarity') + + return set(products_qs) diff --git a/products/views.py b/products/views.py index 61a332e..a54ea68 100644 --- a/products/views.py +++ b/products/views.py @@ -7,6 +7,7 @@ from functools import reduce from django.shortcuts import render from django.conf import settings from django.db.models import Q +from django.core import serializers # Create your views here. from rest_framework import status @@ -18,12 +19,12 @@ from rest_framework.decorators import api_view, permission_classes, action import requests from products.models import Product -from products.serializers import ProductSerializer, TagFilterSerializer, ProductSearchSerializer +from products.serializers import ProductSerializer, TagFilterSerializer, SearchResultSerializer from companies.models import Company from history.models import HistorySync from back_latienda.permissions import IsCreator -from .utils import extract_search_filters +from .utils import extract_search_filters, find_related_products_v3 from utils.tag_serializers import TaggitSerializer from utils.tag_filters import ProductTagFilter @@ -144,8 +145,14 @@ def load_coop_products(request): def product_search(request): """ Takes a string of data, return relevant products + + Params: + - query_string: used for search [MANDATORY] + - limit: max number of returned instances [OPTIONAL] + - offset: where to start counting results [OPTIONAL] """ query_string = request.GET.get('query_string', None) + if query_string is None: return Response({"errors": {"details": "No query string to parse"}}) try: @@ -155,27 +162,29 @@ def product_search(request): chunks = query_string.split(' ') for chunk in chunks: - # search in tags - tags = Product.tags.tag_model.objects.filter(name__icontains=chunk) - # search in category - categories = Product.category.tag_model.objects.filter(name__icontains=chunk) - # search in attributes - attributes = Product.attributes.tag_model.objects.filter(name__icontains=chunk) - # unified tag search - products_qs = Product.objects.filter( - Q(name__icontains=chunk)| - Q(description__icontains=chunk)| - Q(tags__in=tags)| - Q(category__in=categories)| - Q(attributes__in=attributes) - ) - for instance in products_qs: - result_set.add(instance) + product_set = find_related_products_v3(chunk) + # add to result set + result_set.update(product_set) + # TODO: add search for entire phrase # extract filters from result_set filters = extract_search_filters(result_set) - # serialize and respond - product_serializer = ProductSearchSerializer(result_set, many=True, context={'request': request}) - return Response(data={"filters": filters, "products": product_serializer.data}) + # order results and respond + result_list = list(result_set) + ranked_products = sorted(result_list, key= lambda rank:rank.rank, reverse=True) + serializer = SearchResultSerializer(ranked_products, many=True) + product_results = [dict(i) for i in serializer.data] + # check for pagination + limit = request.GET.get('limit', None) + offset = request.GET.get('offset', None) + if limit is not None and offset is not None: + limit = int(limit) + offset = int(offset) + product_results = product_results[offset:(limit+offset)] + elif limit is not None: + limit = int(limit) + product_results = product_results[:limit] + + return Response(data={"filters": filters, "products": product_results}) except Exception as e: - return Response({"errors": {"details": str(type(e))}}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + return Response({"errors": {"details": str(e)}}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)