This commit is contained in:
Sam
2021-02-19 12:34:15 +00:00
7 changed files with 212 additions and 28 deletions

View File

@@ -66,7 +66,7 @@ By default a `LimitOffsetPagination` pagination is enabled
Examples: `http://127.0.0.1:8000/api/v1/products/?limit=10&offset=0` Examples: `http://127.0.0.1:8000/api/v1/products/?limit=10&offset=0`
The response data has the following keyspayload: The response data has the following keys:
``` ```
dict_keys(['count', 'next', 'previous', 'results']) dict_keys(['count', 'next', 'previous', 'results'])
``` ```
@@ -294,4 +294,6 @@ To create a dataset of fake companies and products:
`python manage.py addtestdata` `python manage.py addtestdata`
Creates 10 Companies, with 100 products each. Creates 10 Companies, with 10 products each.
WARNING: the script deletes existing instances of both Company and Product

View File

@@ -42,6 +42,7 @@ INSTALLED_APPS = [
'django.contrib.messages', 'django.contrib.messages',
'django.contrib.staticfiles', 'django.contrib.staticfiles',
'django.contrib.gis', 'django.contrib.gis',
'django.contrib.postgres',
# 3rd party # 3rd party
'rest_framework', 'rest_framework',

View File

@@ -70,7 +70,10 @@ class Command(BaseCommand):
# TODO: apply automatic tags from tag list # TODO: apply automatic tags from tag list
# TODO: write image to S3 storage # TODO: write image to S3 storage
# create instance # create instance
product = ProductFactory(name=name, description=description) product = ProductFactory(
company=company,
name=name,
description=description)
# get image # get image
response = requests.get(self.logo_url, stream=True) response = requests.get(self.logo_url, stream=True)

View File

@@ -18,12 +18,13 @@ class ProductSerializer(TaggitSerializer, serializers.ModelSerializer):
exclude = ['created', 'updated', 'creator'] exclude = ['created', 'updated', 'creator']
class ProductSearchSerializer(TaggitSerializer, serializers.ModelSerializer): class SearchResultSerializer(TaggitSerializer, serializers.ModelSerializer):
tags = TagListSerializerField(required=False) tags = TagListSerializerField(required=False)
category = SingleTagSerializerField(required=False) # main tag category category = SingleTagSerializerField(required=False) # main tag category
attributes = TagListSerializerField(required=False) attributes = TagListSerializerField(required=False)
company = CompanySerializer(read_only=True) company = CompanySerializer(read_only=True)
rank = serializers.FloatField()
class Meta: class Meta:
model = Product model = Product

View File

@@ -13,6 +13,7 @@ from rest_framework import status
from companies.factories import CompanyFactory from companies.factories import CompanyFactory
from products.factories import ProductFactory from products.factories import ProductFactory
from products.models import Product from products.models import Product
from products.utils import find_related_products_v3
from core.factories import CustomUserFactory from core.factories import CustomUserFactory
from core.utils import get_tokens_for_user from core.utils import get_tokens_for_user
@@ -461,7 +462,7 @@ class ProductSearchTest(TestCase):
def test_anon_user_can_search(self): def test_anon_user_can_search(self):
expected_instances = [ expected_instances = [
self.factory(tags="lunares/blancos",description="zapatos verdes"), self.factory(tags="lunares/rojos", category='zapatos', description="zapatos verdes"),
self.factory(tags="colores/rojos, tono/brillante"), self.factory(tags="colores/rojos, tono/brillante"),
self.factory(tags="lunares/azules", description="zapatos rojos"), self.factory(tags="lunares/azules", description="zapatos rojos"),
self.factory(tags="lunares/rojos", description="zapatos"), self.factory(tags="lunares/rojos", description="zapatos"),
@@ -477,15 +478,51 @@ class ProductSearchTest(TestCase):
url = f"{self.endpoint}?query_string={query_string}" url = f"{self.endpoint}?query_string={query_string}"
# send in request # send in request
response = self.client.get(url) response = self.client.get(url)
payload = response.json()
# check response # check response
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
# load response data
payload = response.json()
# check for object creation # check for object creation
self.assertEquals(len(payload['products']), len(expected_instances)) self.assertEquals(len(payload['products']), len(expected_instances))
# check ids
for i in range(len(payload['products'])):
self.assertTrue(payload['products'][i]['id'] == expected_instances[i].id)
# check results ordered by rank
current = 1
for i in range(len(payload['products'])):
self.assertTrue(payload['products'][i]['rank'] <= current )
current = payload['products'][i]['rank']
# check for filters # check for filters
self.assertNotEquals([], payload['filters']['singles']) self.assertNotEquals([], payload['filters']['singles'])
self.assertTrue(len(payload['filters']) >= 2 ) self.assertTrue(len(payload['filters']) >= 2 )
def test_anon_user_can_paginate_search(self):
expected_instances = [
self.factory(tags="lunares/rojos", category='zapatos', description="zapatos verdes"),
self.factory(tags="colores/rojos, tono/brillante"),
self.factory(tags="lunares/azules", description="zapatos rojos"),
self.factory(tags="lunares/rojos", description="zapatos"),
self.factory(attributes='"zapatos de campo", tono/oscuro'),
]
unexpected_instances = [
self.factory(description="chanclas"),
self.factory(tags="azules"),
]
query_string = quote("zapatos rojos")
limit = 2
url = f"{self.endpoint}?query_string={query_string}&limit=2"
# send in request
response = self.client.get(url)
# check response
self.assertEqual(response.status_code, 200)
# load response data
payload = response.json()
# check for object creation
self.assertEquals(len(payload['products']), limit)
class MyProductsViewTest(APITestCase): class MyProductsViewTest(APITestCase):
"""my_products tests """my_products tests
@@ -529,3 +566,46 @@ class MyProductsViewTest(APITestCase):
# check response # check response
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
class FindRelatedProductsTest(APITestCase):
def setUp(self):
"""Tests setup
"""
self.factory = ProductFactory
self.model = Product
# clear table
self.model.objects.all().delete()
def test_v3_find_by_tags(self):
# create tagged product
tag = 'cool'
expected_instances = [
self.factory(tags=tag),
self.factory(tags=f'{tag} hat'),
self.factory(tags=f'temperatures/{tag}'),
self.factory(tags=f'temperatures/{tag}, body/hot'),
self.factory(tags=f'temperatures/{tag}, hats/{tag}'),
# multiple hits
self.factory(tags=tag, attributes=tag),
self.factory(tags=tag, attributes=tag, category=tag),
self.factory(tags=tag, attributes=tag, category=tag, name=tag),
self.factory(tags=tag, attributes=tag, category=tag, name=tag, description=tag),
]
unexpected_instances = [
self.factory(tags="notcool"), # shouldn't catch it
self.factory(tags="azules"),
]
# searh for it
results = find_related_products_v3(tag)
# assert result
self.assertTrue(len(results) == len(expected_instances))

View File

@@ -1,5 +1,11 @@
import logging import logging
from django.db.models import Q
from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector, TrigramSimilarity
from products.models import Product
def extract_search_filters(result_set): def extract_search_filters(result_set):
""" """
@@ -49,3 +55,85 @@ def extract_search_filters(result_set):
except Exception as e: except Exception as e:
logging.error(f'Extacting filters for {item}') logging.error(f'Extacting filters for {item}')
return filter_dict return filter_dict
def find_related_products_v1(keyword):
"""
Classical approach to the search
Using Q objects
"""
# search in tags
tags = Product.tags.tag_model.objects.filter(name__icontains=keyword)
# search in category
categories = Product.category.tag_model.objects.filter(name__icontains=keyword)
# search in attributes
attributes = Product.attributes.tag_model.objects.filter(name__icontains=keyword)
# unified tag search
products_qs = Product.objects.filter(
Q(name__icontains=keyword)|
Q(description__icontains=keyword)|
Q(tags__in=tags)|
Q(category__in=categories)|
Q(attributes__in=attributes)
)
return products_qs
def find_related_products_v5(keyword):
"""
Single query solution, using Q objects
"""
products_qs = Product.objects.filter(
Q(name__icontains=keyword)|
Q(description__icontains=keyword)|
Q(tags__label__icontains=keyword)|
Q(category__name__icontains=keyword)|
Q(attributes__label__icontains=keyword)
)
return set(products_qs)
def find_related_products_v2(keyword):
"""
More advanced: using search vectors
"""
fields=('name', 'description', 'tags__label', 'attributes__label', 'category__name')
vector = SearchVector(*fields)
products_qs = Product.objects.annotate(
search=vector
).filter(search=keyword)
return set(products_qs)
def find_related_products_v3(keyword):
"""
Ranked product search
SearchVectors for the fields
SearchQuery for the value
SearchRank for relevancy scoring and ranking
"""
vector = SearchVector('name') + SearchVector('description') + SearchVector('tags__label') + SearchVector('attributes__label') + SearchVector('category__name')
query = SearchQuery(keyword)
products_qs = Product.objects.annotate(
rank=SearchRank(vector, query)
).filter(rank__gt=0.05) # removed order_by because its lost in casting
return set(products_qs)
def find_related_products_v4(keyword):
"""
Similarity-ranked search using trigrams
Not working
"""
# fields=('name', 'description', 'tags__label', 'attributes__label', 'category__name')
products_qs = Product.objects.annotate(
similarity=TrigramSimilarity('name', keyword),
).order_by('-similarity')
return set(products_qs)

View File

@@ -7,6 +7,7 @@ from functools import reduce
from django.shortcuts import render from django.shortcuts import render
from django.conf import settings from django.conf import settings
from django.db.models import Q from django.db.models import Q
from django.core import serializers
# Create your views here. # Create your views here.
from rest_framework import status from rest_framework import status
@@ -18,12 +19,12 @@ from rest_framework.decorators import api_view, permission_classes, action
import requests import requests
from products.models import Product from products.models import Product
from products.serializers import ProductSerializer, TagFilterSerializer, ProductSearchSerializer from products.serializers import ProductSerializer, TagFilterSerializer, SearchResultSerializer
from companies.models import Company from companies.models import Company
from history.models import HistorySync from history.models import HistorySync
from back_latienda.permissions import IsCreator from back_latienda.permissions import IsCreator
from .utils import extract_search_filters from .utils import extract_search_filters, find_related_products_v3
from utils.tag_serializers import TaggitSerializer from utils.tag_serializers import TaggitSerializer
from utils.tag_filters import ProductTagFilter from utils.tag_filters import ProductTagFilter
@@ -144,8 +145,14 @@ def load_coop_products(request):
def product_search(request): def product_search(request):
""" """
Takes a string of data, return relevant products Takes a string of data, return relevant products
Params:
- query_string: used for search [MANDATORY]
- limit: max number of returned instances [OPTIONAL]
- offset: where to start counting results [OPTIONAL]
""" """
query_string = request.GET.get('query_string', None) query_string = request.GET.get('query_string', None)
if query_string is None: if query_string is None:
return Response({"errors": {"details": "No query string to parse"}}) return Response({"errors": {"details": "No query string to parse"}})
try: try:
@@ -155,27 +162,29 @@ def product_search(request):
chunks = query_string.split(' ') chunks = query_string.split(' ')
for chunk in chunks: for chunk in chunks:
# search in tags product_set = find_related_products_v3(chunk)
tags = Product.tags.tag_model.objects.filter(name__icontains=chunk) # add to result set
# search in category result_set.update(product_set)
categories = Product.category.tag_model.objects.filter(name__icontains=chunk) # TODO: add search for entire phrase
# search in attributes
attributes = Product.attributes.tag_model.objects.filter(name__icontains=chunk)
# unified tag search
products_qs = Product.objects.filter(
Q(name__icontains=chunk)|
Q(description__icontains=chunk)|
Q(tags__in=tags)|
Q(category__in=categories)|
Q(attributes__in=attributes)
)
for instance in products_qs:
result_set.add(instance)
# extract filters from result_set # extract filters from result_set
filters = extract_search_filters(result_set) filters = extract_search_filters(result_set)
# serialize and respond # order results and respond
product_serializer = ProductSearchSerializer(result_set, many=True, context={'request': request}) result_list = list(result_set)
return Response(data={"filters": filters, "products": product_serializer.data}) ranked_products = sorted(result_list, key= lambda rank:rank.rank, reverse=True)
serializer = SearchResultSerializer(ranked_products, many=True)
product_results = [dict(i) for i in serializer.data]
# check for pagination
limit = request.GET.get('limit', None)
offset = request.GET.get('offset', None)
if limit is not None and offset is not None:
limit = int(limit)
offset = int(offset)
product_results = product_results[offset:(limit+offset)]
elif limit is not None:
limit = int(limit)
product_results = product_results[:limit]
return Response(data={"filters": filters, "products": product_results})
except Exception as e: except Exception as e:
return Response({"errors": {"details": str(type(e))}}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response({"errors": {"details": str(e)}}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)