Commit 72fcb566 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: better close states

parent af22b35a
...@@ -21,7 +21,7 @@ type MuxConn struct { ...@@ -21,7 +21,7 @@ type MuxConn struct {
curId uint32 curId uint32
rwc io.ReadWriteCloser rwc io.ReadWriteCloser
streams map[uint32]*Stream streams map[uint32]*Stream
mu sync.Mutex mu sync.RWMutex
wlock sync.Mutex wlock sync.Mutex
} }
...@@ -48,8 +48,8 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { ...@@ -48,8 +48,8 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
// Close closes the underlying io.ReadWriteCloser. This will also close // Close closes the underlying io.ReadWriteCloser. This will also close
// all streams that are open. // all streams that are open.
func (m *MuxConn) Close() error { func (m *MuxConn) Close() error {
m.mu.Lock() m.mu.RLock()
defer m.mu.Unlock() defer m.mu.RUnlock()
// Close all the streams // Close all the streams
for _, w := range m.streams { for _, w := range m.streams {
...@@ -94,12 +94,17 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { ...@@ -94,12 +94,17 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
switch stream.state { switch stream.state {
case streamStateListen: case streamStateListen:
stream.mu.Unlock() stream.mu.Unlock()
case streamStateClosed:
// This can happen if it becomes established, some data is sent,
// and it closed all within the time period we wait above.
// This case will be fixed when we have edge-triggered checks.
fallthrough
case streamStateEstablished: case streamStateEstablished:
stream.mu.Unlock() stream.mu.Unlock()
break ACCEPT_ESTABLISH_LOOP break ACCEPT_ESTABLISH_LOOP
default: default:
defer stream.mu.Unlock() defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream went to bad state: %d", stream.state) return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
} }
} }
} }
...@@ -140,12 +145,17 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { ...@@ -140,12 +145,17 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
switch stream.state { switch stream.state {
case streamStateSynSent: case streamStateSynSent:
stream.mu.Unlock() stream.mu.Unlock()
case streamStateClosed:
// This can happen if it becomes established, some data is sent,
// and it closed all within the time period we wait above.
// This case will be fixed when we have edge-triggered checks.
fallthrough
case streamStateEstablished: case streamStateEstablished:
stream.mu.Unlock() stream.mu.Unlock()
return stream, nil return stream, nil
default: default:
defer stream.mu.Unlock() defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream went to bad state: %d", stream.state) return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
} }
} }
} }
...@@ -166,9 +176,21 @@ func (m *MuxConn) NextId() uint32 { ...@@ -166,9 +176,21 @@ func (m *MuxConn) NextId() uint32 {
} }
func (m *MuxConn) openStream(id uint32) (*Stream, error) { func (m *MuxConn) openStream(id uint32) (*Stream, error) {
// First grab a read-lock if we have the stream already we can
// cheaply return it.
m.mu.RLock()
if stream, ok := m.streams[id]; ok {
m.mu.RUnlock()
return stream, nil
}
// Now acquire a full blown write lock so we can create the stream
m.mu.RUnlock()
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
// We have to check this again because there is a time period
// above where we couldn't lost this lock.
if stream, ok := m.streams[id]; ok { if stream, ok := m.streams[id]; ok {
return stream, nil return stream, nil
} }
...@@ -182,7 +204,6 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) { ...@@ -182,7 +204,6 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) {
id: id, id: id,
mux: m, mux: m,
reader: dataR, reader: dataR,
writer: dataW,
writeCh: writeCh, writeCh: writeCh,
} }
stream.setState(streamStateClosed) stream.setState(streamStateClosed)
...@@ -190,8 +211,16 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) { ...@@ -190,8 +211,16 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) {
// Start the goroutine that will read from the queue and write // Start the goroutine that will read from the queue and write
// data out. // data out.
go func() { go func() {
defer dataW.Close()
for { for {
data := <-writeCh data := <-writeCh
if data == nil {
// A nil is a tombstone letting us know we're done
// accepting data.
return
}
if _, err := dataW.Write(data); err != nil { if _, err := dataW.Write(data); err != nil {
return return
} }
...@@ -237,12 +266,16 @@ func (m *MuxConn) loop() { ...@@ -237,12 +266,16 @@ func (m *MuxConn) loop() {
return return
} }
log.Printf("[DEBUG] Stream %d received packet %d", id, packetType)
switch packetType { switch packetType {
case muxPacketAck: case muxPacketAck:
stream.mu.Lock() stream.mu.Lock()
if stream.state == streamStateSynSent { switch stream.state {
case streamStateSynSent:
stream.setState(streamStateEstablished) stream.setState(streamStateEstablished)
} else { case streamStateFinWait1:
stream.remoteClose()
default:
log.Printf("[ERR] Ack received for stream in state: %d", stream.state) log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
} }
stream.mu.Unlock() stream.mu.Unlock()
...@@ -259,13 +292,23 @@ func (m *MuxConn) loop() { ...@@ -259,13 +292,23 @@ func (m *MuxConn) loop() {
stream.mu.Unlock() stream.mu.Unlock()
case muxPacketFin: case muxPacketFin:
stream.mu.Lock() stream.mu.Lock()
stream.setState(streamStateClosed) switch stream.state {
stream.writer.Close() case streamStateEstablished:
m.write(id, muxPacketAck, nil)
fallthrough
case streamStateFinWait1:
stream.remoteClose()
// Remove this stream from being active so that it
// can be re-used
m.mu.Lock()
delete(m.streams, stream.id)
m.mu.Unlock()
default:
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
}
stream.mu.Unlock() stream.mu.Unlock()
m.mu.Lock()
delete(m.streams, stream.id)
m.mu.Unlock()
case muxPacketData: case muxPacketData:
stream.mu.Lock() stream.mu.Lock()
if stream.state == streamStateEstablished { if stream.state == streamStateEstablished {
...@@ -306,7 +349,6 @@ type Stream struct { ...@@ -306,7 +349,6 @@ type Stream struct {
id uint32 id uint32
mux *MuxConn mux *MuxConn
reader io.Reader reader io.Reader
writer io.WriteCloser
state streamState state streamState
stateUpdated time.Time stateUpdated time.Time
mu sync.Mutex mu sync.Mutex
...@@ -321,23 +363,37 @@ const ( ...@@ -321,23 +363,37 @@ const (
streamStateSynRecv streamStateSynRecv
streamStateSynSent streamStateSynSent
streamStateEstablished streamStateEstablished
streamStateFinWait streamStateFinWait1
) )
func (s *Stream) Close() error { func (s *Stream) Close() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock()
if s.state != streamStateEstablished { if s.state != streamStateEstablished {
s.mu.Unlock()
return fmt.Errorf("Stream in bad state: %d", s.state) return fmt.Errorf("Stream in bad state: %d", s.state)
} }
if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil { if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil {
return err return err
} }
s.setState(streamStateFinWait1)
s.mu.Unlock()
for {
time.Sleep(50 * time.Millisecond)
s.mu.Lock()
switch s.state {
case streamStateFinWait1:
s.mu.Unlock()
case streamStateClosed:
s.mu.Unlock()
return nil
default:
defer s.mu.Unlock()
return fmt.Errorf("Stream %d went to bad state: %d", s.id, s.state)
}
}
s.setState(streamStateClosed)
s.writer.Close()
return nil return nil
} }
...@@ -346,9 +402,22 @@ func (s *Stream) Read(p []byte) (int, error) { ...@@ -346,9 +402,22 @@ func (s *Stream) Read(p []byte) (int, error) {
} }
func (s *Stream) Write(p []byte) (int, error) { func (s *Stream) Write(p []byte) (int, error) {
s.mu.Lock()
state := s.state
s.mu.Unlock()
if state != streamStateEstablished {
return 0, fmt.Errorf("Stream in bad state to send: %d", state)
}
return s.mux.write(s.id, muxPacketData, p) return s.mux.write(s.id, muxPacketData, p)
} }
func (s *Stream) remoteClose() {
s.setState(streamStateClosed)
s.writeCh <- nil
}
func (s *Stream) setState(state streamState) { func (s *Stream) setState(state streamState) {
s.state = state s.state = state
s.stateUpdated = time.Now().UTC() s.stateUpdated = time.Now().UTC()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment