diff --git a/internal/test/ouroboros_mock/connection.go b/internal/test/ouroboros_mock/connection.go index b098d039..a6ba3254 100644 --- a/internal/test/ouroboros_mock/connection.go +++ b/internal/test/ouroboros_mock/connection.go @@ -185,6 +185,9 @@ func (c *Connection) processInputEntry(entry ConversationEntry) error { if msg == nil { return fmt.Errorf("received unknown message type: %d", msgType) } + // Set CBOR for expected message to match received to make comparison easier + entry.InputMessage.SetCbor(msg.Cbor()) + // Compare received message to expected message if !reflect.DeepEqual(msg, entry.InputMessage) { return fmt.Errorf( "parsed message does not match expected value: got %#v, expected %#v", diff --git a/protocol/handshake/client_test.go b/protocol/handshake/client_test.go index b1ae0ce2..ad392a6e 100644 --- a/protocol/handshake/client_test.go +++ b/protocol/handshake/client_test.go @@ -22,7 +22,7 @@ import ( "github.com/blinklabs-io/gouroboros/internal/test/ouroboros_mock" ) -func TestBasicHandshake(t *testing.T) { +func TestClientBasicHandshake(t *testing.T) { mockConn := ouroboros_mock.NewConnection( ouroboros_mock.ProtocolRoleClient, []ouroboros_mock.ConversationEntry{ @@ -52,7 +52,7 @@ func TestBasicHandshake(t *testing.T) { } } -func TestDoubleStart(t *testing.T) { +func TestClientDoubleStart(t *testing.T) { mockConn := ouroboros_mock.NewConnection( ouroboros_mock.ProtocolRoleClient, []ouroboros_mock.ConversationEntry{ diff --git a/protocol/handshake/server.go b/protocol/handshake/server.go index 3358a7ec..cf981880 100644 --- a/protocol/handshake/server.go +++ b/protocol/handshake/server.go @@ -62,40 +62,89 @@ func (s *Server) handleMessage(msg protocol.Message, isResponse bool) error { return err } -func (s *Server) handleProposeVersions(msgGeneric protocol.Message) error { +func (s *Server) handleProposeVersions(msg protocol.Message) error { if s.config.FinishedFunc == nil { return fmt.Errorf( "received handshake ProposeVersions message but no callback function is defined", ) } - msg := msgGeneric.(*MsgProposeVersions) - var highestVersion uint16 - var versionData protocol.VersionData - for proposedVersion := range msg.VersionMap { - if proposedVersion > highestVersion { - for allowedVersion := range s.config.ProtocolVersionMap { - if allowedVersion == proposedVersion { - highestVersion = proposedVersion - versionConfig := protocol.GetProtocolVersion(proposedVersion) - tmpVersionData, err := versionConfig.NewVersionDataFromCborFunc(msg.VersionMap[proposedVersion]) - versionData = tmpVersionData - if err != nil { - return err - } - break - } - } + msgProposeVersions := msg.(*MsgProposeVersions) + // Compute intersection of supported and proposed protocol versions + var versionIntersect []uint16 + for proposedVersion := range msgProposeVersions.VersionMap { + if _, ok := s.config.ProtocolVersionMap[proposedVersion]; ok { + versionIntersect = append(versionIntersect, proposedVersion) } } - if highestVersion > 0 { - resp := NewMsgAcceptVersion(highestVersion, versionData) - if err := s.SendMessage(resp); err != nil { + // Send refusal if there are no matching versions + if len(versionIntersect) == 0 { + var supportedVersions []uint16 + for supportedVersion := range s.config.ProtocolVersionMap { + supportedVersions = append(supportedVersions, supportedVersion) + } + msgRefuse := NewMsgRefuse( + []any{ + RefuseReasonVersionMismatch, + supportedVersions, + }, + ) + if err := s.SendMessage(msgRefuse); err != nil { return err } - return s.config.FinishedFunc(highestVersion, versionData) - } else { - // TODO: handle failures - // https://github.com/blinklabs-io/gouroboros/issues/32 - return fmt.Errorf("handshake failed, but we don't yet support this") + return fmt.Errorf("handshake failed: refused due to version mismatch") + } + // Compute highest version from intersection + var proposedVersion uint16 + for _, version := range versionIntersect { + if version > proposedVersion { + proposedVersion = version + } + } + // Decode protocol parameters for selected version + versionInfo := protocol.GetProtocolVersion(proposedVersion) + versionData := s.config.ProtocolVersionMap[proposedVersion] + proposedVersionData, err := versionInfo.NewVersionDataFromCborFunc( + msgProposeVersions.VersionMap[proposedVersion], + ) + if err != nil { + msgRefuse := NewMsgRefuse( + []any{ + RefuseReasonDecodeError, + proposedVersion, + err.Error(), + }, + ) + if err := s.SendMessage(msgRefuse); err != nil { + return err + } + return fmt.Errorf( + "handshake failed: refused due to protocol parameters decode failure: %s", + err, + ) + } + // Check network magic + if proposedVersionData.NetworkMagic() != versionData.NetworkMagic() { + errMsg := fmt.Sprintf("network magic mismatch: %#v /= %#v", versionData, proposedVersionData) + msgRefuse := NewMsgRefuse( + []any{ + RefuseReasonRefused, + proposedVersion, + errMsg, + }, + ) + if err := s.SendMessage(msgRefuse); err != nil { + return err + } + return fmt.Errorf( + "handshake failed: refused due to protocol parameters mismatch: %s", + errMsg, + ) + } + // Accept the proposed version + // We send our version data in the response and the proposed version data in the callback + msgAcceptVersion := NewMsgAcceptVersion(proposedVersion, versionData) + if err := s.SendMessage(msgAcceptVersion); err != nil { + return err } + return s.config.FinishedFunc(proposedVersion, proposedVersionData) } diff --git a/protocol/handshake/server_test.go b/protocol/handshake/server_test.go new file mode 100644 index 00000000..7742cce1 --- /dev/null +++ b/protocol/handshake/server_test.go @@ -0,0 +1,135 @@ +// Copyright 2023 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package handshake_test + +import ( + "fmt" + "testing" + + ouroboros "github.com/blinklabs-io/gouroboros" + "github.com/blinklabs-io/gouroboros/internal/test/ouroboros_mock" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/blinklabs-io/gouroboros/protocol/handshake" +) + +func TestServerBasicHandshake(t *testing.T) { + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleServer, + []ouroboros_mock.ConversationEntry{ + // MsgProposeVersions from mock client + { + Type: ouroboros_mock.EntryTypeOutput, + ProtocolId: handshake.ProtocolId, + OutputMessages: []protocol.Message{ + handshake.NewMsgProposeVersions( + protocol.ProtocolVersionMap{ + (10 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic), + (11 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic), + (12 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic), + }, + ), + }, + }, + // MsgAcceptVersion from server + { + Type: ouroboros_mock.EntryTypeInput, + IsResponse: true, + ProtocolId: handshake.ProtocolId, + MsgFromCborFunc: handshake.NewMsgFromCbor, + InputMessageType: handshake.MessageTypeAcceptVersion, + InputMessage: handshake.NewMsgAcceptVersion( + (12 + protocol.ProtocolVersionNtCOffset), + protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic), + ), + }, + }, + ) + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithServer(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + // Async error handler + go func() { + err, ok := <-oConn.ErrorChan() + if !ok { + return + } + // We can't call t.Fatalf() from a different Goroutine, so we panic instead + panic(fmt.Sprintf("unexpected Ouroboros error: %s", err)) + }() + // Close Ouroboros connection + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error when closing Ouroboros object: %s", err) + } +} + +func TestServerHandshakeRefuseVersionMismatch(t *testing.T) { + expectedErr := fmt.Errorf("handshake failed: refused due to version mismatch") + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleServer, + []ouroboros_mock.ConversationEntry{ + // MsgProposeVersions from mock client + { + Type: ouroboros_mock.EntryTypeOutput, + ProtocolId: handshake.ProtocolId, + OutputMessages: []protocol.Message{ + handshake.NewMsgProposeVersions( + protocol.ProtocolVersionMap{ + (100 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic), + (101 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic), + (102 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic), + }, + ), + }, + }, + // MsgRefuse from server + { + Type: ouroboros_mock.EntryTypeInput, + IsResponse: true, + ProtocolId: handshake.ProtocolId, + MsgFromCborFunc: handshake.NewMsgFromCbor, + InputMessageType: handshake.MessageTypeRefuse, + InputMessage: handshake.NewMsgRefuse( + []any{ + handshake.RefuseReasonVersionMismatch, + protocol.GetProtocolVersionsNtC(), + }, + ), + }, + }, + ) + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithServer(true), + ) + if err != nil { + if err.Error() != expectedErr.Error() { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + } + // Async error handler + go func() { + err, ok := <-oConn.ErrorChan() + if !ok { + return + } + panic(fmt.Sprintf("unexpected Ouroboros error: %s", err)) + }() +}