Commit bec978fd authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: Clean up old streams [GH-708]

parent b1f07dcb
......@@ -27,6 +27,7 @@ type MuxConn struct {
streams map[uint32]*Stream
mu sync.RWMutex
wlock sync.Mutex
doneCh chan struct{}
}
type muxPacketType byte
......@@ -44,8 +45,10 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
m := &MuxConn{
rwc: rwc,
streams: make(map[uint32]*Stream),
doneCh: make(chan struct{}),
}
go m.cleaner()
go m.loop()
return m
......@@ -211,18 +214,45 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) {
return m.streams[id], nil
}
func (m *MuxConn) cleaner() {
for {
done := false
select {
case <-time.After(500 * time.Millisecond):
case <-m.doneCh:
done = true
}
m.mu.Lock()
for id, s := range m.streams {
s.mu.Lock()
if s.state == streamStateClosed {
delete(m.streams, id)
}
s.mu.Unlock()
}
if done {
for _, s := range m.streams {
s.mu.Lock()
s.closeWriter()
s.mu.Unlock()
}
}
m.mu.Unlock()
if done {
return
}
}
}
func (m *MuxConn) loop() {
// Force close every stream that we know about when we exit so
// that they all read EOF and don't block forever.
defer func() {
log.Printf("[INFO] Mux connection loop exiting")
m.mu.Lock()
defer m.mu.Unlock()
for _, w := range m.streams {
w.mu.Lock()
w.remoteClose()
w.mu.Unlock()
}
close(m.doneCh)
}()
var id uint32
......@@ -277,6 +307,11 @@ func (m *MuxConn) loop() {
stream.setState(streamStateEstablished)
case streamStateFinWait1:
stream.setState(streamStateFinWait2)
case streamStateLastAck:
stream.closeWriter()
fallthrough
case streamStateClosing:
stream.setState(streamStateClosed)
default:
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
}
......@@ -294,20 +329,17 @@ func (m *MuxConn) loop() {
stream.mu.Lock()
switch stream.state {
case streamStateEstablished:
stream.closeWriter()
stream.setState(streamStateCloseWait)
m.write(id, muxPacketAck, nil)
// Close the writer on our end since we won't receive any
// more data.
stream.writeCh <- nil
case streamStateFinWait1:
fallthrough
case streamStateFinWait2:
stream.remoteClose()
m.mu.Lock()
delete(m.streams, stream.id)
m.mu.Unlock()
stream.closeWriter()
stream.setState(streamStateClosed)
m.write(id, muxPacketAck, nil)
case streamStateFinWait1:
stream.closeWriter()
stream.setState(streamStateClosing)
m.write(id, muxPacketAck, nil)
default:
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
}
......@@ -377,6 +409,8 @@ const (
streamStateFinWait1
streamStateFinWait2
streamStateCloseWait
streamStateClosing
streamStateLastAck
)
func (s *Stream) Close() error {
......@@ -390,7 +424,7 @@ func (s *Stream) Close() error {
if s.state == streamStateEstablished {
s.setState(streamStateFinWait1)
} else {
s.remoteClose()
s.setState(streamStateLastAck)
}
s.mux.write(s.id, muxPacketFin, nil)
......@@ -413,8 +447,7 @@ func (s *Stream) Write(p []byte) (int, error) {
return s.mux.write(s.id, muxPacketData, p)
}
func (s *Stream) remoteClose() {
s.setState(streamStateClosed)
func (s *Stream) closeWriter() {
s.writeCh <- nil
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment