Skip to content

Commit

Permalink
optimize: add WithDialFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
li-jin-gou committed Dec 8, 2022
1 parent a12c7a0 commit ca3e61c
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
67 changes: 67 additions & 0 deletions pkg/app/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2051,3 +2051,70 @@ func TestClientDialerName(t *testing.T) {
t.Errorf("expected 'empty string', but get %s", dName)
}
}

func TestClientDoWithDialFunc(t *testing.T) {
t.Parallel()

ch := make(chan error, 1)
uri := "/foo/bar/baz"
body := "request body"
opt := config.NewOptions([]config.Option{})

opt.Addr = "unix-test-10021"
opt.Network = "unix"
engine := route.NewEngine(opt)

engine.POST("/foo/bar/baz", func(c context.Context, ctx *app.RequestContext) {
if string(ctx.Request.Header.Method()) != consts.MethodPost {
ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", ctx.Request.Header.Method(), consts.MethodPost)
return
}
reqURI := ctx.Request.RequestURI()
if string(reqURI) != uri {
ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri)
return
}
cl := ctx.Request.Header.ContentLength()
if cl != len(body) {
ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body))
return
}
reqBody := ctx.Request.Body()
if string(reqBody) != body {
ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body)
return
}
ch <- nil
})
go engine.Run()
defer func() {
engine.Close()
}()
time.Sleep(time.Millisecond * 500)

c, _ := NewClient(WithDialFunc(func(addr string) (network.Conn, error) {
return dialer.DialConnection(opt.Network, opt.Addr, time.Second, nil)
}))

var req protocol.Request
req.Header.SetMethod(consts.MethodPost)
req.SetRequestURI(uri)
req.SetHost("xxx.com")
req.SetBodyString(body)

var resp protocol.Response

err := c.Do(context.Background(), &req, &resp)
if err != nil {
t.Fatalf("error when doing request: %s", err)
}

select {
case err = <-ch:
if err != nil {
t.Fatalf("err = %s", err.Error())
}
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
}
35 changes: 35 additions & 0 deletions pkg/app/client/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/cloudwego/hertz/pkg/app/client/retry"
"github.com/cloudwego/hertz/pkg/common/config"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/network/dialer"
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/protocol/consts"
)
Expand Down Expand Up @@ -147,3 +148,37 @@ func WithWriteTimeout(t time.Duration) config.ClientOption {
o.WriteTimeout = t
}}
}

// WithDialFunc is used to set dialer function.
func WithDialFunc(f DialFunc, dialers ...network.Dialer) config.ClientOption {
return config.ClientOption{F: func(o *config.ClientOptions) {
d := dialer.DefaultDialer()
if len(dialers) != 0 {
d = dialers[0]
}
o.Dialer = newCustomDialerWithDialFunc(d, f)
}}
}

type DialFunc func(addr string) (network.Conn, error)

// customDialer set customDialerFunc and params to set dailFunc
type customDialer struct {
network.Dialer
dialFunc DialFunc
}

func (m *customDialer) DialConnection(network, address string, timeout time.Duration,
tlsConfig *tls.Config) (conn network.Conn, err error) {
if m.dialFunc != nil {
return m.dialFunc(address)
}
return m.Dialer.DialConnection(network, address, timeout, tlsConfig)
}

func newCustomDialerWithDialFunc(dialer network.Dialer, dialFunc DialFunc) network.Dialer {
return &customDialer{
Dialer: dialer,
dialFunc: dialFunc,
}
}

0 comments on commit ca3e61c

Please sign in to comment.