Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wd0517 committed Mar 17, 2024
1 parent 5ca5ef1 commit 46f73a0
Show file tree
Hide file tree
Showing 14 changed files with 171 additions and 9 deletions.
47 changes: 47 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,50 @@ jobs:
run: tox
env:
DJANGO_VERSION: ${{ matrix.django-version }}

vector-tests:
strategy:
fail-fast: false
matrix:
python-version:
- '3.8'
- '3.9'
- '3.10'
- '3.11'
django-version:
- '4.2.5'

name: vector-py${{ matrix.python-version }}_django${{ matrix.django-version }}
runs-on: ubuntu-latest
env:
TIDB_HOST: ${{ secrets.SERVRLESS_TEST_TIDB_HOST }}
TIDB_USER: ${{ secrets.SERVRLESS_TEST_TIDB_USER }}
TIDB_PASSWORD: ${{ secrets.SERVRLESS_TEST_TIDB_PASSWORD }}
steps:
- name: Checkout
uses: actions/checkout@v3

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install tox tox-gh-actions
sudo apt-get update
sudo apt-get install -y libmemcached-dev zlib1g-dev
- name: Hack for vector tests
run: |
sed '27a cp -r ./tests/tidb_vector/ $DJANGO_TESTS_DIR/django/tests/tidb_vector/' -i django_test_suite.sh
sed '31a pip install numpy~=1.0' -i django_test_suite.sh
cat django_test_suite.sh
echo "tidb_vector" > django_test_apps.txt
cat django_test_apps.txt
- name: Run tests
run: tox
env:
DJANGO_VERSION: ${{ matrix.django-version }}
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ SECRET_KEY = 'django_tests_secret_key'

