diff --git a/lib/util/errors/merror.go b/lib/util/errors/merror.go index 2ef7bc75..31a53aef 100644 --- a/lib/util/errors/merror.go +++ b/lib/util/errors/merror.go @@ -64,7 +64,7 @@ func (e *MError) Error() string { func (e *MError) Is(s error) bool { is := errors.Is(e.cerr, s) for _, e := range e.uerr { - is = is || errors.Is(e, s) + is = is || errors.Is(e, s) if is { break } @@ -78,6 +78,14 @@ func (e *MError) Cause() []error { // Collect is used to collect multiple errors. `Unwrap` is noop and `Is(err, ErrMine) == true`. While `As(err, underlyingError)` do not work, you can still get underlying errors by `MError.Cause`. func Collect(cerr error, uerr ...error) error { + n := 0 + for _, e := range uerr { + if e != nil { + uerr[n] = e + n++ + } + } + uerr = uerr[:n] if len(uerr) == 0 { return nil } diff --git a/lib/util/errors/merror_test.go b/lib/util/errors/merror_test.go index 37664123..d6c5576f 100644 --- a/lib/util/errors/merror_test.go +++ b/lib/util/errors/merror_test.go @@ -32,4 +32,8 @@ func TestCollect(t *testing.T) { require.ErrorIsf(t, e, e1, "but errors.Is works for all errors") require.Equal(t, e.(*serr.MError).Cause(), []error{e2, e3}, "get underlying errors") require.NoError(t, serr.Collect(e3), "nil if there is no underlying error") + + e4 := serr.Collect(e1, e2, nil).(*serr.MError) + require.Len(t, e4.Cause(), 1, "collect non-nil erros only") + require.NoError(t, serr.Collect(e3, nil, nil), "nil if all errors are nil") } diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index ee873c12..3eac1a5d 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -109,12 +109,5 @@ func (cc *ClientConnection) processMsg(ctx context.Context) error { } func (cc *ClientConnection) Close() error { - var errs []error - if err := cc.pkt.Close(); err != nil { - errs = append(errs, err) - } - if err := cc.connMgr.Close(); err != nil { - errs = append(errs, err) - } - return errors.Collect(ErrCloseConn, errs...) + return errors.Collect(ErrCloseConn, cc.pkt.Close(), cc.connMgr.Close()) } diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 93d22cbb..7f90bad8 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -158,11 +158,9 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) { // Close closes the server. func (s *SQLServer) Close() error { - var errs []error + errs := make([]error, 0, 4) if s.listener != nil { - if err := s.listener.Close(); err != nil { - errs = append(errs, err) - } + errs = append(errs, s.listener.Close()) } s.mu.Lock() diff --git a/pkg/server/server.go b/pkg/server/server.go index 416fbdc2..b66bbf27 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -205,26 +205,18 @@ func (s *Server) Run(ctx context.Context) { } func (s *Server) Close() error { - var errs []error + errs := make([]error, 0, 4) if s.Proxy != nil { - if err := s.Proxy.Close(); err != nil { - errs = append(errs, err) - } + errs = append(errs, s.Proxy.Close()) } if s.NamespaceManager != nil { - if err := s.NamespaceManager.Close(); err != nil { - errs = append(errs, err) - } + errs = append(errs, s.NamespaceManager.Close()) } if s.ConfigManager != nil { - if err := s.ConfigManager.Close(); err != nil { - errs = append(errs, err) - } + errs = append(errs, s.ConfigManager.Close()) } if s.ObserverClient != nil { - if err := s.ObserverClient.Close(); err != nil { - errs = append(errs, err) - } + errs = append(errs, s.ObserverClient.Close()) } if s.Etcd != nil { var wg waitgroup.WaitGroup