Commit 72fcb566 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: better close states

parent af22b35a
......@@ -21,7 +21,7 @@ type MuxConn struct {
curId uint32
rwc io.ReadWriteCloser
streams map[uint32]*Stream
mu sync.Mutex
mu sync.RWMutex
wlock sync.Mutex
}
......@@ -48,8 +48,8 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
// Close closes the underlying io.ReadWriteCloser. This will also close
// all streams that are open.
func (m *MuxConn) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
// Close all the streams
for _, w := range m.streams {
......@@ -94,12 +94,17 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
switch stream.state {
case streamStateListen:
stream.mu.Unlock()
case streamStateClosed:
// This can happen if it becomes established, some data is sent,
// and it closed all within the time period we wait above.
// This case will be fixed when we have edge-triggered checks.
fallthrough
case streamStateEstablished:
stream.mu.Unlock()
break ACCEPT_ESTABLISH_LOOP
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream went to bad state: %d", stream.state)
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
}
}
}
......@@ -140,12 +145,17 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
switch stream.state {
case streamStateSynSent:
stream.mu.Unlock()
case streamStateClosed:
// This can happen if it becomes established, some data is sent,
// and it closed all within the time period we wait above.
// This case will be fixed when we have edge-triggered checks.
fallthrough
case streamStateEstablished:
stream.mu.Unlock()
return stream, nil
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream went to bad state: %d", stream.state)
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
}
}
}
......@@ -166,9 +176,21 @@ func (m *MuxConn) NextId() uint32 {
}
func (m *MuxConn) openStream(id uint32) (*Stream, error) {
// First grab a read-lock if we have the stream already we can
// cheaply return it.
m.mu.RLock()
if stream, ok := m.streams[id]; ok {
m.mu.RUnlock()
return stream, nil
}
// Now acquire a full blown write lock so we can create the stream
m.mu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
// We have to check this again because there is a time period
// above where we couldn't lost this lock.
if stream, ok := m.streams[id]; ok {
return stream, nil
}
......@@ -182,7 +204,6 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) {
id: id,
mux: m,
reader: dataR,
writer: dataW,
writeCh: writeCh,
}
stream.setState(streamStateClosed)
......@@ -190,8 +211,16 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) {
// Start the goroutine that will read from the queue and write
// data out.
go func() {
defer dataW.Close()
for {
data := <-writeCh
if data == nil {
// A nil is a tombstone letting us know we're done
// accepting data.
return
}
if _, err := dataW.Write(data); err != nil {
return
}
......@@ -237,12 +266,16 @@ func (m *MuxConn) loop() {
return
}
log.Printf("[DEBUG] Stream %d received packet %d", id, packetType)
switch packetType {
case muxPacketAck:
stream.mu.Lock()
if stream.state == streamStateSynSent {
switch stream.state {
case streamStateSynSent:
stream.setState(streamStateEstablished)
} else {
case streamStateFinWait1:
stream.remoteClose()
default:
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
......@@ -259,13 +292,23 @@ func (m *MuxConn) loop() {
stream.mu.Unlock()
case muxPacketFin:
stream.mu.Lock()
stream.setState(streamStateClosed)
stream.writer.Close()
switch stream.state {
case streamStateEstablished:
m.write(id, muxPacketAck, nil)
fallthrough
case streamStateFinWait1:
stream.remoteClose()
// Remove this stream from being active so that it
// can be re-used
m.mu.Lock()
delete(m.streams, stream.id)
m.mu.Unlock()
default:
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
}
stream.mu.Unlock()
m.mu.Lock()
delete(m.streams, stream.id)
m.mu.Unlock()
case muxPacketData:
stream.mu.Lock()
if stream.state == streamStateEstablished {
......@@ -306,7 +349,6 @@ type Stream struct {
id uint32
mux *MuxConn
reader io.Reader
writer io.WriteCloser
state streamState
stateUpdated time.Time
mu sync.Mutex
......@@ -321,23 +363,37 @@ const (
streamStateSynRecv
streamStateSynSent
streamStateEstablished
streamStateFinWait
streamStateFinWait1
)
func (s *Stream) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.state != streamStateEstablished {
s.mu.Unlock()
return fmt.Errorf("Stream in bad state: %d", s.state)
}
if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil {
return err
}
s.setState(streamStateFinWait1)
s.mu.Unlock()
for {
time.Sleep(50 * time.Millisecond)
s.mu.Lock()
switch s.state {
case streamStateFinWait1:
s.mu.Unlock()
case streamStateClosed:
s.mu.Unlock()
return nil
default:
defer s.mu.Unlock()
return fmt.Errorf("Stream %d went to bad state: %d", s.id, s.state)
}
}
s.setState(streamStateClosed)
s.writer.Close()
return nil
}
......@@ -346,9 +402,22 @@ func (s *Stream) Read(p []byte) (int, error) {
}
func (s *Stream) Write(p []byte) (int, error) {
s.mu.Lock()
state := s.state
s.mu.Unlock()
if state != streamStateEstablished {
return 0, fmt.Errorf("Stream in bad state to send: %d", state)
}
return s.mux.write(s.id, muxPacketData, p)
}
func (s *Stream) remoteClose() {
s.setState(streamStateClosed)
s.writeCh <- nil
}
func (s *Stream) setState(state streamState) {
s.state = state
s.stateUpdated = time.Now().UTC()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment