diff --git a/internal/api/message.go b/internal/api/message.go index 59586d7..06980be 100644 --- a/internal/api/message.go +++ b/internal/api/message.go @@ -1,14 +1,21 @@ package api import ( + "crypto/rand" + "encoding/base64" "encoding/binary" "errors" "io" ) const ( - // MaxPayloadLength is the maximum allowed length of a message payload. - MaxPayloadLength = 2097152 + // TokenLength is the length of the message token in bytes. + TokenLength = 16 +) + +var ( + // token is the message token. + token [TokenLength]byte ) // Message types. @@ -24,6 +31,7 @@ const ( type Header struct { Type uint16 Length uint32 + Token [TokenLength]byte } // Message is an API message. @@ -34,13 +42,11 @@ type Message struct { // NewMessage returns a new message with type t and payload p. func NewMessage(t uint16, p []byte) *Message { - if len(p) > MaxPayloadLength { - return nil - } return &Message{ Header: Header{ Type: t, Length: uint32(len(p)), + Token: token, }, Value: p, } @@ -69,8 +75,8 @@ func ReadMessage(r io.Reader) (*Message, error) { if h.Type == TypeNone || h.Type >= TypeUndefined { return nil, errors.New("invalid message type") } - if h.Length > MaxPayloadLength { - return nil, errors.New("invalid message length") + if h.Token != token { + return nil, errors.New("invalid message token") } // read payload @@ -107,3 +113,26 @@ func WriteMessage(w io.Writer, m *Message) error { return nil } + +// GetToken generates and returns the message token as string. This should be +// used once on the server side before the server is started. Token must be +// passed to the client side. +func GetToken() (string, error) { + _, err := rand.Read(token[:]) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(token[:]), nil +} + +// SetToken sets the message token from string. This should be used on the +// client side before sending requests to the server. Token must match token on +// the server side. +func SetToken(s string) error { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return err + } + copy(token[:], b) + return nil +} diff --git a/internal/api/message_test.go b/internal/api/message_test.go index 13afd86..f6991c3 100644 --- a/internal/api/message_test.go +++ b/internal/api/message_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "encoding/base64" "errors" "log" "reflect" @@ -24,12 +25,6 @@ func TestNewMessage(t *testing.T) { t.Errorf("got %d, want %d", msg.Type, typ) } } - - // invalid payload length - p := [MaxPayloadLength + 1]byte{} - if NewMessage(TypeOK, p[:]) != nil { - t.Error("should not create message with invalid payload length") - } } // TestNewOK tests NewOK. @@ -62,11 +57,11 @@ func TestReadMessageErrors(t *testing.T) { // invalid type {Header: Header{Type: TypeUndefined}}, - // invalid length - {Header: Header{Type: TypeOK, Length: MaxPayloadLength + 1}}, - // short message - {Header: Header{Type: TypeOK, Length: MaxPayloadLength}}, + {Header: Header{Type: TypeOK, Length: 4096}}, + + // invalid token + {Header: Header{Type: TypeOK, Token: [16]byte{1}}}, } { if err := WriteMessage(buf, msg); err != nil { t.Fatal(err) @@ -132,3 +127,35 @@ func TestReadWriteMessage(t *testing.T) { t.Errorf("got %v, want %v", got, want) } } + +// TestGetSetToken tests GetToken and SetToken. +func TestGetSetToken(t *testing.T) { + // reset token after tests + defer func() { token = [TokenLength]byte{} }() + + // get new test token + testToken, err := GetToken() + if err != nil { + t.Fatal(err) + } + s := base64.RawURLEncoding.EncodeToString(token[:]) + if testToken != s { + t.Fatal("encoded token should match internal token") + } + + // set token + if err := SetToken(testToken); err != nil { + t.Fatal(err) + } + + // check token + s = base64.RawURLEncoding.EncodeToString(token[:]) + if s != testToken { + t.Fatal("internal token should match encoded token") + } + + // setting invalid token + if err := SetToken("not a valid encoded token!"); err == nil { + t.Fatal("invalid token should return error") + } +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 4451816..4afe915 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -3,8 +3,6 @@ package daemon import ( "context" - "crypto/rand" - "encoding/base64" "fmt" "net" "reflect" @@ -344,13 +342,6 @@ func (d *Daemon) updateVPNConfig(request *api.Request) { return } - // check token - if configUpdate.Token != d.token { - log.Error("Daemon got invalid token in vpn config update") - request.Error("invalid token in config update message") - return - } - // handle config update for vpn (dis)connect if configUpdate.Reason == "disconnect" { d.updateVPNConfigDown() @@ -508,13 +499,11 @@ func (d *Daemon) cleanup(ctx context.Context) { // initToken creates the daemon token for client authentication. func (d *Daemon) initToken() error { - // TODO: is this good enough for us? - b := make([]byte, 16) - _, err := rand.Read(b) + token, err := api.GetToken() if err != nil { return err } - d.token = base64.RawURLEncoding.EncodeToString(b) + d.token = token return nil } diff --git a/internal/daemon/vpnconfigupdate.go b/internal/daemon/vpnconfigupdate.go index f4a401f..9706967 100644 --- a/internal/daemon/vpnconfigupdate.go +++ b/internal/daemon/vpnconfigupdate.go @@ -9,7 +9,6 @@ import ( // VPNConfigUpdate is a VPN configuration update. type VPNConfigUpdate struct { Reason string - Token string Config *vpnconfig.Config } @@ -17,16 +16,13 @@ type VPNConfigUpdate struct { func (c *VPNConfigUpdate) Valid() bool { switch c.Reason { case "disconnect": - // token must be valid and config nil - if c.Token == "" || c.Config != nil { + // config must be nil + if c.Config != nil { return false } case "connect": - // token and config must be valid - if c.Token == "" || c.Config == nil { - return false - } - if !c.Config.Valid() { + // config must be valid + if c.Config == nil || !c.Config.Valid() { return false } default: diff --git a/internal/daemon/vpnconfigupdate_test.go b/internal/daemon/vpnconfigupdate_test.go index 5bf2f4b..97f049b 100644 --- a/internal/daemon/vpnconfigupdate_test.go +++ b/internal/daemon/vpnconfigupdate_test.go @@ -22,6 +22,7 @@ func TestVPNConfigUpdateValid(t *testing.T) { // test invalid disconnect u = NewVPNConfigUpdate() u.Reason = "disconnect" + u.Config = vpnconfig.New() got = u.Valid() want = false @@ -29,7 +30,7 @@ func TestVPNConfigUpdateValid(t *testing.T) { t.Errorf("got %t, want %t", got, want) } - // test invalid connect, no token and no config + // test invalid connect, no config u = NewVPNConfigUpdate() u.Reason = "connect" @@ -42,7 +43,6 @@ func TestVPNConfigUpdateValid(t *testing.T) { // test invalid connect, invalid config u = NewVPNConfigUpdate() u.Reason = "connect" - u.Token = "some test token" u.Config = vpnconfig.New() u.Config.Device.Name = "name is too long for a network device" @@ -55,7 +55,6 @@ func TestVPNConfigUpdateValid(t *testing.T) { // test valid disconnect u = NewVPNConfigUpdate() u.Reason = "disconnect" - u.Token = "some test token" got = u.Valid() want = true @@ -66,7 +65,6 @@ func TestVPNConfigUpdateValid(t *testing.T) { // test valid connect u = NewVPNConfigUpdate() u.Reason = "connect" - u.Token = "some test token" u.Config = vpnconfig.New() got = u.Valid() @@ -87,13 +85,11 @@ func TestVPNConfigUpdateJSON(t *testing.T) { // valid disconnect u = NewVPNConfigUpdate() u.Reason = "disconnect" - u.Token = "some test token" updates = append(updates, u) // valid connect u = NewVPNConfigUpdate() u.Reason = "connect" - u.Token = "some test token" u.Config = vpnconfig.New() updates = append(updates, u) diff --git a/internal/vpncscript/client_test.go b/internal/vpncscript/client_test.go index 65e45c6..8fd6c4b 100644 --- a/internal/vpncscript/client_test.go +++ b/internal/vpncscript/client_test.go @@ -69,7 +69,7 @@ func TestRunClient(t *testing.T) { return confUpdate } - // test with maximum payload length + // test with varying payload lengths server = api.NewServer(config) go func() { for r := range server.Requests() { @@ -79,23 +79,12 @@ func TestRunClient(t *testing.T) { if err := server.Start(); err != nil { t.Fatal(err) } - if err := runClient(sockfile, getConfUpdate(api.MaxPayloadLength)); err != nil { - t.Fatal(err) - } - server.Stop() - - // test with more than maximum payload length - server = api.NewServer(config) - go func() { - for r := range server.Requests() { - r.Close() + for _, length := range []int{ + 2048, 4096, 8192, 65536, 2097152, + } { + if err := runClient(sockfile, getConfUpdate(length)); err != nil { + t.Errorf("length %d returned error: %v", length, err) } - }() - if err := server.Start(); err != nil { - t.Fatal(err) - } - if err := runClient(sockfile, getConfUpdate(api.MaxPayloadLength+1)); err == nil { - t.Fatal("too long message should return error") } server.Stop() } diff --git a/internal/vpncscript/cmd.go b/internal/vpncscript/cmd.go index cc66fdc..cbe7583 100644 --- a/internal/vpncscript/cmd.go +++ b/internal/vpncscript/cmd.go @@ -8,6 +8,7 @@ import ( "os" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/api" "github.com/telekom-mms/oc-daemon/internal/daemon" ) @@ -47,6 +48,11 @@ func run(args []string) error { socketFile = e.socketFile } + // set token from environemt + if err := api.SetToken(e.token); err != nil { + return fmt.Errorf("VPNCScript could not set token: %w", err) + } + printDebugEnvironment() log.WithField("env", e).Debug("VPNCScript parsed environment") diff --git a/internal/vpncscript/cmd_test.go b/internal/vpncscript/cmd_test.go index 57bb397..f91da05 100644 --- a/internal/vpncscript/cmd_test.go +++ b/internal/vpncscript/cmd_test.go @@ -24,6 +24,12 @@ func TestRun(t *testing.T) { t.Errorf("help should return ErrHelp, got: %v", err) } + // test with invalid token + t.Setenv("oc_daemon_token", "this is not a valid encoded token!") + if err := run([]string{"test"}); err == nil { + t.Errorf("invalid token should return error") + } + // prepare environment with not existing sockfile os.Clearenv() sockfile := filepath.Join(t.TempDir(), "sockfile") diff --git a/internal/vpncscript/config.go b/internal/vpncscript/config.go index 749851c..f603481 100644 --- a/internal/vpncscript/config.go +++ b/internal/vpncscript/config.go @@ -221,7 +221,6 @@ func createConfig(env *env) (*vpnconfig.Config, error) { func createConfigUpdate(env *env) (*daemon.VPNConfigUpdate, error) { update := daemon.NewVPNConfigUpdate() update.Reason = env.reason - update.Token = env.token if env.reason == "connect" { c, err := createConfig(env) if err != nil { diff --git a/internal/vpncscript/config_test.go b/internal/vpncscript/config_test.go index 988e995..1e1e2d7 100644 --- a/internal/vpncscript/config_test.go +++ b/internal/vpncscript/config_test.go @@ -87,7 +87,6 @@ func TestCreateConfigUpdate(t *testing.T) { // create expected values based on test environment reason := "connect" - token := "some token" config := &vpnconfig.Config{ Gateway: net.IPv4(10, 1, 1, 1), PID: 12345, @@ -129,9 +128,6 @@ func TestCreateConfigUpdate(t *testing.T) { if got.Reason != reason { t.Errorf("got %s, want %s", got.Reason, reason) } - if got.Token != token { - t.Errorf("got %s, want %s", got.Token, token) - } if !reflect.DeepEqual(got.Config, config) { t.Errorf("got:\n%#v\nwant:\n%#v", got.Config, config) }