diff --git a/go.mod b/go.mod index ed1d9ad..f0e798b 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/avast/retry-go/v4 v4.5.1 github.com/gin-contrib/pprof v1.5.0 github.com/gin-gonic/gin v1.10.0 + github.com/go-ping/ping v1.1.0 github.com/gorilla/mux v1.8.0 github.com/mitchellh/go-ps v1.0.0 github.com/prometheus/client_golang v1.12.1 @@ -69,6 +70,7 @@ require ( golang.org/x/arch v0.8.0 // indirect golang.org/x/net v0.25.0 // indirect golang.org/x/oauth2 v0.15.0 // indirect + golang.org/x/sync v0.5.0 // indirect golang.org/x/sys v0.20.0 // indirect golang.org/x/term v0.20.0 // indirect golang.org/x/text v0.15.0 // indirect diff --git a/go.sum b/go.sum index cbf83e7..1da2948 100644 --- a/go.sum +++ b/go.sum @@ -101,6 +101,8 @@ github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2Kv github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= +github.com/go-ping/ping v1.1.0 h1:3MCGhVX4fyEUuhsfwPrsEdQw6xspHkv5zHsiSoDFZYw= +github.com/go-ping/ping v1.1.0/go.mod h1:xIFjORFzTxqIV/tDVGO4eDy/bLuSyawEeojSm3GfRGk= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -177,6 +179,7 @@ github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= @@ -390,6 +393,7 @@ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= @@ -411,6 +415,7 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -448,6 +453,7 @@ golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/kubernetes/helm/templates/client-deployment.yaml b/kubernetes/helm/templates/client-deployment.yaml index 22a1e58..f0e0c21 100644 --- a/kubernetes/helm/templates/client-deployment.yaml +++ b/kubernetes/helm/templates/client-deployment.yaml @@ -148,6 +148,7 @@ spec: - --server - {{ .Values.client.serverDsn | required "Please set client.serverDsn" }} - '--key-storage-db=/storage/keys.db' + - '--pairing-client-cache-db=/storage/keycache.db' {{ end }} \ No newline at end of file diff --git a/pkg/cmd/client.go b/pkg/cmd/client.go index 2ad2d7f..9084325 100644 --- a/pkg/cmd/client.go +++ b/pkg/cmd/client.go @@ -38,12 +38,13 @@ var clientCommand *cli.Command = &cli.Command{ helloRetryIntervalFlag, nginxExposerConfdPathFlag, wireguardConfigFilePathFlag, + pairingClientCacheDBPath, keyStorageDBFlag, }, Action: func(c *cli.Context) error { privateKey, publicKey, keyErr := wg.GetOrGenerateKeyPair(getKeyStorage(c)) if keyErr != nil { - logrus.Fatalf("Failed to get or generate key pair: %v", keyErr) + logrus.Fatalf("Failed to get key pair: %v", keyErr) } startPrometheusServer(c) @@ -85,21 +86,37 @@ var clientCommand *cli.Command = &cli.Command{ ) } - client := hello.NewPairingClient( - c.String(peerNameFlag.Name), - &wg.Config{ - PrivateKey: privateKey, - Subnet: "32", - }, - - hello.KeyPair{ - PublicKey: publicKey, - PrivateKey: privateKey, - }, - wg.NewWatcher(c.String(wireguardConfigFilePathFlag.Name)), - hello.NewJSONPairingEncoder(), - transport, + pairingKeyCache := hello.NewInMemoryKeyCachingPairingClientStorage() + if c.String(pairingClientCacheDBPath.Name) != "" { + var err error + pairingKeyCache, err = hello.NewBoltKeyCachingPairingClientStorage(c.String(pairingClientCacheDBPath.Name)) + if err != nil { + logrus.Fatalf("Failed to create pairing key cache: %v", err) + } + } + wgReloader := wg.NewWatcher(c.String(wireguardConfigFilePathFlag.Name)) + wgConfig := &wg.Config{ + PrivateKey: privateKey, + Subnet: "32", + } + keyPair := hello.KeyPair{ + PublicKey: publicKey, + PrivateKey: privateKey, + } + client := hello.NewKeyCachingPairingClient( + pairingKeyCache, + wgConfig, + wgReloader, + hello.NewDefaultPairingClient( + c.String(peerNameFlag.Name), + wgConfig, + keyPair, + wgReloader, + hello.NewJSONPairingEncoder(), + transport, + ), ) + var pairingResponse hello.PairingResponse for { var err error diff --git a/pkg/cmd/flags.go b/pkg/cmd/flags.go index f42bc5e..5a1a04f 100644 --- a/pkg/cmd/flags.go +++ b/pkg/cmd/flags.go @@ -22,6 +22,11 @@ var keyStorageDBFlag *cli.StringFlag = &cli.StringFlag{ Value: "", } +var pairingClientCacheDBPath *cli.StringFlag = &cli.StringFlag{ + Name: "pairing-client-cache-db", + Value: "", +} + var kubernetesFlag *cli.BoolFlag = &cli.BoolFlag{ Name: "kubernetes", Usage: "Use kubernetes to create proxy services", diff --git a/pkg/cmd/root.go b/pkg/cmd/root.go index 69560ba..0de8bc1 100644 --- a/pkg/cmd/root.go +++ b/pkg/cmd/root.go @@ -16,8 +16,9 @@ var debugFlag = &cli.BoolFlag{ // Run starts wormgole func Run() { app := &cli.App{ - Name: "wormhole", - Usage: "Wormhole is an utility to create reverse websocket tunnels, similar to ngrok", + Name: "wormhole", + Usage: ("Wormhole is an utility to create reverse websocket tunnels, " + + "similar to ngrok, but designed to be used in a kubernetes cluster"), EnableBashCompletion: true, Commands: []*cli.Command{ serverCommand, diff --git a/pkg/hello/pairing.go b/pkg/hello/pairing.go index c5d20f6..1915f7c 100644 --- a/pkg/hello/pairing.go +++ b/pkg/hello/pairing.go @@ -1,14 +1,160 @@ package hello import ( + "encoding/json" "fmt" "github.com/glothriel/wormhole/pkg/wg" "github.com/sirupsen/logrus" + bolt "go.etcd.io/bbolt" ) -// PairingClient is a client that can pair with a server -type PairingClient struct { +// PairingClient allows pairing with a server +type PairingClient interface { + Pair() (PairingResponse, error) +} + +type keyCachingPairingClient struct { + client PairingClient + storage KeyCachingPairingClientStorage + wgConfig *wg.Config + wgReloader WireguardConfigReloader + + pinger pinger +} + +func (c *keyCachingPairingClient) Pair() (PairingResponse, error) { + response, getErr := c.storage.Get() + if getErr == nil { + c.wgConfig.Address = response.AssignedIP + c.wgConfig.Upsert(wg.Peer{ + Name: response.Name, + Endpoint: response.Wireguard.Endpoint, + PublicKey: response.Wireguard.PublicKey, + AllowedIPs: fmt.Sprintf("%s/32,%s/32", response.InternalServerIP, response.AssignedIP), + PersistentKeepalive: 10, + }) + + updateErr := c.wgReloader.Update(*c.wgConfig) + if updateErr != nil { + logrus.Errorf("Failed to update Wireguard config: %v", updateErr) + } + logrus.Infof( + "Trying to ping server %s with the config from the cache", response.InternalServerIP, + ) + pingerErr := c.pinger.Ping(response.InternalServerIP) + if pingerErr == nil { + logrus.Infof("Successfully pinged server %s, using IP from the cache", response.InternalServerIP) + return response, nil + } + logrus.Warnf("Failed to ping server %s: %v, attempting to pair using PSK", response.InternalServerIP, pingerErr) + } else { + logrus.Info("No cached pairing response found, pairing with server") + } + childResponse, pairErr := c.client.Pair() + if pairErr != nil { + return PairingResponse{}, pairErr + } + setErr := c.storage.Set(childResponse) + if setErr != nil { + logrus.Errorf("Failed to store pairing response: %v", setErr) + } + + return childResponse, nil +} + +// NewKeyCachingPairingClient is a decorator that tries to cache the keys obtained by child client +func NewKeyCachingPairingClient( + storage KeyCachingPairingClientStorage, + wgConfig *wg.Config, + + wgReloader WireguardConfigReloader, + client PairingClient, +) PairingClient { + return &keyCachingPairingClient{ + client: client, + storage: storage, + wgReloader: wgReloader, + wgConfig: wgConfig, + + pinger: &retryingPinger{&defaultPinger{}}, + } +} + +// KeyCachingPairingClientStorage is a storage for pairing responses cache +type KeyCachingPairingClientStorage interface { + Set(PairingResponse) error + Get() (PairingResponse, error) +} + +type boltKeyCachingPairingClientStorage struct { + db *bolt.DB +} + +func (s *boltKeyCachingPairingClientStorage) Get() (PairingResponse, error) { + var response PairingResponse + err := s.db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte("pairing")) + if bucket == nil { + return fmt.Errorf("bucket does not exist") + } + data := bucket.Get([]byte("response")) + if data == nil { + return fmt.Errorf("response does not exist") + } + return json.Unmarshal(data, &response) + }) + return response, err +} + +func (s *boltKeyCachingPairingClientStorage) Set(response PairingResponse) error { + return s.db.Update(func(tx *bolt.Tx) error { + bucket, createErr := tx.CreateBucketIfNotExists([]byte("pairing")) + if createErr != nil { + return createErr + } + encoded, encodeErr := json.Marshal(response) + if encodeErr != nil { + return encodeErr + } + return bucket.Put([]byte("response"), encoded) + }) +} + +// NewBoltKeyCachingPairingClientStorage creates a new KeyCachingPairingClientStorage backed by a bolt database +func NewBoltKeyCachingPairingClientStorage(path string) (KeyCachingPairingClientStorage, error) { + db, err := bolt.Open(path, 0600, nil) + if err != nil { + return nil, err + } + return &boltKeyCachingPairingClientStorage{db: db}, nil +} + +type inMemoryKeyCachingPairingClientStorage struct { + isSet bool + response PairingResponse +} + +func (s *inMemoryKeyCachingPairingClientStorage) Get() (PairingResponse, error) { + if !s.isSet { + return PairingResponse{}, fmt.Errorf("response not set") + } + return s.response, nil +} + +func (s *inMemoryKeyCachingPairingClientStorage) Set(response PairingResponse) error { + s.response = response + s.isSet = true + return nil +} + +// NewInMemoryKeyCachingPairingClientStorage creates a new KeyCachingPairingClientStorage backed by memory +func NewInMemoryKeyCachingPairingClientStorage() KeyCachingPairingClientStorage { + return &inMemoryKeyCachingPairingClientStorage{} +} + +// defaultPairingClient is a client that can pair with a server +type defaultPairingClient struct { clientName string keyPair KeyPair wgConfig *wg.Config @@ -19,7 +165,7 @@ type PairingClient struct { } // Pair sends a pairing request to the server and returns the response -func (c *PairingClient) Pair() (PairingResponse, error) { +func (c *defaultPairingClient) Pair() (PairingResponse, error) { request := PairingRequest{ Name: c.clientName, Wireguard: PairingRequestWireguardConfig{ @@ -53,16 +199,16 @@ func (c *PairingClient) Pair() (PairingResponse, error) { return decoded, c.wgReloader.Update(*c.wgConfig) } -// NewPairingClient creates a new PairingClient instance -func NewPairingClient( +// NewDefaultPairingClient executes HTTP pairing requests to the server +func NewDefaultPairingClient( clientName string, wgConfig *wg.Config, keyPair KeyPair, wgReloader WireguardConfigReloader, encoder PairingEncoder, transport PairingClientTransport, -) *PairingClient { - return &PairingClient{ +) PairingClient { + return &defaultPairingClient{ clientName: clientName, keyPair: keyPair, wgConfig: wgConfig, diff --git a/pkg/hello/pinger.go b/pkg/hello/pinger.go new file mode 100644 index 0000000..380e070 --- /dev/null +++ b/pkg/hello/pinger.go @@ -0,0 +1,43 @@ +package hello + +import ( + "fmt" + "time" + + "github.com/avast/retry-go/v4" + "github.com/go-ping/ping" +) + +type pinger interface { + Ping(address string) error +} + +type defaultPinger struct{} + +func (p *defaultPinger) Ping(address string) error { + pinger, pingerErr := ping.NewPinger(address) + if pingerErr != nil { + return fmt.Errorf("failed to create pinger: %v", pingerErr) + } + pinger.Count = 3 + pinger.Timeout = 3 * time.Second + runErr := pinger.Run() + if runErr != nil { + return fmt.Errorf("failed to run pinger: %v", runErr) + } + if pinger.Statistics().PacketsRecv == 0 { + return fmt.Errorf("failed to ping server %s over the tunnel", address) + } + return nil +} + +type retryingPinger struct { + pinger pinger +} + +// uses avast/retry-go to retry pings +func (p *retryingPinger) Ping(address string) error { + return retry.Do(func() error { + return p.pinger.Ping(address) + }, retry.Attempts(5), retry.Delay(time.Second)) +} diff --git a/pkg/k8s/svcdetector/service.go b/pkg/k8s/svcdetector/service.go index c86104e..62e605d 100644 --- a/pkg/k8s/svcdetector/service.go +++ b/pkg/k8s/svcdetector/service.go @@ -2,10 +2,12 @@ package svcdetector import ( "fmt" + "math" "strconv" "strings" "github.com/glothriel/wormhole/pkg/peers" + "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" ) @@ -67,7 +69,13 @@ func (wrapper defaultServiceWrapper) ports() []corev1.ServicePort { } } else { for _, portDefinition := range wrapper.k8sSvc.Spec.Ports { - if portDefinition.Port == int32(portAsNumber) { + portAsInt32, portErr := safePortConversion(portAsNumber) + if portErr != nil { + logrus.Errorf("invalid port number: %v", portErr) + continue + } + + if portDefinition.Port == portAsInt32 { thePorts = append(thePorts, *portDefinition.DeepCopy()) } } @@ -76,6 +84,20 @@ func (wrapper defaultServiceWrapper) ports() []corev1.ServicePort { return thePorts } +func safePortConversion(portNumber int64) (int32, error) { + // Check lower bound + if portNumber < 0 { + return 0, fmt.Errorf("port number cannot be negative: %d", portNumber) + } + + // Check upper bound + if portNumber > math.MaxInt32 { + return 0, fmt.Errorf("port number exceeds maximum int32 value: %d", portNumber) + } + + return int32(portNumber), nil // nolint: gosec +} + func (wrapper defaultServiceWrapper) apps() []peers.App { apps := make([]peers.App, 0) exposedPorts := wrapper.ports() diff --git a/tests/test_kubernetes.py b/tests/test_kubernetes.py index 399614d..cf29aae 100644 --- a/tests/test_kubernetes.py +++ b/tests/test_kubernetes.py @@ -235,3 +235,30 @@ def _ensure_that_proxied_service_is_reachable(): 'http://server-nginx-nginx.client.svc.cluster.local', max_time_seconds=10, ) + + + +def test_reconnecting_clients_with_keys( + kubectl, + k8s_server, + k8s_client, + mock_server, +): + + @retry(tries=int(DEFAULT_RETRY_TRIES / 10), delay=DEFAULT_RETRY_DELAY) + def _wait_for_peers_paired_using_psk(): + assert "Paired with server, assigned IP" in kubectl.run( + ["-n", "client", "logs", "-l", "application=wormhole-client-client", "-c", "wormhole"] + ).stdout.decode() + + _wait_for_peers_paired_using_psk() + + kubectl.run(["-n", "client", "delete", "pod", "-l", "application=wormhole-client-client"]) + + @retry(tries=int(DEFAULT_RETRY_TRIES / 10), delay=DEFAULT_RETRY_DELAY) + def _wait_for_peers_paired_using_keys(): + assert "using IP from the cache" in kubectl.run( + ["-n", "client", "logs", "-l", "application=wormhole-client-client", "-c", "wormhole"] + ).stdout.decode() + + _wait_for_peers_paired_using_keys()