diff --git a/app/services/search/service.go b/app/services/search/service.go index 5cb964f4..69f755d4 100644 --- a/app/services/search/service.go +++ b/app/services/search/service.go @@ -15,7 +15,7 @@ func New( semdexSearcher semdex.Searcher, ) searcher.Searcher { switch cfg.SemdexProvider { - case "chromem", "weaviate": + case "chromem", "weaviate", "pinecone": return semdexSearcher default: diff --git a/app/services/semdex/semdexer/pinecone_semdexer/index.go b/app/services/semdex/semdexer/pinecone_semdexer/index.go new file mode 100644 index 00000000..f278b0c7 --- /dev/null +++ b/app/services/semdex/semdexer/pinecone_semdexer/index.go @@ -0,0 +1,111 @@ +package pinecone_semdexer + +import ( + "context" + "runtime" + "sync" + + "github.com/Southclaws/fault" + "github.com/Southclaws/fault/fctx" + "github.com/rs/xid" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/Southclaws/storyden/app/resources/datagraph" + "github.com/Southclaws/storyden/internal/infrastructure/vector/pinecone" +) + +func (c *pineconeSemdexer) Index(ctx context.Context, object datagraph.Item) error { + chunks := object.GetContent().Split() + + if len(chunks) == 0 { + return fault.New("no text chunks to index", fctx.With(ctx)) + } + + numWorkers := min(runtime.NumCPU(), len(chunks)) + chunkQueue := make(chan string, len(chunks)) + errChan := make(chan error, len(chunks)) + chunkChan := make(chan *pinecone.Vector, len(chunks)) + + var wg sync.WaitGroup + + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + for chunk := range chunkQueue { + vec, err := c.ef(ctx, chunk) + if err != nil { + errChan <- err + } + + objectID := object.GetID() + + metadata, err := structpb.NewStruct(map[string]any{ + "datagraph_id": objectID.String(), + "datagraph_type": object.GetKind().String(), + "name": object.GetName(), + "content": chunk, + }) + if err != nil { + errChan <- err + } + + chunkID := generateChunkID(objectID, chunk).String() + + chunkChan <- &pinecone.Vector{ + Id: chunkID, + Values: vec, + Metadata: metadata, + } + } + }(i) + } + + go func() { + for _, chunk := range chunks { + chunkQueue <- chunk + } + close(chunkQueue) + }() + + go func() { + wg.Wait() + + close(errChan) + close(chunkChan) + }() + + for err := range errChan { + if err != nil { + return err + } + } + + var vecs []*pinecone.Vector + for vec := range chunkChan { + vecs = append(vecs, vec) + } + + _, err := c.index.UpsertVectors(ctx, vecs) + if err != nil { + return fault.Wrap(err, fctx.With(ctx)) + } + + return nil +} + +func (c *pineconeSemdexer) Delete(ctx context.Context, object xid.ID) error { + filter, err := structpb.NewStruct(map[string]any{ + "datagraph_id": object.String(), + }) + if err != nil { + return fault.Wrap(err, fctx.With(ctx)) + } + + err = c.index.DeleteVectorsByFilter(ctx, filter) + if err != nil { + return fault.Wrap(err, fctx.With(ctx)) + } + + return nil +} diff --git a/app/services/semdex/semdexer/pinecone_semdexer/object.go b/app/services/semdex/semdexer/pinecone_semdexer/object.go new file mode 100644 index 00000000..b43220dc --- /dev/null +++ b/app/services/semdex/semdexer/pinecone_semdexer/object.go @@ -0,0 +1,122 @@ +package pinecone_semdexer + +import ( + "fmt" + "net/url" + + "github.com/Southclaws/dt" + "github.com/Southclaws/fault" + "github.com/rs/xid" + + "github.com/Southclaws/storyden/app/resources/datagraph" + "github.com/Southclaws/storyden/app/services/semdex" + "github.com/Southclaws/storyden/internal/infrastructure/vector/pinecone" +) + +type Object struct { + ID xid.ID + Kind datagraph.Kind + Relevance float64 + URL url.URL + Content string +} + +type Objects []*Object + +func (o *Object) ToChunk() *semdex.Chunk { + return &semdex.Chunk{ + ID: o.ID, + Kind: o.Kind, + URL: o.URL, + Content: o.Content, + } +} + +func (o *Object) ToRef() *datagraph.Ref { + return &datagraph.Ref{ + ID: o.ID, + Kind: o.Kind, + Relevance: o.Relevance, + } +} + +func (o Objects) ToChunks() []*semdex.Chunk { + chunks := make([]*semdex.Chunk, len(o)) + for i, object := range o { + chunks[i] = object.ToChunk() + } + return chunks +} + +func (o Objects) ToRefs() datagraph.RefList { + refs := make(datagraph.RefList, len(o)) + for i, object := range o { + refs[i] = object.ToRef() + } + return refs +} + +func mapObject(v *pinecone.ScoredVector) (*Object, error) { + meta := v.Vector.Metadata.AsMap() + + idRaw, ok := meta["datagraph_id"] + if !ok { + return nil, fault.New("missing datagraph_id in metadata") + } + + typeRaw, ok := meta["datagraph_type"] + if !ok { + return nil, fault.New("missing datagraph_type in metadata") + } + + contentRaw, ok := meta["content"] + if !ok { + return nil, fault.New("missing content in metadata") + } + + // + + idString, ok := idRaw.(string) + if !ok { + return nil, fault.New("datagraph_id in metadata is not a string") + } + + typeString, ok := typeRaw.(string) + if !ok { + return nil, fault.New("datagraph_type in metadata is not a string") + } + + content, ok := contentRaw.(string) + if !ok { + return nil, fault.New("content in metadata is not a string") + } + + // + + id, err := xid.FromString(idString) + if err != nil { + return nil, fault.Wrap(err) + } + + dk, err := datagraph.NewKind(typeString) + if err != nil { + return nil, fault.Wrap(err) + } + + sdr, err := url.Parse(fmt.Sprintf("%s:%s/%s", datagraph.RefScheme, dk.String(), id.String())) + if err != nil { + return nil, err + } + + return &Object{ + ID: id, + Kind: dk, + Relevance: float64((v.Score + 1) / 2), + URL: *sdr, + Content: content, + }, nil +} + +func mapObjects(objects []*pinecone.ScoredVector) (Objects, error) { + return dt.MapErr(objects, mapObject) +} diff --git a/app/services/semdex/semdexer/pinecone_semdexer/pinecone.go b/app/services/semdex/semdexer/pinecone_semdexer/pinecone.go new file mode 100644 index 00000000..1dd30fcc --- /dev/null +++ b/app/services/semdex/semdexer/pinecone_semdexer/pinecone.go @@ -0,0 +1,67 @@ +package pinecone_semdexer + +import ( + "context" + "hash/fnv" + + "github.com/Southclaws/dt" + "github.com/Southclaws/fault" + "github.com/google/uuid" + "github.com/rs/xid" + + "github.com/Southclaws/storyden/app/resources/datagraph" + "github.com/Southclaws/storyden/app/resources/datagraph/hydrate" + "github.com/Southclaws/storyden/app/services/semdex" + "github.com/Southclaws/storyden/internal/config" + "github.com/Southclaws/storyden/internal/infrastructure/ai" + "github.com/Southclaws/storyden/internal/infrastructure/vector/pinecone" +) + +type pineconeSemdexer struct { + client *pinecone.Client + index *pinecone.Index + hydrator *hydrate.Hydrator + ef ai.Embedder +} + +func New(ctx context.Context, cfg config.Config, pc *pinecone.Client, rh *hydrate.Hydrator, aip ai.Prompter) (semdex.Semdexer, error) { + if _, ok := aip.(*ai.Disabled); ok { + return nil, fault.New("a language model provider must be enabled for the pinecone semdexer to be enabled") + } + + ef := aip.EmbeddingFunc() + + index, err := pc.GetOrCreateIndex(ctx, cfg.PineconeIndex) + if err != nil { + return nil, err + } + + return &pineconeSemdexer{ + client: pc, + index: index, + hydrator: rh, + ef: ef, + }, nil +} + +func generateChunkID(id xid.ID, chunk string) uuid.UUID { + // We don't currently support sharing chunks across content nodes, so append + // the object's ID to the chunk's hash, to ensure it's unique to the object. + payload := []byte(append(id.Bytes(), chunk...)) + + return uuid.NewHash(fnv.New128(), uuid.NameSpaceOID, payload, 4) +} + +func chunkIDsFor(id xid.ID) func(chunk string) string { + return func(chunk string) string { + return generateChunkID(id, chunk).String() + } +} + +func chunkIDsForItem(object datagraph.Item) []string { + return dt.Map(object.GetContent().Split(), chunkIDsFor(object.GetID())) +} + +func (c *pineconeSemdexer) GetMany(ctx context.Context, limit uint, ids ...xid.ID) (datagraph.RefList, error) { + return nil, nil +} diff --git a/app/services/semdex/semdexer/pinecone_semdexer/recommend.go b/app/services/semdex/semdexer/pinecone_semdexer/recommend.go new file mode 100644 index 00000000..1d9cc02b --- /dev/null +++ b/app/services/semdex/semdexer/pinecone_semdexer/recommend.go @@ -0,0 +1,106 @@ +package pinecone_semdexer + +import ( + "context" + + "github.com/Southclaws/dt" + "github.com/Southclaws/fault" + "github.com/Southclaws/fault/fctx" + "github.com/pinecone-io/go-pinecone/pinecone" + + "github.com/Southclaws/storyden/app/resources/datagraph" +) + +func (s *pineconeSemdexer) Recommend(ctx context.Context, object datagraph.Item) (datagraph.ItemList, error) { + refs, err := s.RecommendRefs(ctx, object) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + items, err := s.hydrator.Hydrate(ctx, refs...) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + return items, nil +} + +func (s *pineconeSemdexer) RecommendRefs(ctx context.Context, object datagraph.Item) (datagraph.RefList, error) { + chunkIDs := chunkIDsForItem(object) + if len(chunkIDs) == 0 { + return nil, nil + } + + response, err := s.index.FetchVectors(ctx, chunkIDs) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + chunkvecs := [][]float32{} + + for _, v := range response.Vectors { + chunkvecs = append(chunkvecs, v.Values) + } + + targetvec := averageVectors(chunkvecs) + + result, err := s.index.QueryByVectorValues(ctx, &pinecone.QueryByVectorValuesRequest{ + Vector: targetvec, + TopK: 10, + IncludeMetadata: true, + }) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + objects, err := mapObjects(result.Matches) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + results := objects.ToRefs() + + deduped := dedupeChunks(results) + + filtered := filterChunks(deduped) + + // filter out the source of the recommendations query + withoutSource := dt.Filter(filtered, func(r *datagraph.Ref) bool { + return r.ID != object.GetID() + }) + + return withoutSource, nil +} + +func averageVectors(datagraphID [][]float32) []float32 { + if len(datagraphID) == 0 { + return []float32{} + } + + // Determine the length of vectors + vectorLength := len(datagraphID[0]) + if vectorLength == 0 { + return []float32{} + } + + // Initialize a slice to store the sum of vectors + sum := make([]float32, vectorLength) + + // Sum all vectors + for _, vector := range datagraphID { + if len(vector) != vectorLength { + panic("Vectors must have the same length") + } + for i := 0; i < vectorLength; i++ { + sum[i] += vector[i] + } + } + + // Compute the average + count := float32(len(datagraphID)) + for i := 0; i < vectorLength; i++ { + sum[i] /= count + } + + return sum +} diff --git a/app/services/semdex/semdexer/pinecone_semdexer/search.go b/app/services/semdex/semdexer/pinecone_semdexer/search.go new file mode 100644 index 00000000..e1d3dd50 --- /dev/null +++ b/app/services/semdex/semdexer/pinecone_semdexer/search.go @@ -0,0 +1,143 @@ +package pinecone_semdexer + +import ( + "context" + "sort" + + "github.com/Southclaws/dt" + "github.com/Southclaws/fault" + "github.com/Southclaws/fault/fctx" + "github.com/rs/xid" + "github.com/samber/lo" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/Southclaws/storyden/app/resources/datagraph" + "github.com/Southclaws/storyden/app/resources/pagination" + "github.com/Southclaws/storyden/app/services/search/searcher" + "github.com/Southclaws/storyden/app/services/semdex" + "github.com/Southclaws/storyden/internal/infrastructure/vector/pinecone" +) + +func (s *pineconeSemdexer) Search(ctx context.Context, q string, p pagination.Parameters, opts searcher.Options) (*pagination.Result[datagraph.Item], error) { + refs, err := s.SearchRefs(ctx, q, p, opts) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + items, err := s.hydrator.Hydrate(ctx, refs.Items...) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + result := pagination.NewPageResult(p, refs.Results, items) + return &result, nil +} + +func (s *pineconeSemdexer) SearchRefs(ctx context.Context, q string, p pagination.Parameters, opts searcher.Options) (*pagination.Result[*datagraph.Ref], error) { + objects, err := s.searchObjects(ctx, q, p, opts) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + results := objects.ToRefs() + + deduped := dedupeChunks(results) + + filtered := filterChunks(deduped) + + pagedResult := pagination.NewPageResult(p, len(results), filtered) + + return &pagedResult, nil +} + +func (s *pineconeSemdexer) SearchChunks(ctx context.Context, q string, p pagination.Parameters, opts searcher.Options) ([]*semdex.Chunk, error) { + objects, err := s.searchObjects(ctx, q, p, opts) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + return objects.ToChunks(), nil +} + +func (s *pineconeSemdexer) searchObjects(ctx context.Context, q string, p pagination.Parameters, opts searcher.Options) (Objects, error) { + vec, err := s.ef(ctx, q) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + filterMap := map[string]any{} + + opts.Kinds.Call(func(kind []datagraph.Kind) { + filterMap["datagraph_type"] = map[string]any{ + "$in": dt.Map(kind, func(k datagraph.Kind) any { return k.String() }), + } + }) + + filter, err := structpb.NewStruct(filterMap) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + response, err := s.index.QueryByVectorValues(ctx, &pinecone.QueryByVectorValuesRequest{ + Vector: vec, + TopK: uint32(p.Limit()), + MetadataFilter: filter, + IncludeValues: false, + IncludeMetadata: true, + }) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + return mapObjects(response.Matches) +} + +func filterChunks(results []*datagraph.Ref) []*datagraph.Ref { + filtered := dt.Filter(results, func(r *datagraph.Ref) bool { + return r.Relevance > 0.5 + }) + + return filtered +} + +func dedupeChunks(results []*datagraph.Ref) []*datagraph.Ref { + groupedByID := lo.GroupBy(results, func(r *datagraph.Ref) xid.ID { return r.ID }) + + // for each grouped result, compute the average score and flatten + // the list of results into a single result per ID + // this is a naive approach to deduplication + + list := lo.Values(groupedByID) + + deduped := dt.Reduce(list, func(acc []*datagraph.Ref, curr []*datagraph.Ref) []*datagraph.Ref { + first := curr[0] + score := []float64{} + + for _, r := range curr { + score = append(score, r.Relevance) + } + + next := &datagraph.Ref{ + ID: first.ID, + Kind: first.Kind, + Relevance: maxFloat64(score...), + } + + return append(acc, next) + }, []*datagraph.Ref{}) + + sort.Sort(datagraph.RefList(deduped)) + + return deduped +} + +// max of all input floats +func maxFloat64(a ...float64) float64 { + max := a[0] + for _, n := range a { + if n > max { + max = n + } + } + return max +} diff --git a/app/services/semdex/semdexer/semdexer.go b/app/services/semdex/semdexer/semdexer.go index 474199a1..d27a8300 100644 --- a/app/services/semdex/semdexer/semdexer.go +++ b/app/services/semdex/semdexer/semdexer.go @@ -1,6 +1,8 @@ package semdexer import ( + "context" + "github.com/weaviate/weaviate-go-client/v4/weaviate" "go.uber.org/fx" @@ -9,15 +11,19 @@ import ( "github.com/Southclaws/storyden/app/services/semdex" "github.com/Southclaws/storyden/app/services/semdex/asker" "github.com/Southclaws/storyden/app/services/semdex/semdexer/chromem_semdexer" + "github.com/Southclaws/storyden/app/services/semdex/semdexer/pinecone_semdexer" "github.com/Southclaws/storyden/app/services/semdex/semdexer/weaviate_semdexer" "github.com/Southclaws/storyden/internal/config" "github.com/Southclaws/storyden/internal/infrastructure/ai" + "github.com/Southclaws/storyden/internal/infrastructure/vector/pinecone" weaviate_infra "github.com/Southclaws/storyden/internal/infrastructure/weaviate" ) func newSemdexer( + ctx context.Context, cfg config.Config, wc *weaviate.Client, + pc *pinecone.Client, weaviateClassName weaviate_infra.WeaviateClassName, hydrator *hydrate.Hydrator, @@ -34,6 +40,9 @@ func newSemdexer( case "weaviate": return weaviate_semdexer.New(wc, weaviateClassName, hydrator), nil + case "pinecone": + return pinecone_semdexer.New(ctx, cfg, pc, hydrator, prompter) + default: return &semdex.Disabled{}, nil } diff --git a/go.mod b/go.mod index 470779ba..3e5c90d7 100644 --- a/go.mod +++ b/go.mod @@ -135,6 +135,7 @@ require ( github.com/oklog/ulid v1.3.1 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/pinecone-io/go-pinecone v1.1.1 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/rabbitmq/amqp091-go v1.10.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect @@ -163,6 +164,7 @@ require ( go.opentelemetry.io/otel/trace v1.30.0 // indirect golang.org/x/image v0.20.0 // indirect golang.org/x/tools v0.25.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20241113202542-65e8d215514f // indirect google.golang.org/grpc v1.67.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index d4a93ada..f675c5bb 100644 --- a/go.sum +++ b/go.sum @@ -406,6 +406,8 @@ github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY= github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= +github.com/pinecone-io/go-pinecone v1.1.1 h1:pKoIiYcBIbrR7gaq0JXPiVnNEtevFYeq/AYL7T0NbbE= +github.com/pinecone-io/go-pinecone v1.1.1/go.mod h1:KfJhn4yThX293+fbtrZLnxe2PJYo8557Py062W4FYKk= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -715,8 +717,11 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20241021214115-324edc3d5d38 h1:Q3nlH8iSQSRUwOskjbcSMcF2jiYMNiQYZ0c2KEJLKKU= +google.golang.org/genproto v0.0.0-20241113202542-65e8d215514f h1:zDoHYmMzMacIdjNe+P2XiTmPsLawi/pCbSPfxt6lTfw= google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed h1:3RgNmBoI9MZhsj3QxC+AP/qQhNwpCLOvYDYYsFrhFt0= google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo= +google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 h1:M0KvPgPmDZHPlbRbaNU1APr28TvwvvdUPlSv7PUvy8g= +google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:dguCy7UOdZhTvLzDyt15+rOrawrpM4q7DD9dQ1P11P4= google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 h1:zciRKQ4kBpFgpfC5QQCVtnnNAcLIqweL7plyZRQHVpI= google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= google.golang.org/genproto/googleapis/rpc v0.0.0-20241113202542-65e8d215514f h1:C1QccEa9kUwvMgEUORqQD9S17QesQijxjZ84sO82mfo= diff --git a/internal/config/config.go b/internal/config/config.go index 56c8f174..5d9180c4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -51,11 +51,21 @@ type Config struct { LanguageModelProvider string `envconfig:"LANGUAGE_MODEL_PROVIDER"` OpenAIKey string `envconfig:"OPENAI_API_KEY"` - SemdexProvider string `envconfig:"SEMDEX_PROVIDER" default:""` + // chromem (local), weaviate, pinecone + SemdexProvider string `envconfig:"SEMDEX_PROVIDER" default:""` + + // Weaviate WeaviateURL string `envconfig:"WEAVIATE_URL"` WeaviateToken string `envconfig:"WEAVIATE_API_TOKEN"` WeaviateClassName string `envconfig:"WEAVIATE_CLASS_NAME"` + // Pinecone + PineconeAPIKey string `envconfig:"PINECONE_API_KEY"` + PineconeIndex string `envconfig:"PINECONE_INDEX"` + PineconeDimentions int32 `envconfig:"PINECONE_DIMENSIONS"` + PineconeCloud string `envconfig:"PINECONE_CLOUD"` + PineconeRegion string `envconfig:"PINECONE_REGION"` + SemdexLocalPath string `envconfig:"SEMDEX_LOCAL_PATH" default:"data/semdex"` } diff --git a/internal/infrastructure/ai/ai.go b/internal/infrastructure/ai/ai.go index d6eb94c0..d43a5922 100644 --- a/internal/infrastructure/ai/ai.go +++ b/internal/infrastructure/ai/ai.go @@ -10,6 +10,8 @@ type Result struct { Answer string } +type Embedder func(ctx context.Context, text string) ([]float32, error) + type Prompter interface { Prompt(ctx context.Context, input string) (*Result, error) PromptStream(ctx context.Context, input string) (chan string, chan error) diff --git a/internal/infrastructure/infrastructure.go b/internal/infrastructure/infrastructure.go index dd6b565d..6d89d11e 100644 --- a/internal/infrastructure/infrastructure.go +++ b/internal/infrastructure/infrastructure.go @@ -17,6 +17,7 @@ import ( "github.com/Southclaws/storyden/internal/infrastructure/pubsub/queue" "github.com/Southclaws/storyden/internal/infrastructure/rate" "github.com/Southclaws/storyden/internal/infrastructure/sms" + "github.com/Southclaws/storyden/internal/infrastructure/vector/pinecone" "github.com/Southclaws/storyden/internal/infrastructure/weaviate" "github.com/Southclaws/storyden/internal/infrastructure/webauthn" ) @@ -34,6 +35,7 @@ func Build() fx.Option { object.Build(), frontend.Build(), weaviate.Build(), + pinecone.Build(), fx.Provide(ai.New), jwt.Build(), queue.Build(), diff --git a/internal/infrastructure/vector/pinecone/pinecone.go b/internal/infrastructure/vector/pinecone/pinecone.go new file mode 100644 index 00000000..00020f96 --- /dev/null +++ b/internal/infrastructure/vector/pinecone/pinecone.go @@ -0,0 +1,99 @@ +package pinecone + +import ( + "context" + "errors" + + "github.com/Southclaws/fault" + "github.com/Southclaws/fault/fctx" + "github.com/pinecone-io/go-pinecone/pinecone" + "go.uber.org/fx" + + "github.com/Southclaws/storyden/internal/config" +) + +type Client struct { + *pinecone.Client + size int32 + cloud pinecone.Cloud + region string +} + +type Index = pinecone.IndexConnection + +type Vector = pinecone.Vector + +type Metadata = pinecone.Metadata + +type MetadataFilter = pinecone.MetadataFilter + +type QueryByVectorValuesRequest = pinecone.QueryByVectorValuesRequest + +type ScoredVector = pinecone.ScoredVector + +func Build() fx.Option { + return fx.Provide(newPinecone) +} + +func newPinecone(cfg config.Config) (*Client, error) { + c, err := pinecone.NewClient(pinecone.NewClientParams{ + ApiKey: cfg.PineconeAPIKey, + }) + if err != nil { + return nil, err + } + + return &Client{ + Client: c, + size: cfg.PineconeDimentions, + cloud: pinecone.Cloud(cfg.PineconeCloud), + region: cfg.PineconeRegion, + }, nil +} + +func (c *Client) GetOrCreateIndex(ctx context.Context, name string) (*Index, error) { + desc, err := func() (*pinecone.Index, error) { + index, err := c.DescribeIndex(ctx, name) + if err == nil { + return index, nil + } + + if !isNotFound(err) { + return nil, err + } + + index, err = c.CreateServerlessIndex(ctx, &pinecone.CreateServerlessIndexRequest{ + Name: name, + Dimension: c.size, + Metric: "cosine", + Cloud: c.cloud, + Region: c.region, + }) + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + return index, nil + }() + if err != nil { + return nil, fault.Wrap(err, fctx.With(ctx)) + } + + idxConnection, err := c.Index(pinecone.NewIndexConnParams{Host: desc.Host, Namespace: "storyden"}) + if err != nil { + return nil, err + } + + return idxConnection, nil +} + +func isNotFound(err error) bool { + pe := &pinecone.PineconeError{} + if errors.As(err, &pe) { + if pe.Code == 404 { + return true + } + } + + return false +}