Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add req param #12

Merged
merged 7 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type Client struct {
connected bool
encodingBase64 bool
lastEventID atomic.Value // []byte
body []byte
}

var defaultClient, _ = client.NewClient(client.WithDialer(standard.NewDialer()), client.WithResponseBodyStream(true))
Expand Down Expand Up @@ -170,16 +171,21 @@ func (c *Client) SetOnConnectCallback(fn ConnCallback) {
c.connectedCallback = fn
}

// SetMaxBufferSize set sse client MaxBufferSize
// SetMaxBufferSize set sse client MaxBufferSize
func (c *Client) SetMaxBufferSize(size int) {
c.maxBufferSize = size
}

// SetURL set sse client url
// SetURL set sse client url
func (c *Client) SetURL(url string) {
c.url = url
}

// SetBody set sse client request body
func (c *Client) SetBody(body []byte) {
c.body = body
}

// SetMethod set sse client request method
func (c *Client) SetMethod(method string) {
c.method = method
Expand Down Expand Up @@ -225,6 +231,11 @@ func (c *Client) GetHertzClient() *client.Client {
return c.hertzClient
}

// GetBody get sse client request body
func (c *Client) GetBody() []byte {
return c.body
}

// GetLastEventID get sse client lastEventID
func (c *Client) GetLastEventID() []byte {
return c.lastEventID.Load().([]byte)
Expand All @@ -247,6 +258,10 @@ func (c *Client) request(ctx context.Context, req *protocol.Request, resp *proto
req.Header.Set(k, v)
}

if len(c.body) != 0 {
req.SetBody(c.body)
}

err := c.hertzClient.Do(ctx, req, resp)
return err
}
Expand Down
50 changes: 50 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"testing"
"time"

"github.com/cloudwego/hertz/pkg/protocol/consts"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/common/hlog"
Expand Down Expand Up @@ -119,6 +121,28 @@ func newServer401(port string) {
h.Spin()
}

func newServerWithPOSTBody(empty bool, port string) {
h := server.Default(server.WithHostPorts("0.0.0.0:" + port))

h.POST("/sse", func(ctx context.Context, c *app.RequestContext) {
// client can tell server last event it received with Last-Event-ID header
lastEventID := GetLastEventID(c)
hlog.CtxInfof(ctx, "last event ID: %s", lastEventID)

// you must set status code and response headers before first render call
c.SetStatusCode(http.StatusOK)
s := NewStream(c)
body, err := c.Body()
if err != nil {
return
}
for a := 0; a < 10; a++ {
s.Publish(&Event{Data: body})
}
})
h.Run()
}

func publishMsgs(s *Stream, empty bool, count int) {
for a := 0; a < count; a++ {
if empty {
Expand Down Expand Up @@ -275,3 +299,29 @@ func TestTrimHeader(t *testing.T) {
assert.DeepEqual(t, tc.want, got)
}
}

func TestRequestWithBody(t *testing.T) {
go newServerWithPOSTBody(false, "9006")
time.Sleep(time.Second)
c := NewClient("http://127.0.0.1:9006/sse")
c.SetMethod(consts.MethodPost)
c.body = []byte(`{"msg":"echo"}`)
events := make(chan *Event)
var cErr error
go func() {
cErr = c.Subscribe(func(msg *Event) {
if msg.Data != nil {
events <- msg
return
}
})
}()

for i := 0; i < 5; i++ {
msg, err := wait(events, time.Second*1)
assert.Nil(t, err)
assert.DeepEqual(t, []byte(`{"msg":"echo"}`), msg)
}

assert.Nil(t, cErr)
}
Loading