Skip to content

Commit

Permalink
Merge pull request #9 from pavel-one/websocket-fix
Browse files Browse the repository at this point in the history
Websocket fix
  • Loading branch information
pavel-one authored May 2, 2023
2 parents 30fffcb + fa0ae13 commit 4c3ce6d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 49 deletions.
28 changes: 1 addition & 27 deletions gpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package EdgeGPT

import (
"encoding/json"
"errors"
"fmt"
"github.com/gorilla/websocket"
"github.com/pavel-one/EdgeGPT-Go/config"
"github.com/pavel-one/EdgeGPT-Go/internal/CookieManager"
"github.com/pavel-one/EdgeGPT-Go/internal/Helpers"
Expand Down Expand Up @@ -54,7 +52,7 @@ func NewGPT(conf *config.GPT) (*GPT, error) {
return nil, err
}

hub, err := gpt.createHub()
hub, err := NewHub(gpt.Conversation, conf)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -163,27 +161,3 @@ func (g *GPT) AskSync(style, message string) (*responses.MessageWrapper, error)
log.Infoln("New ask:", message)
return m, nil
}

// createHub create websocket hub
func (g *GPT) createHub() (*Hub, error) {
if g.Conversation == nil {
return nil, errors.New("not set conversation")
}

conn, _, err := websocket.DefaultDialer.Dial(g.Config.WssUrl.String(), Helpers.GetHeaders(g.Config.Headers))
if err != nil {
return nil, err
}

h := &Hub{
conversation: g.Conversation,
conn: conn,
}

if err := h.initialHandshake(); err != nil {
return nil, err
}

log.Infoln("New hub for conversation:", g.Conversation.ConversationId)
return h, nil
}
107 changes: 85 additions & 22 deletions hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,62 +2,125 @@ package EdgeGPT

import (
"encoding/json"
"errors"
"github.com/gorilla/websocket"
"github.com/pavel-one/EdgeGPT-Go/config"
"github.com/pavel-one/EdgeGPT-Go/internal/Helpers"
"github.com/pavel-one/EdgeGPT-Go/responses"
"net/url"
"sync"
)

type Hub struct {
conversation *Conversation
conn *websocket.Conn
wssUrl *url.URL
headers map[string]string
InvocationId int
mu sync.Mutex
}

// initialHandshake request for initial session
func (c *Hub) initialHandshake() error {
message := []byte("{\"protocol\": \"json\", \"version\": 1}" + Delimiter)
func NewHub(conversation *Conversation, config *config.GPT) (*Hub, error) {
if conversation == nil {
return nil, errors.New("not set conversation")
}

h := &Hub{
conversation: conversation,
conn: nil,
wssUrl: config.WssUrl,
headers: config.Headers,
}

conn, err := h.NewConnect()
if err != nil {
return nil, err
}
h.conn = conn

log.Infoln("New hub for conversation:", conversation.ConversationId)

return h, nil
}

// NewConnect create new websocket connection
func (h *Hub) NewConnect() (*websocket.Conn, error) {
conn, _, err := websocket.DefaultDialer.Dial(h.wssUrl.String(), Helpers.GetHeaders(h.headers))
if err != nil {
return nil, err
}

if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
return err
message := []byte("{\"protocol\": \"json\", \"version\": 1}" + Delimiter)
if err := conn.WriteMessage(websocket.TextMessage, message); err != nil {
return nil, err
}
if _, _, err := conn.ReadMessage(); err != nil { //wait initial
return nil, err
}

if _, _, err := c.conn.ReadMessage(); err != nil { //wait initial
return err
return conn, nil
}

// CheckAndReconnect check active connection and reconnect
func (h *Hub) CheckAndReconnect() error {
if h.conn == nil {
return errors.New("not set connection")
}

c.InvocationId = 0
if err := h.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
log.Infoln("Reconnection")

h.Close()
h.conn = nil

conn, err := h.NewConnect()
if err != nil {
return err
}
h.conn = conn
}

return nil
}

// send new message to websocket
func (c *Hub) send(style, message string) (*responses.MessageWrapper, error) {
c.mu.Lock()
func (h *Hub) send(style, message string) (*responses.MessageWrapper, error) {
if h.conn == nil {
return nil, errors.New("not set connection")
}
h.mu.Lock()

m, err := json.Marshal(c.getRequest(style, message))
if err := h.CheckAndReconnect(); err != nil {
return nil, err
}

m, err := json.Marshal(h.getRequest(style, message))
if err != nil {
return nil, err
}

m = append(m, DelimiterByte)

if err := c.conn.WriteMessage(websocket.TextMessage, m); err != nil {
if err := h.conn.WriteMessage(websocket.TextMessage, m); err != nil {
return nil, err
}

return responses.NewMessageWrapper(message, &c.mu, c.conn), nil
return responses.NewMessageWrapper(message, &h.mu, h.conn), nil
}

// Close hub and connection
// TODO: Use this!
func (c *Hub) Close() {
c.conn.Close()
func (h *Hub) Close() {
if h.conn == nil {
return
}

log.Infoln("Close connection")
h.conn.Close()
}

// getRequest generate struct for new request websocket
func (c *Hub) getRequest(style, message string) map[string]any {
func (h *Hub) getRequest(style, message string) map[string]any {
switch style {
case "creative":
style = StyleCreative
Expand All @@ -82,7 +145,7 @@ func (c *Hub) getRequest(style, message string) map[string]any {
}

m := map[string]any{
"invocationId": string(rune(c.InvocationId)),
"invocationId": string(rune(h.InvocationId)),
"target": "chat",
"type": 4,
"arguments": []map[string]any{
Expand All @@ -106,22 +169,22 @@ func (c *Hub) getRequest(style, message string) map[string]any {
"224locals0",
},
"traceId": Helpers.RandomHex(32),
"isStartOfSession": c.InvocationId == 0,
"isStartOfSession": h.InvocationId == 0,
"message": map[string]any{
"author": "user",
"inputMethod": "Keyboard",
"text": message,
"messageType": "Chat",
},
"conversationSignature": c.conversation.ConversationSignature,
"conversationSignature": h.conversation.ConversationSignature,
"participant": map[string]any{
"id": c.conversation.ClientId,
"id": h.conversation.ClientId,
},
"conversationId": c.conversation.ConversationId,
"conversationId": h.conversation.ConversationId,
},
},
}
c.InvocationId++
h.InvocationId++

return m
}

0 comments on commit 4c3ce6d

Please sign in to comment.