Commit b25baa62 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: fix a blocking issue

parent 57bde34c
...@@ -400,8 +400,6 @@ func (m *MuxConn) loop() { ...@@ -400,8 +400,6 @@ func (m *MuxConn) loop() {
stream.mu.Unlock() stream.mu.Unlock()
case muxPacketData: case muxPacketData:
unlocked := false
stream.mu.Lock() stream.mu.Lock()
switch stream.state { switch stream.state {
case streamStateFinWait1: case streamStateFinWait1:
...@@ -409,26 +407,15 @@ func (m *MuxConn) loop() { ...@@ -409,26 +407,15 @@ func (m *MuxConn) loop() {
case streamStateFinWait2: case streamStateFinWait2:
fallthrough fallthrough
case streamStateEstablished: case streamStateEstablished:
if len(data) > 0 { if len(data) > 0 && stream.writeCh != nil {
// Get a reference to the write channel while we have //log.Printf("[TRACE] %p: Stream %d (%s) WRITE-START", m, id, from)
// the lock because otherwise the field might change. stream.writeCh <- data
// We unlock early here because the write might block //log.Printf("[TRACE] %p: Stream %d (%s) WRITE-END", m, id, from)
// for a long time.
writeCh := stream.writeCh
stream.mu.Unlock()
unlocked = true
// Blocked write, this provides some backpressure on
// the connection if there is a lot of data incoming.
writeCh <- data
} }
default: default:
log.Printf("[ERR] Data received for stream in state: %d", stream.state) log.Printf("[ERR] Data received for stream in state: %d", stream.state)
} }
stream.mu.Unlock()
if !unlocked {
stream.mu.Unlock()
}
} }
} }
} }
...@@ -516,6 +503,7 @@ func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream { ...@@ -516,6 +503,7 @@ func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream {
go func() { go func() {
defer dataW.Close() defer dataW.Close()
drain := false
for { for {
data := <-writeCh data := <-writeCh
if data == nil { if data == nil {
...@@ -524,8 +512,14 @@ func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream { ...@@ -524,8 +512,14 @@ func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream {
return return
} }
if drain {
// We're draining, meaning we're just waiting for the
// write channel to close.
continue
}
if _, err := dataW.Write(data); err != nil { if _, err := dataW.Write(data); err != nil {
return drain = true
} }
} }
}() }()
...@@ -568,7 +562,10 @@ func (s *Stream) Write(p []byte) (int, error) { ...@@ -568,7 +562,10 @@ func (s *Stream) Write(p []byte) (int, error) {
} }
func (s *Stream) closeWriter() { func (s *Stream) closeWriter() {
s.writeCh <- nil if s.writeCh != nil {
s.writeCh <- nil
s.writeCh = nil
}
} }
func (s *Stream) setState(state streamState) { func (s *Stream) setState(state streamState) {
...@@ -594,6 +591,7 @@ func (s *Stream) waitState(target streamState) error { ...@@ -594,6 +591,7 @@ func (s *Stream) waitState(target streamState) error {
delete(s.stateChange, stateCh) delete(s.stateChange, stateCh)
}() }()
//log.Printf("[TRACE] %p: Stream %d (%s) waiting for state: %d", s.mux, s.id, s.from, target)
state := <-stateCh state := <-stateCh
if state == target { if state == target {
return nil return nil
......
...@@ -76,6 +76,7 @@ func TestMuxConn(t *testing.T) { ...@@ -76,6 +76,7 @@ func TestMuxConn(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
defer s1.Close()
data := readStream(t, s1) data := readStream(t, s1)
if data != "another" { if data != "another" {
t.Fatalf("bad: %#v", data) t.Fatalf("bad: %#v", data)
...@@ -84,6 +85,7 @@ func TestMuxConn(t *testing.T) { ...@@ -84,6 +85,7 @@ func TestMuxConn(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
defer s0.Close()
data := readStream(t, s0) data := readStream(t, s0)
if data != "hello" { if data != "hello" {
t.Fatalf("bad: %#v", data) t.Fatalf("bad: %#v", data)
...@@ -110,6 +112,9 @@ func TestMuxConn(t *testing.T) { ...@@ -110,6 +112,9 @@ func TestMuxConn(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
s0.Close()
s1.Close()
// Wait for the server to be done // Wait for the server to be done
<-doneCh <-doneCh
} }
...@@ -131,18 +136,20 @@ func TestMuxConn_lotsOfData(t *testing.T) { ...@@ -131,18 +136,20 @@ func TestMuxConn_lotsOfData(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
var wg sync.WaitGroup var data [1024]byte
wg.Add(1) for {
n, err := s0.Read(data[:])
if err == io.EOF {
break
}
go func() { dataString := string(data[0:n])
defer wg.Done() if dataString != "hello" {
data := readStream(t, s0) t.Fatalf("bad: %#v", dataString)
if data != "hello" {
t.Fatalf("bad: %#v", data)
} }
}() }
wg.Wait() s0.Close()
}() }()
s0, err := client.Dial(0) s0, err := client.Dial(0)
...@@ -156,6 +163,10 @@ func TestMuxConn_lotsOfData(t *testing.T) { ...@@ -156,6 +163,10 @@ func TestMuxConn_lotsOfData(t *testing.T) {
} }
} }
if err := s0.Close(); err != nil {
t.Fatalf("err: %s", err)
}
// Wait for the server to be done // Wait for the server to be done
<-doneCh <-doneCh
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment