Commit cde8d49e authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Handle NACKs locally.

parent f01674fe
...@@ -16,6 +16,8 @@ import ( ...@@ -16,6 +16,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"sfu/packetlist"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pion/rtcp" "github.com/pion/rtcp"
"github.com/pion/rtp" "github.com/pion/rtp"
...@@ -269,8 +271,10 @@ func addUpConn(c *client, id string) (*upConnection, error) { ...@@ -269,8 +271,10 @@ func addUpConn(c *client, id string) (*upConnection, error) {
c.mu.Unlock() c.mu.Unlock()
return return
} }
list := packetlist.New(32)
track := &upTrack{ track := &upTrack{
track: remote, track: remote,
list: list,
maxBitrate: ^uint64(0), maxBitrate: ^uint64(0),
} }
u.tracks = append(u.tracks, track) u.tracks = append(u.tracks, track)
...@@ -286,13 +290,13 @@ func addUpConn(c *client, id string) (*upConnection, error) { ...@@ -286,13 +290,13 @@ func addUpConn(c *client, id string) (*upConnection, error) {
} }
go func() { go func() {
buf := make([]byte, 1500) buf := make([]byte, packetlist.BufSize)
var packet rtp.Packet var packet rtp.Packet
var local []*downTrack var local []*downTrack
var localTime time.Time var localTime time.Time
for { for {
now := time.Now() now := time.Now()
if now.Sub(localTime) > time.Second / 2 { if now.Sub(localTime) > time.Second/2 {
local = track.getLocal() local = track.getLocal()
localTime = now localTime = now
} }
...@@ -311,6 +315,8 @@ func addUpConn(c *client, id string) (*upConnection, error) { ...@@ -311,6 +315,8 @@ func addUpConn(c *client, id string) (*upConnection, error) {
continue continue
} }
list.Store(packet.SequenceNumber, buf[:i])
for _, l := range local { for _, l := range local {
if l.muted() { if l.muted() {
continue continue
...@@ -523,6 +529,8 @@ func rtcpListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RT ...@@ -523,6 +529,8 @@ func rtcpListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RT
uint64(ms), uint64(ms),
) )
case *rtcp.ReceiverReport: case *rtcp.ReceiverReport:
case *rtcp.TransportLayerNack:
sendRecovery(p, track)
default: default:
log.Printf("RTCP: %T", p) log.Printf("RTCP: %T", p)
} }
...@@ -592,6 +600,25 @@ func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error { ...@@ -592,6 +600,25 @@ func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error {
}) })
} }
func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) {
var packet rtp.Packet
for _, nack := range p.Nacks {
for _, seqno := range nack.PacketList() {
raw := track.remote.list.Get(seqno)
if raw != nil {
err := packet.Unmarshal(raw)
if err != nil {
continue
}
err = track.track.WriteRTP(&packet)
if err != nil {
log.Printf("%v", err)
}
}
}
}
}
func countMediaStreams(data string) (int, error) { func countMediaStreams(data string) (int, error) {
desc := sdp.NewJSEPSessionDescription(false) desc := sdp.NewJSEPSessionDescription(false)
err := desc.Unmarshal(data) err := desc.Unmarshal(data)
......
...@@ -15,11 +15,14 @@ import ( ...@@ -15,11 +15,14 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"sfu/packetlist"
"github.com/pion/webrtc/v2" "github.com/pion/webrtc/v2"
) )
type upTrack struct { type upTrack struct {
track *webrtc.Track track *webrtc.Track
list *packetlist.List
maxBitrate uint64 maxBitrate uint64
mu sync.Mutex mu sync.Mutex
...@@ -172,6 +175,7 @@ func addGroup(name string, desc *groupDescription) (*group, error) { ...@@ -172,6 +175,7 @@ func addGroup(name string, desc *groupDescription) (*group, error) {
webrtc.DefaultPayloadTypeVP8, 90000, webrtc.DefaultPayloadTypeVP8, 90000,
[]webrtc.RTCPFeedback{ []webrtc.RTCPFeedback{
{"goog-remb", ""}, {"goog-remb", ""},
{"nack", ""},
{"nack", "pli"}, {"nack", "pli"},
}, },
"", "",
......
package packetlist
import (
"sync"
)
const BufSize = 1500
type entry struct {
seqno uint16
length int
buf [BufSize]byte
}
type List struct {
mu sync.Mutex
tail int
entries []entry
}
func New(capacity int) *List {
return &List{
entries: make([]entry, capacity),
}
}
func (list *List) Store(seqno uint16, buf []byte) {
list.mu.Lock()
defer list.mu.Unlock()
list.entries[list.tail].seqno = seqno
copy(list.entries[list.tail].buf[:], buf)
list.entries[list.tail].length = len(buf)
list.tail = (list.tail + 1) % len(list.entries)
}
func (list *List) Get(seqno uint16) []byte {
list.mu.Lock()
defer list.mu.Unlock()
for _, entry := range list.entries {
if entry.length == 0 || entry.seqno != seqno {
continue
}
buf := make([]byte, entry.length)
copy(buf, entry.buf[:entry.length])
return buf
}
return nil
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment