Commit a2f46a98 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: MuxConn implements three-way handshake

parent 311fb206
......@@ -253,9 +253,10 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int
}
func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) {
log.Printf("[DEBUG] %s: Connecting to stream %d", name, id)
conn, err := mux.Accept(id)
if err != nil {
log.Printf("'%s' accept error: %s", name, err)
log.Printf("[ERR] '%s' accept error: %s", name, err)
return
}
......@@ -271,8 +272,8 @@ func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io
}
written, err := io.Copy(dst, src)
log.Printf("%d bytes written for '%s'", written, name)
log.Printf("[INFO] %d bytes written for '%s'", written, name)
if err != nil {
log.Printf("'%s' copy error: %s", name, err)
log.Printf("[ERR] '%s' copy error: %s", name, err)
}
}
......@@ -33,6 +33,7 @@ type muxPacketType byte
const (
muxPacketSyn muxPacketType = iota
muxPacketSynAck
muxPacketAck
muxPacketFin
muxPacketData
......@@ -77,49 +78,27 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
// If the stream isn't closed, then it is already open somehow
stream.mu.Lock()
defer stream.mu.Unlock()
if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
}
if stream.state == streamStateSynRecv {
// Fast track establishing since we already got the syn
stream.setState(streamStateEstablished)
stream.mu.Unlock()
}
if stream.state != streamStateEstablished {
// Go into the listening state
if stream.state == streamStateClosed {
// Go into the listening state and wait for a syn
stream.setState(streamStateListen)
if err := stream.waitState(streamStateSynRecv); err != nil {
return nil, err
}
}
// Register a state change listener to wait for changes
stateCh := make(chan streamState, 10)
stream.registerStateListener(stateCh)
defer func() {
stream.mu.Lock()
defer stream.mu.Unlock()
stream.deregisterStateListener(stateCh)
}()
stream.mu.Unlock()
// Wait for the connection to establish
ACCEPT_ESTABLISH_LOOP:
for {
state := <-stateCh
switch state {
case streamStateListen:
case streamStateEstablished:
break ACCEPT_ESTABLISH_LOOP
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
}
if stream.state == streamStateSynRecv {
// Send a syn-ack
if _, err := m.write(stream.id, muxPacketSynAck, nil); err != nil {
return nil, err
}
}
// Send the ack down
if _, err := m.write(stream.id, muxPacketAck, nil); err != nil {
if err := stream.waitState(streamStateEstablished); err != nil {
return nil, err
}
......@@ -136,8 +115,8 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
// If the stream isn't closed, then it is already open somehow
stream.mu.Lock()
defer stream.mu.Unlock()
if stream.state != streamStateClosed {
stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
}
......@@ -147,28 +126,12 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
}
stream.setState(streamStateSynSent)
// Register a state change listener to wait for changes
stateCh := make(chan streamState, 10)
stream.registerStateListener(stateCh)
defer func() {
stream.mu.Lock()
defer stream.mu.Unlock()
stream.deregisterStateListener(stateCh)
}()
stream.mu.Unlock()
for {
state := <-stateCh
switch state {
case streamStateSynSent:
case streamStateEstablished:
return stream, nil
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
}
if err := stream.waitState(streamStateEstablished); err != nil {
return nil, err
}
m.write(id, muxPacketAck, nil)
return stream, nil
}
// NextId returns the next available stream ID that isn't currently
......@@ -247,6 +210,7 @@ func (m *MuxConn) loop() {
// Force close every stream that we know about when we exit so
// that they all read EOF and don't block forever.
defer func() {
log.Printf("[INFO] Mux connection loop exiting")
m.mu.Lock()
defer m.mu.Unlock()
for _, w := range m.streams {
......@@ -288,12 +252,23 @@ func (m *MuxConn) loop() {
return
}
//log.Printf("[DEBUG] Stream %d received packet %d", id, packetType)
log.Printf("[TRACE] Stream %d received packet %d", id, packetType)
switch packetType {
case muxPacketSyn:
stream.mu.Lock()
switch stream.state {
case streamStateClosed:
fallthrough
case streamStateListen:
stream.setState(streamStateSynRecv)
default:
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketAck:
stream.mu.Lock()
switch stream.state {
case streamStateSynSent:
case streamStateSynRecv:
stream.setState(streamStateEstablished)
case streamStateFinWait1:
stream.setState(streamStateFinWait2)
......@@ -301,15 +276,13 @@ func (m *MuxConn) loop() {
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketSyn:
case muxPacketSynAck:
stream.mu.Lock()
switch stream.state {
case streamStateClosed:
stream.setState(streamStateSynRecv)
case streamStateListen:
case streamStateSynSent:
stream.setState(streamStateEstablished)
default:
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
log.Printf("[ERR] SynAck received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketFin:
......@@ -451,6 +424,7 @@ func (s *Stream) deregisterStateListener(ch chan<- streamState) {
}
func (s *Stream) setState(state streamState) {
log.Printf("[TRACE] Stream %d went to state %d", s.id, state)
s.state = state
s.stateUpdated = time.Now().UTC()
for ch, _ := range s.stateChange {
......@@ -460,3 +434,22 @@ func (s *Stream) setState(state streamState) {
}
}
}
func (s *Stream) waitState(target streamState) error {
// Register a state change listener to wait for changes
stateCh := make(chan streamState, 10)
s.registerStateListener(stateCh)
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.deregisterStateListener(stateCh)
}()
state := <-stateCh
if state == target {
return nil
} else {
return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment