Skip to content

Commit

Permalink
🚜 Refactor event handling and print handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Mar 26, 2024
1 parent 6dc1bdf commit cbdc316
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 85 deletions.
2 changes: 1 addition & 1 deletion pkg/cmds/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (g *GeppettoCommand) RunIntoWriter(
}
}()

router.AddHandler("chat", "chat", chat.StepPrinterFunc("chat", w))
router.AddHandler("chat", "chat", chat.StepPrinterFunc("", w))

contextManager := conversation.NewManager()

Expand Down
75 changes: 75 additions & 0 deletions pkg/events/publish.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package events

import (
"encoding/json"
"fmt"
"github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/rs/zerolog/log"
"sync"
)

// NOTE(manuel, 2024-03-24) This might be worth moving / integrating into the event router
// It sounds also logical that this is the thing that would add sequence numbers to events?

// PublisherManager is used to distribute messages to a set of Publishers.
// As such, you "subscribe" a publisher to the given topic.
// When you Publish a message, it will get distributed to all publishers
// on the channel they were subscribed with.
//
// The Manager also keeps a sequence number for each outgoing message,
// in the order they are handled by Publish.
type PublisherManager struct {
Publishers map[string][]message.Publisher
sequenceNumber uint64
mutex sync.Mutex
}

func NewPublisherManager() *PublisherManager {
return &PublisherManager{
Publishers: make(map[string][]message.Publisher),
}
}

func (s *PublisherManager) SubscribePublisher(topic string, sub message.Publisher) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.Publishers[topic] = append(s.Publishers[topic], sub)
}

// Publish distributes a message to all Publishers across all topics.
// Serializing the payload to JSON is done by Publish itself.
//
// Returns an error for any processing or distribution issues.
func (s *PublisherManager) Publish(payload interface{}) error {
// lock for the sequence number hash
s.mutex.Lock()
defer s.mutex.Unlock()

b, err := json.Marshal(payload)
if err != nil {
return err
}

msg := message.NewMessage(watermill.NewUUID(), b)
msg.Metadata.Set("sequence_number", fmt.Sprintf("%d", s.sequenceNumber))
s.sequenceNumber++

for topic, subs := range s.Publishers {
for _, sub := range subs {
err = sub.Publish(topic, msg)
if err != nil {
log.Warn().Err(err).Msg("failed to publish")
}
}
}

return nil
}

func (s *PublisherManager) PublishBlind(payload interface{}) {
err := s.Publish(payload)
if err != nil {
log.Warn().Err(err).Msg("failed to publish")
}
}
49 changes: 0 additions & 49 deletions pkg/helpers/pubsub.go

This file was deleted.

7 changes: 4 additions & 3 deletions pkg/steps/ai/chat/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/go-go-golems/bobatea/pkg/chat/conversation"
"github.com/go-go-golems/geppetto/pkg/events"
"github.com/go-go-golems/geppetto/pkg/helpers"
"github.com/go-go-golems/geppetto/pkg/steps"
"github.com/pkg/errors"
Expand All @@ -15,13 +16,13 @@ type EchoStep struct {
TimePerCharacter time.Duration
cancel context.CancelFunc
eg *errgroup.Group
subscriptionManager *helpers.PublisherManager
subscriptionManager *events.PublisherManager
}

func NewEchoStep() *EchoStep {
return &EchoStep{
TimePerCharacter: 100 * time.Millisecond,
subscriptionManager: helpers.NewPublisherManager(),
subscriptionManager: events.NewPublisherManager(),
}
}

Expand All @@ -32,7 +33,7 @@ func (e *EchoStep) Interrupt() {
}

func (e *EchoStep) AddPublishedTopic(publisher message.Publisher, topic string) error {
e.subscriptionManager.AddPublishedTopic(topic, publisher)
e.subscriptionManager.SubscribePublisher(topic, publisher)
return nil
}

