Commit c52e1f4c authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Move keyframe handling to the sender side.

This is simpler and gets rid of ErrKeyframeNeeded.
parent b2ea8e85
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
) )
var ErrConnectionClosed = errors.New("connection is closed") var ErrConnectionClosed = errors.New("connection is closed")
var ErrKeyframeNeeded = errors.New("keyframe needed")
// Type Up represents a connection in the client to server direction. // Type Up represents a connection in the client to server direction.
type Up interface { type Up interface {
...@@ -30,6 +29,7 @@ type UpTrack interface { ...@@ -30,6 +29,7 @@ type UpTrack interface {
// get a recent packet. Returns 0 if the packet is not in cache. // get a recent packet. Returns 0 if the packet is not in cache.
GetRTP(seqno uint16, result []byte) uint16 GetRTP(seqno uint16, result []byte) uint16
Nack(conn Up, seqnos []uint16) error Nack(conn Up, seqnos []uint16) error
RequestKeyframe() error
} }
// Type Down represents a connection in the server to client direction. // Type Down represents a connection in the server to client direction.
......
...@@ -458,7 +458,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { ...@@ -458,7 +458,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error {
sample, ts := t.builder.PopWithTimestamp() sample, ts := t.builder.PopWithTimestamp()
if sample == nil { if sample == nil {
if kfNeeded { if kfNeeded {
return conn.ErrKeyframeNeeded t.remote.RequestKeyframe()
} }
return nil return nil
} }
...@@ -506,7 +506,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { ...@@ -506,7 +506,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error {
if t.writer == nil { if t.writer == nil {
if !keyframe { if !keyframe {
return conn.ErrKeyframeNeeded t.remote.RequestKeyframe()
} }
return nil return nil
} }
......
...@@ -235,7 +235,7 @@ type rtpUpTrack struct { ...@@ -235,7 +235,7 @@ type rtpUpTrack struct {
atomics *upTrackAtomics atomics *upTrackAtomics
cname atomic.Value cname atomic.Value
localCh chan localTrackAction localCh chan trackAction
readerDone chan struct{} readerDone chan struct{}
mu sync.Mutex mu sync.Mutex
...@@ -246,14 +246,20 @@ type rtpUpTrack struct { ...@@ -246,14 +246,20 @@ type rtpUpTrack struct {
bufferedNACKs []uint16 bufferedNACKs []uint16
} }
type localTrackAction struct { const (
add bool trackActionAdd = iota
trackActionDel
trackActionKeyframe
)
type trackAction struct {
action int
track conn.DownTrack track conn.DownTrack
} }
func (up *rtpUpTrack) notifyLocal(add bool, track conn.DownTrack) { func (up *rtpUpTrack) action(action int, track conn.DownTrack) {
select { select {
case up.localCh <- localTrackAction{add, track}: case up.localCh <- trackAction{action, track}:
case <-up.readerDone: case <-up.readerDone:
} }
} }
...@@ -271,7 +277,12 @@ func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error { ...@@ -271,7 +277,12 @@ func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error {
// do this asynchronously, to avoid deadlocks when multiple // do this asynchronously, to avoid deadlocks when multiple
// clients call this simultaneously. // clients call this simultaneously.
go up.notifyLocal(true, local) go up.action(trackActionAdd, local)
return nil
}
func (up *rtpUpTrack) RequestKeyframe() error {
go up.action(trackActionKeyframe, nil)
return nil return nil
} }
...@@ -283,7 +294,7 @@ func (up *rtpUpTrack) DelLocal(local conn.DownTrack) bool { ...@@ -283,7 +294,7 @@ func (up *rtpUpTrack) DelLocal(local conn.DownTrack) bool {
up.local = append(up.local[:i], up.local[i+1:]...) up.local = append(up.local[:i], up.local[i+1:]...)
// do this asynchronously, to avoid deadlocking when // do this asynchronously, to avoid deadlocking when
// multiple clients call this simultaneously. // multiple clients call this simultaneously.
go up.notifyLocal(false, l) go up.action(trackActionDel, l)
return true return true
} }
} }
...@@ -489,7 +500,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon ...@@ -489,7 +500,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon
rate: estimator.New(time.Second), rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate), jitter: jitter.New(remote.Codec().ClockRate),
atomics: &upTrackAtomics{}, atomics: &upTrackAtomics{},
localCh: make(chan localTrackAction, 2), localCh: make(chan trackAction, 2),
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
} }
...@@ -977,7 +988,6 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) { ...@@ -977,7 +988,6 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) {
} }
func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) { func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) {
var gotFir bool
lastFirSeqno := uint8(0) lastFirSeqno := uint8(0)
buf := make([]byte, 1500) buf := make([]byte, 1500)
...@@ -1001,18 +1011,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT ...@@ -1001,18 +1011,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT
for _, p := range ps { for _, p := range ps {
switch p := p.(type) { switch p := p.(type) {
case *rtcp.PictureLossIndication: case *rtcp.PictureLossIndication:
remote, ok := conn.remote.(*rtpUpConnection) track.remote.RequestKeyframe()
if !ok {
continue
}
rt, ok := track.remote.(*rtpUpTrack)
if !ok {
continue
}
err := remote.sendPLI(rt)
if err != nil && err != ErrRateLimited {
log.Printf("sendPLI: %v", err)
}
case *rtcp.FullIntraRequest: case *rtcp.FullIntraRequest:
found := false found := false
var seqno uint8 var seqno uint8
...@@ -1028,29 +1027,8 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT ...@@ -1028,29 +1027,8 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT
continue continue
} }
increment := true if seqno != lastFirSeqno {
if gotFir { track.remote.RequestKeyframe()
increment = seqno != lastFirSeqno
}
gotFir = true
lastFirSeqno = seqno
remote, ok := conn.remote.(*rtpUpConnection)
if !ok {
continue
}
rt, ok := track.remote.(*rtpUpTrack)
if !ok {
continue
}
err := remote.sendFIR(rt, increment)
if err == ErrUnsupportedFeedback {
err := remote.sendPLI(rt)
if err != nil && err != ErrRateLimited {
log.Printf("sendPLI: %v", err)
}
} else if err != nil && err != ErrRateLimited {
log.Printf("sendFIR: %v", err)
} }
case *rtcp.ReceiverEstimatedMaximumBitrate: case *rtcp.ReceiverEstimatedMaximumBitrate:
track.maxREMBBitrate.Set(p.Bitrate, jiffies) track.maxREMBBitrate.Set(p.Bitrate, jiffies)
......
...@@ -21,6 +21,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { ...@@ -21,6 +21,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
codec := track.track.Codec() codec := track.track.Codec()
sendNACK := track.hasRtcpFb("nack", "") sendNACK := track.hasRtcpFb("nack", "")
var kfNeeded bool
buf := make([]byte, packetcache.BufSize) buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet var packet rtp.Packet
for { for {
...@@ -41,8 +42,10 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { ...@@ -41,8 +42,10 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
track.jitter.Accumulate(packet.Timestamp) track.jitter.Accumulate(packet.Timestamp)
kf, _ := isKeyframe(codec.MimeType, &packet) kf, kfKnown := isKeyframe(codec.MimeType, &packet)
if kf || !kfKnown {
kfNeeded = false
}
if packet.Extension { if packet.Extension {
packet.Extension = false packet.Extension = false
packet.Extensions = nil packet.Extensions = nil
...@@ -102,11 +105,29 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { ...@@ -102,11 +105,29 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
select { select {
case action := <-track.localCh: case action := <-track.localCh:
err := writers.add(action.track, action.add) switch action.action {
case trackActionAdd, trackActionDel:
err := writers.add(
action.track,
action.action == trackActionAdd,
)
if err != nil { if err != nil {
log.Printf("add/remove track: %v", err) log.Printf("add/remove track: %v", err)
} }
case trackActionKeyframe:
kfNeeded = true
default:
log.Printf("Unknown action %v", action.action)
}
default: default:
} }
if kfNeeded {
err := conn.sendPLI(track)
if err != nil && err != ErrRateLimited {
log.Printf("sendPLI: %v", err)
kfNeeded = false
}
}
} }
} }
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"errors" "errors"
"log" "log"
"sort" "sort"
"strings"
"time" "time"
"github.com/pion/rtp" "github.com/pion/rtp"
...@@ -223,33 +222,22 @@ func sendKeyframe(kf []uint16, track conn.DownTrack, cache *packetcache.Cache) { ...@@ -223,33 +222,22 @@ func sendKeyframe(kf []uint16, track conn.DownTrack, cache *packetcache.Cache) {
return return
} }
err = track.WriteRTP(&packet) err = track.WriteRTP(&packet)
if err != nil && err != conn.ErrKeyframeNeeded { if err != nil {
return return
} }
track.Accumulate(uint32(bytes)) track.Accumulate(uint32(bytes))
} }
} }
const (
kfUnneeded = iota
kfNeededPLI
kfNeededFIR
kfNeededNewFIR
)
// rtpWriterLoop is the main loop of an rtpWriter. // rtpWriterLoop is the main loop of an rtpWriter.
func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
defer close(writer.done) defer close(writer.done)
codec := track.track.Codec()
buf := make([]byte, packetcache.BufSize) buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet var packet rtp.Packet
local := make([]conn.DownTrack, 0) local := make([]conn.DownTrack, 0)
kfNeeded := kfUnneeded
for { for {
select { select {
case action := <-writer.action: case action := <-writer.action:
...@@ -277,8 +265,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { ...@@ -277,8 +265,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
found, _, lts := track.cache.Last() found, _, lts := track.cache.Last()
kts, _, kf := track.cache.Keyframe() kts, _, kf := track.cache.Keyframe()
if strings.ToLower(codec.MimeType) == "video/vp8" && if found && len(kf) > 0 {
found && len(kf) > 0 {
if ((lts-kts)&0x80000000) != 0 || if ((lts-kts)&0x80000000) != 0 ||
lts-kts < 2*90000 { lts-kts < 2*90000 {
// we got a recent keyframe // we got a recent keyframe
...@@ -288,8 +275,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { ...@@ -288,8 +275,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
track.cache, track.cache,
) )
} else { } else {
// Request a new keyframe track.RequestKeyframe()
kfNeeded = kfNeededNewFIR
} }
} else { } else {
// no keyframe yet, one should // no keyframe yet, one should
...@@ -333,44 +319,10 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { ...@@ -333,44 +319,10 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
for _, l := range local { for _, l := range local {
err := l.WriteRTP(&packet) err := l.WriteRTP(&packet)
if err != nil { if err != nil {
if err == conn.ErrKeyframeNeeded {
kfNeeded = kfNeededPLI
} else {
continue continue
} }
}
l.Accumulate(uint32(bytes)) l.Accumulate(uint32(bytes))
} }
if kfNeeded > kfUnneeded {
kf, kfKnown :=
isKeyframe(codec.MimeType, &packet)
if kf {
kfNeeded = kfUnneeded
}
if kfNeeded >= kfNeededFIR {
err := up.sendFIR(
track,
kfNeeded >= kfNeededNewFIR,
)
if err == ErrUnsupportedFeedback {
kfNeeded = kfNeededPLI
} else {
kfNeeded = kfNeededFIR
}
}
if kfNeeded == kfNeededPLI {
up.sendPLI(track)
}
if !kfKnown {
// we cannot detect keyframes for
// this codec, reset our state
kfNeeded = kfUnneeded
}
}
} }
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment