Commit 4da03a3c authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Make rate estimator estimate packet rates too.

parent 19a65318
...@@ -8,12 +8,14 @@ import ( ...@@ -8,12 +8,14 @@ import (
type Estimator struct { type Estimator struct {
interval time.Duration interval time.Duration
count uint32 bytes uint32
packets uint32
mu sync.Mutex mu sync.Mutex
totalBytes uint32 totalBytes uint32
totalPackets uint32 totalPackets uint32
rate uint32 rate uint32
packetRate uint32
time time.Time time time.Time
} }
...@@ -26,30 +28,38 @@ func New(interval time.Duration) *Estimator { ...@@ -26,30 +28,38 @@ func New(interval time.Duration) *Estimator {
func (e *Estimator) swap(now time.Time) { func (e *Estimator) swap(now time.Time) {
interval := now.Sub(e.time) interval := now.Sub(e.time)
count := atomic.SwapUint32(&e.count, 0) bytes := atomic.SwapUint32(&e.bytes, 0)
packets := atomic.SwapUint32(&e.packets, 0)
atomic.AddUint32(&e.totalBytes, bytes)
atomic.AddUint32(&e.totalPackets, packets)
if interval < time.Millisecond { if interval < time.Millisecond {
e.rate = 0 e.rate = 0
e.packetRate = 0
} else { } else {
e.rate = uint32(uint64(count*1000) / uint64(interval/time.Millisecond)) e.rate = uint32(uint64(bytes*1000) /
uint64(interval/time.Millisecond))
e.packetRate = uint32(uint64(packets*1000) /
uint64(interval/time.Millisecond))
} }
e.time = now e.time = now
} }
func (e *Estimator) Accumulate(count uint32) { func (e *Estimator) Accumulate(count uint32) {
atomic.AddUint32(&e.totalBytes, count) atomic.AddUint32(&e.bytes, count)
atomic.AddUint32(&e.totalPackets, 1) atomic.AddUint32(&e.packets, 1)
atomic.AddUint32(&e.count, count)
} }
func (e *Estimator) estimate(now time.Time) uint32 { func (e *Estimator) estimate(now time.Time) (uint32, uint32) {
if now.Sub(e.time) > e.interval { if now.Sub(e.time) > e.interval {
e.swap(now) e.swap(now)
} }
return e.rate return e.rate, e.packetRate
} }
func (e *Estimator) Estimate() uint32 { func (e *Estimator) Estimate() (uint32, uint32) {
now := time.Now() now := time.Now()
e.mu.Lock() e.mu.Lock()
...@@ -58,7 +68,7 @@ func (e *Estimator) Estimate() uint32 { ...@@ -58,7 +68,7 @@ func (e *Estimator) Estimate() uint32 {
} }
func (e *Estimator) Totals() (uint32, uint32) { func (e *Estimator) Totals() (uint32, uint32) {
b := atomic.LoadUint32(&e.totalBytes) b := atomic.LoadUint32(&e.totalBytes) + atomic.LoadUint32(&e.bytes)
p := atomic.LoadUint32(&e.totalPackets) p := atomic.LoadUint32(&e.totalPackets) + atomic.LoadUint32(&e.packets)
return p, b return p, b
} }
...@@ -13,11 +13,14 @@ func TestEstimator(t *testing.T) { ...@@ -13,11 +13,14 @@ func TestEstimator(t *testing.T) {
e.Accumulate(42) e.Accumulate(42)
e.Accumulate(128) e.Accumulate(128)
e.estimate(now.Add(time.Second)) e.estimate(now.Add(time.Second))
rate := e.estimate(now.Add(time.Second + time.Millisecond)) rate, packetRate := e.estimate(now.Add(time.Second + time.Millisecond))
if rate != 42+128 { if rate != 42+128 {
t.Errorf("Expected %v, got %v", 42+128, rate) t.Errorf("Expected %v, got %v", 42+128, rate)
} }
if packetRate != 2 {
t.Errorf("Expected 2, got %v", packetRate)
}
totalP, totalB := e.Totals() totalP, totalB := e.Totals()
if totalP != 2 { if totalP != 2 {
...@@ -26,4 +29,15 @@ func TestEstimator(t *testing.T) { ...@@ -26,4 +29,15 @@ func TestEstimator(t *testing.T) {
if totalB != 42+128 { if totalB != 42+128 {
t.Errorf("Expected %v, got %v", 42+128, totalB) t.Errorf("Expected %v, got %v", 42+128, totalB)
} }
e.Accumulate(12)
totalP, totalB = e.Totals()
if totalP != 3 {
t.Errorf("Expected 2, got %v", totalP)
}
if totalB != 42+128+12 {
t.Errorf("Expected %v, got %v", 42+128, totalB)
}
} }
...@@ -573,8 +573,9 @@ func getClientStats(c *webClient) clientStats { ...@@ -573,8 +573,9 @@ func getClientStats(c *webClient) clientStats {
loss := uint8(lost * 100 / expected) loss := uint8(lost * 100 / expected)
jitter := time.Duration(t.jitter.Jitter()) * jitter := time.Duration(t.jitter.Jitter()) *
(time.Second / time.Duration(t.jitter.HZ())) (time.Second / time.Duration(t.jitter.HZ()))
rate, _ := t.rate.Estimate()
conns.tracks = append(conns.tracks, trackStats{ conns.tracks = append(conns.tracks, trackStats{
bitrate: uint64(t.rate.Estimate()) * 8, bitrate: uint64(rate) * 8,
maxBitrate: atomic.LoadUint64(&t.maxBitrate), maxBitrate: atomic.LoadUint64(&t.maxBitrate),
loss: loss, loss: loss,
jitter: jitter, jitter: jitter,
...@@ -590,13 +591,14 @@ func getClientStats(c *webClient) clientStats { ...@@ -590,13 +591,14 @@ func getClientStats(c *webClient) clientStats {
conns := connStats{id: down.id} conns := connStats{id: down.id}
for _, t := range down.tracks { for _, t := range down.tracks {
jiffies := rtptime.Jiffies() jiffies := rtptime.Jiffies()
rate, _ := t.rate.Estimate()
rtt := rtptime.ToDuration(atomic.LoadUint64(&t.rtt), rtt := rtptime.ToDuration(atomic.LoadUint64(&t.rtt),
rtptime.JiffiesPerSec) rtptime.JiffiesPerSec)
loss, jitter := t.stats.Get(jiffies) loss, jitter := t.stats.Get(jiffies)
j := time.Duration(jitter) * time.Second / j := time.Duration(jitter) * time.Second /
time.Duration(t.track.Codec().ClockRate) time.Duration(t.track.Codec().ClockRate)
conns.tracks = append(conns.tracks, trackStats{ conns.tracks = append(conns.tracks, trackStats{
bitrate: uint64(t.rate.Estimate()) * 8, bitrate: uint64(rate) * 8,
maxBitrate: t.GetMaxBitrate(jiffies), maxBitrate: t.GetMaxBitrate(jiffies),
loss: uint8(uint32(loss) * 100 / 256), loss: uint8(uint32(loss) * 100 / 256),
rtt: rtt, rtt: rtt,
......
...@@ -847,7 +847,8 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) { ...@@ -847,7 +847,8 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) {
if loss < 5 { if loss < 5 {
// if our actual rate is low, then we're not probing the // if our actual rate is low, then we're not probing the
// bottleneck // bottleneck
actual := 8 * uint64(track.rate.Estimate()) r, _ := track.rate.Estimate()
actual := 8 * uint64(r)
if actual >= (rate*7)/8 { if actual >= (rate*7)/8 {
// loss < 0.02, multiply by 1.05 // loss < 0.02, multiply by 1.05
rate = rate * 269 / 256 rate = rate * 269 / 256
...@@ -937,7 +938,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT ...@@ -937,7 +938,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT
} }
case *rtcp.TransportLayerNack: case *rtcp.TransportLayerNack:
maxBitrate := track.GetMaxBitrate(jiffies) maxBitrate := track.GetMaxBitrate(jiffies)
bitrate := track.rate.Estimate() bitrate, _ := track.rate.Estimate()
if uint64(bitrate)*7/8 < maxBitrate { if uint64(bitrate)*7/8 < maxBitrate {
sendRecovery(p, track) sendRecovery(p, track)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment