diff --git a/pkg/chatgpt/chatgpt.go b/pkg/chatgpt/chatgpt.go index 34f5667..83d19d4 100644 --- a/pkg/chatgpt/chatgpt.go +++ b/pkg/chatgpt/chatgpt.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "log" @@ -143,7 +144,51 @@ func (c *Client) Chat(ctx context.Context, model string) (io.ReadWriteCloser, er chromedp.Navigate("https://chat.openai.com/?"+suffix), chromedp.WaitVisible("textarea", chromedp.ByQuery), ); err != nil { - return nil, fmt.Errorf("chatgpt: couldn't click on model selector: %w", err) + return nil, fmt.Errorf("chatgpt: couldn't navigate to url: %w", err) + } + + // Wait because there could be redirects + time.Sleep(1 * time.Second) + + // The url might have changed due to redirects + var url string + if err := chromedp.Run(tabCtx, chromedp.Location(&url)); err != nil { + return nil, fmt.Errorf("chatgpt: couldn't get url: %w", err) + } + if !strings.Contains(url, suffix) { + // Navigating to the URL didn't work, try clicking on the model selector + + // Determine which model option to select + option := 1 + if model == "gpt-4" { + option = 2 + } + + // Click on model selector + ctx, cancel := context.WithTimeout(tabCtx, 5*time.Second) + defer cancel() + if err := chromedp.Run(ctx, + chromedp.Click("button.relative.flex", chromedp.ByQuery), + ); err != nil && !errors.Is(err, context.DeadlineExceeded) { + return nil, fmt.Errorf("chatgpt: couldn't click on model selector: %w", err) + } + time.Sleep(200 * time.Millisecond) + + // Click on model option + if err := chromedp.Run(ctx, + chromedp.Click(fmt.Sprintf("ul li:nth-child(%d)", option), chromedp.ByQuery), + ); err != nil && !errors.Is(err, context.DeadlineExceeded) { + return nil, fmt.Errorf("chatgpt: couldn't click on model option: %w", err) + } + + // Test if the url is correct, if not, return an error + var url string + if err := chromedp.Run(tabCtx, chromedp.Location(&url)); err != nil { + return nil, fmt.Errorf("chatgpt: couldn't get url: %w", err) + } + if !strings.Contains(url, suffix) { + return nil, fmt.Errorf("chatgpt: couldn't click on model option %s", model) + } } rd, wr := io.Pipe() @@ -213,8 +258,45 @@ func (r *rw) Write(b []byte) (n int, err error) { unlock := r.rateLimit.Lock(r.ctx) defer unlock() - // Send the message msg := strings.TrimSpace(string(b)) + + for { + err := r.sendMessage(r.ctx, msg) + if errors.Is(err, errTooManyRequests) { + // Too many requests, wait for 5 minutes and try again + log.Println("chatgpt: too many requests, waiting for 5 minutes...") + select { + case <-time.After(5 * time.Minute): + case <-r.ctx.Done(): + return 0, r.ctx.Err() + } + // Load the page again using the conversation ID + if err := chromedp.Run(r.ctx, + chromedp.Navigate("https://chat.openai.com/c/"+r.conversationID), + chromedp.WaitVisible("textarea", chromedp.ByQuery), + ); err != nil { + return 0, fmt.Errorf("chatgpt: couldn't navigate to conversation url: %w", err) + } + continue + } + if err != nil { + return 0, err + } + break + } + go func() { + response := r.lastResponse + "\n" + if _, err := r.pipeWriter.Write([]byte(response)); err != nil { + log.Printf("chatgpt: could not write to pipe: %v", err) + } + }() + return len(b), nil +} + +var errTooManyRequests = errors.New("chatgpt: too many requests") + +func (r *rw) sendMessage(ctx context.Context, msg string) error { + // Send the message for { ctx, cancel := context.WithTimeout(r.ctx, 10*time.Second) if err := chromedp.Run(ctx, @@ -237,7 +319,7 @@ func (r *rw) Write(b []byte) (n int, err error) { if err := chromedp.Run(r.ctx, chromedp.Value("textarea", &textarea, chromedp.ByQuery), ); err != nil { - return 0, fmt.Errorf("chatgpt: couldn't obtain textarea value: %w", err) + return fmt.Errorf("chatgpt: couldn't obtain textarea value: %w", err) } if strings.TrimSpace(textarea) == strings.TrimSpace(msg) { break @@ -245,7 +327,7 @@ func (r *rw) Write(b []byte) (n int, err error) { log.Println("chatgpt: waiting for textarea to be updated...") select { case <-r.ctx.Done(): - return 0, r.ctx.Err() + return r.ctx.Err() case <-time.After(100 * time.Millisecond): } } @@ -258,77 +340,82 @@ func (r *rw) Write(b []byte) (n int, err error) { chromedp.ListenTarget( wait, func(ev interface{}) { - e, ok := ev.(*network.EventRequestWillBeSent) - if !ok { - return - } - switch e.Request.URL { - case "https://chat.openai.com/backend-api/conversation": - lck.Lock() - defer lck.Unlock() - if len(e.Request.PostDataEntries) == 0 { - return - } - v, err := base64.StdEncoding.DecodeString(e.Request.PostDataEntries[0].Bytes) - if err != nil { - return - } - var c conversation - if err := json.Unmarshal(v, &c); err != nil { - return - } - if len(c.Messages) == 0 || c.Messages[0].ID == "" { - log.Println("chatgpt: messsage id not found", string(v)) - return - } - if len(c.Messages[0].Content.Parts) == 0 { - log.Println("chatgpt: message content not found", string(v)) - return - } - convMsg := c.Messages[0].Content.Parts[0] - if strings.TrimSpace(convMsg) != strings.TrimSpace(msg) { - // Skip mismatched messages - return - } - conv = &c - case "https://chat.openai.com/backend-api/moderations": - lck.Lock() - defer lck.Unlock() - if len(e.Request.PostDataEntries) == 0 { - return - } - v, err := base64.StdEncoding.DecodeString(e.Request.PostDataEntries[0].Bytes) - if err != nil { - return - } - var m moderation - if err := json.Unmarshal(v, &m); err != nil { - return - } - if conv == nil { - log.Printf("chatgpt: moderation received before conversation: %s\n", string(v)) - return - } - if m.MessageID == conv.Messages[0].ID { + switch e := ev.(type) { + case *network.EventResponseReceived: + switch e.Response.URL { + case "https://chat.openai.com/backend-api/conversation": + if e.Response.Status == 429 { + // TODO: handle rate limit + // We should detect this and retry after a while + log.Println("chatgpt: rate limited detected") + return + } + default: return } - prefix := fmt.Sprintf("%s\n%s\n\n", r.lastResponse, conv.Messages[0].Content.Parts[0]) - if !strings.HasPrefix(m.Input, prefix) { + case *network.EventRequestWillBeSent: + switch e.Request.URL { + case "https://chat.openai.com/backend-api/conversation": + lck.Lock() + defer lck.Unlock() + if len(e.Request.PostDataEntries) == 0 { + return + } + v, err := base64.StdEncoding.DecodeString(e.Request.PostDataEntries[0].Bytes) + if err != nil { + return + } + var c conversation + if err := json.Unmarshal(v, &c); err != nil { + return + } + if len(c.Messages) == 0 || c.Messages[0].ID == "" { + log.Println("chatgpt: messsage id not found", string(v)) + return + } + if len(c.Messages[0].Content.Parts) == 0 { + log.Println("chatgpt: message content not found", string(v)) + return + } + convMsg := c.Messages[0].Content.Parts[0] + if strings.TrimSpace(convMsg) != strings.TrimSpace(msg) { + // Skip mismatched messages + return + } + conv = &c + case "https://chat.openai.com/backend-api/moderations": + lck.Lock() + defer lck.Unlock() + if len(e.Request.PostDataEntries) == 0 { + return + } + v, err := base64.StdEncoding.DecodeString(e.Request.PostDataEntries[0].Bytes) + if err != nil { + return + } + var m moderation + if err := json.Unmarshal(v, &m); err != nil { + return + } + if conv == nil { + log.Printf("chatgpt: moderation received before conversation: %s\n", string(v)) + return + } + if m.MessageID == conv.Messages[0].ID { + return + } + prefix := fmt.Sprintf("%s\n%s\n\n", r.lastResponse, conv.Messages[0].Content.Parts[0]) + if !strings.HasPrefix(m.Input, prefix) { + return + } + if r.conversationID == "" { + r.conversationID = m.ConversationID + } + r.lastResponse = strings.TrimPrefix(m.Input, prefix) + done() + default: return } - if r.conversationID == "" { - r.conversationID = m.ConversationID - } - r.lastResponse = strings.TrimPrefix(m.Input, prefix) - go func() { - response := r.lastResponse + "\n" - if _, err := r.pipeWriter.Write([]byte(response)); err != nil { - log.Printf("chatgpt: could not write to pipe: %v", err) - } - }() - done() - default: - return } }, ) @@ -341,15 +428,15 @@ func (r *rw) Write(b []byte) (n int, err error) { chromedp.Click("textarea", chromedp.ByQuery), chromedp.Click("textarea + button", chromedp.ByQuery), ); err != nil { - return 0, fmt.Errorf("chatgpt: couldn't click button: %w", err) + return fmt.Errorf("chatgpt: couldn't click button: %w", err) } // Wait for the response select { case <-wait.Done(): - return len(b), nil + return nil case <-r.ctx.Done(): - return 0, r.ctx.Err() + return r.ctx.Err() } }