diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index 7e10bcf..b6ea0d0 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -43,6 +43,8 @@ var ( ErrIngressClosing = psrpc.NewErrorf(psrpc.Unavailable, "ingress closing") ErrMissingStreamKey = psrpc.NewErrorf(psrpc.InvalidArgument, "missing stream key") ErrPrerollBufferReset = psrpc.NewErrorf(psrpc.Internal, "preroll buffer reset") + ErrInvalidSimulcast = psrpc.NewErrorf(psrpc.NotAcceptable, "invalid simulcast configuration") + ErrSimulcastTranscode = psrpc.NewErrorf(psrpc.NotAcceptable, "simulcast is not supported when transcoding") ) func New(err string) error { diff --git a/pkg/whip/sdk_media_sink.go b/pkg/whip/sdk_media_sink.go index 92ae8de..864d2da 100644 --- a/pkg/whip/sdk_media_sink.go +++ b/pkg/whip/sdk_media_sink.go @@ -19,6 +19,7 @@ import ( "context" "io" "strings" + "sync" "sync/atomic" "time" @@ -43,18 +44,31 @@ var ( ErrParamsUnavailable = psrpc.NewErrorf(psrpc.InvalidArgument, "codec parameters unavailable in sample") ) +type SDKMediaSinkTrack struct { + readySamples chan *sample + writePLI func() + + quality livekit.VideoQuality + width, height uint + + sink *SDKMediaSink +} + type SDKMediaSink struct { logger logger.Logger params *params.Params - writePLI func() - track *webrtc.TrackRemote outputSync *utils.TrackOutputSynchronizer trackStatsGatherer atomic.Pointer[stats.MediaTrackStatGatherer] sdkOutput *lksdk_output.LKSDKOutput + sinkInitialized bool - readySamples chan *sample - fuse core.Fuse - trackInitialized bool + codecParameters webrtc.RTPCodecParameters + streamKind types.StreamKind + + tracksLock sync.Mutex + tracks []*SDKMediaSinkTrack + + fuse core.Fuse } type sample struct { @@ -62,80 +76,54 @@ type sample struct { ts time.Duration } -func NewSDKMediaSink(l logger.Logger, p *params.Params, sdkOutput *lksdk_output.LKSDKOutput, track *webrtc.TrackRemote, outputSync *utils.TrackOutputSynchronizer, writePLI func()) *SDKMediaSink { - s := &SDKMediaSink{ - logger: l, - params: p, - writePLI: writePLI, - track: track, - outputSync: outputSync, - sdkOutput: sdkOutput, - readySamples: make(chan *sample, 15), - fuse: core.NewFuse(), +func NewSDKMediaSink( + l logger.Logger, p *params.Params, sdkOutput *lksdk_output.LKSDKOutput, + codecParameters webrtc.RTPCodecParameters, streamKind types.StreamKind, + outputSync *utils.TrackOutputSynchronizer, +) *SDKMediaSink { + return &SDKMediaSink{ + logger: l, + params: p, + outputSync: outputSync, + sdkOutput: sdkOutput, + fuse: core.NewFuse(), + tracks: []*SDKMediaSinkTrack{}, + streamKind: streamKind, + codecParameters: codecParameters, } - - return s } -func (sp *SDKMediaSink) PushSample(s *media.Sample, ts time.Duration) error { - if sp.fuse.IsBroken() { - return io.EOF - } +func (sp *SDKMediaSink) AddTrack(quality livekit.VideoQuality) { + sp.tracksLock.Lock() + defer sp.tracksLock.Unlock() - err := sp.ensureTrackInitialized(s) - if err != nil { - return err - } - if !sp.trackInitialized { - // Drop the sample - return nil - } - - // Synchronize the outputs before the network jitter buffer to avoid old samples stuck - // in the channel from increasing the whole pipeline delay. - drop, err := sp.outputSync.WaitForMediaTime(ts) - if err != nil { - return err - } - if drop { - sp.logger.Debugw("dropping sample", "timestamp", ts) - return nil - } - - select { - case <-sp.fuse.Watch(): - return io.EOF - case sp.readySamples <- &sample{s, ts}: - default: - // drop the sample if the output queue is full. This is needed if we are reconnecting. - } - - return nil + sp.tracks = append(sp.tracks, &SDKMediaSinkTrack{ + readySamples: make(chan *sample, 15), + sink: sp, + quality: quality, + }) } -func (sp *SDKMediaSink) NextSample(ctx context.Context) (media.Sample, error) { - for { - select { - case <-sp.fuse.Watch(): - case <-ctx.Done(): - return media.Sample{}, io.EOF - case s := <-sp.readySamples: - g := sp.trackStatsGatherer.Load() - if g != nil { - g.MediaReceived(int64(len(s.s.Data))) - } +func (sp *SDKMediaSink) SetWritePLI(quality livekit.VideoQuality, writePLI func()) *SDKMediaSinkTrack { + sp.tracksLock.Lock() + defer sp.tracksLock.Unlock() - return *s.s, nil + for i := range sp.tracks { + if sp.tracks[i].quality == quality { + sp.tracks[i].writePLI = writePLI + return sp.tracks[i] } } + + return nil } -func (sp *SDKMediaSink) SetStatsGatherer(st *stats.LocalMediaStatsGatherer) { +func (t *SDKMediaSinkTrack) SetStatsGatherer(st *stats.LocalMediaStatsGatherer) { var path string - switch sp.track.Kind() { - case webrtc.RTPCodecTypeAudio: + switch t.sink.streamKind { + case types.Audio: path = stats.OutputAudio - case webrtc.RTPCodecTypeVideo: + case types.Video: path = stats.OutputVideo default: path = "output.unknown" @@ -143,88 +131,173 @@ func (sp *SDKMediaSink) SetStatsGatherer(st *stats.LocalMediaStatsGatherer) { g := st.RegisterTrackStats(path) - sp.trackStatsGatherer.Store(g) + t.sink.trackStatsGatherer.Store(g) } -func (sp *SDKMediaSink) OnBind() error { - sp.logger.Infow("media sink bound") +func (sp *SDKMediaSink) Close() error { + sp.fuse.Break() + sp.outputSync.Close() return nil } -func (sp *SDKMediaSink) OnUnbind() error { - sp.logger.Infow("media sink unbound") +func (sp *SDKMediaSink) ensureAudioTracksInitialized(s *media.Sample, t *SDKMediaSinkTrack) (bool, error) { + stereo := strings.Contains(sp.codecParameters.SDPFmtpLine, "sprop-stereo=1") + audioState := getAudioState(sp.codecParameters.MimeType, stereo, sp.codecParameters.ClockRate) + sp.params.SetInputAudioState(context.Background(), audioState, true) - return nil + sp.logger.Infow("adding audio track", "stereo", stereo, "codec", sp.codecParameters.MimeType) + if err := sp.sdkOutput.AddAudioTrack(t, sp.codecParameters.MimeType, false, stereo); err != nil { + return false, err + } + sp.sinkInitialized = true + return sp.sinkInitialized, nil } -func (sp *SDKMediaSink) ForceKeyFrame() error { - if sp.writePLI != nil { - sp.writePLI() +func (sp *SDKMediaSink) ensureVideoTracksInitialized(s *media.Sample, t *SDKMediaSinkTrack) (bool, error) { + var err error + t.width, t.height, err = getVideoParams(sp.codecParameters.MimeType, s) + switch err { + case nil: + // continue + case ErrParamsUnavailable: + return false, nil + default: + return false, err } - return nil -} + layers := []*livekit.VideoLayer{} + sampleProviders := []lksdk_output.VideoSampleProvider{} + + for _, track := range sp.tracks { + if track.width != 0 && track.height != 0 { + layers = append(layers, &livekit.VideoLayer{ + Width: uint32(track.width), + Height: uint32(track.height), + Quality: track.quality, + }) + sampleProviders = append(sampleProviders, track) + } + } -func (sp *SDKMediaSink) SetWriter(w io.WriteCloser) error { - return psrpc.Unimplemented -} + // Simulcast + if len(sp.tracks) > 1 { + if len(layers) != len(sp.tracks) { + return false, nil + } + } else { + // Non-simulcast + if len(layers) != 1 { + return false, nil + } -func (sp *SDKMediaSink) Close() error { - sp.fuse.Break() - sp.outputSync.Close() + } + + if len(layers) != 0 { + videoState := getVideoState(sp.codecParameters.MimeType, uint(layers[0].Width), uint(layers[0].Height)) + sp.params.SetInputVideoState(context.Background(), videoState, true) + } + + if err := sp.sdkOutput.AddVideoTrack(sampleProviders, layers, sp.codecParameters.MimeType); err != nil { + return false, err + } + + for _, l := range layers { + sp.logger.Infow("adding video track", "width", l.Width, "height", l.Height, "codec", sp.codecParameters.MimeType) + } + sp.sinkInitialized = true + + return sp.sinkInitialized, nil - return nil } -func (sp *SDKMediaSink) ensureTrackInitialized(s *media.Sample) error { - if sp.trackInitialized { - return nil +func (sp *SDKMediaSink) ensureTracksInitialized(s *media.Sample, t *SDKMediaSinkTrack) (bool, error) { + sp.tracksLock.Lock() + defer sp.tracksLock.Unlock() + + if sp.sinkInitialized { + return sp.sinkInitialized, nil } - kind := streamKindFromCodecType(sp.track.Kind()) - mimeType := sp.track.Codec().MimeType + if sp.streamKind == types.Audio { + return sp.ensureAudioTracksInitialized(s, t) + } - switch kind { - case types.Audio: - stereo := parseAudioFmtp(sp.track.Codec().SDPFmtpLine) - audioState := getAudioState(sp.track.Codec().MimeType, stereo, sp.track.Codec().ClockRate) - sp.params.SetInputAudioState(context.Background(), audioState, true) + return sp.ensureVideoTracksInitialized(s, t) +} - sp.logger.Infow("adding audio track", "stereo", stereo, "codec", mimeType) - sp.sdkOutput.AddAudioTrack(sp, mimeType, false, stereo) - case types.Video: - w, h, err := getVideoParams(mimeType, s) - switch err { - case nil: - // continue - case ErrParamsUnavailable: - return nil - default: - return err - } +func (t *SDKMediaSinkTrack) NextSample(ctx context.Context) (media.Sample, error) { + for { + select { + case <-t.sink.fuse.Watch(): + case <-ctx.Done(): + return media.Sample{}, io.EOF + case s := <-t.readySamples: + g := t.sink.trackStatsGatherer.Load() + if g != nil { + g.MediaReceived(int64(len(s.s.Data))) + } - layers := []*livekit.VideoLayer{ - &livekit.VideoLayer{Width: uint32(w), Height: uint32(h), Quality: livekit.VideoQuality_HIGH}, - } - s := []lksdk_output.VideoSampleProvider{ - sp, + return *s.s, nil } + } +} - videoState := getVideoState(sp.track.Codec().MimeType, w, h) - sp.params.SetInputVideoState(context.Background(), videoState, true) +func (t *SDKMediaSinkTrack) PushSample(s *media.Sample, ts time.Duration) error { + if t.sink.fuse.IsBroken() { + return io.EOF + } + + tracksInitialized, err := t.sink.ensureTracksInitialized(s, t) + if err != nil { + return err + } else if !tracksInitialized { + // Drop the sample + return nil + } + + // Synchronize the outputs before the network jitter buffer to avoid old samples stuck + // in the channel from increasing the whole pipeline delay. + drop, err := t.sink.outputSync.WaitForMediaTime(ts) + if err != nil { + return err + } + if drop { + t.sink.logger.Debugw("dropping sample", "timestamp", ts) + return nil + } - sp.logger.Infow("adding video track", "width", w, "height", h, "codec", mimeType) - sp.sdkOutput.AddVideoTrack(s, layers, mimeType) + select { + case <-t.sink.fuse.Watch(): + return io.EOF + case t.readySamples <- &sample{s, ts}: + default: + // drop the sample if the output queue is full. This is needed if we are reconnecting. } - sp.trackInitialized = true + return nil +} +func (t *SDKMediaSinkTrack) Close() error { + return t.sink.Close() +} + +func (t *SDKMediaSinkTrack) OnBind() error { + t.sink.logger.Infow("media sink bound") + return nil +} + +func (t *SDKMediaSinkTrack) OnUnbind() error { + t.sink.logger.Infow("media sink unbound") return nil } -func parseAudioFmtp(audioFmtp string) bool { - return strings.Index(audioFmtp, "sprop-stereo=1") >= 0 +func (t *SDKMediaSinkTrack) ForceKeyFrame() error { + if t.writePLI != nil { + t.writePLI() + } + + return nil } func getVideoParams(mimeType string, s *media.Sample) (uint, uint, error) { diff --git a/pkg/whip/whip_handler.go b/pkg/whip/whip_handler.go index 3daee11..f551b14 100644 --- a/pkg/whip/whip_handler.go +++ b/pkg/whip/whip_handler.go @@ -17,12 +17,14 @@ package whip import ( "context" "io" + "strings" "sync" "time" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/interceptor" "github.com/pion/rtcp" + "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" google_protobuf2 "google.golang.org/protobuf/types/known/emptypb" @@ -64,11 +66,16 @@ type whipHandler struct { result chan error closeOnce sync.Once - trackLock sync.Mutex - tracks map[string]*webrtc.TrackRemote - trackHandlers map[types.StreamKind]*whipTrackHandler + trackLock sync.Mutex + simulcastLayers []string + tracks map[string]*webrtc.TrackRemote + trackHandlers []*whipTrackHandler + trackAddedChan chan *webrtc.TrackRemote + + trackSDKMediaSinkLock sync.Mutex + trackSDKMediaSink map[types.StreamKind]*SDKMediaSink + trackRelayMediaSink map[types.StreamKind]*RelayMediaSink // only for transcoding mode - trackAddedChan chan *webrtc.TrackRemote } func NewWHIPHandler(webRTCConfig *rtcconfig.WebRTCConfig) *whipHandler { @@ -81,8 +88,9 @@ func NewWHIPHandler(webRTCConfig *rtcconfig.WebRTCConfig) *whipHandler { outputSync: utils.NewOutputSynchronizer(), result: make(chan error, 1), tracks: make(map[string]*webrtc.TrackRemote), - trackHandlers: make(map[types.StreamKind]*whipTrackHandler), + trackHandlers: []*whipTrackHandler{}, trackRelayMediaSink: make(map[types.StreamKind]*RelayMediaSink), + trackSDKMediaSink: make(map[types.StreamKind]*SDKMediaSink), } } @@ -94,22 +102,25 @@ func (h *whipHandler) Init(ctx context.Context, p *params.Params, sdpOffer strin h.updateSettings() + offer := &webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: sdpOffer, + } + h.expectedTrackCount, err = h.validateOfferAndGetExpectedTrackCount(offer) + if err != nil { + return "", err + } + if p.BypassTranscoding { h.sdkOutput, err = lksdk_output.NewLKSDKOutput(ctx, p) if err != nil { return "", err } + } else if len(h.simulcastLayers) != 0 { + return "", errors.ErrSimulcastTranscode } - offer := &webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: sdpOffer, - } - h.expectedTrackCount, err = validateOfferAndGetExpectedTrackCount(offer) h.trackAddedChan = make(chan *webrtc.TrackRemote, h.expectedTrackCount) - if err != nil { - return "", err - } m, err := newMediaEngine() if err != nil { @@ -343,9 +354,6 @@ func (h *whipHandler) getSDPAnswer(ctx context.Context, offer *webrtc.SessionDes if err != nil { return "", err } - if len(parsedAnswer.MediaDescriptions) != h.expectedTrackCount { - return "", errors.ErrUnsupportedDecodeFormat - } for _, m := range parsedAnswer.MediaDescriptions { // Pion puts a media description with fmt = 0 and no attributes for unsupported codecs if len(m.Attributes) == 0 { @@ -371,7 +379,21 @@ func (h *whipHandler) addTrack(track *webrtc.TrackRemote, receiver *webrtc.RTPRe sync := h.sync.AddTrack(track, whipIdentity) - mediaSink, err := h.newMediaSink(track) + trackQuality := livekit.VideoQuality_HIGH + if track.RID() != "" { + for i, expectedRid := range h.simulcastLayers { + if expectedRid == track.RID() { + switch i { + case 1: + trackQuality = livekit.VideoQuality_MEDIUM + case 2: + trackQuality = livekit.VideoQuality_LOW + } + } + } + } + + mediaSink, err := h.getMediaSink(track, trackQuality) if err != nil { logger.Warnw("failed creating whip media handler", err) return @@ -382,7 +404,7 @@ func (h *whipHandler) addTrack(track *webrtc.TrackRemote, receiver *webrtc.RTPRe logger.Warnw("failed creating whip track handler", err) return } - h.trackHandlers[kind] = th + h.trackHandlers = append(h.trackHandlers, th) select { case h.trackAddedChan <- track: @@ -391,12 +413,29 @@ func (h *whipHandler) addTrack(track *webrtc.TrackRemote, receiver *webrtc.RTPRe } } -func (h *whipHandler) newMediaSink(track *webrtc.TrackRemote) (MediaSink, error) { +func (h *whipHandler) getMediaSink(track *webrtc.TrackRemote, trackQuality livekit.VideoQuality) (MediaSink, error) { kind := streamKindFromCodecType(track.Kind()) if h.sdkOutput != nil { - // pasthrough - return NewSDKMediaSink(h.logger, h.params, h.sdkOutput, track, h.outputSync.AddTrack(), func() { + h.trackSDKMediaSinkLock.Lock() + defer h.trackSDKMediaSinkLock.Unlock() + + if _, ok := h.trackSDKMediaSink[kind]; !ok { + h.trackSDKMediaSink[kind] = NewSDKMediaSink(h.logger, h.params, h.sdkOutput, track.Codec(), streamKindFromCodecType(track.Kind()), h.outputSync.AddTrack()) + + layers := []livekit.VideoQuality{livekit.VideoQuality_HIGH} + if kind == types.Video && len(h.simulcastLayers) == 3 { + layers = []livekit.VideoQuality{livekit.VideoQuality_HIGH, livekit.VideoQuality_MEDIUM, livekit.VideoQuality_LOW} + } else if kind == types.Video && len(h.simulcastLayers) == 2 { + layers = []livekit.VideoQuality{livekit.VideoQuality_HIGH, livekit.VideoQuality_MEDIUM} + } + + for _, layer := range layers { + h.trackSDKMediaSink[kind].AddTrack(layer) + } + } + + return h.trackSDKMediaSink[kind].SetWritePLI(trackQuality, func() { h.writePLI(track.SSRC()) }), nil } else { @@ -429,22 +468,54 @@ func streamKindFromCodecType(typ webrtc.RTPCodecType) types.StreamKind { } } -func validateOfferAndGetExpectedTrackCount(offer *webrtc.SessionDescription) (int, error) { +func (h *whipHandler) validateOfferAndGetExpectedTrackCount(offer *webrtc.SessionDescription) (int, error) { parsed, err := offer.Unmarshal() if err != nil { return 0, err } - mediaTypes := make(map[string]struct{}) + audioCount, videoCount := 0, 0 + for _, m := range parsed.MediaDescriptions { - if _, ok := mediaTypes[m.MediaName.Media]; ok { + if types.StreamKind(m.MediaName.Media) == types.Audio { // Duplicate track for a given type. Forbidden by the RFC - return 0, errors.ErrDuplicateTrack + if audioCount != 0 { + return 0, errors.ErrDuplicateTrack + } + + audioCount++ + + } else if types.StreamKind(m.MediaName.Media) == types.Video { + // Duplicate track for a given type. Forbidden by the RFC + if videoCount != 0 { + return 0, errors.ErrDuplicateTrack + } + + for _, a := range m.Attributes { + if a.Key == "simulcast" { + spaceSplit := strings.Split(a.Value, " ") + if len(spaceSplit) != 2 || spaceSplit[0] != "send" { + return 0, errors.ErrInvalidSimulcast + } + + layersSplit := strings.Split(spaceSplit[1], ";") + if len(layersSplit) != 2 && len(layersSplit) != 3 { + return 0, errors.ErrInvalidSimulcast + } + + h.simulcastLayers = layersSplit + videoCount += len(h.simulcastLayers) + } + } + + // No Simulcast + if videoCount == 0 { + videoCount++ + } } - mediaTypes[m.MediaName.Media] = struct{}{} } - return len(parsed.MediaDescriptions), nil + return audioCount + videoCount, nil } func newMediaEngine() (*webrtc.MediaEngine, error) { @@ -482,6 +553,14 @@ func newMediaEngine() (*webrtc.MediaEngine, error) { } } + if err := m.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESMidURI}, webrtc.RTPCodecTypeVideo); err != nil { + return nil, err + } + + if err := m.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESRTPStreamIDURI}, webrtc.RTPCodecTypeVideo); err != nil { + return nil, err + } + return m, nil }