Skip to content

Commit

Permalink
Add support for MKI in use_srtp
Browse files Browse the repository at this point in the history
Resolves #650
  • Loading branch information
Sean-Der committed Jul 30, 2024
1 parent 7139e0e commit 7ab74fb
Show file tree
Hide file tree
Showing 13 changed files with 174 additions and 68 deletions.
4 changes: 4 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ type Config struct {
// Servers will assert that clients send one of these profiles and will respond as needed
SRTPProtectionProfiles []SRTPProtectionProfile

// SRTPMasterKeyIdentifier value (if any) is sent via the use_srtp
// extension for Clients and Servers
SRTPMasterKeyIdentifier []byte

// ClientAuth determines the server's policy for
// TLS Client Authentication. The default is NoClientCert.
ClientAuth ClientAuthType
Expand Down
10 changes: 10 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien
localSignatureSchemes: signatureSchemes,
extendedMasterSecret: config.ExtendedMasterSecret,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
localSRTPMasterKeyIdentifier: config.SRTPMasterKeyIdentifier,
serverName: serverName,
supportedProtocols: config.SupportedProtocols,
clientAuth: config.ClientAuth,
Expand Down Expand Up @@ -426,6 +427,15 @@ func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
return profile, true
}

// RemoteSRTPMasterKeyIdentifier returns the MasterKeyIdentifier value from the use_srtp
func (c *Conn) RemoteSRTPMasterKeyIdentifier() ([]byte, bool) {
if profile := c.state.getSRTPProtectionProfile(); profile == 0 {
return nil, false
}

return c.state.remoteSRTPMasterKeyIdentifier, true
}

func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
c.lock.Lock()
defer c.lock.Unlock()
Expand Down
42 changes: 28 additions & 14 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,12 +866,14 @@ func TestSRTPConfiguration(t *testing.T) {
defer report()

for _, test := range []struct {
Name string
ClientSRTP []SRTPProtectionProfile
ServerSRTP []SRTPProtectionProfile
ExpectedProfile SRTPProtectionProfile
WantClientError error
WantServerError error
Name string
ClientSRTP []SRTPProtectionProfile
ServerSRTP []SRTPProtectionProfile
ClientSRTPMasterKeyIdentifier []byte
ServerSRTPMasterKeyIdentifier []byte
ExpectedProfile SRTPProtectionProfile
WantClientError error
WantServerError error
}{
{
Name: "No SRTP in use",
Expand All @@ -882,12 +884,14 @@ func TestSRTPConfiguration(t *testing.T) {
WantServerError: nil,
},
{
Name: "SRTP both ends",
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
WantClientError: nil,
WantServerError: nil,
Name: "SRTP both ends",
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
ClientSRTPMasterKeyIdentifier: []byte("ClientSRTPMKI"),
ServerSRTPMasterKeyIdentifier: []byte("ServerSRTPMKI"),
WantClientError: nil,
WantServerError: nil,
},
{
Name: "SRTP client only",
Expand Down Expand Up @@ -933,11 +937,11 @@ func TestSRTPConfiguration(t *testing.T) {
c := make(chan result)

go func() {
client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ClientSRTP}, true)
client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ClientSRTP, SRTPMasterKeyIdentifier: test.ServerSRTPMasterKeyIdentifier}, true)
c <- result{client, err}
}()

server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ServerSRTP}, true)
server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ServerSRTP, SRTPMasterKeyIdentifier: test.ClientSRTPMasterKeyIdentifier}, true)
if !errors.Is(err, test.WantServerError) {
t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
}
Expand Down Expand Up @@ -969,6 +973,16 @@ func TestSRTPConfiguration(t *testing.T) {
if actualServerSRTP != test.ExpectedProfile {
t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP)
}

actualServerMKI, _ := server.RemoteSRTPMasterKeyIdentifier()
if !bytes.Equal(actualServerMKI, test.ServerSRTPMasterKeyIdentifier) {
t.Errorf("TestSRTPConfiguration: Server SRTPMKI Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ServerSRTPMasterKeyIdentifier, actualServerMKI)
}

actualClientMKI, _ := res.c.RemoteSRTPMasterKeyIdentifier()
if !bytes.Equal(actualClientMKI, test.ClientSRTPMasterKeyIdentifier) {
t.Errorf("TestSRTPConfiguration: Client SRTPMKI Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ClientSRTPMasterKeyIdentifier, actualClientMKI)
}
}
}

Expand Down
1 change: 1 addition & 0 deletions flight0handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile
}
state.setSRTPProtectionProfile(profile)
state.remoteSRTPMasterKeyIdentifier = e.MasterKeyIdentifier
case *extension.UseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
Expand Down
3 changes: 2 additions & 1 deletion flight1handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha

if len(cfg.localSRTPProtectionProfiles) > 0 {
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: cfg.localSRTPProtectionProfiles,
ProtectionProfiles: cfg.localSRTPProtectionProfiles,
MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier,
})
}

Expand Down
1 change: 1 addition & 0 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile
}
state.setSRTPProtectionProfile(profile)
state.remoteSRTPMasterKeyIdentifier = e.MasterKeyIdentifier
case *extension.UseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
Expand Down
3 changes: 2 additions & 1 deletion flight4bhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha
}
if state.getSRTPProtectionProfile() != 0 {
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier,
})
}

Expand Down
3 changes: 2 additions & 1 deletion flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha
}
if state.getSRTPProtectionProfile() != 0 {
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier,
})
}
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
Expand Down
49 changes: 25 additions & 24 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,30 +93,31 @@ type handshakeFSM struct {
}

type handshakeConfig struct {
localPSKCallback PSKCallback
localPSKIdentityHint []byte
localCipherSuites []CipherSuite // Available CipherSuites
localSignatureSchemes []signaturehash.Algorithm // Available signature schemes
extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension
localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support
serverName string
supportedProtocols []string
clientAuth ClientAuthType // If we are a client should we request a client certificate
localCertificates []tls.Certificate
nameToCertificate map[string]*tls.Certificate
insecureSkipVerify bool
verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
verifyConnection func(*State) error
sessionStore SessionStore
rootCAs *x509.CertPool
clientCAs *x509.CertPool
initialRetransmitInterval time.Duration
disableRetransmitBackoff bool
customCipherSuites func() []CipherSuite
ellipticCurves []elliptic.Curve
insecureSkipHelloVerify bool
connectionIDGenerator func() []byte
helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte
localPSKCallback PSKCallback
localPSKIdentityHint []byte
localCipherSuites []CipherSuite // Available CipherSuites
localSignatureSchemes []signaturehash.Algorithm // Available signature schemes
extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension
localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support
localSRTPMasterKeyIdentifier []byte
serverName string
supportedProtocols []string
clientAuth ClientAuthType // If we are a client should we request a client certificate
localCertificates []tls.Certificate
nameToCertificate map[string]*tls.Certificate
insecureSkipVerify bool
verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
verifyConnection func(*State) error
sessionStore SessionStore
rootCAs *x509.CertPool
clientCAs *x509.CertPool
initialRetransmitInterval time.Duration
disableRetransmitBackoff bool
customCipherSuites func() []CipherSuite
ellipticCurves []elliptic.Curve
insecureSkipHelloVerify bool
connectionIDGenerator func() []byte
helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte

onFlightState func(flightVal, handshakeState)
log logging.LeveledLogger
Expand Down
15 changes: 8 additions & 7 deletions pkg/protocol/extension/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (

var (
// ErrALPNInvalidFormat is raised when the ALPN format is invalid
ErrALPNInvalidFormat = &protocol.FatalError{Err: errors.New("invalid alpn format")} //nolint:goerr113
errALPNNoAppProto = &protocol.FatalError{Err: errors.New("no application protocol")} //nolint:goerr113
errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errInvalidExtensionType = &protocol.FatalError{Err: errors.New("invalid extension type")} //nolint:goerr113
errInvalidSNIFormat = &protocol.FatalError{Err: errors.New("invalid server name format")} //nolint:goerr113
errInvalidCIDFormat = &protocol.FatalError{Err: errors.New("invalid connection ID format")} //nolint:goerr113
errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113
ErrALPNInvalidFormat = &protocol.FatalError{Err: errors.New("invalid alpn format")} //nolint:goerr113
errALPNNoAppProto = &protocol.FatalError{Err: errors.New("no application protocol")} //nolint:goerr113
errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errInvalidExtensionType = &protocol.FatalError{Err: errors.New("invalid extension type")} //nolint:goerr113
errInvalidSNIFormat = &protocol.FatalError{Err: errors.New("invalid server name format")} //nolint:goerr113
errInvalidCIDFormat = &protocol.FatalError{Err: errors.New("invalid connection ID format")} //nolint:goerr113
errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113
errMasterKeyIdentifierTooLarge = &protocol.FatalError{Err: errors.New("master key identifier is over 255 bytes")} //nolint:goerr113
)
27 changes: 22 additions & 5 deletions pkg/protocol/extension/use_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

package extension

import "encoding/binary"
import (
"encoding/binary"
)

const (
useSRTPHeaderSize = 6
Expand All @@ -14,7 +16,8 @@ const (
//
// https://tools.ietf.org/html/rfc8422
type UseSRTP struct {
ProtectionProfiles []SRTPProtectionProfile
ProtectionProfiles []SRTPProtectionProfile
MasterKeyIdentifier []byte
}

// TypeValue returns the extension TypeValue
Expand All @@ -27,15 +30,20 @@ func (u *UseSRTP) Marshal() ([]byte, error) {
out := make([]byte, useSRTPHeaderSize)

binary.BigEndian.PutUint16(out, uint16(u.TypeValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(u.ProtectionProfiles)*2)+ /* MKI Length */ 1))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(u.ProtectionProfiles)*2)+ /* MKI Length */ 1+len(u.MasterKeyIdentifier)))
binary.BigEndian.PutUint16(out[4:], uint16(len(u.ProtectionProfiles)*2))

for _, v := range u.ProtectionProfiles {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v))
}
if len(u.MasterKeyIdentifier) > 255 {
return nil, errMasterKeyIdentifierTooLarge
}

out = append(out, byte(len(u.MasterKeyIdentifier)))
out = append(out, u.MasterKeyIdentifier...)

out = append(out, 0x00) /* MKI Length */
return out, nil
}

Expand All @@ -48,7 +56,8 @@ func (u *UseSRTP) Unmarshal(data []byte) error {
}

profileCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if supportedGroupsHeaderSize+(profileCount*2) > len(data) {
masterKeyIdentifierIndex := supportedGroupsHeaderSize + (profileCount * 2)
if masterKeyIdentifierIndex+1 > len(data) {
return errLengthMismatch
}

Expand All @@ -58,5 +67,13 @@ func (u *UseSRTP) Unmarshal(data []byte) error {
u.ProtectionProfiles = append(u.ProtectionProfiles, supportedProfile)
}
}

masterKeyIdentifierLen := int(data[masterKeyIdentifierIndex])
if masterKeyIdentifierIndex+masterKeyIdentifierLen >= len(data) {
return errLengthMismatch
}

u.MasterKeyIdentifier = append([]byte{}, data[masterKeyIdentifierIndex+1:masterKeyIdentifierIndex+1+masterKeyIdentifierLen]...)

return nil
}
74 changes: 63 additions & 11 deletions pkg/protocol/extension/use_srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,72 @@
package extension

import (
"errors"
"reflect"
"testing"
)

func TestExtensionUseSRTP(t *testing.T) {
rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00}
parsedUseSRTP := &UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
}

raw, err := parsedUseSRTP.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(raw, rawUseSRTP) {
t.Errorf("extensionUseSRTP marshal: got %#v, want %#v", raw, rawUseSRTP)
}
t.Run("No MasterKeyIdentifier", func(t *testing.T) {
rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00}
parsedUseSRTP := &UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
MasterKeyIdentifier: []byte{},
}

marshaled, err := parsedUseSRTP.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(marshaled, rawUseSRTP) {
t.Errorf("extensionUseSRTP marshal: got %#v, want %#v", marshaled, rawUseSRTP)
}

unmarshaled := &UseSRTP{}
if err := unmarshaled.Unmarshal(rawUseSRTP); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(unmarshaled, parsedUseSRTP) {
t.Errorf("extensionUseSRTP unmarshal: got %#v, want %#v", unmarshaled, parsedUseSRTP)
}
})

t.Run("With MasterKeyIdentifier", func(t *testing.T) {
rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x0a, 0x00, 0x02, 0x00, 0x01, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE}
parsedUseSRTP := &UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
MasterKeyIdentifier: []byte{0xA, 0xB, 0xC, 0xD, 0xE},
}

marshaled, err := parsedUseSRTP.Marshal()
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(marshaled, rawUseSRTP) {
t.Errorf("extensionUseSRTP marshal: got %#v, want %#v", marshaled, rawUseSRTP)
}

unmarshaled := &UseSRTP{}
if err := unmarshaled.Unmarshal(rawUseSRTP); err != nil {
t.Error(err)
} else if !reflect.DeepEqual(unmarshaled, parsedUseSRTP) {
t.Errorf("extensionUseSRTP unmarshal: got %#v, want %#v", unmarshaled, parsedUseSRTP)
}
})

t.Run("Invalid Lengths", func(t *testing.T) {
unmarshaled := &UseSRTP{}

if err := unmarshaled.Unmarshal([]byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x04, 0x00, 0x01, 0x00}); !errors.Is(errLengthMismatch, err) {
t.Error(err)
}

if err := unmarshaled.Unmarshal([]byte{0x00, 0x0e, 0x00, 0x0a, 0x00, 0x02, 0x00, 0x01, 0x01}); !errors.Is(errLengthMismatch, err) {
t.Error(err)
}

if _, err := (&UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
MasterKeyIdentifier: make([]byte, 500),
}).Marshal(); !errors.Is(errMasterKeyIdentifierTooLarge, err) {
panic(err)
}
})
}
10 changes: 6 additions & 4 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ type State struct {
cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen
CipherSuiteID CipherSuiteID

srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile
PeerCertificates [][]byte
IdentityHint []byte
SessionID []byte
srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile
remoteSRTPMasterKeyIdentifier []byte

PeerCertificates [][]byte
IdentityHint []byte
SessionID []byte

// Connection Identifiers must be negotiated afresh on session resumption.
// https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension
Expand Down

0 comments on commit 7ab74fb

Please sign in to comment.