Commit 263258a0 authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Implement renegotiation of down streams.

We used to destroy and recreate down streams whenever something changed,
which turned out to be racy.  We now properly implement renegotiation,
as well as atomic replacement of a stream by another one.
parent 368da133
......@@ -17,7 +17,6 @@ type Up interface {
DelLocal(Down) bool
Id() string
User() (string, string)
Codecs() []webrtc.RTPCodecCapability
}
// Type UpTrack represents a track in the client to server direction.
......
......@@ -78,6 +78,7 @@ type downTrackAtomics struct {
type rtpDownTrack struct {
track *webrtc.TrackLocalStaticRTP
sender *webrtc.RTPSender
remote conn.UpTrack
ssrc webrtc.SSRC
maxBitrate *bitrate
......@@ -156,7 +157,7 @@ func (down *rtpDownConnection) getTracks() []*rtpDownTrack {
}
func newDownConn(c group.Client, id string, remote conn.Up) (*rtpDownConnection, error) {
api := group.APIFromCodecs(remote.Codecs())
api := c.Group().API()
pc, err := api.NewPeerConnection(*ice.ICEConfiguration())
if err != nil {
return nil, err
......@@ -366,17 +367,6 @@ func (up *rtpUpConnection) User() (string, string) {
return up.userId, up.username
}
func (up *rtpUpConnection) Codecs() []webrtc.RTPCodecCapability {
up.mu.Lock()
defer up.mu.Unlock()
codecs := make([]webrtc.RTPCodecCapability, len(up.tracks))
for i := range up.tracks {
codecs[i] = up.tracks[i].Codec()
}
return codecs
}
func (up *rtpUpConnection) AddLocal(local conn.Down) error {
up.mu.Lock()
defer up.mu.Unlock()
......
......@@ -315,55 +315,50 @@ func getConn(c *webClient, id string) iceConnection {
return nil
}
func addDownConn(c *webClient, id string, remote conn.Up) (*rtpDownConnection, error) {
conn, err := newDownConn(c, id, remote)
if err != nil {
return nil, err
}
err = addDownConnHelper(c, conn, remote)
if err != nil {
conn.pc.Close()
return nil, err
}
return conn, err
}
func addDownConn(c *webClient, remote conn.Up) (*rtpDownConnection, bool, error) {
id := remote.Id()
func addDownConnHelper(c *webClient, conn *rtpDownConnection, remote conn.Up) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.up != nil && c.up[conn.id] != nil {
return errors.New("Adding duplicate connection")
if c.up != nil && c.up[id] != nil {
return nil, false, errors.New("adding duplicate connection")
}
if c.down == nil {
c.down = make(map[string]*rtpDownConnection)
}
old := c.down[conn.id]
if old != nil {
// Avoid calling Close under a lock
go old.pc.Close()
if down := c.down[id]; down != nil {
return down, false, nil
}
conn.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
sendICE(c, conn.id, candidate)
down, err := newDownConn(c, id, remote)
if err != nil {
return nil, false, err
}
down.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
sendICE(c, down.id, candidate)
})
conn.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
down.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
if state == webrtc.ICEConnectionStateFailed {
c.action(connectionFailedAction{id: conn.id})
c.action(connectionFailedAction{id: down.id})
}
})
err := remote.AddLocal(conn)
err = remote.AddLocal(down)
if err != nil {
return err
down.pc.Close()
return nil, false, err
}
c.down[conn.id] = conn
return nil
c.down[down.id] = down
go rtcpDownSender(down)
return down, true, nil
}
func delDownConn(c *webClient, id string) error {
......@@ -397,46 +392,40 @@ func delDownConnHelper(c *webClient, id string) *rtpDownConnection {
return conn
}
func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrack, remoteConn conn.Up) (*webrtc.RTPSender, error) {
rt, ok := remoteTrack.(*rtpUpTrack)
if !ok {
return nil, errors.New("unexpected up track type")
}
conn.mu.Lock()
defer conn.mu.Unlock()
var errUnexpectedTrackType = errors.New("unexpected track type, this shouldn't happen")
remoteSSRC := rt.track.SSRC()
func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack, remoteConn conn.Up) error {
for _, t := range conn.tracks {
tt, ok := t.remote.(*rtpUpTrack)
if !ok {
return nil, errors.New("unexpected up track type")
return errUnexpectedTrackType
}
if tt.track.SSRC() == remoteSSRC {
return nil, os.ErrExist
if tt == remoteTrack {
return os.ErrExist
}
}
local, err := webrtc.NewTrackLocalStaticRTP(
remoteTrack.Codec(),
rt.track.ID(), rt.track.StreamID(),
remoteTrack.track.ID(), remoteTrack.track.StreamID(),
)
if err != nil {
return nil, err
return err
}
sender, err := conn.pc.AddTrack(local)
if err != nil {
return nil, err
return err
}
parms := sender.GetParameters()
if len(parms.Encodings) != 1 {
return nil, errors.New("got multiple encodings")
return errors.New("got multiple encodings")
}
track := &rtpDownTrack{
track: local,
sender: sender,
ssrc: parms.Encodings[0].SSRC,
remote: remoteTrack,
maxBitrate: new(bitrate),
......@@ -449,7 +438,79 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrac
go rtcpDownListener(conn, track, sender)
return sender, nil
return nil
}
func delDownTrackUnlocked(conn *rtpDownConnection, track *rtpDownTrack) error {
for i := range conn.tracks {
if conn.tracks[i] == track {
track.remote.DelLocal(track)
conn.tracks =
append(conn.tracks[:i], conn.tracks[i+1:]...)
return conn.pc.RemoveTrack(track.sender)
}
}
return os.ErrNotExist
}
func replaceTracks(conn *rtpDownConnection, remote []conn.UpTrack, remoteConn conn.Up) error {
conn.mu.Lock()
defer conn.mu.Unlock()
var add []*rtpUpTrack
var del []*rtpDownTrack
outer:
for _, rtrack := range remote {
rt, ok := rtrack.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
for _, track := range conn.tracks {
rt2, ok := track.remote.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
if rt == rt2 {
continue outer
}
}
add = append(add, rt)
}
outer2:
for _, track := range conn.tracks {
rt, ok := track.remote.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
for _, rtrack := range remote {
rt2, ok := rtrack.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
if rt == rt2 {
continue outer2
}
}
del = append(del, track)
}
for _, t := range del {
err := delDownTrackUnlocked(conn, t)
if err != nil {
return err
}
}
for _, rt := range add {
err := addDownTrackUnlocked(conn, rt, remoteConn)
if err != nil {
return err
}
}
return nil
}
func negotiate(c *webClient, down *rtpDownConnection, restartIce bool, replace string) error {
......@@ -618,20 +679,6 @@ func gotICE(c *webClient, candidate *webrtc.ICECandidateInit, id string) error {
}
func (c *webClient) setRequested(requested map[string]uint32) error {
if c.down != nil {
down := make([]string, 0, len(c.down))
for id := range c.down {
down = append(down, id)
}
for _, id := range down {
c.write(clientMessage{
Type: "close",
Id: id,
})
delDownConn(c, id)
}
}
c.requested = requested
go pushConns(c, c.group)
......@@ -652,40 +699,6 @@ func (c *webClient) isRequested(label string) bool {
return c.requested[label] != 0
}
func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rtpDownConnection, error) {
requested := false
for _, t := range tracks {
if c.isRequested(t.Label()) {
requested = true
break
}
}
if !requested {
delDownConn(c, remote.Id())
return nil, nil
}
down, err := addDownConn(c, remote.Id(), remote)
if err != nil {
return nil, err
}
for _, t := range tracks {
if !c.isRequested(t.Label()) {
continue
}
_, err = addDownTrack(c, down, t, remote)
if err != nil {
delDownConn(c, down.id)
return nil, err
}
}
go rtcpDownSender(down)
return down, nil
}
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
err := c.action(pushConnAction{g, id, up, tracks, replace})
if err != nil {
......@@ -823,45 +836,44 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
}
if a.conn == nil {
if a.replace != "" {
err := delDownConn(
c, a.replace,
closeDownConnection(
c, a.replace, "",
)
if err == nil {
c.write(clientMessage{
Type: "close",
Id: a.replace,
})
}
}
err := delDownConn(c, a.id)
if err == nil {
c.write(clientMessage{
Type: "close",
Id: a.id,
})
closeDownConnection(c, a.id, "")
continue
}
tracks := make([]conn.UpTrack, 0, len(a.tracks))
for _, t := range a.tracks {
if c.isRequested(t.Label()) {
tracks = append(tracks, t)
}
}
if len(tracks) == 0 {
closeDownConnection(c, a.id, "")
continue
}
down, err := addDownConnTracks(
c, a.conn, a.tracks,
)
down, _, err := addDownConn(c, a.conn)
if err != nil {
return err
}
if down != nil {
err = negotiate(
c, down, false, a.replace,
)
if err != nil {
log.Printf(
"Negotiation failed: %v",
err)
delDownConn(c, down.id)
c.error(group.UserError(
"Negotiation failed",
))
continue
}
err = replaceTracks(down, tracks, a.conn)
if err != nil {
return err
}
err = negotiate(
c, down, false, a.replace,
)
if err != nil {
log.Printf(
"Negotiation failed: %v",
err)
delDownConn(c, down.id)
c.error(group.UserError(
"Negotiation failed",
))
continue
}
case pushConnsAction:
g := c.group
......@@ -1011,7 +1023,11 @@ func leaveGroup(c *webClient) {
c.group = nil
}
func failDownConnection(c *webClient, id string, message string) error {
func closeDownConnection(c *webClient, id string, message string) error {
err := delDownConn(c, id)
if err != nil && !os.IsNotExist(err) {
log.Printf("Close down connection: %v", err)
}
if id != "" {
err := c.write(clientMessage{
Type: "close",
......@@ -1207,7 +1223,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
if err != ErrUnknownId {
message = "negotiation failed"
}
return failDownConnection(c, m.Id, message)
return closeDownConnection(c, m.Id, message)
}
down := getDownConn(c, m.Id)
if down.negotiationNeeded > negotiationUnneeded {
......@@ -1217,7 +1233,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
"",
)
if err != nil {
return failDownConnection(
return closeDownConnection(
c, m.Id, "negotiation failed",
)
}
......@@ -1227,7 +1243,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
if down != nil {
err := negotiate(c, down, true, "")
if err != nil {
return failDownConnection(
return closeDownConnection(
c, m.Id, "renegotiation failed",
)
}
......@@ -1242,14 +1258,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
return nil
}
case "abort":
err := delDownConn(c, m.Id)
if err != nil {
log.Printf("Abort: %v", err)
}
c.write(clientMessage{
Type: "close",
Id: m.Id,
})
return closeDownConnection(c, m.Id, "")
case "ice":
if m.Candidate == nil {
return group.ProtocolError("null candidate")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment