Skip to content

Commit

Permalink
implement limit connections (#1527)
Browse files Browse the repository at this point in the history
* implement limit connections
  • Loading branch information
SimoneDutto authored Jan 21, 2025
1 parent daeb980 commit 8de5428
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 23 deletions.
121 changes: 121 additions & 0 deletions internal/ssh/listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// nolint:goheader
// Copyright 2009 The Go Authors.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
// - Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// - Neither the name of Google LLC nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

package ssh

import (
"net"
"sync"
"time"
)

// N.B.:
// This is a copypaste of netutil.LimiLister (link: https://cs.opensource.google/go/x/net/+/refs/tags/v0.34.0:netutil/listen.go),
// but we add a timeout so when we are at the limit we actively close connections instead of waiting indefinetely. (Look at line 44)

// limitListenerWithTimeout returns a Listener that accepts at most n simultaneous
// connections from the provided Listener, and it timeouts when the max
// has been reached and no seats has been freed for the timeout period.
func limitListenerWithTimeout(l net.Listener, n int, timeout time.Duration) net.Listener {
return &limitListener{
Listener: l,
sem: make(chan struct{}, n),
done: make(chan struct{}),
timeout: timeout,
}
}

type limitListener struct {
net.Listener
sem chan struct{}
closeOnce sync.Once // ensures the done chan is only closed once
done chan struct{} // no values sent; closed when Close is called
timeout time.Duration // timeout for acquiring the connection
}

// acquire acquires the limiting semaphore. Returns true if successfully
// acquired, false if the listener is closed and the semaphore is not
// acquired.
func (l *limitListener) acquire() bool {
select {
case <-l.done:
return false
case l.sem <- struct{}{}:
return true
// we add a timeout here, so the connection is closed when the timeout has passed instead of waiting.
case <-time.After(l.timeout):
return false
}
}
func (l *limitListener) release() { <-l.sem }

// Accept waits for and returns the next connection to the listener, by checking the semaphore and the timeout.
func (l *limitListener) Accept() (net.Conn, error) {
if !l.acquire() {
// If the semaphore isn't acquired because the listener was closed, expect
// that this call to accept won't block, but immediately return an error.
// If it instead returns a spurious connection (due to a bug in the
// Listener, such as https://golang.org/issue/50216), we immediately close
// it and try again. Some buggy Listener implementations (like the one in
// the aforementioned issue) seem to assume that Accept will be called to
// completion, and may otherwise fail to clean up the client end of pending
// connections.
for {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
c.Close()
}
}

c, err := l.Listener.Accept()
if err != nil {
l.release()
return nil, err
}
return &limitListenerConn{Conn: c, release: l.release}, nil
}

func (l *limitListener) Close() error {
err := l.Listener.Close()
l.closeOnce.Do(func() { close(l.done) })
return err
}

type limitListenerConn struct {
net.Conn
releaseOnce sync.Once
release func()
}

func (l *limitListenerConn) Close() error {
err := l.Conn.Close()
l.releaseOnce.Do(l.release)
return err
}
81 changes: 61 additions & 20 deletions internal/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net"
"time"

"github.com/gliderlabs/ssh"
"github.com/juju/names/v5"
Expand All @@ -17,8 +18,10 @@ import (
"github.com/canonical/jimm/v3/internal/openfga"
)

// juju_ssh_default_port is the default port we expect the juju controllers to respond on.
const juju_ssh_default_port = 17022
// jujuSSHDefaultPort is the default port we expect the juju controllers to respond on.
const jujuSSHDefaultPort = 17022
const defaultAcceptConnectionTimeout = time.Second
const defaultMaxConcurrentConnections = 100

type publicKeySSHUserKey struct{}

Expand Down Expand Up @@ -47,40 +50,78 @@ type forwardMessage struct {
type Config struct {
Port string
HostKey []byte
MaxConcurrentConnections string
MaxConcurrentConnections int
AcceptConnectionTimeout time.Duration
}

// Server is the struct holding the jump server and some
type Server struct {
*ssh.Server

maxConcurrentConnections int
acceptConnectionTimeout time.Duration
}

// NewJumpServer creates the jump server struct.
func NewJumpServer(ctx context.Context, config Config, sshAuthorizer SSHAuthorizer, sshResolver SSHResolver) (*ssh.Server, error) {
func NewJumpServer(ctx context.Context, config Config, sshAuthorizer SSHAuthorizer, sshResolver SSHResolver) (Server, error) {
zapctx.Info(ctx, "NewJumpServer")

if sshResolver == nil {
return nil, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.")
return Server{}, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.")
}
server := &ssh.Server{
Addr: fmt.Sprintf(":%s", config.Port),
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": directTCPIPHandler(sshResolver),
},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
user, err := sshAuthorizer.PublicKeyHandler(ctx, ctx.User(), key.Marshal())
if err != nil {
zapctx.Debug(ctx, fmt.Sprintf("cannot verify key for user %s", ctx.User()), zap.Error(err))
return false
}
ctx.SetValue(publicKeySSHUserKey{}, user)
return true
config = setConfigDefaults(config)
server := Server{
Server: &ssh.Server{
Addr: fmt.Sprintf(":%s", config.Port),
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": directTCPIPHandler(sshResolver),
},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
user, err := sshAuthorizer.PublicKeyHandler(ctx, ctx.User(), key.Marshal())
if err != nil {
zapctx.Debug(ctx, fmt.Sprintf("cannot verify key for user %s", ctx.User()), zap.Error(err))
return false
}
ctx.SetValue(publicKeySSHUserKey{}, user)
return true
},
},
maxConcurrentConnections: config.MaxConcurrentConnections,
acceptConnectionTimeout: config.AcceptConnectionTimeout,
}
hostKey, err := gossh.ParsePrivateKey([]byte(config.HostKey))
if err != nil {
return nil, fmt.Errorf("Cannot parse hostkey.")
return Server{}, fmt.Errorf("Cannot parse hostkey.")
}
server.AddHostKey(hostKey)

return server, nil
}

// setConfigDefaults sets the default values for the configuration.
func setConfigDefaults(config Config) Config {
if config.Port == "" {
config.Port = fmt.Sprint(jujuSSHDefaultPort)
}
if config.MaxConcurrentConnections <= 0 {
config.MaxConcurrentConnections = defaultMaxConcurrentConnections
}
if config.AcceptConnectionTimeout <= 0 {
config.AcceptConnectionTimeout = defaultAcceptConnectionTimeout
}
return config
}

// ListenAndServe create a LimitListenerWithTimeout and Serve requests.
func (srv Server) ListenAndServe() error {
ln, err := net.Listen("tcp", srv.Addr)
ln = limitListenerWithTimeout(ln, srv.maxConcurrentConnections, srv.acceptConnectionTimeout)
if err != nil {
return err
}
return srv.Serve(ln)
}

func directTCPIPHandler(sshResolver SSHResolver) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
return func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
d := forwardMessage{}
Expand All @@ -91,7 +132,7 @@ func directTCPIPHandler(sshResolver SSHResolver) func(srv *ssh.Server, conn *gos
return
}
if d.DestPort == 0 {
d.DestPort = juju_ssh_default_port
d.DestPort = jujuSSHDefaultPort
}
if !names.IsValidModel(d.DestAddr) {
rejectConnectionAndLogError(ctx, newChan, "invalid model uuid", nil)
Expand Down
37 changes: 34 additions & 3 deletions internal/ssh/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
type sshSuite struct {
destinationJujuSSHServer *gliderssh.Server
destinationServerPort int
jumpSSHServer *gliderssh.Server
jumpSSHServer ssh.Server
jumpServerPort int
privateKey gossh.Signer
hostKey gossh.Signer
Expand Down Expand Up @@ -112,8 +112,9 @@ func (s *sshSuite) Init(c *qt.C) {

jumpServer, err := ssh.NewJumpServer(context.Background(),
ssh.Config{
Port: fmt.Sprint(port),
HostKey: hostKey,
Port: fmt.Sprint(port),
HostKey: hostKey,
MaxConcurrentConnections: 10,
},
mocks.SSHAuthorizer{
PublicKeyHandler_: func(ctx context.Context, claimUser string, key []byte) (*openfga.User, error) {
Expand Down Expand Up @@ -269,6 +270,36 @@ func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) {
c.Assert(err, qt.ErrorMatches, ".*connect failed.*")
}

func (s *sshSuite) TestMaxConcurrentConnections(c *qt.C) {
// fill the max of concurrent connection
maxConcurrentConnections := 10
clients := make([]*gossh.Client, 0)
for range maxConcurrentConnections {
client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{
HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()),
Auth: []gossh.AuthMethod{
gossh.PublicKeys(s.privateKey),
},
User: "alice",
})
c.Check(err, qt.IsNil)
clients = append(clients, client)
}
// this connection is dropped when we are at maximum connections
_, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{
HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()),
Auth: []gossh.AuthMethod{
gossh.PublicKeys(s.privateKey),
},
User: "alice",
Timeout: 50 * time.Millisecond,
})
c.Check(err, qt.ErrorMatches, ".*connection reset.*")
for _, client := range clients {
client.Close()
}
}

func TestIdentityManager(t *testing.T) {
qtsuite.Run(qt.New(t), &sshSuite{})
}

0 comments on commit 8de5428

Please sign in to comment.