- [AUTO_RANDOM](#using-auto_random)
- [AUTO_ID_CACHE](#using-auto_id_cache)
- [VectorField (Beta)](#vectorfield-beta)
- [Vector (Beta)](#vector-beta)

### Using `AUTO_RANDOM`

Expand Down Expand Up @@ -137,7 +137,7 @@ But there are some limitations:
- `tidb_auto_id_cache` can only affect the table creation, after that it will be ignored even if you change it.
- `tidb_auto_id_cache` only affects the `AUTO_INCREMENT` column.

### VectorField (Beta)
### Vector (Beta)

Now only TiDB Cloud Serverless cluster supports vector data type, see [Integrating Vector Search into TiDB Serverless for AI Applications](https://www.pingcap.com/blog/integrating-vector-search-into-tidb-for-ai-applications/).

Expand Down
2 changes: 1 addition & 1 deletion django_test_suite.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip3 install -e .
git clone --depth 1 --branch $DJANGO_VERSION https://github.com/django/django.git $DJANGO_TESTS_DIR/django
cp tidb_settings.py $DJANGO_TESTS_DIR/django/tidb_settings.py
cp tidb_settings.py $DJANGO_TESTS_DIR/django/tests/tidb_settings.py
cp -r ./tests/ $DJANGO_TESTS_DIR/django/tests/tidb/
cp -r ./tests/tidb/ $DJANGO_TESTS_DIR/django/tests/tidb/

cd $DJANGO_TESTS_DIR/django && pip3 install -e . && pip3 install -r tests/requirements/py3.txt && pip3 install -r tests/requirements/mysql.txt; cd ../../
cd $DJANGO_TESTS_DIR/django/tests
Expand Down
25 changes: 23 additions & 2 deletions django_tidb/fields/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,27 @@ class VectorField(Field):
Status: Beta
Info: https://www.pingcap.com/blog/integrating-vector-search-into-tidb-for-ai-applications/
Example:
```python
from django.db import models
from django_tidb.fields.vector import VectorField, CosineDistance
class Document(models.Model):
content = models.TextField()
embedding = VectorField(dimensions=3)
# Create a document
Document.objects.create(
content="test content",
embedding=[1, 2, 3],
)
# Query with distance
Document.objects.alias(
distance=CosineDistance('embedding', [3, 1, 2])
).filter(distance__lt=5)
```
"""

description = "Vector"
Expand All @@ -62,8 +83,8 @@ def deconstruct(self):

def db_type(self, connection):
if self.dimensions is None:
return "vector"
return "vector(%d)" % self.dimensions
return "vector<float>"
return "vector<float>(%d)" % self.dimensions

def from_db_value(self, value, expression, connection):
return decode_vector(value)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file added tests/tidb_vector/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions tests/tidb_vector/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from django.db import models

from django_tidb.fields.vector import VectorField


class Document(models.Model):
content = models.TextField()
embedding = VectorField()


class DocumentExplicitDimension(models.Model):
content = models.TextField()
embedding = VectorField(dimensions=3)
79 changes: 79 additions & 0 deletions tests/tidb_vector/test_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
from math import sqrt
from django.db.utils import OperationalError
from django.test import TestCase
from django_tidb.fields.vector import (
CosineDistance, L1Distance, L2Distance, NegativeInnerProduct
)

from .models import Document, DocumentExplicitDimension


class TiDBVectorFieldTests(TestCase):
model = Document

def test_create_get(self):
obj = self.model.objects.create(
content="test content",
embedding=[1, 2, 3],
)
obj = self.model.objects.get(pk=obj.pk)
self.assertTrue(np.array_equal(obj.embedding, np.array([1, 2, 3])))
self.assertEqual(obj.embedding.dtype, np.float32)

def test_get_with_different_dimension(self):
self.model.objects.create(
content="test content",
embedding=[1, 2, 3],
)
with self.assertRaises(OperationalError) as cm:
list(
self.model.objects.annotate(
distance=CosineDistance('embedding', [3, 1, 2, 4])
).values_list('distance', flat=True)
)
self.assertIn('vectors have different dimensions', str(cm.exception))

def create_documents(self):
vectors = [
[1, 1, 1],
[2, 2, 2],
[1, 1, 2]
]
for i, v in enumerate(vectors):
self.model.objects.create(
content=f"{i + 1}",
embedding=v,
)

def test_l1_distance(self):
self.create_documents()
distance = L1Distance('embedding', [1, 1, 1])
docs = self.model.objects.annotate(distance=distance).order_by('distance')
self.assertEqual([d.content for d in docs], ['1', '3', '2'])
self.assertEqual([d.distance for d in docs], [0, 1, 3])

def test_l2_distance(self):
self.create_documents()
distance = L2Distance('embedding', [1, 1, 1])
docs = self.model.objects.annotate(distance=distance).order_by('distance')
self.assertEqual([d.content for d in docs], ['1', '3', '2'])
self.assertEqual([d.distance for d in docs], [0, 1, sqrt(3)])

def test_cosine_distance(self):
self.create_documents()
distance = CosineDistance('embedding', [1, 1, 1])
docs = self.model.objects.annotate(distance=distance).order_by('distance')
self.assertEqual([d.content for d in docs], ['1', '2', '3'])
self.assertEqual([d.distance for d in docs], [0, 0, 0.05719095841793653])

def test_negative_inner_product(self):
self.create_documents()
distance = NegativeInnerProduct('embedding', [1, 1, 1])
docs = self.model.objects.annotate(distance=distance).order_by('distance')
self.assertEqual([d.content for d in docs], ['2', '3', '1'])
self.assertEqual([d.distance for d in docs], [-6, -4, -3])


class TiDBVectorFieldExplicitDimensionTests(TiDBVectorFieldTests):
model = DocumentExplicitDimension
10 changes: 6 additions & 4 deletions tidb_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

hosts = os.getenv("TIDB_HOST", "127.0.0.1")
port = os.getenv("TIDB_PORT", 4000)
user = os.getenv("TIDB_USER", "root")
password = os.getenv("TIDB_PASSWORD", "")

DATABASES = {
"default": {
"ENGINE": "django_tidb",
"USER": "root",
"PASSWORD": "",
"USER": user,
"PASSWORD": password,
"HOST": hosts,
"PORT": port,
"TEST": {
Expand All @@ -31,8 +33,8 @@
},
"other": {
"ENGINE": "django_tidb",
"USER": "root",
"PASSWORD": "",
"USER": user,
"PASSWORD": password,
"HOST": hosts,
"PORT": port,
"TEST": {
Expand Down

0 comments on commit 46f73a0

Please sign in to comment.