Commit 263258a0 authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Implement renegotiation of down streams.

We used to destroy and recreate down streams whenever something changed,
which turned out to be racy.  We now properly implement renegotiation,
as well as atomic replacement of a stream by another one.
parent 368da133
...@@ -17,7 +17,6 @@ type Up interface { ...@@ -17,7 +17,6 @@ type Up interface {
DelLocal(Down) bool DelLocal(Down) bool
Id() string Id() string
User() (string, string) User() (string, string)
Codecs() []webrtc.RTPCodecCapability
} }
// Type UpTrack represents a track in the client to server direction. // Type UpTrack represents a track in the client to server direction.
......
...@@ -78,6 +78,7 @@ type downTrackAtomics struct { ...@@ -78,6 +78,7 @@ type downTrackAtomics struct {
type rtpDownTrack struct { type rtpDownTrack struct {
track *webrtc.TrackLocalStaticRTP track *webrtc.TrackLocalStaticRTP
sender *webrtc.RTPSender
remote conn.UpTrack remote conn.UpTrack
ssrc webrtc.SSRC ssrc webrtc.SSRC
maxBitrate *bitrate maxBitrate *bitrate
...@@ -156,7 +157,7 @@ func (down *rtpDownConnection) getTracks() []*rtpDownTrack { ...@@ -156,7 +157,7 @@ func (down *rtpDownConnection) getTracks() []*rtpDownTrack {
} }
func newDownConn(c group.Client, id string, remote conn.Up) (*rtpDownConnection, error) { func newDownConn(c group.Client, id string, remote conn.Up) (*rtpDownConnection, error) {
api := group.APIFromCodecs(remote.Codecs()) api := c.Group().API()
pc, err := api.NewPeerConnection(*ice.ICEConfiguration()) pc, err := api.NewPeerConnection(*ice.ICEConfiguration())
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -366,17 +367,6 @@ func (up *rtpUpConnection) User() (string, string) { ...@@ -366,17 +367,6 @@ func (up *rtpUpConnection) User() (string, string) {
return up.userId, up.username return up.userId, up.username
} }
func (up *rtpUpConnection) Codecs() []webrtc.RTPCodecCapability {
up.mu.Lock()
defer up.mu.Unlock()
codecs := make([]webrtc.RTPCodecCapability, len(up.tracks))
for i := range up.tracks {
codecs[i] = up.tracks[i].Codec()
}
return codecs
}
func (up *rtpUpConnection) AddLocal(local conn.Down) error { func (up *rtpUpConnection) AddLocal(local conn.Down) error {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
......
...@@ -315,55 +315,50 @@ func getConn(c *webClient, id string) iceConnection { ...@@ -315,55 +315,50 @@ func getConn(c *webClient, id string) iceConnection {
return nil return nil
} }
func addDownConn(c *webClient, id string, remote conn.Up) (*rtpDownConnection, error) { func addDownConn(c *webClient, remote conn.Up) (*rtpDownConnection, bool, error) {
conn, err := newDownConn(c, id, remote) id := remote.Id()
if err != nil {
return nil, err
}
err = addDownConnHelper(c, conn, remote)
if err != nil {
conn.pc.Close()
return nil, err
}
return conn, err
}
func addDownConnHelper(c *webClient, conn *rtpDownConnection, remote conn.Up) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.up != nil && c.up[conn.id] != nil { if c.up != nil && c.up[id] != nil {
return errors.New("Adding duplicate connection") return nil, false, errors.New("adding duplicate connection")
} }
if c.down == nil { if c.down == nil {
c.down = make(map[string]*rtpDownConnection) c.down = make(map[string]*rtpDownConnection)
} }
old := c.down[conn.id] if down := c.down[id]; down != nil {
if old != nil { return down, false, nil
// Avoid calling Close under a lock
go old.pc.Close()
} }
conn.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { down, err := newDownConn(c, id, remote)
sendICE(c, conn.id, candidate) if err != nil {
return nil, false, err
}
down.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
sendICE(c, down.id, candidate)
}) })
conn.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { down.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
if state == webrtc.ICEConnectionStateFailed { if state == webrtc.ICEConnectionStateFailed {
c.action(connectionFailedAction{id: conn.id}) c.action(connectionFailedAction{id: down.id})
} }
}) })
err := remote.AddLocal(conn) err = remote.AddLocal(down)
if err != nil { if err != nil {
return err down.pc.Close()
return nil, false, err
} }
c.down[conn.id] = conn c.down[down.id] = down
return nil
go rtcpDownSender(down)
return down, true, nil
} }
func delDownConn(c *webClient, id string) error { func delDownConn(c *webClient, id string) error {
...@@ -397,46 +392,40 @@ func delDownConnHelper(c *webClient, id string) *rtpDownConnection { ...@@ -397,46 +392,40 @@ func delDownConnHelper(c *webClient, id string) *rtpDownConnection {
return conn return conn
} }
func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrack, remoteConn conn.Up) (*webrtc.RTPSender, error) { var errUnexpectedTrackType = errors.New("unexpected track type, this shouldn't happen")
rt, ok := remoteTrack.(*rtpUpTrack)
if !ok {
return nil, errors.New("unexpected up track type")
}
conn.mu.Lock()
defer conn.mu.Unlock()
remoteSSRC := rt.track.SSRC() func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack, remoteConn conn.Up) error {
for _, t := range conn.tracks { for _, t := range conn.tracks {
tt, ok := t.remote.(*rtpUpTrack) tt, ok := t.remote.(*rtpUpTrack)
if !ok { if !ok {
return nil, errors.New("unexpected up track type") return errUnexpectedTrackType
} }
if tt.track.SSRC() == remoteSSRC { if tt == remoteTrack {
return nil, os.ErrExist return os.ErrExist
} }
} }
local, err := webrtc.NewTrackLocalStaticRTP( local, err := webrtc.NewTrackLocalStaticRTP(
remoteTrack.Codec(), remoteTrack.Codec(),
rt.track.ID(), rt.track.StreamID(), remoteTrack.track.ID(), remoteTrack.track.StreamID(),
) )
if err != nil { if err != nil {
return nil, err return err
} }
sender, err := conn.pc.AddTrack(local) sender, err := conn.pc.AddTrack(local)
if err != nil { if err != nil {
return nil, err return err
} }
parms := sender.GetParameters() parms := sender.GetParameters()
if len(parms.Encodings) != 1 { if len(parms.Encodings) != 1 {
return nil, errors.New("got multiple encodings") return errors.New("got multiple encodings")
} }
track := &rtpDownTrack{ track := &rtpDownTrack{
track: local, track: local,
sender: sender,
ssrc: parms.Encodings[0].SSRC, ssrc: parms.Encodings[0].SSRC,
remote: remoteTrack, remote: remoteTrack,
maxBitrate: new(bitrate), maxBitrate: new(bitrate),
...@@ -449,7 +438,79 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrac ...@@ -449,7 +438,79 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrac
go rtcpDownListener(conn, track, sender) go rtcpDownListener(conn, track, sender)
return sender, nil return nil
}
func delDownTrackUnlocked(conn *rtpDownConnection, track *rtpDownTrack) error {
for i := range conn.tracks {
if conn.tracks[i] == track {
track.remote.DelLocal(track)
conn.tracks =
append(conn.tracks[:i], conn.tracks[i+1:]...)
return conn.pc.RemoveTrack(track.sender)
}
}
return os.ErrNotExist
}
func replaceTracks(conn *rtpDownConnection, remote []conn.UpTrack, remoteConn conn.Up) error {
conn.mu.Lock()
defer conn.mu.Unlock()
var add []*rtpUpTrack
var del []*rtpDownTrack
outer:
for _, rtrack := range remote {
rt, ok := rtrack.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
for _, track := range conn.tracks {
rt2, ok := track.remote.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
if rt == rt2 {
continue outer
}
}
add = append(add, rt)
}
outer2:
for _, track := range conn.tracks {
rt, ok := track.remote.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
for _, rtrack := range remote {
rt2, ok := rtrack.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
if rt == rt2 {
continue outer2
}
}
del = append(del, track)
}
for _, t := range del {
err := delDownTrackUnlocked(conn, t)
if err != nil {
return err
}
}
for _, rt := range add {
err := addDownTrackUnlocked(conn, rt, remoteConn)
if err != nil {
return err
}
}
return nil
} }
func negotiate(c *webClient, down *rtpDownConnection, restartIce bool, replace string) error { func negotiate(c *webClient, down *rtpDownConnection, restartIce bool, replace string) error {
...@@ -618,20 +679,6 @@ func gotICE(c *webClient, candidate *webrtc.ICECandidateInit, id string) error { ...@@ -618,20 +679,6 @@ func gotICE(c *webClient, candidate *webrtc.ICECandidateInit, id string) error {
} }
func (c *webClient) setRequested(requested map[string]uint32) error { func (c *webClient) setRequested(requested map[string]uint32) error {
if c.down != nil {
down := make([]string, 0, len(c.down))
for id := range c.down {
down = append(down, id)
}
for _, id := range down {
c.write(clientMessage{
Type: "close",
Id: id,
})
delDownConn(c, id)
}
}
c.requested = requested c.requested = requested
go pushConns(c, c.group) go pushConns(c, c.group)
...@@ -652,40 +699,6 @@ func (c *webClient) isRequested(label string) bool { ...@@ -652,40 +699,6 @@ func (c *webClient) isRequested(label string) bool {
return c.requested[label] != 0 return c.requested[label] != 0
} }
func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rtpDownConnection, error) {
requested := false
for _, t := range tracks {
if c.isRequested(t.Label()) {
requested = true
break
}
}
if !requested {
delDownConn(c, remote.Id())
return nil, nil
}
down, err := addDownConn(c, remote.Id(), remote)
if err != nil {
return nil, err
}
for _, t := range tracks {
if !c.isRequested(t.Label()) {
continue
}
_, err = addDownTrack(c, down, t, remote)
if err != nil {
delDownConn(c, down.id)
return nil, err
}
}
go rtcpDownSender(down)
return down, nil
}
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error { func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
err := c.action(pushConnAction{g, id, up, tracks, replace}) err := c.action(pushConnAction{g, id, up, tracks, replace})
if err != nil { if err != nil {
...@@ -823,45 +836,44 @@ func clientLoop(c *webClient, ws *websocket.Conn) error { ...@@ -823,45 +836,44 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
} }
if a.conn == nil { if a.conn == nil {
if a.replace != "" { if a.replace != "" {
err := delDownConn( closeDownConnection(
c, a.replace, c, a.replace, "",
) )
if err == nil {
c.write(clientMessage{
Type: "close",
Id: a.replace,
})
}
} }
err := delDownConn(c, a.id) closeDownConnection(c, a.id, "")
if err == nil { continue
c.write(clientMessage{ }
Type: "close", tracks := make([]conn.UpTrack, 0, len(a.tracks))
Id: a.id, for _, t := range a.tracks {
}) if c.isRequested(t.Label()) {
tracks = append(tracks, t)
} }
}
if len(tracks) == 0 {
closeDownConnection(c, a.id, "")
continue continue
} }
down, err := addDownConnTracks(
c, a.conn, a.tracks, down, _, err := addDownConn(c, a.conn)
)
if err != nil { if err != nil {
return err return err
} }
if down != nil { err = replaceTracks(down, tracks, a.conn)
err = negotiate( if err != nil {
c, down, false, a.replace, return err
) }
if err != nil { err = negotiate(
log.Printf( c, down, false, a.replace,
"Negotiation failed: %v", )
err) if err != nil {
delDownConn(c, down.id) log.Printf(
c.error(group.UserError( "Negotiation failed: %v",
"Negotiation failed", err)
)) delDownConn(c, down.id)
continue c.error(group.UserError(
} "Negotiation failed",
))
continue
} }
case pushConnsAction: case pushConnsAction:
g := c.group g := c.group
...@@ -1011,7 +1023,11 @@ func leaveGroup(c *webClient) { ...@@ -1011,7 +1023,11 @@ func leaveGroup(c *webClient) {
c.group = nil c.group = nil
} }
func failDownConnection(c *webClient, id string, message string) error { func closeDownConnection(c *webClient, id string, message string) error {
err := delDownConn(c, id)
if err != nil && !os.IsNotExist(err) {
log.Printf("Close down connection: %v", err)
}
if id != "" { if id != "" {
err := c.write(clientMessage{ err := c.write(clientMessage{
Type: "close", Type: "close",
...@@ -1207,7 +1223,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1207,7 +1223,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
if err != ErrUnknownId { if err != ErrUnknownId {
message = "negotiation failed" message = "negotiation failed"
} }
return failDownConnection(c, m.Id, message) return closeDownConnection(c, m.Id, message)
} }
down := getDownConn(c, m.Id) down := getDownConn(c, m.Id)
if down.negotiationNeeded > negotiationUnneeded { if down.negotiationNeeded > negotiationUnneeded {
...@@ -1217,7 +1233,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1217,7 +1233,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
"", "",
) )
if err != nil { if err != nil {
return failDownConnection( return closeDownConnection(
c, m.Id, "negotiation failed", c, m.Id, "negotiation failed",
) )
} }
...@@ -1227,7 +1243,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1227,7 +1243,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
if down != nil { if down != nil {
err := negotiate(c, down, true, "") err := negotiate(c, down, true, "")
if err != nil { if err != nil {
return failDownConnection( return closeDownConnection(
c, m.Id, "renegotiation failed", c, m.Id, "renegotiation failed",
) )
} }
...@@ -1242,14 +1258,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1242,14 +1258,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
return nil return nil
} }
case "abort": case "abort":
err := delDownConn(c, m.Id) return closeDownConnection(c, m.Id, "")
if err != nil {
log.Printf("Abort: %v", err)
}
c.write(clientMessage{
Type: "close",
Id: m.Id,
})
case "ice": case "ice":
if m.Candidate == nil { if m.Candidate == nil {
return group.ProtocolError("null candidate") return group.ProtocolError("null candidate")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment