From 178000f4a364e7489cd1d11b25ed50735837e73d Mon Sep 17 00:00:00 2001 From: Billy Zha Date: Fri, 23 Dec 2022 16:22:09 +0800 Subject: [PATCH] code clean Signed-off-by: Billy Zha --- cmd/oras/internal/option/remote.go | 29 ++++++++++++++++++++++++----- internal/net/net.go | 17 ++++++----------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/cmd/oras/internal/option/remote.go b/cmd/oras/internal/option/remote.go index 4334cee10..118d75559 100644 --- a/cmd/oras/internal/option/remote.go +++ b/cmd/oras/internal/option/remote.go @@ -46,8 +46,9 @@ type Remote struct { Username string PasswordFromStdin bool Password string - resolveFlag []string - onet.Dialer + + resolveFlag []string + resolveDialContext func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) } // ApplyFlags applies flags to a command flag set. @@ -104,9 +105,14 @@ func (opts *Remote) ReadPassword() (err error) { // parseResolve parses resolve flag. func (opts *Remote) parseResolve() error { + if len(opts.resolveFlag) == 0 { + return nil + } + formatError := func(param, message string) error { return fmt.Errorf("failed to parse resolve flag %q: %s", param, message) } + var dialer onet.Dialer for _, r := range opts.resolveFlag { parts := strings.SplitN(r, ":", 3) if len(parts) < 3 { @@ -123,7 +129,11 @@ func (opts *Remote) parseResolve() error { if to == nil { return formatError(r, "invalid IP address") } - opts.Dialer.Add(parts[0], port, to) + dialer.Add(parts[0], port, to) + } + opts.resolveDialContext = func(base *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + dialer.Dialer = base + return dialer.DialContext } return nil } @@ -152,12 +162,21 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client if err := opts.parseResolve(); err != nil { return nil, err } + resolveDialContext := opts.resolveDialContext + if resolveDialContext == nil { + resolveDialContext = func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return dialer.DialContext + } + } client = &auth.Client{ Client: &http.Client{ // default value are derived from http.DefaultTransport Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: opts.Dialer.DialContext, + Proxy: http.ProxyFromEnvironment, + DialContext: resolveDialContext(&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }), ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, diff --git a/internal/net/net.go b/internal/net/net.go index a25408705..538d21271 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -19,18 +19,15 @@ import ( "context" "fmt" "net" - "time" ) -var defaultDialer = &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, -} - +// Dialer struct provides dialing function with predefined DNS resolves. type Dialer struct { + *net.Dialer resolve map[string]string } +// Add adds an entry for DNS resolve. func (d *Dialer) Add(from string, port int, to net.IP) { if d.resolve == nil { d.resolve = make(map[string]string) @@ -41,10 +38,8 @@ func (d *Dialer) Add(from string, port int, to net.IP) { // DialContext connects to the addr on the named network using // the provided context. func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - for k := range d.resolve { - if k == addr { - addr = d.resolve[k] - } + if resolve, ok := d.resolve[addr]; ok { + addr = resolve } - return defaultDialer.DialContext(ctx, network, addr) + return d.Dialer.DialContext(ctx, network, addr) }