Commit a2f46a98 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: MuxConn implements three-way handshake

parent 311fb206
...@@ -253,9 +253,10 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int ...@@ -253,9 +253,10 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int
} }
func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) { func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) {
log.Printf("[DEBUG] %s: Connecting to stream %d", name, id)
conn, err := mux.Accept(id) conn, err := mux.Accept(id)
if err != nil { if err != nil {
log.Printf("'%s' accept error: %s", name, err) log.Printf("[ERR] '%s' accept error: %s", name, err)
return return
} }
...@@ -271,8 +272,8 @@ func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io ...@@ -271,8 +272,8 @@ func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io
} }
written, err := io.Copy(dst, src) written, err := io.Copy(dst, src)
log.Printf("%d bytes written for '%s'", written, name) log.Printf("[INFO] %d bytes written for '%s'", written, name)
if err != nil { if err != nil {
log.Printf("'%s' copy error: %s", name, err) log.Printf("[ERR] '%s' copy error: %s", name, err)
} }
} }
...@@ -33,6 +33,7 @@ type muxPacketType byte ...@@ -33,6 +33,7 @@ type muxPacketType byte
const ( const (
muxPacketSyn muxPacketType = iota muxPacketSyn muxPacketType = iota
muxPacketSynAck
muxPacketAck muxPacketAck
muxPacketFin muxPacketFin
muxPacketData muxPacketData
...@@ -77,49 +78,27 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { ...@@ -77,49 +78,27 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
// If the stream isn't closed, then it is already open somehow // If the stream isn't closed, then it is already open somehow
stream.mu.Lock() stream.mu.Lock()
defer stream.mu.Unlock()
if stream.state != streamStateSynRecv && stream.state != streamStateClosed { if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state) return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
} }
if stream.state == streamStateSynRecv { if stream.state == streamStateClosed {
// Fast track establishing since we already got the syn // Go into the listening state and wait for a syn
stream.setState(streamStateEstablished)
stream.mu.Unlock()
}
if stream.state != streamStateEstablished {
// Go into the listening state
stream.setState(streamStateListen) stream.setState(streamStateListen)
if err := stream.waitState(streamStateSynRecv); err != nil {
// Register a state change listener to wait for changes return nil, err
stateCh := make(chan streamState, 10)
stream.registerStateListener(stateCh)
defer func() {
stream.mu.Lock()
defer stream.mu.Unlock()
stream.deregisterStateListener(stateCh)
}()
stream.mu.Unlock()
// Wait for the connection to establish
ACCEPT_ESTABLISH_LOOP:
for {
state := <-stateCh
switch state {
case streamStateListen:
case streamStateEstablished:
break ACCEPT_ESTABLISH_LOOP
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
} }
} }
if stream.state == streamStateSynRecv {
// Send a syn-ack
if _, err := m.write(stream.id, muxPacketSynAck, nil); err != nil {
return nil, err
}
} }
// Send the ack down if err := stream.waitState(streamStateEstablished); err != nil {
if _, err := m.write(stream.id, muxPacketAck, nil); err != nil {
return nil, err return nil, err
} }
...@@ -136,8 +115,8 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { ...@@ -136,8 +115,8 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
// If the stream isn't closed, then it is already open somehow // If the stream isn't closed, then it is already open somehow
stream.mu.Lock() stream.mu.Lock()
defer stream.mu.Unlock()
if stream.state != streamStateClosed { if stream.state != streamStateClosed {
stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state) return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
} }
...@@ -147,28 +126,12 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { ...@@ -147,28 +126,12 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
} }
stream.setState(streamStateSynSent) stream.setState(streamStateSynSent)
// Register a state change listener to wait for changes if err := stream.waitState(streamStateEstablished); err != nil {
stateCh := make(chan streamState, 10) return nil, err
stream.registerStateListener(stateCh) }
defer func() {
stream.mu.Lock()
defer stream.mu.Unlock()
stream.deregisterStateListener(stateCh)
}()
stream.mu.Unlock()
for { m.write(id, muxPacketAck, nil)
state := <-stateCh
switch state {
case streamStateSynSent:
case streamStateEstablished:
return stream, nil return stream, nil
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
}
}
} }
// NextId returns the next available stream ID that isn't currently // NextId returns the next available stream ID that isn't currently
...@@ -247,6 +210,7 @@ func (m *MuxConn) loop() { ...@@ -247,6 +210,7 @@ func (m *MuxConn) loop() {
// Force close every stream that we know about when we exit so // Force close every stream that we know about when we exit so
// that they all read EOF and don't block forever. // that they all read EOF and don't block forever.
defer func() { defer func() {
log.Printf("[INFO] Mux connection loop exiting")
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
for _, w := range m.streams { for _, w := range m.streams {
...@@ -288,12 +252,23 @@ func (m *MuxConn) loop() { ...@@ -288,12 +252,23 @@ func (m *MuxConn) loop() {
return return
} }
//log.Printf("[DEBUG] Stream %d received packet %d", id, packetType) log.Printf("[TRACE] Stream %d received packet %d", id, packetType)
switch packetType { switch packetType {
case muxPacketSyn:
stream.mu.Lock()
switch stream.state {
case streamStateClosed:
fallthrough
case streamStateListen:
stream.setState(streamStateSynRecv)
default:
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketAck: case muxPacketAck:
stream.mu.Lock() stream.mu.Lock()
switch stream.state { switch stream.state {
case streamStateSynSent: case streamStateSynRecv:
stream.setState(streamStateEstablished) stream.setState(streamStateEstablished)
case streamStateFinWait1: case streamStateFinWait1:
stream.setState(streamStateFinWait2) stream.setState(streamStateFinWait2)
...@@ -301,15 +276,13 @@ func (m *MuxConn) loop() { ...@@ -301,15 +276,13 @@ func (m *MuxConn) loop() {
log.Printf("[ERR] Ack received for stream in state: %d", stream.state) log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
} }
stream.mu.Unlock() stream.mu.Unlock()
case muxPacketSyn: case muxPacketSynAck:
stream.mu.Lock() stream.mu.Lock()
switch stream.state { switch stream.state {
case streamStateClosed: case streamStateSynSent:
stream.setState(streamStateSynRecv)
case streamStateListen:
stream.setState(streamStateEstablished) stream.setState(streamStateEstablished)
default: default:
log.Printf("[ERR] Syn received for stream in state: %d", stream.state) log.Printf("[ERR] SynAck received for stream in state: %d", stream.state)
} }
stream.mu.Unlock() stream.mu.Unlock()
case muxPacketFin: case muxPacketFin:
...@@ -451,6 +424,7 @@ func (s *Stream) deregisterStateListener(ch chan<- streamState) { ...@@ -451,6 +424,7 @@ func (s *Stream) deregisterStateListener(ch chan<- streamState) {
} }
func (s *Stream) setState(state streamState) { func (s *Stream) setState(state streamState) {
log.Printf("[TRACE] Stream %d went to state %d", s.id, state)
s.state = state s.state = state
s.stateUpdated = time.Now().UTC() s.stateUpdated = time.Now().UTC()
for ch, _ := range s.stateChange { for ch, _ := range s.stateChange {
...@@ -460,3 +434,22 @@ func (s *Stream) setState(state streamState) { ...@@ -460,3 +434,22 @@ func (s *Stream) setState(state streamState) {
} }
} }
} }
func (s *Stream) waitState(target streamState) error {
// Register a state change listener to wait for changes
stateCh := make(chan streamState, 10)
s.registerStateListener(stateCh)
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.deregisterStateListener(stateCh)
}()
state := <-stateCh
if state == target {
return nil
} else {
return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment