From c8cbce0a484d404ba6e85559e2294f937eed922a Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Sat, 23 Mar 2024 16:15:58 -0500 Subject: [PATCH] feat: split conversation entry types Fixes #2 --- connection.go | 30 ++++++++++---------- entry.go | 77 +++++++++++++++++++++++++++------------------------ 2 files changed, 56 insertions(+), 51 deletions(-) diff --git a/connection.go b/connection.go index fbc7949..dc4415d 100644 --- a/connection.go +++ b/connection.go @@ -143,24 +143,24 @@ func (c *Connection) asyncLoop() { return default: } - switch entry.Type { - case EntryTypeInput: + switch entry := entry.(type) { + case ConversationEntryInput: if err := c.processInputEntry(entry); err != nil { panic(err.Error()) } - case EntryTypeOutput: + case ConversationEntryOutput: if err := c.processOutputEntry(entry); err != nil { panic(fmt.Sprintf("output error: %s", err)) } - case EntryTypeClose: + case ConversationEntryClose: c.Close() - case EntryTypeSleep: + case ConversationEntrySleep: time.Sleep(entry.Duration) default: panic( fmt.Sprintf( - "unknown conversation entry type: %d: %#v", - entry.Type, + "unknown conversation entry type: %T: %#v", + entry, entry, ), ) @@ -168,7 +168,7 @@ func (c *Connection) asyncLoop() { } } -func (c *Connection) processInputEntry(entry ConversationEntry) error { +func (c *Connection) processInputEntry(entry ConversationEntryInput) error { // Wait for segment to be received from muxer segment, ok := <-c.muxerRecvChan if !ok { @@ -193,7 +193,7 @@ func (c *Connection) processInputEntry(entry ConversationEntry) error { if err != nil { return fmt.Errorf("decode error: %s", err) } - if entry.InputMessage != nil { + if entry.Message != nil { // Create Message object from CBOR msg, err := entry.MsgFromCborFunc(uint(msgType), segment.Payload) if err != nil { @@ -208,25 +208,25 @@ func (c *Connection) processInputEntry(entry ConversationEntry) error { // As changing the CBOR of the expected message is not thread-safe, we instead clear the // CBOR of the received message msg.SetCbor(nil) - if !reflect.DeepEqual(msg, entry.InputMessage) { + if !reflect.DeepEqual(msg, entry.Message) { return fmt.Errorf( "parsed message does not match expected value: got %#v, expected %#v", msg, - entry.InputMessage, + entry.Message, ) } } else { - if entry.InputMessageType == uint(msgType) { + if entry.MessageType == uint(msgType) { return nil } - return fmt.Errorf("input message is not of expected type: expected %d, got %d", entry.InputMessageType, msgType) + return fmt.Errorf("input message is not of expected type: expected %d, got %d", entry.MessageType, msgType) } return nil } -func (c *Connection) processOutputEntry(entry ConversationEntry) error { +func (c *Connection) processOutputEntry(entry ConversationEntryOutput) error { payloadBuf := bytes.NewBuffer(nil) - for _, msg := range entry.OutputMessages { + for _, msg := range entry.Messages { // Get raw CBOR from message data := msg.Cbor() // If message has no raw CBOR, encode the message diff --git a/entry.go b/entry.go index 55e9016..468f92b 100644 --- a/entry.go +++ b/entry.go @@ -29,41 +29,51 @@ const ( MockKeepAliveCookie uint16 = 999 ) -type EntryType int +type ConversationEntry interface { + isConversationEntry() +} -const ( - EntryTypeNone EntryType = 0 - EntryTypeInput EntryType = 1 - EntryTypeOutput EntryType = 2 - EntryTypeClose EntryType = 3 - EntryTypeSleep EntryType = 4 -) +type conversationEntryBase struct{} + +func (c conversationEntryBase) isConversationEntry() {} + +type ConversationEntryInput struct { + conversationEntryBase + ProtocolId uint16 + IsResponse bool + Message protocol.Message + MessageType uint + MsgFromCborFunc protocol.MessageFromCborFunc +} -type ConversationEntry struct { - Type EntryType - ProtocolId uint16 - IsResponse bool - OutputMessages []protocol.Message - InputMessage protocol.Message - InputMessageType uint - MsgFromCborFunc protocol.MessageFromCborFunc - Duration time.Duration +type ConversationEntryOutput struct { + conversationEntryBase + ProtocolId uint16 + IsResponse bool + Messages []protocol.Message +} + +type ConversationEntryClose struct { + conversationEntryBase +} + +type ConversationEntrySleep struct { + conversationEntryBase + Duration time.Duration } // ConversationEntryHandshakeRequestGeneric is a pre-defined conversation event that matches a generic // handshake request from a client -var ConversationEntryHandshakeRequestGeneric = ConversationEntry{ - Type: EntryTypeInput, - ProtocolId: handshake.ProtocolId, - InputMessageType: handshake.MessageTypeProposeVersions, +var ConversationEntryHandshakeRequestGeneric = ConversationEntryInput{ + ProtocolId: handshake.ProtocolId, + MessageType: handshake.MessageTypeProposeVersions, } // ConversationEntryHandshakeNtCResponse is a pre-defined conversation entry for a server NtC handshake response -var ConversationEntryHandshakeNtCResponse = ConversationEntry{ - Type: EntryTypeOutput, +var ConversationEntryHandshakeNtCResponse = ConversationEntryOutput{ ProtocolId: handshake.ProtocolId, IsResponse: true, - OutputMessages: []protocol.Message{ + Messages: []protocol.Message{ handshake.NewMsgAcceptVersion( MockProtocolVersionNtC, protocol.VersionDataNtC9to14(MockNetworkMagic), @@ -72,11 +82,10 @@ var ConversationEntryHandshakeNtCResponse = ConversationEntry{ } // ConversationEntryHandshakeNtNResponse is a pre-defined conversation entry for a server NtN handshake response -var ConversationEntryHandshakeNtNResponse = ConversationEntry{ - Type: EntryTypeOutput, +var ConversationEntryHandshakeNtNResponse = ConversationEntryOutput{ ProtocolId: handshake.ProtocolId, IsResponse: true, - OutputMessages: []protocol.Message{ + Messages: []protocol.Message{ handshake.NewMsgAcceptVersion( MockProtocolVersionNtN, protocol.VersionDataNtN13andUp{ @@ -92,19 +101,17 @@ var ConversationEntryHandshakeNtNResponse = ConversationEntry{ } // ConversationEntryKeepAliveRequest is a pre-defined conversation entry for a keep-alive request -var ConversationEntryKeepAliveRequest = ConversationEntry{ - Type: EntryTypeInput, +var ConversationEntryKeepAliveRequest = ConversationEntryInput{ ProtocolId: keepalive.ProtocolId, - InputMessage: keepalive.NewMsgKeepAlive(MockKeepAliveCookie), + Message: keepalive.NewMsgKeepAlive(MockKeepAliveCookie), MsgFromCborFunc: keepalive.NewMsgFromCbor, } // ConversationEntryKeepAliveResponse is a pre-defined conversation entry for a keep-alive response -var ConversationEntryKeepAliveResponse = ConversationEntry{ - Type: EntryTypeOutput, +var ConversationEntryKeepAliveResponse = ConversationEntryOutput{ ProtocolId: keepalive.ProtocolId, IsResponse: true, - OutputMessages: []protocol.Message{ + Messages: []protocol.Message{ keepalive.NewMsgKeepAliveResponse(MockKeepAliveCookie), }, } @@ -130,7 +137,5 @@ var ConversationKeepAliveClose = []ConversationEntry{ ConversationEntryHandshakeRequestGeneric, ConversationEntryHandshakeNtNResponse, ConversationEntryKeepAliveRequest, - ConversationEntry{ - Type: EntryTypeClose, - }, + ConversationEntryClose{}, }