diff --git a/edgedb/protocol/codecs/codecs.pyx b/edgedb/protocol/codecs/codecs.pyx index d0cfa7b5..da40e283 100644 --- a/edgedb/protocol/codecs/codecs.pyx +++ b/edgedb/protocol/codecs/codecs.pyx @@ -17,6 +17,7 @@ # +import array import decimal import uuid import datetime @@ -24,6 +25,8 @@ from edgedb import describe from edgedb import enums from edgedb.datatypes import datatypes +from libc.string cimport memcpy + include "./edb_types.pxi" @@ -347,14 +350,16 @@ cdef dict BASE_SCALAR_CODECS = {} cdef register_base_scalar_codec( str name, pgproto.encode_func encoder, - pgproto.decode_func decoder): + pgproto.decode_func decoder, + object tid = None): cdef: BaseCodec codec - tid = TYPE_IDS.get(name) if tid is None: - raise RuntimeError(f'cannot find known ID for type {name!r}') + tid = TYPE_IDS.get(name) + if tid is None: + raise RuntimeError(f'cannot find known ID for type {name!r}') tid = tid.bytes if tid in BASE_SCALAR_CODECS: @@ -510,6 +515,94 @@ cdef config_memory_decode(pgproto.CodecContext settings, FRBuffer *buf): return datatypes.ConfigMemory(bytes=bytes) +DEF PGVECTOR_MAX_DIM = (1 << 16) - 1 + + +cdef pgvector_encode_memview(pgproto.CodecContext settings, WriteBuffer buf, + float[:] obj): + cdef: + float item + Py_ssize_t objlen + Py_ssize_t i + + objlen = len(obj) + if objlen > PGVECTOR_MAX_DIM: + raise ValueError('too many elements in vector value') + + buf.write_int32(4 + objlen*4) + buf.write_int16(objlen) + buf.write_int16(0) + for i in range(objlen): + buf.write_float(obj[i]) + + +cdef pgvector_encode(pgproto.CodecContext settings, WriteBuffer buf, + object obj): + cdef: + float item + Py_ssize_t objlen + float[:] memview + Py_ssize_t i + + # If we can take a typed memview of the object, we use that. + # That is good, because it means we can consume array.array and + # numpy.ndarray without needing to unbox. + # Otherwise we take the slow path, indexing into the array using + # the normal protocol. + try: + memview = obj + except (ValueError, TypeError) as e: + pass + else: + pgvector_encode_memview(settings, buf, memview) + return + + if not _is_array_iterable(obj): + raise TypeError( + 'a sized iterable container expected (got type {!r})'.format( + type(obj).__name__)) + + # Annoyingly, this is literally identical code to the fast path... + # but the types are different in critical ways. + objlen = len(obj) + if objlen > PGVECTOR_MAX_DIM: + raise ValueError('too many elements in vector value') + + buf.write_int32(4 + objlen*4) + buf.write_int16(objlen) + buf.write_int16(0) + for i in range(objlen): + buf.write_float(obj[i]) + + +cdef object ONE_EL_ARRAY = array.array('f', [0.0]) + + +cdef pgvector_decode(pgproto.CodecContext settings, FRBuffer *buf): + cdef: + int32_t dim + Py_ssize_t size + Py_buffer view + char *p + float[:] array_view + + dim = hton.unpack_uint16(frb_read(buf, 2)) + frb_read(buf, 2) + + size = dim * 4 + p = frb_read(buf, size) + + # Create a float array with size dim + val = ONE_EL_ARRAY * dim + + # And fill it with the buffer contents + array_view = val + memcpy(&array_view[0], p, size) + val.byteswap() + + return val + + cdef checked_decimal_encode( pgproto.CodecContext settings, WriteBuffer buf, obj ): @@ -712,4 +805,12 @@ cdef register_base_scalar_codecs(): pgproto.text_encode, pgproto.text_decode) + register_base_scalar_codec( + 'ext::pgvector::vector', + pgvector_encode, + pgvector_decode, + uuid.UUID('9565dd88-04f5-11ee-a691-0b6ebe179825'), + ) + + register_base_scalar_codecs() diff --git a/tests/test_vector.py b/tests/test_vector.py new file mode 100644 index 00000000..a96119d3 --- /dev/null +++ b/tests/test_vector.py @@ -0,0 +1,141 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from edgedb import _testbase as tb +import edgedb + +import array + + +# An array.array subtype where indexing doesn't work. +# We use this to verify that the non-boxing memoryview based +# fast path works, since the slow path won't work on this object. +class brokenarray(array.array): + def __getitem__(self, i): + raise AssertionError("the fast path wasn't used!") + + +class TestVector(tb.SyncQueryTestCase): + def setUp(self): + super().setUp() + + if not self.client.query_required_single(''' + select exists ( + select sys::ExtensionPackage filter .name = 'vector' + ) + '''): + self.skipTest("feature not implemented") + + self.client.execute(''' + create extension vector version '1.0' + ''') + + def tearDown(self): + try: + self.client.execute(''' + drop extension vector version '1.0' + ''') + finally: + super().tearDown() + + async def test_vector_01(self): + # if not self.client.query_required_single(''' + # select exists ( + # select sys::ExtensionPackage filter .name = 'vector' + # ) + # '''): + # self.skipTest("feature not implemented") + + # self.client.execute(''' + # create extension vector version '1.0' + # ''') + + val = self.client.query_single(''' + select '[1.5,2.0,3.8]' + ''') + self.assertTrue(isinstance(val, array.array)) + self.assertEqual(val, array.array('f', [1.5, 2.0, 3.8])) + + val = self.client.query_single( + ''' + select $0 + ''', + [3.0, 9.0, -42.5], + ) + self.assertEqual(val, '[3,9,-42.5]') + + val = self.client.query_single( + ''' + select $0 + ''', + array.array('f', [3.0, 9.0, -42.5]) + ) + self.assertEqual(val, '[3,9,-42.5]') + + val = self.client.query_single( + ''' + select $0 + ''', + array.array('i', [1, 2, 3]), + ) + self.assertEqual(val, '[1,2,3]') + + # Test that the fast-path works: if the encoder tries to + # call __getitem__ on this brokenarray, it will fail. + val = self.client.query_single( + ''' + select $0 + ''', + brokenarray('f', [3.0, 9.0, -42.5]) + ) + self.assertEqual(val, '[3,9,-42.5]') + + # I don't think it's worth adding a dependency to test this, + # but this works too: + # import numpy as np + # val = self.client.query_single( + # ''' + # select $0 + # ''', + # np.asarray([3.0, 9.0, -42.5], dtype=np.float32), + # ) + + # Some sad path tests + with self.assertRaises(edgedb.InvalidArgumentError): + self.client.query_single( + ''' + select $0 + ''', + [3.0, None, -42.5], + ) + + with self.assertRaises(edgedb.InvalidArgumentError): + self.client.query_single( + ''' + select $0 + ''', + [3.0, 'x', -42.5], + ) + + with self.assertRaises(edgedb.InvalidArgumentError): + self.client.query_single( + ''' + select $0 + ''', + 'foo', + )