Commit 00fbfafe authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Make unbounded channels explicit.

We used to have unbounded channels embedded within rtpconn
and webClient.  Make the structure explicit and testable.
parent dcde4562
......@@ -23,6 +23,7 @@ import (
"github.com/jech/galene/packetcache"
"github.com/jech/galene/packetmap"
"github.com/jech/galene/rtptime"
"github.com/jech/galene/unbounded"
)
type bitrate struct {
......@@ -403,7 +404,7 @@ type rtpUpTrack struct {
jitter *jitter.Estimator
cname atomic.Value
actionCh chan struct{}
actions *unbounded.Channel[trackAction]
readerDone chan struct{}
mu sync.Mutex
......@@ -412,7 +413,6 @@ type rtpUpTrack struct {
srRTPTime uint32
local []conn.DownTrack
bufferedNACKs []uint16
actions []trackAction
}
const (
......@@ -427,17 +427,7 @@ type trackAction struct {
}
func (up *rtpUpTrack) action(action int, track conn.DownTrack) {
up.mu.Lock()
empty := len(up.actions) == 0
up.actions = append(up.actions, trackAction{action, track})
up.mu.Unlock()
if empty {
select {
case up.actionCh <- struct{}{}:
default:
}
}
up.actions.Put(trackAction{action, track})
}
func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error {
......@@ -682,7 +672,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon
cache: packetcache.New(minPacketCache(remote)),
rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate),
actionCh: make(chan struct{}, 1),
actions: unbounded.New[trackAction](),
readerDone: make(chan struct{}),
}
......@@ -923,7 +913,7 @@ func maxUpBitrate(t *rtpUpTrack) uint64 {
// assume that lower spatial layers take up 1/5 of
// the throughput
if maxsid > 0 {
maxrate = sadd(maxrate, maxrate / 4)
maxrate = sadd(maxrate, maxrate/4)
}
// assume that each layer takes two times less
// throughput than the higher one. Then we've
......@@ -1003,7 +993,7 @@ func sendUpRTCP(up *rtpUpConnection) error {
}
ssrcs = append(ssrcs, uint32(t.track.SSRC()))
if t.Kind() == webrtc.RTPCodecTypeAudio {
rate = sadd(rate, 100 * 1024)
rate = sadd(rate, 100*1024)
} else if t.Label() == "l" {
rate = sadd(rate, group.LowBitrate)
} else {
......
......@@ -31,11 +31,8 @@ func readLoop(track *rtpUpTrack) {
for {
select {
case <-track.actionCh:
track.mu.Lock()
actions := track.actions
track.actions = nil
track.mu.Unlock()
case <-track.actions.Ch:
actions := track.actions.Get()
for _, action := range actions {
switch action.action {
case trackActionAdd, trackActionDel:
......
......@@ -20,6 +20,7 @@ import (
"github.com/jech/galene/group"
"github.com/jech/galene/ice"
"github.com/jech/galene/token"
"github.com/jech/galene/unbounded"
)
func errorToWSCloseMessage(id string, err error) (*clientMessage, []byte) {
......@@ -65,16 +66,11 @@ type webClient struct {
done chan struct{}
writeCh chan interface{}
writerDone chan struct{}
actionCh chan struct{}
actions *unbounded.Channel[any]
mu sync.Mutex
down map[string]*rtpDownConnection
up map[string]*rtpUpConnection
// action may be called with the group mutex taken, and therefore
// actions needs to use its own mutex.
actionMu sync.Mutex
actions []interface{}
}
func (c *webClient) Group() *group.Group {
......@@ -106,9 +102,10 @@ func (c *webClient) SetPermissions(perms []string) {
}
func (c *webClient) PushClient(group, kind, id string, username string, perms []string, data map[string]interface{}) error {
return c.action(pushClientAction{
c.action(pushClientAction{
group, kind, id, username, perms, data,
})
return nil
}
type clientMessage struct {
......@@ -733,7 +730,8 @@ func (c *webClient) setRequestedStream(down *rtpDownConnection, requested []stri
}
func (c *webClient) RequestConns(target group.Client, g *group.Group, id string) error {
return c.action(requestConnsAction{g, target, id})
c.action(requestConnsAction{g, target, id})
return nil
}
func requestConns(target group.Client, g *group.Group, id string) {
......@@ -804,10 +802,7 @@ func requestedTracks(c *webClient, requested []string, tracks []conn.UpTrack) ([
}
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
err := c.action(pushConnAction{g, id, up, tracks, replace})
if err != nil {
return err
}
c.action(pushConnAction{g, id, up, tracks, replace})
return nil
}
......@@ -854,9 +849,9 @@ func StartClient(conn *websocket.Conn) (err error) {
}
c := &webClient{
id: m.Id,
actionCh: make(chan struct{}, 1),
done: make(chan struct{}),
id: m.Id,
actions: unbounded.New[any](),
done: make(chan struct{}),
}
defer close(c.done)
......@@ -996,11 +991,8 @@ func clientLoop(c *webClient, ws *websocket.Conn, versionError bool) error {
case error:
return m
}
case <-c.actionCh:
c.actionMu.Lock()
actions := c.actions
c.actions = nil
c.actionMu.Unlock()
case <-c.actions.Ch:
actions := c.actions.Get()
for _, a := range actions {
err := handleAction(c, a)
if err != nil {
......@@ -1090,7 +1082,7 @@ func pushDownConn(c *webClient, id string, up conn.Up, tracks []conn.UpTrack, re
return nil
}
func handleAction(c *webClient, a interface{}) error {
func handleAction(c *webClient, a any) error {
switch a := a.(type) {
case pushConnAction:
if c.group == nil || c.group != a.group {
......@@ -1353,15 +1345,18 @@ func setPermissions(g *group.Group, id string, perm string) error {
default:
return group.UserError("unknown permission")
}
return c.action(permissionsChangedAction{})
c.action(permissionsChangedAction{})
return nil
}
func (c *webClient) Kick(id string, user *string, message string) error {
return c.action(kickAction{id, user, message})
c.action(kickAction{id, user, message})
return nil
}
func (c *webClient) Joined(group, kind string) error {
return c.action(joinedAction{group, kind})
c.action(joinedAction{group, kind})
return nil
}
func kickClient(g *group.Group, id string, user *string, dest string, message string) error {
......@@ -2087,22 +2082,8 @@ func (c *webClient) Warn(oponly bool, message string) error {
var ErrClientDead = errors.New("client is dead")
func (c *webClient) action(a interface{}) error {
c.actionMu.Lock()
empty := len(c.actions) == 0
c.actions = append(c.actions, a)
c.actionMu.Unlock()
if empty {
select {
case c.actionCh <- struct{}{}:
return nil
case <-c.done:
return ErrClientDead
default:
}
}
return nil
func (c *webClient) action(a interface{}) {
c.actions.Put(a)
}
func (c *webClient) write(m clientMessage) error {
......
package unbounded
import (
"sync"
)
// Type Channel implements an unbounded channel
type Channel[T any] struct {
// Ch triggers whenever the channel becomes non-empty
Ch chan struct{}
mu sync.Mutex
queue []T
}
// New creates a new unbounded channel
func New[T any]() *Channel[T] {
return &Channel[T]{
Ch: make(chan struct{}, 1),
}
}
// Put inserts a new element into ch.
// If ch was previously empty, it triggers ch.Ch.
func (ch *Channel[T]) Put(v T) {
ch.mu.Lock()
empty := len(ch.queue) == 0
ch.queue = append(ch.queue, v)
ch.mu.Unlock()
if empty {
select {
case ch.Ch <- struct{}{}:
default:
}
}
}
// Get removes all the elements of ch.
// It is usually called when ch.Ch triggers, but may be called at any time.
func (ch *Channel[T]) Get() []T {
ch.mu.Lock()
defer ch.mu.Unlock()
queue := ch.queue
ch.queue = nil
return queue
}
package unbounded
import (
"testing"
"time"
)
func TestUnbounded(t *testing.T) {
ch := New[int]()
go func() {
for i := 0; i < 1000; i++ {
ch.Put(i)
}
}()
n := 0
for n < 1000 {
<-ch.Ch
vs := ch.Get()
for _, v := range vs {
if n != v {
t.Errorf("Expected %v, got %v", n, v)
}
n++
}
}
go func() {
for i := 0; i < 1000; i++ {
ch.Put(i)
time.Sleep(time.Microsecond)
}
}()
n = 0
for n < 1000 {
<-ch.Ch
vs := ch.Get()
for _, v := range vs {
if n != v {
t.Errorf("Expected %v, got %v", n, v)
}
n++
}
}
vs := ch.Get()
if len(vs) != 0 {
t.Errorf("Channel is not empty (%v)", len(vs))
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment