Skip to content

Commit

Permalink
Added pgx package
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jul 23, 2024
1 parent 1d5ab1b commit f83a210
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 58 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 0.2.1 (unreleased)

- Added `pgx` package
- Added `EncodeBinary` and `DecodeBinary` methods to `Vector`
- Added `EncodeBinary` and `DecodeBinary` methods to `SparseVector`

Expand Down
60 changes: 2 additions & 58 deletions examples/loading_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ package pgvector_test

import (
"context"
"database/sql/driver"
"fmt"
"math/rand"
"testing"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/pgvector/pgvector-go"
pgxvector "github.com/pgvector/pgvector-go/pgx"
)

func TestLoading(t *testing.T) {
Expand Down Expand Up @@ -40,7 +39,7 @@ func TestLoading(t *testing.T) {
panic(err)
}

err = RegisterType(ctx, conn)
err = pgxvector.RegisterTypes(ctx, conn)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -108,58 +107,3 @@ func TestLoading(t *testing.T) {
panic(err)
}
}

type VectorCodec struct{}

func (VectorCodec) FormatSupported(format int16) bool {
return format == pgx.BinaryFormatCode
}

func (VectorCodec) PreferredFormat() int16 {
return pgx.BinaryFormatCode
}

func (VectorCodec) PlanEncode(m *pgtype.Map, oid uint32, format int16, value any) pgtype.EncodePlan {
_, ok := value.(pgvector.Vector)
if !ok {
return nil
}

if format == pgx.BinaryFormatCode {
return encodePlanVectorCodecBinary{}
}

return nil
}

type encodePlanVectorCodecBinary struct{}

func (encodePlanVectorCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
v := value.(pgvector.Vector)
return v.EncodeBinary(buf)
}

func (VectorCodec) PlanScan(m *pgtype.Map, oid uint32, format int16, target any) pgtype.ScanPlan {
return nil
}

func (VectorCodec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return nil, fmt.Errorf("Not implemented")
}

func (VectorCodec) DecodeValue(m *pgtype.Map, oid uint32, format int16, src []byte) (any, error) {
return nil, fmt.Errorf("Not implemented")
}

func RegisterType(ctx context.Context, conn *pgx.Conn) error {
name := "vector"
var oid uint32
err := conn.QueryRow(ctx, "SELECT oid FROM pg_type WHERE typname = $1", name).Scan(&oid)
if err != nil {
return err
}
codec := &VectorCodec{}
ty := &pgtype.Type{Name: name, OID: oid, Codec: codec}
conn.TypeMap().RegisterType(ty)
return nil
}
31 changes: 31 additions & 0 deletions pgx/register.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package pgx

import (
"context"
"fmt"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)

func RegisterTypes(ctx context.Context, conn *pgx.Conn) error {
var vectorOid *uint32
var sparsevecOid *uint32
err := conn.QueryRow(ctx, "SELECT to_regtype('vector')::oid, to_regtype('sparsevec')::oid").Scan(&vectorOid, &sparsevecOid)
if err != nil {
return err
}

if vectorOid == nil {
return fmt.Errorf("vector type not found in the database")
}

tm := conn.TypeMap()
tm.RegisterType(&pgtype.Type{Name: "vector", OID: *vectorOid, Codec: &VectorCodec{}})

if sparsevecOid != nil {
tm.RegisterType(&pgtype.Type{Name: "sparsevec", OID: *sparsevecOid, Codec: &SparseVectorCodec{}})
}

return nil
}
83 changes: 83 additions & 0 deletions pgx/sparsevec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package pgx

import (
"database/sql/driver"
"fmt"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/pgvector/pgvector-go"
)

type SparseVectorCodec struct{}

func (SparseVectorCodec) FormatSupported(format int16) bool {
return format == pgx.BinaryFormatCode
}

func (SparseVectorCodec) PreferredFormat() int16 {
return pgx.BinaryFormatCode
}

