Skip to content

Commit

Permalink
Add Simulcast support for WebRTC
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean-Der committed Feb 1, 2024
1 parent 9da2483 commit 91055c9
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 136 deletions.
2 changes: 2 additions & 0 deletions pkg/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ var (
ErrIngressClosing = psrpc.NewErrorf(psrpc.Unavailable, "ingress closing")
ErrMissingStreamKey = psrpc.NewErrorf(psrpc.InvalidArgument, "missing stream key")
ErrPrerollBufferReset = psrpc.NewErrorf(psrpc.Internal, "preroll buffer reset")
ErrInvalidSimulcast = psrpc.NewErrorf(psrpc.NotAcceptable, "invalid simulcast configuration")
ErrSimulcastTranscode = psrpc.NewErrorf(psrpc.NotAcceptable, "simulcast is not supported when transcoding")
)

func New(err string) error {
Expand Down
275 changes: 166 additions & 109 deletions pkg/whip/sdk_media_sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
"io"
"strings"
"sync"
"time"

"github.com/Eyevinn/mp4ff/avc"
Expand All @@ -41,168 +42,224 @@ var (
ErrParamsUnavailable = psrpc.NewErrorf(psrpc.InvalidArgument, "codec parameters unavailable in sample")
)

type SDKMediaSinkTrack struct {
readySamples chan *sample
writePLI func()

quality livekit.VideoQuality
width, height uint

sink *SDKMediaSink
}

type SDKMediaSink struct {
logger logger.Logger
params *params.Params
writePLI func()
track *webrtc.TrackRemote
outputSync *utils.TrackOutputSynchronizer
sdkOutput *lksdk_output.LKSDKOutput

readySamples chan *sample
fuse core.Fuse
trackInitialized bool
logger logger.Logger
params *params.Params
outputSync *utils.TrackOutputSynchronizer
sdkOutput *lksdk_output.LKSDKOutput
sinkInitialized bool

codecParameters webrtc.RTPCodecParameters
streamKind types.StreamKind

tracksLock sync.Mutex
tracks []*SDKMediaSinkTrack

fuse core.Fuse
}

type sample struct {
s *media.Sample
ts time.Duration
}

func NewSDKMediaSink(l logger.Logger, p *params.Params, sdkOutput *lksdk_output.LKSDKOutput, track *webrtc.TrackRemote, outputSync *utils.TrackOutputSynchronizer, writePLI func()) *SDKMediaSink {
s := &SDKMediaSink{
logger: l,
params: p,
writePLI: writePLI,
track: track,
outputSync: outputSync,
sdkOutput: sdkOutput,
func NewSDKMediaSink(
l logger.Logger, p *params.Params, sdkOutput *lksdk_output.LKSDKOutput,
codecParameters webrtc.RTPCodecParameters, streamKind types.StreamKind,
outputSync *utils.TrackOutputSynchronizer,
) *SDKMediaSink {
return &SDKMediaSink{
logger: l,
params: p,
outputSync: outputSync,
sdkOutput: sdkOutput,
fuse: core.NewFuse(),
tracks: []*SDKMediaSinkTrack{},
streamKind: streamKind,
codecParameters: codecParameters,
}
}

func (sp *SDKMediaSink) AddTrack(quality livekit.VideoQuality) {
sp.tracksLock.Lock()
defer sp.tracksLock.Unlock()

sp.tracks = append(sp.tracks, &SDKMediaSinkTrack{
readySamples: make(chan *sample, 15),
fuse: core.NewFuse(),
sink: sp,
quality: quality,
})
}

func (sp *SDKMediaSink) SetWritePLI(quality livekit.VideoQuality, writePLI func()) *SDKMediaSinkTrack {
sp.tracksLock.Lock()
defer sp.tracksLock.Unlock()

for i := range sp.tracks {
if sp.tracks[i].quality == quality {
sp.tracks[i].writePLI = writePLI
return sp.tracks[i]
}
}

return s
return nil
}

func (sp *SDKMediaSink) PushSample(s *media.Sample, ts time.Duration) error {
if sp.fuse.IsBroken() {
return io.EOF
func (sp *SDKMediaSink) Close() error {
sp.fuse.Break()
sp.outputSync.Close()

return nil
}

func (sp *SDKMediaSink) ensureTracksInitialized(s *media.Sample, t *SDKMediaSinkTrack) (bool, error) {
sp.tracksLock.Lock()
defer sp.tracksLock.Unlock()

if sp.sinkInitialized {
return sp.sinkInitialized, nil
}

err := sp.ensureTrackInitialized(s)
if err != nil {
return err
if sp.streamKind == types.Audio {
stereo := strings.Contains(sp.codecParameters.SDPFmtpLine, "sprop-stereo=1")
audioState := getAudioState(sp.codecParameters.MimeType, stereo, sp.codecParameters.ClockRate)
sp.params.SetInputAudioState(context.Background(), audioState, true)

sp.logger.Infow("adding audio track", "stereo", stereo, "codec", sp.codecParameters.MimeType)
if err := sp.sdkOutput.AddAudioTrack(t, sp.codecParameters.MimeType, false, stereo); err != nil {
return false, err
}
sp.sinkInitialized = true
return sp.sinkInitialized, nil
}
if !sp.trackInitialized {
// Drop the sample
return nil

var err error
t.width, t.height, err = getVideoParams(sp.codecParameters.MimeType, s)
switch err {
case nil:
// continue
case ErrParamsUnavailable:
return false, nil
default:
return false, err
}

// Synchronize the outputs before the network jitter buffer to avoid old samples stuck
// in the channel from increasing the whole pipeline delay.
drop, err := sp.outputSync.WaitForMediaTime(ts)
if err != nil {
return err
layers := []*livekit.VideoLayer{}
sampleProviders := []lksdk_output.VideoSampleProvider{}

for _, track := range sp.tracks {
if track.width != 0 && track.height != 0 {
layers = append(layers, &livekit.VideoLayer{
Width: uint32(track.width),
Height: uint32(track.height),
Quality: track.quality,
})
sampleProviders = append(sampleProviders, track)
}
}
if drop {
sp.logger.Debugw("dropping sample", "timestamp", ts)
return nil

if len(layers) == 0 && len(sp.tracks) != 1 {
return false, nil
} else if len(layers) != len(sp.tracks) {
return false, nil
}

select {
case <-sp.fuse.Watch():
return io.EOF
case sp.readySamples <- &sample{s, ts}:
default:
// drop the sample if the output queue is full. This is needed if we are reconnecting.
if len(layers) != 0 {
videoState := getVideoState(sp.codecParameters.MimeType, uint(layers[0].Width), uint(layers[0].Height))
sp.params.SetInputVideoState(context.Background(), videoState, true)
}

return nil
if err := sp.sdkOutput.AddVideoTrack(sampleProviders, layers, sp.codecParameters.MimeType); err != nil {
return false, err
}

for _, l := range layers {
sp.logger.Infow("adding video track", "width", l.Width, "height", l.Height, "codec", sp.codecParameters.MimeType)
}
sp.sinkInitialized = true

return sp.sinkInitialized, nil
}

func (sp *SDKMediaSink) NextSample(ctx context.Context) (media.Sample, error) {
func (t *SDKMediaSinkTrack) NextSample(ctx context.Context) (media.Sample, error) {
for {
select {
case <-sp.fuse.Watch():
case <-t.sink.fuse.Watch():
case <-ctx.Done():
return media.Sample{}, io.EOF
case s := <-sp.readySamples:
case s := <-t.readySamples:
return *s.s, nil
}
}
}

func (sp *SDKMediaSink) OnBind() error {
sp.logger.Infow("media sink bound")

return nil
}
func (t *SDKMediaSinkTrack) PushSample(s *media.Sample, ts time.Duration) error {
if t.sink.fuse.IsBroken() {
return io.EOF
}

func (sp *SDKMediaSink) OnUnbind() error {
sp.logger.Infow("media sink unbound")
tracksInitialized, err := t.sink.ensureTracksInitialized(s, t)
if err != nil {
return err
} else if !tracksInitialized {
// Drop the sample
return nil
}

return nil
}
// Synchronize the outputs before the network jitter buffer to avoid old samples stuck
// in the channel from increasing the whole pipeline delay.
drop, err := t.sink.outputSync.WaitForMediaTime(ts)
if err != nil {
return err
}
if drop {
t.sink.logger.Debugw("dropping sample", "timestamp", ts)
return nil
}

func (sp *SDKMediaSink) ForceKeyFrame() error {
if sp.writePLI != nil {
sp.writePLI()
select {
case <-t.sink.fuse.Watch():
return io.EOF
case t.readySamples <- &sample{s, ts}:
default:
// drop the sample if the output queue is full. This is needed if we are reconnecting.
}

return nil
}

func (sp *SDKMediaSink) SetWriter(w io.WriteCloser) error {
return psrpc.Unimplemented
func (t *SDKMediaSinkTrack) Close() error {
return t.sink.Close()
}

func (sp *SDKMediaSink) Close() error {
sp.fuse.Break()
sp.outputSync.Close()

func (t *SDKMediaSinkTrack) OnBind() error {
t.sink.logger.Infow("media sink bound")
return nil
}

func (sp *SDKMediaSink) ensureTrackInitialized(s *media.Sample) error {
if sp.trackInitialized {
return nil
}

kind := streamKindFromCodecType(sp.track.Kind())
mimeType := sp.track.Codec().MimeType

switch kind {
case types.Audio:
stereo := parseAudioFmtp(sp.track.Codec().SDPFmtpLine)
audioState := getAudioState(sp.track.Codec().MimeType, stereo, sp.track.Codec().ClockRate)
sp.params.SetInputAudioState(context.Background(), audioState, true)

sp.logger.Infow("adding audio track", "stereo", stereo, "codec", mimeType)
sp.sdkOutput.AddAudioTrack(sp, mimeType, false, stereo)
case types.Video:
w, h, err := getVideoParams(mimeType, s)
switch err {
case nil:
// continue
case ErrParamsUnavailable:
return nil
default:
return err
}

layers := []*livekit.VideoLayer{
&livekit.VideoLayer{Width: uint32(w), Height: uint32(h), Quality: livekit.VideoQuality_HIGH},
}
s := []lksdk_output.VideoSampleProvider{
sp,
}

videoState := getVideoState(sp.track.Codec().MimeType, w, h)
sp.params.SetInputVideoState(context.Background(), videoState, true)
func (t *SDKMediaSinkTrack) OnUnbind() error {
t.sink.logger.Infow("media sink unbound")
return nil
}

sp.logger.Infow("adding video track", "width", w, "height", h, "codec", mimeType)
sp.sdkOutput.AddVideoTrack(s, layers, mimeType)
func (t *SDKMediaSinkTrack) ForceKeyFrame() error {
if t.writePLI != nil {
t.writePLI()
}

sp.trackInitialized = true

return nil
}

func parseAudioFmtp(audioFmtp string) bool {
return strings.Index(audioFmtp, "sprop-stereo=1") >= 0
}

func getVideoParams(mimeType string, s *media.Sample) (uint, uint, error) {
switch strings.ToLower(mimeType) {
case strings.ToLower(webrtc.MimeTypeH264):
Expand Down
Loading

0 comments on commit 91055c9

Please sign in to comment.