From d8662fe5a28c93c9e7add9485ee759374c28d853 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 31 May 2024 16:28:26 -0400 Subject: [PATCH] [Go] add DefineAction --- go/ai/embedder.go | 2 +- go/ai/retriever.go | 6 ++---- go/ai/tools.go | 3 +-- go/core/action.go | 13 ++++++++++++- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/go/ai/embedder.go b/go/ai/embedder.go index 0aa5a0cdd..51d4138d5 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -35,5 +35,5 @@ type EmbedRequest struct { // RegisterEmbedder registers the actions for a specific embedder. func RegisterEmbedder(name string, embedder Embedder) { - core.RegisterAction(name, core.NewAction(name, core.ActionTypeEmbedder, nil, embedder.Embed)) + core.DefineAction(name, core.ActionTypeEmbedder, nil, embedder.Embed) } diff --git a/go/ai/retriever.go b/go/ai/retriever.go index 4b2def15d..c22ba01c5 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -56,12 +56,10 @@ func DefineRetriever( index func(context.Context, *IndexerRequest) error, retrieve func(context.Context, *RetrieverRequest) (*RetrieverResponse, error), ) Retriever { - ia := core.NewAction(name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) { + ia := core.DefineAction(name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) { return struct{}{}, index(ctx, req) }) - core.RegisterAction(name, ia) - ra := core.NewAction(name, core.ActionTypeRetriever, nil, retrieve) - core.RegisterAction(name, ra) + ra := core.DefineAction(name, core.ActionTypeRetriever, nil, retrieve) return &retriever{ia, ra} } diff --git a/go/ai/tools.go b/go/ai/tools.go index 7e89a89fe..30b7dfc4e 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -41,8 +41,7 @@ func RegisterTool(name string, definition *ToolDefinition, metadata map[string]a } metadata["type"] = "tool" - // TODO: There is no provider for a tool. - core.RegisterAction("tool", core.NewAction(definition.Name, core.ActionTypeTool, metadata, fn)) + core.DefineAction(definition.Name, core.ActionTypeTool, metadata, fn) } // toolActionType is the instantiated core.Action type registered diff --git a/go/core/action.go b/go/core/action.go index 10d2a886e..8201e30e6 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -64,7 +64,18 @@ type Action[In, Out, Stream any] struct { Metadata map[string]any } -// See js/common/src/types.ts +// See js/core/src/action.ts + +// DefineAction creates a new Action and registers it. +func DefineAction[In, Out any](name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { + return defineAction(globalRegistry, name, atype, metadata, fn) +} + +func defineAction[In, Out any](r *registry, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { + a := NewAction(name, atype, metadata, fn) + r.registerAction(name, a) + return a +} // NewAction creates a new Action with the given name and non-streaming function. func NewAction[In, Out any](name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {