diff --git a/cmd/root.go b/cmd/root.go index 53814dbb..cd9650fe 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -139,6 +139,12 @@ without having to manage any client SSL certificates.`, cmd.PersistentFlags().Uint64Var(&c.conf.MaxConnections, "max-connections", 0, `Limits the number of connections by refusing any additional connections. When this flag is not set, there is no limit.`) + cmd.PersistentFlags().DurationVar(&c.conf.WaitOnClose, "max-sigterm-delay", 0, + `Maximum amount of time to wait after for any open connections +to close after receiving a TERM signal. The proxy will shut +down when the number of open connections reaches 0 or when +the maximum time has passed. Defaults to 0s.`) + cmd.PersistentFlags().StringVar(&c.telemetryProject, "telemetry-project", "", "Enable Cloud Monitoring and Cloud Trace integration with the provided project ID.") cmd.PersistentFlags().BoolVar(&c.disableTraces, "disable-traces", false, @@ -389,7 +395,7 @@ func runSignalWrapper(cmd *Command) error { cmd.Println("The proxy has started successfully and is ready for new connections!") defer func() { if cErr := p.Close(); cErr != nil { - cmd.PrintErrf("error during shutdown: %v\n", cErr) + cmd.PrintErrf("The proxy failed to close cleanly: %v\n", cErr) } }() @@ -400,9 +406,9 @@ func runSignalWrapper(cmd *Command) error { err := <-shutdownCh switch { case errors.Is(err, errSigInt): - cmd.PrintErrln("SIGINT signal received. Shuting down...") + cmd.PrintErrln("SIGINT signal received. Shutting down...") case errors.Is(err, errSigTerm): - cmd.PrintErrln("SIGTERM signal received. Shuting down...") + cmd.PrintErrln("SIGTERM signal received. Shutting down...") default: cmd.PrintErrf("The proxy has encountered a terminal error: %v\n", err) } diff --git a/cmd/root_test.go b/cmd/root_test.go index 8dd663af..de56747e 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -170,6 +170,13 @@ func TestNewCommandArguments(t *testing.T) { MaxConnections: 1, }), }, + { + desc: "using wait after signterm flag", + args: []string{"--max-sigterm-delay", "10s", "/projects/proj/locations/region/clusters/clust/instances/inst"}, + want: withDefaults(&proxy.Config{ + WaitOnClose: 10 * time.Second, + }), + }, } for _, tc := range tcs { diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 94e430d0..0aa1a44a 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -85,6 +85,11 @@ type Config struct { // connections. A zero-value indicates no limit. MaxConnections uint64 + // WaitOnClose sets the duration to wait for connections to close before + // shutting down. Not setting this field means to close immediately + // regardless of any open connections. + WaitOnClose time.Duration + // Dialer specifies the dialer to use when connecting to AlloyDB // instances. Dialer alloydb.Dialer @@ -172,6 +177,10 @@ type Client struct { // mnts is a list of all mounted sockets for this client mnts []*socketMount + + // waitOnClose is the maximum duration to wait for open connections to close + // when shutting down. + waitOnClose time.Duration } // NewClient completes the initial setup required to get the proxy to a "steady" state. @@ -210,10 +219,11 @@ func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client, } c := &Client{ - mnts: mnts, - cmd: cmd, - dialer: d, - maxConns: conf.MaxConnections, + mnts: mnts, + cmd: cmd, + dialer: d, + maxConns: conf.MaxConnections, + waitOnClose: conf.WaitOnClose, } return c, nil } @@ -262,16 +272,40 @@ func (m MultiErr) Error() string { func (c *Client) Close() error { var mErr MultiErr + // First, close all open socket listeners to prevent additional connections. for _, m := range c.mnts { err := m.Close() if err != nil { mErr = append(mErr, err) } } + // Next, close the dialer to prevent any additional refreshes. cErr := c.dialer.Close() if cErr != nil { mErr = append(mErr, cErr) } + if c.waitOnClose == 0 { + if len(mErr) > 0 { + return mErr + } + return nil + } + timeout := time.After(c.waitOnClose) + tick := time.Tick(100 * time.Millisecond) + for { + select { + case <-tick: + if atomic.LoadUint64(&c.connCount) > 0 { + continue + } + case <-timeout: + } + break + } + open := atomic.LoadUint64(&c.connCount) + if open > 0 { + mErr = append(mErr, fmt.Errorf("%d connection(s) still open after waiting %v", open, c.waitOnClose)) + } if len(mErr) > 0 { return mErr } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index f3504071..f0555ca5 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -65,7 +65,7 @@ type errorDialer struct { fakeDialer } -func (errorDialer) Close() error { +func (*errorDialer) Close() error { return errors.New("errorDialer returns error on Close") } @@ -143,15 +143,15 @@ func TestClientInitialization(t *testing.T) { desc: "with incrementing automatic port selection", in: &proxy.Config{ Addr: "127.0.0.1", - Port: 5432, // default port + Port: 6000, Instances: []proxy.InstanceConnConfig{ {Name: inst1}, {Name: inst2}, }, }, wantTCPAddrs: []string{ - "127.0.0.1:5432", - "127.0.0.1:5433", + "127.0.0.1:6000", + "127.0.0.1:6001", }, }, { @@ -238,25 +238,6 @@ func TestClientInitialization(t *testing.T) { } } -func tryTCPDial(t *testing.T, addr string) net.Conn { - attempts := 10 - var ( - conn net.Conn - err error - ) - for i := 0; i < attempts; i++ { - conn, err = net.Dial("tcp", addr) - if err != nil { - time.Sleep(100 * time.Millisecond) - continue - } - return conn - } - - t.Fatalf("failed to dial in %v attempts: %v", attempts, err) - return nil -} - func TestClientLimitsMaxConnections(t *testing.T) { d := &fakeDialer{} in := &proxy.Config{ @@ -291,17 +272,92 @@ func TestClientLimitsMaxConnections(t *testing.T) { // wait only a second for the result (since nothing is writing to the // socket) conn2.SetReadDeadline(time.Now().Add(time.Second)) - _, rErr := conn2.Read(make([]byte, 1)) - if rErr != io.EOF { - t.Fatalf("conn.Read should return io.EOF, got = %v", rErr) + + wantEOF := func(t *testing.T, c net.Conn) { + var got error + for i := 0; i < 10; i++ { + _, got = c.Read(make([]byte, 1)) + if got == io.EOF { + return + } + time.Sleep(100 * time.Millisecond) + } + t.Fatalf("conn.Read should return io.EOF, got = %v", got) } + wantEOF(t, conn2) + want := 1 if got := d.dialAttempts(); got != want { t.Fatalf("dial attempts did not match expected, want = %v, got = %v", want, got) } } +func tryTCPDial(t *testing.T, addr string) net.Conn { + attempts := 10 + var ( + conn net.Conn + err error + ) + for i := 0; i < attempts; i++ { + conn, err = net.Dial("tcp", addr) + if err != nil { + time.Sleep(100 * time.Millisecond) + continue + } + return conn + } + + t.Fatalf("failed to dial in %v attempts: %v", attempts, err) + return nil +} + +func TestClientCloseWaitsForActiveConnections(t *testing.T) { + in := &proxy.Config{ + Addr: "127.0.0.1", + Port: 5000, + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:pg"}, + }, + Dialer: &fakeDialer{}, + } + c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in) + if err != nil { + t.Fatalf("proxy.NewClient error: %v", err) + } + go c.Serve(context.Background()) + + conn := tryTCPDial(t, "127.0.0.1:5000") + _ = conn.Close() + + if err := c.Close(); err != nil { + t.Fatalf("c.Close error: %v", err) + } + + in.WaitOnClose = time.Second + in.Port = 5001 + c, err = proxy.NewClient(context.Background(), &cobra.Command{}, in) + if err != nil { + t.Fatalf("proxy.NewClient error: %v", err) + } + go c.Serve(context.Background()) + + var open []net.Conn + for i := 0; i < 5; i++ { + conn = tryTCPDial(t, "127.0.0.1:5001") + open = append(open, conn) + } + defer func() { + for _, o := range open { + o.Close() + } + }() + + if err := c.Close(); err == nil { + t.Fatal("c.Close should error, got = nil") + } +} + func TestClientClosesCleanly(t *testing.T) { in := &proxy.Config{ Addr: "127.0.0.1", @@ -316,12 +372,8 @@ func TestClientClosesCleanly(t *testing.T) { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } go c.Serve(context.Background()) - time.Sleep(time.Second) // allow the socket to start listening - conn, dErr := net.Dial("tcp", "127.0.0.1:5000") - if dErr != nil { - t.Fatalf("net.Dial error = %v", dErr) - } + conn := tryTCPDial(t, "127.0.0.1:5000") _ = conn.Close() if err := c.Close(); err != nil { @@ -343,7 +395,9 @@ func TestClosesWithError(t *testing.T) { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } go c.Serve(context.Background()) - time.Sleep(time.Second) // allow the socket to start listening + + conn := tryTCPDial(t, "127.0.0.1:5000") + defer conn.Close() if err = c.Close(); err == nil { t.Fatal("c.Close() should error, got nil")