Skip to content

Commit

Permalink
use custom embedding provider in chromem
Browse files Browse the repository at this point in the history
  • Loading branch information
barneyferry committed Dec 15, 2024
1 parent 9854967 commit 9e38766
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions app/services/semdex/semdexer/chromem_semdexer/chromem.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/Southclaws/storyden/app/services/search/searcher"
"github.com/Southclaws/storyden/app/services/semdex"
"github.com/Southclaws/storyden/internal/config"
"github.com/Southclaws/storyden/internal/infrastructure/ai"
)

type chromemRefIndex struct {
Expand All @@ -25,17 +26,17 @@ type chromemRefIndex struct {
hydrator *hydrate.Hydrator
}

func New(cfg config.Config, rh *hydrate.Hydrator) (semdex.Semdexer, error) {
func New(cfg config.Config, rh *hydrate.Hydrator, aip ai.Prompter) (semdex.Semdexer, error) {
db, err := chromem.NewPersistentDB(cfg.SemdexLocalPath, false)
if err != nil {
return nil, err
}

if cfg.OpenAIKey == "" {
return nil, fault.New("OpenAI API key is required for embedded semdexer")
if _, ok := aip.(*ai.Disabled); ok {
return nil, fault.New("a language model provider must be enabled for the embedded semdexer to be enabled")
}

ef := chromem.NewEmbeddingFuncOpenAI(cfg.OpenAIKey, chromem.EmbeddingModelOpenAI3Large)
ef := aip.EmbeddingFunc()

collection, err := db.GetOrCreateCollection("semdex", nil, ef)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions app/services/semdex/semdexer/semdexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/Southclaws/storyden/app/services/semdex/semdexer/chromem_semdexer"
"github.com/Southclaws/storyden/app/services/semdex/semdexer/weaviate_semdexer"
"github.com/Southclaws/storyden/internal/config"
"github.com/Southclaws/storyden/internal/infrastructure/ai"
weaviate_infra "github.com/Southclaws/storyden/internal/infrastructure/weaviate"
)

Expand All @@ -19,18 +20,17 @@ func newSemdexer(

weaviateClassName weaviate_infra.WeaviateClassName,
hydrator *hydrate.Hydrator,
prompter ai.Prompter,
) (semdex.Semdexer, error) {
if cfg.SemdexProvider != "" && cfg.LanguageModelProvider == "" {
return nil, fault.New("semdex requires a language model provider to be enabled")
}

switch cfg.SemdexProvider {
case "chromem":

return chromem_semdexer.New(cfg, hydrator)
return chromem_semdexer.New(cfg, hydrator, prompter)

case "weaviate":

return weaviate_semdexer.New(wc, weaviateClassName, hydrator), nil

default:
Expand Down

0 comments on commit 9e38766

Please sign in to comment.