Skip to content

Commit

Permalink
✨ 🚑 Properly stream new messages using the new conversation management
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Jan 15, 2024
1 parent 0b4ba69 commit 61b8c8c
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 137 deletions.
66 changes: 2 additions & 64 deletions cmd/experiments/tool-ui/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/go-go-golems/glazed/pkg/middlewares"
glazed_settings "github.com/go-go-golems/glazed/pkg/settings"
"github.com/invopop/jsonschema"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -223,69 +222,8 @@ func (t *ToolUiCommand) runWithUi(ctx context.Context,
_ = p

t.router.AddNoPublisherHandler("ui",
"ui", t.pubSub,
func(msg *message.Message) error {
msg.Ack()

e, err := chat.NewEventFromJson(msg.Payload)
if err != nil {
return err
}

metadata := boba_chat.StreamMetadata{
ID: e.Metadata.ID,
ParentID: e.Metadata.ParentID,
ConversationID: e.Metadata.ConversationID,
}

switch e.Type {
case chat.EventTypeError:
p.Send(boba_chat.StreamCompletionError{
Err: e.Error,
StreamMetadata: metadata,
})
case chat.EventTypePartial:
p_, ok := e.ToPartialCompletion()
if !ok {
return errors.New("payload is not of type EventPartialCompletionPayload")
}
p.Send(boba_chat.StreamCompletionMsg{
Delta: p_.Delta,
Completion: p_.Completion,
StreamMetadata: metadata,
})
case chat.EventTypeFinal:
p_, ok := e.ToText()
if !ok {
return errors.New("payload is not of type EventTextPayload")
}
p.Send(boba_chat.StreamDoneMsg{
StreamMetadata: metadata,
Completion: p_.Text,
})
case chat.EventTypeInterrupt:
p.Send(boba_chat.StreamDoneMsg{
StreamMetadata: metadata,
})
case chat.EventTypeStart:
p.Send(boba_chat.StreamStartMsg{
StreamMetadata: metadata,
})
case chat.EventTypeStatus:
p_, ok := e.ToText()
if !ok {
return errors.New("payload is not of type EventTextPayload")
}
p.Send(boba_chat.StreamStatusMsg{
Text: p_.Text,
StreamMetadata: metadata,
})
}

_ = metadata

return nil
})
"ui", t.pubSub, ui.StepChatForwardFunc(p),
)

ctx, cancel := context.WithCancel(ctx)

Expand Down
2 changes: 2 additions & 0 deletions cmd/pinocchio/prompts/examples/test-chat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ flags:
default: "you"
help: Of what am I asking?

system-prompt: |
You are a {{.pretend}}. You are {{.what}} {{.of}}.
prompt: |
Say "hello".
27 changes: 19 additions & 8 deletions pkg/steps/ai/chat/step.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ type EventPartialCompletion struct {
Completion string `json:"completion"`
}

type ToolCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}

// EventMetadata contains all the information that is passed along with watermill message,
// specific to chat steps.
type EventMetadata struct {
Expand All @@ -60,24 +65,30 @@ func NewEventFromJson(b []byte) (Event, error) {
return e, nil
}

func (e Event) ToText() (EventText, bool) {
var ret EventText
func ToTypedEvent[T any](e Event) (*T, bool) {
var ret *T
err := json.Unmarshal(e.payload, &ret)
if err != nil {
return EventText{}, false
return nil, false
}

return ret, true
}

func (e Event) ToText() (EventText, bool) {
ret, ok := ToTypedEvent[EventText](e)
if !ok || ret == nil {
return EventText{}, false
}
return *ret, true
}

func (e Event) ToPartialCompletion() (EventPartialCompletion, bool) {
var ret EventPartialCompletion
err := json.Unmarshal(e.payload, &ret)
if err != nil {
ret, ok := ToTypedEvent[EventPartialCompletion](e)
if !ok || ret == nil {
return EventPartialCompletion{}, false
}

return ret, true
return *ret, true
}

type StepOption func(Step) error
Expand Down
30 changes: 13 additions & 17 deletions pkg/steps/ai/ollama/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,44 +26,40 @@ func NewChatCompletionStep(client *api.Client, settings *settings.StepSettings)
}
}

func ConvertMessage(ollamaMsg *api.Message) *conversation.Message {
gepMsg := conversation.NewMessage(ollamaMsg.Content, ollamaMsg.Role)

return gepMsg
}

func (ccs *ChatCompletionStep) Start(
ctx context.Context,
messages []*conversation.Message,
) (steps.StepResult[string], error) {
ollamaMessages := []api.Message{}
for _, msg := range messages {
ollamaMessages = append(ollamaMessages, api.Message{
Content: msg.Text,
Role: msg.Role,
})
switch content := msg.Content.(type) {
case *conversation.ChatMessageContent:
ollamaMessages = append(ollamaMessages, api.Message{
Content: content.Text,
Role: string(content.Role),
})
}
}
var parentMessage *conversation.Message
parentID := uuid.Nil
conversationID := uuid.New()
parentID := conversation.NullNode

if len(messages) > 0 {
parentMessage = messages[len(messages)-1]
parentID = parentMessage.ID
conversationID = parentMessage.ConversationID
}

metadata := chat.EventMetadata{
ID: uuid.New(),
ParentID: parentID,
ConversationID: conversationID,
ID: conversation.NewNodeID(),
ParentID: parentID,
}
stepMetadata := &steps.StepMetadata{
StepID: uuid.New(),
Type: "openai-chat",
InputType: "conversation.Conversation",
OutputType: "string",
Metadata: ccs.Settings.GetMetadata(),
Metadata: map[string]interface{}{
steps.MetadataSettingsSlug: ccs.Settings.GetMetadata(),
},
}

stream := ccs.Settings.Chat.Stream
Expand Down
8 changes: 6 additions & 2 deletions pkg/steps/ai/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ func (csf *Step) Start(
Type: "openai-chat",
InputType: "conversation.Conversation",
OutputType: "string",
Metadata: csf.Settings.GetMetadata(),
Metadata: map[string]interface{}{
steps.MetadataSettingsSlug: csf.Settings.GetMetadata(),
},
}

stream := csf.Settings.Chat.Stream
Expand All @@ -112,7 +114,9 @@ func (csf *Step) Start(

// TODO(manuel, 2023-11-28) We need to collect this goroutine in Close(), or at least I think so?
go func() {
defer close(c)
defer func() {
close(c)
}()
defer stream.Close()

message := ""
Expand Down
27 changes: 22 additions & 5 deletions pkg/steps/ai/openai/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ func (csf *ToolStep) SetStreaming(b bool) {
csf.Settings.Chat.Stream = b
}

const MetadataToolCallsSlug = "tool-calls"

func (csf *ToolStep) Start(
ctx context.Context,
messages []*conversation.Message,
Expand Down Expand Up @@ -116,7 +118,9 @@ func (csf *ToolStep) Start(
Type: "openai-tool-completion",
InputType: "conversation.Conversation",
OutputType: "ToolCompletionResponse",
Metadata: csf.Settings.GetMetadata(),
Metadata: map[string]interface{}{
steps.MetadataSettingsSlug: csf.Settings.GetMetadata(),
},
}

csf.subscriptionManager.PublishBlind(&chat.Event{
Expand Down Expand Up @@ -168,15 +172,26 @@ func (csf *ToolStep) Start(
default:
response, err := stream_.Recv()
if errors.Is(err, io.EOF) {
csf.subscriptionManager.PublishBlind(&chat.EventText{
toolCalls := toolCallMerger.GetToolCalls()
toolCalls_ := []chat.ToolCall{}
for _, toolCall := range toolCalls {
toolCalls_ = append(toolCalls_, chat.ToolCall{
Name: toolCall.Function.Name,
Arguments: toolCall.Function.Arguments,
})
}
stepMetadata.Metadata[MetadataToolCallsSlug] = toolCalls_

msg := &chat.EventText{
Event: chat.Event{
Type: chat.EventTypeFinal,
Metadata: metadata,
Step: stepMetadata,
},
Text: message,
})
toolCalls := toolCallMerger.GetToolCalls()
}

csf.subscriptionManager.PublishBlind(msg)

ret.ToolCalls = toolCalls
ret.Content = message
Expand Down Expand Up @@ -316,6 +331,8 @@ func (e *ExecuteToolStep) AddPublishedTopic(publisher message.Publisher, topic s
return nil
}

const MetadataToolsSlug = "tools"

func (e *ExecuteToolStep) Start(
ctx context.Context,
input ToolCompletionResponse,
Expand Down Expand Up @@ -345,7 +362,7 @@ func (e *ExecuteToolStep) Start(
InputType: "ToolCompletionResponse",
OutputType: "map[string]interface{}",
Metadata: map[string]interface{}{
"tools": toolMetadata,
MetadataToolsSlug: toolMetadata,
},
}

Expand Down
2 changes: 2 additions & 0 deletions pkg/steps/step.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ type StepMetadata struct {
Metadata map[string]interface{} `json:"meta"`
}

const MetadataSettingsSlug = "settings"

type StepResultImpl[T any] struct {
value <-chan helpers.Result[T]
cancel func()
Expand Down
Loading

0 comments on commit 61b8c8c

Please sign in to comment.