Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement limit connections #1527

Merged
merged 6 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions internal/ssh/listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright 2025 Canonical.
package ssh
alesstimec marked this conversation as resolved.
Show resolved Hide resolved

import (
"net"
"sync"
"time"
)

// N.B.:
// This is a copypaste of netutil.LimiLister (link: https://cs.opensource.google/go/x/net/+/refs/tags/v0.34.0:netutil/listen.go),
// but we add a timeout so when we are at the limit we actively close connections instead of waiting indefinetely. (Look at line 44)

// LimitListenerWithTimeout returns a Listener that accepts at most n simultaneous
// connections from the provided Listener, and it timeouts when the max
// has been reached and no seats has been freed for the timeout period.
func LimitListenerWithTimeout(l net.Listener, n int, timeout time.Duration) net.Listener {
SimoneDutto marked this conversation as resolved.
Show resolved Hide resolved
return &limitListener{
Listener: l,
sem: make(chan struct{}, n),
done: make(chan struct{}),
alesstimec marked this conversation as resolved.
Show resolved Hide resolved
timeout: timeout,
}
}

type limitListener struct {
net.Listener
sem chan struct{}
closeOnce sync.Once // ensures the done chan is only closed once
done chan struct{} // no values sent; closed when Close is called
timeout time.Duration // timeout for acquiring the connection
}

// acquire acquires the limiting semaphore. Returns true if successfully
// acquired, false if the listener is closed and the semaphore is not
// acquired.
func (l *limitListener) acquire() bool {
select {
case <-l.done:
return false
case l.sem <- struct{}{}:
return true
// we add a timeout here, so the connection is closed when the timeout has passed instead of waiting.
case <-time.After(l.timeout):
return false
}
}
func (l *limitListener) release() { <-l.sem }

func (l *limitListener) Accept() (net.Conn, error) {
alesstimec marked this conversation as resolved.
Show resolved Hide resolved
if !l.acquire() {
// If the semaphore isn't acquired because the listener was closed, expect
// that this call to accept won't block, but immediately return an error.
// If it instead returns a spurious connection (due to a bug in the
// Listener, such as https://golang.org/issue/50216), we immediately close
// it and try again. Some buggy Listener implementations (like the one in
// the aforementioned issue) seem to assume that Accept will be called to
// completion, and may otherwise fail to clean up the client end of pending
// connections.
for {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
c.Close()
}
}

c, err := l.Listener.Accept()
if err != nil {
l.release()
return nil, err
}
return &limitListenerConn{Conn: c, release: l.release}, nil
}

func (l *limitListener) Close() error {
SimoneDutto marked this conversation as resolved.
Show resolved Hide resolved
err := l.Listener.Close()
l.closeOnce.Do(func() { close(l.done) })
return err
}

type limitListenerConn struct {
net.Conn
releaseOnce sync.Once
release func()
}

func (l *limitListenerConn) Close() error {
SimoneDutto marked this conversation as resolved.
Show resolved Hide resolved
err := l.Conn.Close()
l.releaseOnce.Do(l.release)
return err
}
85 changes: 65 additions & 20 deletions internal/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net"
"time"

"github.com/gliderlabs/ssh"
"github.com/juju/names/v5"
Expand All @@ -17,8 +18,9 @@ import (
"github.com/canonical/jimm/v3/internal/openfga"
)

// juju_ssh_default_port is the default port we expect the juju controllers to respond on.
const juju_ssh_default_port = 17022
// jujuSSHDefaultPort is the default port we expect the juju controllers to respond on.
const jujuSSHDefaultPort = 17022
const defaultAcceptConnectionTimeout = time.Second

type publicKeySSHUserKey struct{}

Expand Down Expand Up @@ -47,40 +49,83 @@ type forwardMessage struct {
type Config struct {
Port string
HostKey []byte
MaxConcurrentConnections string
MaxConcurrentConnections int
AcceptConnectionTimeout time.Duration
}

type Server struct {
SimoneDutto marked this conversation as resolved.
Show resolved Hide resolved
*ssh.Server
SimoneDutto marked this conversation as resolved.
Show resolved Hide resolved

MaxConcurrentConnections int
SimoneDutto marked this conversation as resolved.
Show resolved Hide resolved
AcceptConnectionTimeout time.Duration
}

// NewJumpServer creates the jump server struct.
func NewJumpServer(ctx context.Context, config Config, sshAuthorizer SSHAuthorizer, sshResolver SSHResolver) (*ssh.Server, error) {
func NewJumpServer(ctx context.Context, config Config, sshAuthorizer SSHAuthorizer, sshResolver SSHResolver) (Server, error) {
SimoneDutto marked this conversation as resolved.
Show resolved Hide resolved
zapctx.Info(ctx, "NewJumpServer")

if sshResolver == nil {
return nil, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.")
return Server{}, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.")
}
server := &ssh.Server{
Addr: fmt.Sprintf(":%s", config.Port),
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": directTCPIPHandler(sshResolver),
},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
user, err := sshAuthorizer.PublicKeyHandler(ctx, ctx.User(), key.Marshal())
if err != nil {
zapctx.Debug(ctx, fmt.Sprintf("cannot verify key for user %s", ctx.User()), zap.Error(err))
return false
}
ctx.SetValue(publicKeySSHUserKey{}, user)
return true
config = setConfigDefaults(config)
server := Server{
Server: &ssh.Server{
Addr: fmt.Sprintf(":%s", config.Port),
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": directTCPIPHandler(sshResolver),
},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
user, err := sshAuthorizer.PublicKeyHandler(ctx, ctx.User(), key.Marshal())
if err != nil {
zapctx.Debug(ctx, fmt.Sprintf("cannot verify key for user %s", ctx.User()), zap.Error(err))
return false
}
ctx.SetValue(publicKeySSHUserKey{}, user)
return true
},
},
MaxConcurrentConnections: config.MaxConcurrentConnections,
AcceptConnectionTimeout: config.AcceptConnectionTimeout,
}
hostKey, err := gossh.ParsePrivateKey([]byte(config.HostKey))
if err != nil {
return nil, fmt.Errorf("Cannot parse hostkey.")
return Server{}, fmt.Errorf("Cannot parse hostkey.")
}
server.AddHostKey(hostKey)

return server, nil
}

