From ca3e61c8b701f87614dc961b4702ed890bca068d Mon Sep 17 00:00:00 2001 From: kinggo Date: Sat, 3 Dec 2022 11:50:40 +0800 Subject: [PATCH] optimize: add WithDialFunc --- pkg/app/client/client_test.go | 67 +++++++++++++++++++++++++++++++++++ pkg/app/client/option.go | 35 ++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index 8c86aff30..7df82e217 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -2051,3 +2051,70 @@ func TestClientDialerName(t *testing.T) { t.Errorf("expected 'empty string', but get %s", dName) } } + +func TestClientDoWithDialFunc(t *testing.T) { + t.Parallel() + + ch := make(chan error, 1) + uri := "/foo/bar/baz" + body := "request body" + opt := config.NewOptions([]config.Option{}) + + opt.Addr = "unix-test-10021" + opt.Network = "unix" + engine := route.NewEngine(opt) + + engine.POST("/foo/bar/baz", func(c context.Context, ctx *app.RequestContext) { + if string(ctx.Request.Header.Method()) != consts.MethodPost { + ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", ctx.Request.Header.Method(), consts.MethodPost) + return + } + reqURI := ctx.Request.RequestURI() + if string(reqURI) != uri { + ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri) + return + } + cl := ctx.Request.Header.ContentLength() + if cl != len(body) { + ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body)) + return + } + reqBody := ctx.Request.Body() + if string(reqBody) != body { + ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body) + return + } + ch <- nil + }) + go engine.Run() + defer func() { + engine.Close() + }() + time.Sleep(time.Millisecond * 500) + + c, _ := NewClient(WithDialFunc(func(addr string) (network.Conn, error) { + return dialer.DialConnection(opt.Network, opt.Addr, time.Second, nil) + })) + + var req protocol.Request + req.Header.SetMethod(consts.MethodPost) + req.SetRequestURI(uri) + req.SetHost("xxx.com") + req.SetBodyString(body) + + var resp protocol.Response + + err := c.Do(context.Background(), &req, &resp) + if err != nil { + t.Fatalf("error when doing request: %s", err) + } + + select { + case err = <-ch: + if err != nil { + t.Fatalf("err = %s", err.Error()) + } + case <-time.After(5 * time.Second): + t.Fatalf("timeout") + } +} diff --git a/pkg/app/client/option.go b/pkg/app/client/option.go index 8a64ed844..2051ad817 100644 --- a/pkg/app/client/option.go +++ b/pkg/app/client/option.go @@ -23,6 +23,7 @@ import ( "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/network" + "github.com/cloudwego/hertz/pkg/network/dialer" "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol/consts" ) @@ -147,3 +148,37 @@ func WithWriteTimeout(t time.Duration) config.ClientOption { o.WriteTimeout = t }} } + +// WithDialFunc is used to set dialer function. +func WithDialFunc(f DialFunc, dialers ...network.Dialer) config.ClientOption { + return config.ClientOption{F: func(o *config.ClientOptions) { + d := dialer.DefaultDialer() + if len(dialers) != 0 { + d = dialers[0] + } + o.Dialer = newCustomDialerWithDialFunc(d, f) + }} +} + +type DialFunc func(addr string) (network.Conn, error) + +// customDialer set customDialerFunc and params to set dailFunc +type customDialer struct { + network.Dialer + dialFunc DialFunc +} + +func (m *customDialer) DialConnection(network, address string, timeout time.Duration, + tlsConfig *tls.Config) (conn network.Conn, err error) { + if m.dialFunc != nil { + return m.dialFunc(address) + } + return m.Dialer.DialConnection(network, address, timeout, tlsConfig) +} + +func newCustomDialerWithDialFunc(dialer network.Dialer, dialFunc DialFunc) network.Dialer { + return &customDialer{ + Dialer: dialer, + dialFunc: dialFunc, + } +}