func (SparseVectorCodec) PlanEncode(m *pgtype.Map, oid uint32, format int16, value any) pgtype.EncodePlan {
_, ok := value.(pgvector.SparseVector)
if !ok {
return nil
}

if format == pgx.BinaryFormatCode {
return encodePlanSparseVectorCodecBinary{}
}

return nil
}

type encodePlanSparseVectorCodecBinary struct{}

func (encodePlanSparseVectorCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
v := value.(pgvector.SparseVector)
return v.EncodeBinary(buf)
}

type scanPlanSparseVectorCodecBinary struct{}

func (SparseVectorCodec) PlanScan(m *pgtype.Map, oid uint32, format int16, target any) pgtype.ScanPlan {
_, ok := target.(*pgvector.SparseVector)
if !ok {
return nil
}

if format == pgx.BinaryFormatCode {
return scanPlanSparseVectorCodecBinary{}
}

return nil
}

func (scanPlanSparseVectorCodecBinary) Scan(src []byte, dst any) error {
v := (dst).(*pgvector.SparseVector)
return v.DecodeBinary(src)
}

func (c SparseVectorCodec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return c.DecodeValue(m, oid, format, src)
}

func (c SparseVectorCodec) DecodeValue(m *pgtype.Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}

var vec pgvector.SparseVector
scanPlan := c.PlanScan(m, oid, format, &vec)
if scanPlan == nil {
return nil, fmt.Errorf("Unable to decode sparsevec type")
}

err := scanPlan.Scan(src, &vec)
if err != nil {
return nil, err
}

return vec, nil
}
83 changes: 83 additions & 0 deletions pgx/vector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package pgx

import (
"database/sql/driver"
"fmt"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/pgvector/pgvector-go"
)

type VectorCodec struct{}

func (VectorCodec) FormatSupported(format int16) bool {
return format == pgx.BinaryFormatCode
}

func (VectorCodec) PreferredFormat() int16 {
return pgx.BinaryFormatCode
}

func (VectorCodec) PlanEncode(m *pgtype.Map, oid uint32, format int16, value any) pgtype.EncodePlan {
_, ok := value.(pgvector.Vector)
if !ok {
return nil
}

if format == pgx.BinaryFormatCode {
return encodePlanVectorCodecBinary{}
}

return nil
}

type encodePlanVectorCodecBinary struct{}

func (encodePlanVectorCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
v := value.(pgvector.Vector)
return v.EncodeBinary(buf)
}

type scanPlanVectorCodecBinary struct{}

func (VectorCodec) PlanScan(m *pgtype.Map, oid uint32, format int16, target any) pgtype.ScanPlan {
_, ok := target.(*pgvector.Vector)
if !ok {
return nil
}

if format == pgx.BinaryFormatCode {
return scanPlanVectorCodecBinary{}
}

return nil
}

func (scanPlanVectorCodecBinary) Scan(src []byte, dst any) error {
v := (dst).(*pgvector.Vector)
return v.DecodeBinary(src)
}

func (c VectorCodec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return c.DecodeValue(m, oid, format, src)
}

func (c VectorCodec) DecodeValue(m *pgtype.Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}

var vec pgvector.Vector
scanPlan := c.PlanScan(m, oid, format, &vec)
if scanPlan == nil {
return nil, fmt.Errorf("Unable to decode vector type")
}

err := scanPlan.Scan(src, &vec)
if err != nil {
return nil, err
}

return vec, nil
}
6 changes: 6 additions & 0 deletions pgx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/jackc/pgx/v5"
"github.com/pgvector/pgvector-go"
pgxvector "github.com/pgvector/pgvector-go/pgx"
)

type PgxItem struct {
Expand Down Expand Up @@ -47,6 +48,11 @@ func TestPgx(t *testing.T) {
panic(err)
}

err = pgxvector.RegisterTypes(ctx, conn)
if err != nil {
panic(err)
}

_, err = conn.Exec(ctx, "DROP TABLE IF EXISTS pgx_items")
if err != nil {
panic(err)
Expand Down

0 comments on commit f83a210

Please sign in to comment.