// setConfigDefaults sets the default values for the configuration.
func setConfigDefaults(config Config) Config {
if config.Port == "" {
config.Port = fmt.Sprint(jujuSSHDefaultPort)
}
if config.MaxConcurrentConnections <= 0 {
config.MaxConcurrentConnections = 100
SimoneDutto marked this conversation as resolved.
Show resolved Hide resolved
}
if config.AcceptConnectionTimeout <= 0 {
config.AcceptConnectionTimeout = defaultAcceptConnectionTimeout
}
return config
}

// ListenAndServe create a LimitListenerWithTimeout and Serve requests.
func (srv Server) ListenAndServe() error {
ln, err := net.Listen("tcp", srv.Addr)
if srv.MaxConcurrentConnections == 0 {
SimoneDutto marked this conversation as resolved.
Show resolved Hide resolved
srv.MaxConcurrentConnections = 100
}
if srv.AcceptConnectionTimeout == 0 {
srv.AcceptConnectionTimeout = defaultAcceptConnectionTimeout
}
ln = LimitListenerWithTimeout(ln, srv.MaxConcurrentConnections, srv.AcceptConnectionTimeout)
if err != nil {
return err
}
return srv.Serve(ln)
}

func directTCPIPHandler(sshResolver SSHResolver) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
return func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
d := forwardMessage{}
Expand All @@ -91,7 +136,7 @@ func directTCPIPHandler(sshResolver SSHResolver) func(srv *ssh.Server, conn *gos
return
}
if d.DestPort == 0 {
d.DestPort = juju_ssh_default_port
d.DestPort = jujuSSHDefaultPort
}
if !names.IsValidModel(d.DestAddr) {
rejectConnectionAndLogError(ctx, newChan, "invalid model uuid", nil)
Expand Down
37 changes: 34 additions & 3 deletions internal/ssh/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
type sshSuite struct {
destinationJujuSSHServer *gliderssh.Server
destinationServerPort int
jumpSSHServer *gliderssh.Server
jumpSSHServer ssh.Server
jumpServerPort int
privateKey gossh.Signer
hostKey gossh.Signer
Expand Down Expand Up @@ -112,8 +112,9 @@ func (s *sshSuite) Init(c *qt.C) {

jumpServer, err := ssh.NewJumpServer(context.Background(),
ssh.Config{
Port: fmt.Sprint(port),
HostKey: hostKey,
Port: fmt.Sprint(port),
HostKey: hostKey,
MaxConcurrentConnections: 10,
},
mocks.SSHAuthorizer{
PublicKeyHandler_: func(ctx context.Context, claimUser string, key []byte) (*openfga.User, error) {
Expand Down Expand Up @@ -269,6 +270,36 @@ func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) {
c.Assert(err, qt.ErrorMatches, ".*connect failed.*")
}

func (s *sshSuite) TestMaxConcurrentConnections(c *qt.C) {
// fill the max of concurrent connection
maxConcurrentConnections := 10
clients := make([]*gossh.Client, 0)
for range maxConcurrentConnections {
client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{
HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()),
Auth: []gossh.AuthMethod{
gossh.PublicKeys(s.privateKey),
},
User: "alice",
})
c.Check(err, qt.IsNil)
clients = append(clients, client)
}
// this connection is dropped when we are at maximum connections
_, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{
HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()),
Auth: []gossh.AuthMethod{
gossh.PublicKeys(s.privateKey),
},
User: "alice",
Timeout: 50 * time.Millisecond,
})
c.Check(err, qt.ErrorMatches, ".*connection reset.*")
for _, client := range clients {
client.Close()
}
}

func TestIdentityManager(t *testing.T) {
qtsuite.Run(qt.New(t), &sshSuite{})
}
Loading