Expand Down
7 changes: 4 additions & 3 deletions pkg/steps/ai/claude/step.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/go-go-golems/bobatea/pkg/chat/conversation"
"github.com/go-go-golems/geppetto/pkg/events"
"github.com/go-go-golems/geppetto/pkg/helpers"
"github.com/go-go-golems/geppetto/pkg/steps"
"github.com/go-go-golems/geppetto/pkg/steps/ai/chat"
Expand All @@ -18,18 +19,18 @@ import (
type Step struct {
Settings *settings.StepSettings
cancel context.CancelFunc
subscriptionManager *helpers.PublisherManager
subscriptionManager *events.PublisherManager
}

func NewStep(settings *settings.StepSettings) *Step {
return &Step{
Settings: settings,
subscriptionManager: helpers.NewPublisherManager(),
subscriptionManager: events.NewPublisherManager(),
}
}

func (csf *Step) AddPublishedTopic(publisher message.Publisher, topic string) error {
csf.subscriptionManager.AddPublishedTopic(topic, publisher)
csf.subscriptionManager.SubscribePublisher(topic, publisher)
return nil
}

Expand Down
5 changes: 3 additions & 2 deletions pkg/steps/ai/ollama/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ollama
import (
"context"
"github.com/go-go-golems/bobatea/pkg/chat/conversation"
"github.com/go-go-golems/geppetto/pkg/events"
"github.com/go-go-golems/geppetto/pkg/helpers"
"github.com/go-go-golems/geppetto/pkg/steps"
"github.com/go-go-golems/geppetto/pkg/steps/ai/chat"
Expand All @@ -15,14 +16,14 @@ import (
type ChatCompletionStep struct {
Client *api.Client
Settings *settings.StepSettings
subscriptionManager *helpers.PublisherManager
subscriptionManager *events.PublisherManager
}

func NewChatCompletionStep(client *api.Client, settings *settings.StepSettings) *ChatCompletionStep {
return &ChatCompletionStep{
Client: client,
Settings: settings,
subscriptionManager: helpers.NewPublisherManager(),
subscriptionManager: events.NewPublisherManager(),
}
}

Expand Down
33 changes: 17 additions & 16 deletions pkg/steps/ai/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/go-go-golems/bobatea/pkg/chat/conversation"
"github.com/go-go-golems/geppetto/pkg/events"
"github.com/go-go-golems/geppetto/pkg/helpers"
"github.com/go-go-golems/geppetto/pkg/steps"
"github.com/go-go-golems/geppetto/pkg/steps/ai/chat"
Expand All @@ -16,28 +17,28 @@ import (
var _ steps.Step[conversation.Conversation, string] = &Step{}

type Step struct {
Settings *settings.StepSettings
subscriptionManager *helpers.PublisherManager
Settings *settings.StepSettings
publisherManager *events.PublisherManager
}

func (csf *Step) AddPublishedTopic(publisher message.Publisher, topic string) error {
csf.subscriptionManager.AddPublishedTopic(topic, publisher)
csf.publisherManager.SubscribePublisher(topic, publisher)
return nil
}

type StepOption func(*Step) error

func WithSubscriptionManager(subscriptionManager *helpers.PublisherManager) StepOption {
func WithSubscriptionManager(subscriptionManager *events.PublisherManager) StepOption {
return func(step *Step) error {
step.subscriptionManager = subscriptionManager
step.publisherManager = subscriptionManager
return nil
}
}

func NewStep(settings *settings.StepSettings, options ...StepOption) (*Step, error) {
ret := &Step{
Settings: settings,
subscriptionManager: helpers.NewPublisherManager(),
Settings: settings,
publisherManager: events.NewPublisherManager(),
}

for _, option := range options {
Expand Down Expand Up @@ -99,7 +100,7 @@ func (csf *Step) Start(

stream := csf.Settings.Chat.Stream

csf.subscriptionManager.PublishBlind(&chat.Event{
csf.publisherManager.PublishBlind(&chat.Event{
Type: chat.EventTypeStart,
Metadata: metadata,
Step: stepMetadata,
Expand Down Expand Up @@ -131,7 +132,7 @@ func (csf *Step) Start(
for {
select {
case <-cancellableCtx.Done():
csf.subscriptionManager.PublishBlind(&chat.EventText{
csf.publisherManager.PublishBlind(&chat.EventText{
Event: chat.Event{
Type: chat.EventTypeInterrupt,
Metadata: metadata,
Expand All @@ -146,7 +147,7 @@ func (csf *Step) Start(
response, err := stream.Recv()

if errors.Is(err, io.EOF) {
csf.subscriptionManager.PublishBlind(&chat.EventText{
csf.publisherManager.PublishBlind(&chat.EventText{
Event: chat.Event{
Type: chat.EventTypeFinal,
Metadata: metadata,
Expand All @@ -160,7 +161,7 @@ func (csf *Step) Start(
}
if err != nil {
if errors.Is(err, context.Canceled) {
csf.subscriptionManager.PublishBlind(&chat.EventText{
csf.publisherManager.PublishBlind(&chat.EventText{
Event: chat.Event{
Type: chat.EventTypeInterrupt,
Metadata: metadata,
Expand All @@ -172,7 +173,7 @@ func (csf *Step) Start(
return
}

csf.subscriptionManager.PublishBlind(&chat.Event{
csf.publisherManager.PublishBlind(&chat.Event{
Type: chat.EventTypeError,
Error: err,
Metadata: metadata,
Expand All @@ -184,7 +185,7 @@ func (csf *Step) Start(

message += response.Choices[0].Delta.Content

csf.subscriptionManager.PublishBlind(&chat.EventPartialCompletion{
csf.publisherManager.PublishBlind(&chat.EventPartialCompletion{
Event: chat.Event{
Type: chat.EventTypePartial,
Metadata: metadata,
Expand All @@ -201,7 +202,7 @@ func (csf *Step) Start(
} else {
resp, err := client.CreateChatCompletion(cancellableCtx, *req)
if errors.Is(err, context.Canceled) {
csf.subscriptionManager.PublishBlind(&chat.EventText{
csf.publisherManager.PublishBlind(&chat.EventText{
Event: chat.Event{
Type: chat.EventTypeInterrupt,
Metadata: metadata,
Expand All @@ -213,7 +214,7 @@ func (csf *Step) Start(
}

if err != nil {
csf.subscriptionManager.PublishBlind(&chat.Event{
csf.publisherManager.PublishBlind(&chat.Event{
Type: chat.EventTypeError,
Error: err,
Metadata: metadata,
Expand All @@ -222,7 +223,7 @@ func (csf *Step) Start(
return steps.Reject[string](err, steps.WithMetadata[string](stepMetadata)), nil
}

csf.subscriptionManager.PublishBlind(&chat.EventText{
csf.publisherManager.PublishBlind(&chat.EventText{
Event: chat.Event{
Type: chat.EventTypeFinal,
Metadata: metadata,
Expand Down
9 changes: 5 additions & 4 deletions pkg/steps/ai/openai/execute-tool-step.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/go-go-golems/bobatea/pkg/chat/conversation"
"github.com/go-go-golems/geppetto/pkg/events"
"github.com/go-go-golems/geppetto/pkg/helpers"
"github.com/go-go-golems/geppetto/pkg/steps"
"github.com/go-go-golems/geppetto/pkg/steps/ai/chat"
Expand All @@ -17,7 +18,7 @@ import (

type ExecuteToolStep struct {
Tools map[string]interface{}
subscriptionManager *helpers.PublisherManager
subscriptionManager *events.PublisherManager
messageID conversation.NodeID
parentID conversation.NodeID
}
Expand All @@ -26,7 +27,7 @@ var _ steps.Step[ToolCompletionResponse, map[string]interface{}] = (*ExecuteTool

type ExecuteToolStepOption func(*ExecuteToolStep) error

func WithExecuteToolStepSubscriptionManager(subscriptionManager *helpers.PublisherManager) ExecuteToolStepOption {
func WithExecuteToolStepSubscriptionManager(subscriptionManager *events.PublisherManager) ExecuteToolStepOption {
return func(step *ExecuteToolStep) error {
step.subscriptionManager = subscriptionManager
return nil
Expand All @@ -53,7 +54,7 @@ func NewExecuteToolStep(
) (*ExecuteToolStep, error) {
ret := &ExecuteToolStep{
Tools: tools,
subscriptionManager: helpers.NewPublisherManager(),
subscriptionManager: events.NewPublisherManager(),
}

for _, option := range options {
Expand All @@ -69,7 +70,7 @@ func NewExecuteToolStep(
var _ steps.Step[ToolCompletionResponse, map[string]interface{}] = (*ExecuteToolStep)(nil)

func (e *ExecuteToolStep) AddPublishedTopic(publisher message.Publisher, topic string) error {
e.subscriptionManager.AddPublishedTopic(topic, publisher)
e.subscriptionManager.SubscribePublisher(topic, publisher)
return nil
}

Expand Down
Loading

0 comments on commit cbdc316

Please sign in to comment.