diff --git a/config.go b/config.go index 2fb24195..54a86c0e 100644 --- a/config.go +++ b/config.go @@ -45,6 +45,10 @@ type Config struct { // Servers will assert that clients send one of these profiles and will respond as needed SRTPProtectionProfiles []SRTPProtectionProfile + // SRTPMasterKeyIdentifier value (if any) is sent via the use_srtp + // extension for Clients and Servers + SRTPMasterKeyIdentifier []byte + // ClientAuth determines the server's policy for // TLS Client Authentication. The default is NoClientCert. ClientAuth ClientAuthType diff --git a/conn.go b/conn.go index ce323473..459cdcf5 100644 --- a/conn.go +++ b/conn.go @@ -164,6 +164,7 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien localSignatureSchemes: signatureSchemes, extendedMasterSecret: config.ExtendedMasterSecret, localSRTPProtectionProfiles: config.SRTPProtectionProfiles, + localSRTPMasterKeyIdentifier: config.SRTPMasterKeyIdentifier, serverName: serverName, supportedProtocols: config.SupportedProtocols, clientAuth: config.ClientAuth, @@ -426,6 +427,15 @@ func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { return profile, true } +// RemoteSRTPMasterKeyIdentifier returns the MasterKeyIdentifier value from the use_srtp +func (c *Conn) RemoteSRTPMasterKeyIdentifier() ([]byte, bool) { + if profile := c.state.getSRTPProtectionProfile(); profile == 0 { + return nil, false + } + + return c.state.remoteSRTPMasterKeyIdentifier, true +} + func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { c.lock.Lock() defer c.lock.Unlock() diff --git a/conn_test.go b/conn_test.go index 7192178a..c66e0381 100644 --- a/conn_test.go +++ b/conn_test.go @@ -866,12 +866,14 @@ func TestSRTPConfiguration(t *testing.T) { defer report() for _, test := range []struct { - Name string - ClientSRTP []SRTPProtectionProfile - ServerSRTP []SRTPProtectionProfile - ExpectedProfile SRTPProtectionProfile - WantClientError error - WantServerError error + Name string + ClientSRTP []SRTPProtectionProfile + ServerSRTP []SRTPProtectionProfile + ClientSRTPMasterKeyIdentifier []byte + ServerSRTPMasterKeyIdentifier []byte + ExpectedProfile SRTPProtectionProfile + WantClientError error + WantServerError error }{ { Name: "No SRTP in use", @@ -882,12 +884,14 @@ func TestSRTPConfiguration(t *testing.T) { WantServerError: nil, }, { - Name: "SRTP both ends", - ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, - ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, - ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, - WantClientError: nil, - WantServerError: nil, + Name: "SRTP both ends", + ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, + ClientSRTPMasterKeyIdentifier: []byte("ClientSRTPMKI"), + ServerSRTPMasterKeyIdentifier: []byte("ServerSRTPMKI"), + WantClientError: nil, + WantServerError: nil, }, { Name: "SRTP client only", @@ -933,11 +937,11 @@ func TestSRTPConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ClientSRTP}, true) + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ClientSRTP, SRTPMasterKeyIdentifier: test.ServerSRTPMasterKeyIdentifier}, true) c <- result{client, err} }() - server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ServerSRTP}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ServerSRTP, SRTPMasterKeyIdentifier: test.ClientSRTPMasterKeyIdentifier}, true) if !errors.Is(err, test.WantServerError) { t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) } @@ -969,6 +973,16 @@ func TestSRTPConfiguration(t *testing.T) { if actualServerSRTP != test.ExpectedProfile { t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP) } + + actualServerMKI, _ := server.RemoteSRTPMasterKeyIdentifier() + if !bytes.Equal(actualServerMKI, test.ServerSRTPMasterKeyIdentifier) { + t.Errorf("TestSRTPConfiguration: Server SRTPMKI Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ServerSRTPMasterKeyIdentifier, actualServerMKI) + } + + actualClientMKI, _ := res.c.RemoteSRTPMasterKeyIdentifier() + if !bytes.Equal(actualClientMKI, test.ClientSRTPMasterKeyIdentifier) { + t.Errorf("TestSRTPConfiguration: Client SRTPMKI Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ClientSRTPMasterKeyIdentifier, actualClientMKI) + } } } diff --git a/flight0handler.go b/flight0handler.go index 77d40327..7bb528f1 100644 --- a/flight0handler.go +++ b/flight0handler.go @@ -67,6 +67,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile } state.setSRTPProtectionProfile(profile) + state.remoteSRTPMasterKeyIdentifier = e.MasterKeyIdentifier case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true diff --git a/flight1handler.go b/flight1handler.go index 69019c91..60215c08 100644 --- a/flight1handler.go +++ b/flight1handler.go @@ -91,7 +91,8 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha if len(cfg.localSRTPProtectionProfiles) > 0 { extensions = append(extensions, &extension.UseSRTP{ - ProtectionProfiles: cfg.localSRTPProtectionProfiles, + ProtectionProfiles: cfg.localSRTPProtectionProfiles, + MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } diff --git a/flight3handler.go b/flight3handler.go index 39d9380c..f27c01a7 100644 --- a/flight3handler.go +++ b/flight3handler.go @@ -57,6 +57,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile } state.setSRTPProtectionProfile(profile) + state.remoteSRTPMasterKeyIdentifier = e.MasterKeyIdentifier case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true diff --git a/flight4bhandler.go b/flight4bhandler.go index 71e7044c..d87a1fee 100644 --- a/flight4bhandler.go +++ b/flight4bhandler.go @@ -61,7 +61,8 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha } if state.getSRTPProtectionProfile() != 0 { extensions = append(extensions, &extension.UseSRTP{ - ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, + ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, + MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } diff --git a/flight4handler.go b/flight4handler.go index af8cf6b4..5f867688 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -230,7 +230,8 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha } if state.getSRTPProtectionProfile() != 0 { extensions = append(extensions, &extension.UseSRTP{ - ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, + ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, + MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { diff --git a/handshaker.go b/handshaker.go index 0c2a4364..946cf4bc 100644 --- a/handshaker.go +++ b/handshaker.go @@ -93,30 +93,31 @@ type handshakeFSM struct { } type handshakeConfig struct { - localPSKCallback PSKCallback - localPSKIdentityHint []byte - localCipherSuites []CipherSuite // Available CipherSuites - localSignatureSchemes []signaturehash.Algorithm // Available signature schemes - extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension - localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support - serverName string - supportedProtocols []string - clientAuth ClientAuthType // If we are a client should we request a client certificate - localCertificates []tls.Certificate - nameToCertificate map[string]*tls.Certificate - insecureSkipVerify bool - verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error - verifyConnection func(*State) error - sessionStore SessionStore - rootCAs *x509.CertPool - clientCAs *x509.CertPool - initialRetransmitInterval time.Duration - disableRetransmitBackoff bool - customCipherSuites func() []CipherSuite - ellipticCurves []elliptic.Curve - insecureSkipHelloVerify bool - connectionIDGenerator func() []byte - helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte + localPSKCallback PSKCallback + localPSKIdentityHint []byte + localCipherSuites []CipherSuite // Available CipherSuites + localSignatureSchemes []signaturehash.Algorithm // Available signature schemes + extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension + localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support + localSRTPMasterKeyIdentifier []byte + serverName string + supportedProtocols []string + clientAuth ClientAuthType // If we are a client should we request a client certificate + localCertificates []tls.Certificate + nameToCertificate map[string]*tls.Certificate + insecureSkipVerify bool + verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + verifyConnection func(*State) error + sessionStore SessionStore + rootCAs *x509.CertPool + clientCAs *x509.CertPool + initialRetransmitInterval time.Duration + disableRetransmitBackoff bool + customCipherSuites func() []CipherSuite + ellipticCurves []elliptic.Curve + insecureSkipHelloVerify bool + connectionIDGenerator func() []byte + helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte onFlightState func(flightVal, handshakeState) log logging.LeveledLogger diff --git a/pkg/protocol/extension/errors.go b/pkg/protocol/extension/errors.go index 6b4c9229..5999c96f 100644 --- a/pkg/protocol/extension/errors.go +++ b/pkg/protocol/extension/errors.go @@ -11,11 +11,12 @@ import ( var ( // ErrALPNInvalidFormat is raised when the ALPN format is invalid - ErrALPNInvalidFormat = &protocol.FatalError{Err: errors.New("invalid alpn format")} //nolint:goerr113 - errALPNNoAppProto = &protocol.FatalError{Err: errors.New("no application protocol")} //nolint:goerr113 - errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 - errInvalidExtensionType = &protocol.FatalError{Err: errors.New("invalid extension type")} //nolint:goerr113 - errInvalidSNIFormat = &protocol.FatalError{Err: errors.New("invalid server name format")} //nolint:goerr113 - errInvalidCIDFormat = &protocol.FatalError{Err: errors.New("invalid connection ID format")} //nolint:goerr113 - errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 + ErrALPNInvalidFormat = &protocol.FatalError{Err: errors.New("invalid alpn format")} //nolint:goerr113 + errALPNNoAppProto = &protocol.FatalError{Err: errors.New("no application protocol")} //nolint:goerr113 + errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 + errInvalidExtensionType = &protocol.FatalError{Err: errors.New("invalid extension type")} //nolint:goerr113 + errInvalidSNIFormat = &protocol.FatalError{Err: errors.New("invalid server name format")} //nolint:goerr113 + errInvalidCIDFormat = &protocol.FatalError{Err: errors.New("invalid connection ID format")} //nolint:goerr113 + errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 + errMasterKeyIdentifierTooLarge = &protocol.FatalError{Err: errors.New("master key identifier is over 255 bytes")} //nolint:goerr113 ) diff --git a/pkg/protocol/extension/use_srtp.go b/pkg/protocol/extension/use_srtp.go index ea9f1087..6d5f54b2 100644 --- a/pkg/protocol/extension/use_srtp.go +++ b/pkg/protocol/extension/use_srtp.go @@ -3,7 +3,9 @@ package extension -import "encoding/binary" +import ( + "encoding/binary" +) const ( useSRTPHeaderSize = 6 @@ -14,7 +16,8 @@ const ( // // https://tools.ietf.org/html/rfc8422 type UseSRTP struct { - ProtectionProfiles []SRTPProtectionProfile + ProtectionProfiles []SRTPProtectionProfile + MasterKeyIdentifier []byte } // TypeValue returns the extension TypeValue @@ -27,15 +30,20 @@ func (u *UseSRTP) Marshal() ([]byte, error) { out := make([]byte, useSRTPHeaderSize) binary.BigEndian.PutUint16(out, uint16(u.TypeValue())) - binary.BigEndian.PutUint16(out[2:], uint16(2+(len(u.ProtectionProfiles)*2)+ /* MKI Length */ 1)) + binary.BigEndian.PutUint16(out[2:], uint16(2+(len(u.ProtectionProfiles)*2)+ /* MKI Length */ 1+len(u.MasterKeyIdentifier))) binary.BigEndian.PutUint16(out[4:], uint16(len(u.ProtectionProfiles)*2)) for _, v := range u.ProtectionProfiles { out = append(out, []byte{0x00, 0x00}...) binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v)) } + if len(u.MasterKeyIdentifier) > 255 { + return nil, errMasterKeyIdentifierTooLarge + } + + out = append(out, byte(len(u.MasterKeyIdentifier))) + out = append(out, u.MasterKeyIdentifier...) - out = append(out, 0x00) /* MKI Length */ return out, nil } @@ -48,7 +56,8 @@ func (u *UseSRTP) Unmarshal(data []byte) error { } profileCount := int(binary.BigEndian.Uint16(data[4:]) / 2) - if supportedGroupsHeaderSize+(profileCount*2) > len(data) { + masterKeyIdentifierIndex := supportedGroupsHeaderSize + (profileCount * 2) + if masterKeyIdentifierIndex+1 > len(data) { return errLengthMismatch } @@ -58,5 +67,13 @@ func (u *UseSRTP) Unmarshal(data []byte) error { u.ProtectionProfiles = append(u.ProtectionProfiles, supportedProfile) } } + + masterKeyIdentifierLen := int(data[masterKeyIdentifierIndex]) + if masterKeyIdentifierIndex+masterKeyIdentifierLen >= len(data) { + return errLengthMismatch + } + + u.MasterKeyIdentifier = append([]byte{}, data[masterKeyIdentifierIndex+1:masterKeyIdentifierIndex+1+masterKeyIdentifierLen]...) + return nil } diff --git a/pkg/protocol/extension/use_srtp_test.go b/pkg/protocol/extension/use_srtp_test.go index 25b7b9e1..36e284cd 100644 --- a/pkg/protocol/extension/use_srtp_test.go +++ b/pkg/protocol/extension/use_srtp_test.go @@ -4,20 +4,72 @@ package extension import ( + "errors" "reflect" "testing" ) func TestExtensionUseSRTP(t *testing.T) { - rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00} - parsedUseSRTP := &UseSRTP{ - ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, - } - - raw, err := parsedUseSRTP.Marshal() - if err != nil { - t.Error(err) - } else if !reflect.DeepEqual(raw, rawUseSRTP) { - t.Errorf("extensionUseSRTP marshal: got %#v, want %#v", raw, rawUseSRTP) - } + t.Run("No MasterKeyIdentifier", func(t *testing.T) { + rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00} + parsedUseSRTP := &UseSRTP{ + ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + MasterKeyIdentifier: []byte{}, + } + + marshaled, err := parsedUseSRTP.Marshal() + if err != nil { + t.Error(err) + } else if !reflect.DeepEqual(marshaled, rawUseSRTP) { + t.Errorf("extensionUseSRTP marshal: got %#v, want %#v", marshaled, rawUseSRTP) + } + + unmarshaled := &UseSRTP{} + if err := unmarshaled.Unmarshal(rawUseSRTP); err != nil { + t.Error(err) + } else if !reflect.DeepEqual(unmarshaled, parsedUseSRTP) { + t.Errorf("extensionUseSRTP unmarshal: got %#v, want %#v", unmarshaled, parsedUseSRTP) + } + }) + + t.Run("With MasterKeyIdentifier", func(t *testing.T) { + rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x0a, 0x00, 0x02, 0x00, 0x01, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE} + parsedUseSRTP := &UseSRTP{ + ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + MasterKeyIdentifier: []byte{0xA, 0xB, 0xC, 0xD, 0xE}, + } + + marshaled, err := parsedUseSRTP.Marshal() + if err != nil { + t.Error(err) + } else if !reflect.DeepEqual(marshaled, rawUseSRTP) { + t.Errorf("extensionUseSRTP marshal: got %#v, want %#v", marshaled, rawUseSRTP) + } + + unmarshaled := &UseSRTP{} + if err := unmarshaled.Unmarshal(rawUseSRTP); err != nil { + t.Error(err) + } else if !reflect.DeepEqual(unmarshaled, parsedUseSRTP) { + t.Errorf("extensionUseSRTP unmarshal: got %#v, want %#v", unmarshaled, parsedUseSRTP) + } + }) + + t.Run("Invalid Lengths", func(t *testing.T) { + unmarshaled := &UseSRTP{} + + if err := unmarshaled.Unmarshal([]byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x04, 0x00, 0x01, 0x00}); !errors.Is(errLengthMismatch, err) { + t.Error(err) + } + + if err := unmarshaled.Unmarshal([]byte{0x00, 0x0e, 0x00, 0x0a, 0x00, 0x02, 0x00, 0x01, 0x01}); !errors.Is(errLengthMismatch, err) { + t.Error(err) + } + + if _, err := (&UseSRTP{ + ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + MasterKeyIdentifier: make([]byte, 500), + }).Marshal(); !errors.Is(errMasterKeyIdentifierTooLarge, err) { + panic(err) + } + }) } diff --git a/state.go b/state.go index 27b1ebb3..f2d6df6f 100644 --- a/state.go +++ b/state.go @@ -24,10 +24,12 @@ type State struct { cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen CipherSuiteID CipherSuiteID - srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile - PeerCertificates [][]byte - IdentityHint []byte - SessionID []byte + srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile + remoteSRTPMasterKeyIdentifier []byte + + PeerCertificates [][]byte + IdentityHint []byte + SessionID []byte // Connection Identifiers must be negotiated afresh on session resumption. // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension