Commit a6c8f8e0 authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Handle NACKs locally.

parent d23cac10
...@@ -16,6 +16,8 @@ import ( ...@@ -16,6 +16,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"sfu/packetlist"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pion/rtcp" "github.com/pion/rtcp"
"github.com/pion/rtp" "github.com/pion/rtp"
...@@ -269,8 +271,10 @@ func addUpConn(c *client, id string) (*upConnection, error) { ...@@ -269,8 +271,10 @@ func addUpConn(c *client, id string) (*upConnection, error) {
c.mu.Unlock() c.mu.Unlock()
return return
} }
list := packetlist.New(32)
track := &upTrack{ track := &upTrack{
track: remote, track: remote,
list: list,
maxBitrate: ^uint64(0), maxBitrate: ^uint64(0),
} }
u.tracks = append(u.tracks, track) u.tracks = append(u.tracks, track)
...@@ -286,7 +290,7 @@ func addUpConn(c *client, id string) (*upConnection, error) { ...@@ -286,7 +290,7 @@ func addUpConn(c *client, id string) (*upConnection, error) {
} }
go func() { go func() {
buf := make([]byte, 1500) buf := make([]byte, packetlist.BufSize)
var packet rtp.Packet var packet rtp.Packet
var local []*downTrack var local []*downTrack
var localTime time.Time var localTime time.Time
...@@ -311,6 +315,8 @@ func addUpConn(c *client, id string) (*upConnection, error) { ...@@ -311,6 +315,8 @@ func addUpConn(c *client, id string) (*upConnection, error) {
continue continue
} }
list.Store(packet.SequenceNumber, buf[:i])
for _, l := range local { for _, l := range local {
if l.muted() { if l.muted() {
continue continue
...@@ -523,6 +529,8 @@ func rtcpListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RT ...@@ -523,6 +529,8 @@ func rtcpListener(g *group, conn *downConnection, track *downTrack, s *webrtc.RT
uint64(ms), uint64(ms),
) )
case *rtcp.ReceiverReport: case *rtcp.ReceiverReport:
case *rtcp.TransportLayerNack:
sendRecovery(p, track)
default: default:
log.Printf("RTCP: %T", p) log.Printf("RTCP: %T", p)
} }
...@@ -592,6 +600,25 @@ func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error { ...@@ -592,6 +600,25 @@ func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error {
}) })
} }
func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) {
var packet rtp.Packet
for _, nack := range p.Nacks {
for _, seqno := range nack.PacketList() {
raw := track.remote.list.Get(seqno)
if raw != nil {
err := packet.Unmarshal(raw)
if err != nil {
continue
}
err = track.track.WriteRTP(&packet)
if err != nil {
log.Printf("%v", err)
}
}
}
}
}
func countMediaStreams(data string) (int, error) { func countMediaStreams(data string) (int, error) {
desc := sdp.NewJSEPSessionDescription(false) desc := sdp.NewJSEPSessionDescription(false)
err := desc.Unmarshal(data) err := desc.Unmarshal(data)
......
...@@ -15,11 +15,14 @@ import ( ...@@ -15,11 +15,14 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"sfu/packetlist"
"github.com/pion/webrtc/v2" "github.com/pion/webrtc/v2"
) )
type upTrack struct { type upTrack struct {
track *webrtc.Track track *webrtc.Track
list *packetlist.List
maxBitrate uint64 maxBitrate uint64
mu sync.Mutex mu sync.Mutex
...@@ -172,6 +175,7 @@ func addGroup(name string, desc *groupDescription) (*group, error) { ...@@ -172,6 +175,7 @@ func addGroup(name string, desc *groupDescription) (*group, error) {
webrtc.DefaultPayloadTypeVP8, 90000, webrtc.DefaultPayloadTypeVP8, 90000,
[]webrtc.RTCPFeedback{ []webrtc.RTCPFeedback{
{"goog-remb", ""}, {"goog-remb", ""},
{"nack", ""},
{"nack", "pli"}, {"nack", "pli"},
}, },
"", "",
......
package packetlist
import (
"sync"
)
const BufSize = 1500
type entry struct {
seqno uint16
length int
buf [BufSize]byte
}
type List struct {
mu sync.Mutex
tail int
entries []entry
}
func New(capacity int) *List {
return &List{
entries: make([]entry, capacity),
}
}
func (list *List) Store(seqno uint16, buf []byte) {
list.mu.Lock()
defer list.mu.Unlock()
list.entries[list.tail].seqno = seqno
copy(list.entries[list.tail].buf[:], buf)
list.entries[list.tail].length = len(buf)
list.tail = (list.tail + 1) % len(list.entries)
}
func (list *List) Get(seqno uint16) []byte {
list.mu.Lock()
defer list.mu.Unlock()
for i := range list.entries {
if list.entries[i].length == 0 ||
list.entries[i].seqno != seqno {
continue
}
buf := make([]byte, list.entries[i].length)
copy(buf, list.entries[i].buf[:])
return buf
}
return nil
}
package packetlist
import (
"bytes"
"math/rand"
"testing"
)
func randomBuf() []byte {
length := rand.Int31n(BufSize-1) + 1
buf := make([]byte, length)
rand.Read(buf)
return buf
}
func TestList(t *testing.T) {
buf1 := randomBuf()
buf2 := randomBuf()
list := New(16)
list.Store(13, buf1)
list.Store(17, buf2)
if bytes.Compare(list.Get(13), buf1) != 0 {
t.Errorf("Couldn't get 13")
}
if bytes.Compare(list.Get(17), buf2) != 0 {
t.Errorf("Couldn't get 17")
}
if list.Get(42) != nil {
t.Errorf("Creation ex nihilo")
}
}
func TestOverflow(t *testing.T) {
list := New(16)
for i := 0; i < 32; i++ {
list.Store(uint16(i), []byte{uint8(i)})
}
for i := 0; i < 32; i++ {
buf := list.Get(uint16(i))
if i < 16 {
if buf != nil {
t.Errorf("Creation ex nihilo: %v", i)
}
} else {
L if len(buf) != 1 || buf[0] != uint8(i) {
t.Errorf("Expected [%v], got %v", i, buf)
}
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment