Commit 5c683108 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: close the streams when the underlying rwc closes

parent fe46093b
......@@ -13,14 +13,14 @@ import (
// to actually act as a server as well.
//
// MuxConn works using a fairly dumb multiplexing technique of simply
// prefixing each message with whether it is on stream 0 (the original)
// or stream 1 (the client "server").
// prefixing each message with what stream it is on along with the length
// of the data.
//
// This can likely be abstracted to N streams, but by choosing only two
// we decided to cut a lot of corners and make this easily usable for Packer.
type MuxConn struct {
rwc io.ReadWriteCloser
streams map[byte]io.Writer
streams map[byte]io.WriteCloser
mu sync.RWMutex
wlock sync.Mutex
}
......@@ -28,7 +28,7 @@ type MuxConn struct {
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
m := &MuxConn{
rwc: rwc,
streams: make(map[byte]io.Writer),
streams: make(map[byte]io.WriteCloser),
}
go m.loop()
......@@ -36,6 +36,21 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
return m
}
// Close closes the underlying io.ReadWriteCloser. This will also close
// all streams that are open.
func (m *MuxConn) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
// Close all the streams
for _, w := range m.streams {
w.Close()
}
m.streams = make(map[byte]io.WriteCloser)
return m.rwc.Close()
}
// Stream returns a io.ReadWriteCloser that will only read/write to the
// given stream ID. No handshake is done so if the remote end does not
// have a stream open with the same ID, then the messages will simply
......@@ -67,6 +82,8 @@ func (m *MuxConn) Stream(id byte) (io.ReadWriteCloser, error) {
}
func (m *MuxConn) loop() {
defer m.Close()
for {
var id byte
var length int32
......
......@@ -17,17 +17,14 @@ func readStream(t *testing.T, s io.Reader) string {
return string(data[0:n])
}
func TestMuxConn(t *testing.T) {
func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("err: %s", err)
}
// When the server is done
// Server side
doneCh := make(chan struct{})
readyCh := make(chan struct{})
// The server side
go func() {
defer close(doneCh)
conn, err := l.Accept()
......@@ -35,15 +32,42 @@ func TestMuxConn(t *testing.T) {
if err != nil {
t.Fatalf("err: %s", err)
}
defer conn.Close()
mux := NewMuxConn(conn)
s0, err := mux.Stream(0)
server = NewMuxConn(conn)
}()
// Client side
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("err: %s", err)
}
client = NewMuxConn(conn)
// Wait for the server
<-doneCh
return
}
func TestMuxConn(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
// When the server is done
doneCh := make(chan struct{})
readyCh := make(chan struct{})
// The server side
go func() {
defer close(doneCh)
s0, err := server.Stream(0)
if err != nil {
t.Fatalf("err: %s", err)
}
s1, err := mux.Stream(1)
s1, err := server.Stream(1)
if err != nil {
t.Fatalf("err: %s", err)
}
......@@ -72,20 +96,12 @@ func TestMuxConn(t *testing.T) {
wg.Wait()
}()
// Client side
conn, err := net.Dial("tcp", l.Addr().String())
s0, err := client.Stream(0)
if err != nil {
t.Fatalf("err: %s", err)
}
defer conn.Close()
mux := NewMuxConn(conn)
s0, err := mux.Stream(0)
if err != nil {
t.Fatalf("err: %s", err)
}
s1, err := mux.Stream(1)
s1, err := client.Stream(1)
if err != nil {
t.Fatalf("err: %s", err)
}
......@@ -103,3 +119,47 @@ func TestMuxConn(t *testing.T) {
// Wait for the server to be done
<-doneCh
}
func TestMuxConn_clientClosesStreams(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
s0, err := client.Stream(0)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := client.Close(); err != nil {
t.Fatalf("err: %s", err)
}
// This should block forever since we never write onto this stream.
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
func TestMuxConn_serverClosesStreams(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
s0, err := client.Stream(0)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := server.Close(); err != nil {
t.Fatalf("err: %s", err)
}
// This should block forever since we never write onto this stream.
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment