Commit 0e7bf0b3 authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Carry group around when pushing connections.

This avoids a race condition if the group changes before the connections
are pushed.
parent b134bfcf
...@@ -87,7 +87,11 @@ func (client *Client) Kick(id, user, message string) error { ...@@ -87,7 +87,11 @@ func (client *Client) Kick(id, user, message string) error {
return err return err
} }
func (client *Client) PushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error { func (client *Client) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, label string) error {
if client.group != g {
return nil
}
client.mu.Lock() client.mu.Lock()
defer client.mu.Unlock() defer client.mu.Unlock()
......
...@@ -97,7 +97,7 @@ type Client interface { ...@@ -97,7 +97,7 @@ type Client interface {
Challengeable Challengeable
SetPermissions(ClientPermissions) SetPermissions(ClientPermissions)
OverridePermissions(*Group) bool OverridePermissions(*Group) bool
PushConn(id string, conn conn.Up, tracks []conn.UpTrack, label string) error PushConn(g *Group, id string, conn conn.Up, tracks []conn.UpTrack, label string) error
PushClient(id, username string, add bool) error PushClient(id, username string, add bool) error
} }
......
...@@ -450,7 +450,7 @@ func newUpConn(c group.Client, id string) (*rtpUpConnection, error) { ...@@ -450,7 +450,7 @@ func newUpConn(c group.Client, id string) (*rtpUpConnection, error) {
if complete { if complete {
clients := c.Group().GetClients(c) clients := c.Group().GetClients(c)
for _, cc := range clients { for _, cc := range clients {
cc.PushConn(up.id, up, tracks, up.label) cc.PushConn(c.Group(), up.id, up, tracks, up.label)
} }
go rtcpUpSender(up) go rtcpUpSender(up)
} }
......
...@@ -258,7 +258,7 @@ func delUpConn(c *webClient, id string) bool { ...@@ -258,7 +258,7 @@ func delUpConn(c *webClient, id string) bool {
if g != nil { if g != nil {
go func(clients []group.Client) { go func(clients []group.Client) {
for _, c := range clients { for _, c := range clients {
err := c.PushConn(conn.id, nil, nil, "") err := c.PushConn(g, conn.id, nil, nil, "")
if err != nil { if err != nil {
log.Printf("PushConn: %v", err) log.Printf("PushConn: %v", err)
} }
...@@ -582,21 +582,16 @@ func (c *webClient) setRequested(requested map[string]uint32) error { ...@@ -582,21 +582,16 @@ func (c *webClient) setRequested(requested map[string]uint32) error {
c.requested = requested c.requested = requested
go pushConns(c) go pushConns(c, c.group)
return nil return nil
} }
func pushConns(c group.Client) { func pushConns(c group.Client, g *group.Group) {
group := c.Group() clients := g.GetClients(c)
if group == nil {
log.Printf("Pushing connections to unjoined client")
return
}
clients := group.GetClients(c)
for _, cc := range clients { for _, cc := range clients {
ccc, ok := cc.(*webClient) ccc, ok := cc.(*webClient)
if ok { if ok {
ccc.action(pushConnsAction{c}) ccc.action(pushConnsAction{g, c})
} }
} }
} }
...@@ -638,8 +633,8 @@ func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rt ...@@ -638,8 +633,8 @@ func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rt
return down, nil return down, nil
} }
func (c *webClient) PushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error { func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, label string) error {
err := c.action(pushConnAction{id, up, tracks}) err := c.action(pushConnAction{g, id, up, tracks})
if err != nil { if err != nil {
return err return err
} }
...@@ -709,6 +704,7 @@ func StartClient(conn *websocket.Conn) error { ...@@ -709,6 +704,7 @@ func StartClient(conn *websocket.Conn) error {
} }
type pushConnAction struct { type pushConnAction struct {
group *group.Group
id string id string
conn conn.Up conn conn.Up
tracks []conn.UpTrack tracks []conn.UpTrack
...@@ -720,7 +716,8 @@ type addLabelAction struct { ...@@ -720,7 +716,8 @@ type addLabelAction struct {
} }
type pushConnsAction struct { type pushConnsAction struct {
c group.Client group *group.Group
client group.Client
} }
type connectionFailedAction struct { type connectionFailedAction struct {
...@@ -736,24 +733,10 @@ type kickAction struct { ...@@ -736,24 +733,10 @@ type kickAction struct {
} }
func clientLoop(c *webClient, ws *websocket.Conn) error { func clientLoop(c *webClient, ws *websocket.Conn) error {
defer func() {
if c.group != nil {
group.DelClient(c)
c.group = nil
}
}()
read := make(chan interface{}, 1) read := make(chan interface{}, 1)
go clientReader(ws, read, c.done) go clientReader(ws, read, c.done)
defer func() { defer leaveGroup(c)
c.setRequested(map[string]uint32{})
if c.up != nil {
for id := range c.up {
delUpConn(c, id)
}
}
}()
readTime := time.Now() readTime := time.Now()
...@@ -779,6 +762,10 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { ...@@ -779,6 +762,10 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
case a := <-c.actionCh: case a := <-c.actionCh:
switch a := a.(type) { switch a := a.(type) {
case pushConnAction: case pushConnAction:
g := c.group
if g == nil || a.group != g {
return nil
}
if a.conn == nil { if a.conn == nil {
found := delDownConn(c, a.id) found := delDownConn(c, a.id)
if found { if found {
...@@ -821,6 +808,10 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { ...@@ -821,6 +808,10 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
Value: &label, Value: &label,
}) })
case pushConnsAction: case pushConnsAction:
g := c.group
if g == nil || a.group != g {
return nil
}
for _, u := range c.up { for _, u := range c.up {
if !u.complete() { if !u.complete() {
continue continue
...@@ -831,8 +822,8 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { ...@@ -831,8 +822,8 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
ts[i] = t ts[i] = t
} }
go func() { go func() {
err := a.c.PushConn( err := a.client.PushConn(
u.id, u, ts, u.label, g, u.id, u, ts, u.label,
) )
if err != nil { if err != nil {
log.Printf( log.Printf(
...@@ -855,6 +846,7 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { ...@@ -855,6 +846,7 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
tracks[i] = t.remote tracks[i] = t.remote
} }
go c.PushConn( go c.PushConn(
c.group,
down.remote.Id(), down.remote, down.remote.Id(), down.remote,
tracks, down.remote.Label(), tracks, down.remote.Label(),
) )
...@@ -935,6 +927,24 @@ func failUpConnection(c *webClient, id string, message string) error { ...@@ -935,6 +927,24 @@ func failUpConnection(c *webClient, id string, message string) error {
return nil return nil
} }
func leaveGroup(c *webClient) {
if c.group == nil {
return
}
c.setRequested(map[string]uint32{})
if c.up != nil {
for id := range c.up {
delUpConn(c, id)
}
}
group.DelClient(c)
c.permissions = group.ClientPermissions{}
c.group = nil
}
func failDownConnection(c *webClient, id string, message string) error { func failDownConnection(c *webClient, id string, message string) error {
if id != "" { if id != "" {
err := c.write(clientMessage{ err := c.write(clientMessage{
...@@ -1009,8 +1019,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1009,8 +1019,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
if c.group == nil || c.group.Name() != m.Group { if c.group == nil || c.group.Name() != m.Group {
return group.ProtocolError("you are not joined") return group.ProtocolError("you are not joined")
} }
c.group = nil leaveGroup(c)
c.permissions = group.ClientPermissions{}
perms := c.permissions perms := c.permissions
return c.write(clientMessage{ return c.write(clientMessage{
Type: "joined", Type: "joined",
...@@ -1245,7 +1254,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1245,7 +1254,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
disk.Close() disk.Close()
return c.error(err) return c.error(err)
} }
go pushConns(disk) go pushConns(disk, c.group)
case "unrecord": case "unrecord":
if !c.permissions.Record { if !c.permissions.Record {
return c.error(group.UserError("not authorised")) return c.error(group.UserError("not authorised"))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment