Commit 1a963919 authored by Roger Peppe's avatar Roger Peppe Committed by Rob Pike

netchan: allow use of arbitrary connections.

R=r, r2, rsc
CC=golang-dev
https://golang.org/cl/4119055
parent a54cbcec
...@@ -6,7 +6,7 @@ package netchan ...@@ -6,7 +6,7 @@ package netchan
import ( import (
"gob" "gob"
"net" "io"
"os" "os"
"reflect" "reflect"
"sync" "sync"
...@@ -93,7 +93,7 @@ type encDec struct { ...@@ -93,7 +93,7 @@ type encDec struct {
enc *gob.Encoder enc *gob.Encoder
} }
func newEncDec(conn net.Conn) *encDec { func newEncDec(conn io.ReadWriter) *encDec {
return &encDec{ return &encDec{
dec: gob.NewDecoder(conn), dec: gob.NewDecoder(conn),
enc: gob.NewEncoder(conn), enc: gob.NewEncoder(conn),
...@@ -199,9 +199,10 @@ func (cs *clientSet) sync(timeout int64) os.Error { ...@@ -199,9 +199,10 @@ func (cs *clientSet) sync(timeout int64) os.Error {
// are delivered into the local channel. // are delivered into the local channel.
type netChan struct { type netChan struct {
*chanDir *chanDir
name string name string
id int id int
size int // buffer size of channel. size int // buffer size of channel.
closed bool
// sender-specific state // sender-specific state
ackCh chan bool // buffered with space for all the acks we need ackCh chan bool // buffered with space for all the acks we need
...@@ -227,6 +228,9 @@ func newNetChan(name string, id int, ch *chanDir, ed *encDec, size int, count in ...@@ -227,6 +228,9 @@ func newNetChan(name string, id int, ch *chanDir, ed *encDec, size int, count in
// Close the channel. // Close the channel.
func (nch *netChan) close() { func (nch *netChan) close() {
if nch.closed {
return
}
if nch.dir == Recv { if nch.dir == Recv {
if nch.sendCh != nil { if nch.sendCh != nil {
// If the sender goroutine is active, close the channel to it. // If the sender goroutine is active, close the channel to it.
...@@ -239,6 +243,7 @@ func (nch *netChan) close() { ...@@ -239,6 +243,7 @@ func (nch *netChan) close() {
nch.ch.Close() nch.ch.Close()
close(nch.ackCh) close(nch.ackCh)
} }
nch.closed = true
} }
// Send message from remote side to local receiver. // Send message from remote side to local receiver.
......
...@@ -23,6 +23,7 @@ package netchan ...@@ -23,6 +23,7 @@ package netchan
import ( import (
"log" "log"
"io"
"net" "net"
"os" "os"
"reflect" "reflect"
...@@ -43,7 +44,6 @@ func expLog(args ...interface{}) { ...@@ -43,7 +44,6 @@ func expLog(args ...interface{}) {
// but they must use different ports. // but they must use different ports.
type Exporter struct { type Exporter struct {
*clientSet *clientSet
listener net.Listener
} }
type expClient struct { type expClient struct {
...@@ -57,7 +57,7 @@ type expClient struct { ...@@ -57,7 +57,7 @@ type expClient struct {
seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
} }
func newClient(exp *Exporter, conn net.Conn) *expClient { func newClient(exp *Exporter, conn io.ReadWriter) *expClient {
client := new(expClient) client := new(expClient)
client.exp = exp client.exp = exp
client.encDec = newEncDec(conn) client.encDec = newEncDec(conn)
...@@ -260,39 +260,50 @@ func (client *expClient) ack() int64 { ...@@ -260,39 +260,50 @@ func (client *expClient) ack() int64 {
return n return n
} }
// Wait for incoming connections, start a new runner for each // Serve waits for incoming connections on the listener
func (exp *Exporter) listen() { // and serves the Exporter's channels on each.
// It blocks until the listener is closed.
func (exp *Exporter) Serve(listener net.Listener) {
for { for {
conn, err := exp.listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
expLog("listen:", err) expLog("listen:", err)
break break
} }
client := exp.addClient(conn) go exp.ServeConn(conn)
go client.run()
} }
} }
// NewExporter creates a new Exporter to export channels // ServeConn exports the Exporter's channels on conn.
// on the network and local address defined as in net.Listen. // It blocks until the connection is terminated.
func NewExporter(network, localaddr string) (*Exporter, os.Error) { func (exp *Exporter) ServeConn(conn io.ReadWriter) {
listener, err := net.Listen(network, localaddr) exp.addClient(conn).run()
if err != nil { }
return nil, err
} // NewExporter creates a new Exporter that exports a set of channels.
func NewExporter() *Exporter {
e := &Exporter{ e := &Exporter{
listener: listener,
clientSet: &clientSet{ clientSet: &clientSet{
names: make(map[string]*chanDir), names: make(map[string]*chanDir),
clients: make(map[unackedCounter]bool), clients: make(map[unackedCounter]bool),
}, },
} }
go e.listen() return e
return e, nil }
// ListenAndServe exports the exporter's channels through the
// given network and local address defined as in net.Listen.
func (exp *Exporter) ListenAndServe(network, localaddr string) os.Error {
listener, err := net.Listen(network, localaddr)
if err != nil {
return err
}
go exp.Serve(listener)
return nil
} }
// addClient creates a new expClient and records its existence // addClient creates a new expClient and records its existence
func (exp *Exporter) addClient(conn net.Conn) *expClient { func (exp *Exporter) addClient(conn io.ReadWriter) *expClient {
client := newClient(exp, conn) client := newClient(exp, conn)
exp.mu.Lock() exp.mu.Lock()
exp.clients[client] = true exp.clients[client] = true
...@@ -329,9 +340,6 @@ func (exp *Exporter) Sync(timeout int64) os.Error { ...@@ -329,9 +340,6 @@ func (exp *Exporter) Sync(timeout int64) os.Error {
return exp.clientSet.sync(timeout) return exp.clientSet.sync(timeout)
} }
// Addr returns the Exporter's local network address.
func (exp *Exporter) Addr() net.Addr { return exp.listener.Addr() }
func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) { func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) {
chanType, ok := reflect.Typeof(chT).(*reflect.ChanType) chanType, ok := reflect.Typeof(chT).(*reflect.ChanType)
if !ok { if !ok {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package netchan package netchan
import ( import (
"io"
"log" "log"
"net" "net"
"os" "os"
...@@ -25,7 +26,6 @@ func impLog(args ...interface{}) { ...@@ -25,7 +26,6 @@ func impLog(args ...interface{}) {
// importers, even from the same machine/network port. // importers, even from the same machine/network port.
type Importer struct { type Importer struct {
*encDec *encDec
conn net.Conn
chanLock sync.Mutex // protects access to channel map chanLock sync.Mutex // protects access to channel map
names map[string]*netChan names map[string]*netChan
chans map[int]*netChan chans map[int]*netChan
...@@ -33,23 +33,26 @@ type Importer struct { ...@@ -33,23 +33,26 @@ type Importer struct {
maxId int maxId int
} }
// NewImporter creates a new Importer object to import channels // NewImporter creates a new Importer object to import a set of channels
// from an Exporter at the network and remote address as defined in net.Dial. // from the given connection. The Exporter must be available and serving when
// The Exporter must be available and serving when the Importer is // the Importer is created.
// created. func NewImporter(conn io.ReadWriter) *Importer {
func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
conn, err := net.Dial(network, "", remoteaddr)
if err != nil {
return nil, err
}
imp := new(Importer) imp := new(Importer)
imp.encDec = newEncDec(conn) imp.encDec = newEncDec(conn)
imp.conn = conn
imp.chans = make(map[int]*netChan) imp.chans = make(map[int]*netChan)
imp.names = make(map[string]*netChan) imp.names = make(map[string]*netChan)
imp.errors = make(chan os.Error, 10) imp.errors = make(chan os.Error, 10)
go imp.run() go imp.run()
return imp, nil return imp
}
// Import imports a set of channels from the given network and address.
func Import(network, remoteaddr string) (*Importer, os.Error) {
conn, err := net.Dial(network, "", remoteaddr)
if err != nil {
return nil, err
}
return NewImporter(conn), nil
} }
// shutdown closes all channels for which we are receiving data from the remote side. // shutdown closes all channels for which we are receiving data from the remote side.
...@@ -231,15 +234,13 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size, ...@@ -231,15 +234,13 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size,
// the channel. Messages in flight for the channel may be dropped. // the channel. Messages in flight for the channel may be dropped.
func (imp *Importer) Hangup(name string) os.Error { func (imp *Importer) Hangup(name string) os.Error {
imp.chanLock.Lock() imp.chanLock.Lock()
nc, ok := imp.names[name] defer imp.chanLock.Unlock()
if ok { nc := imp.names[name]
imp.names[name] = nil, false if nc == nil {
imp.chans[nc.id] = nil, false
}
imp.chanLock.Unlock()
if !ok {
return os.ErrorString("netchan import: hangup: no such channel: " + name) return os.ErrorString("netchan import: hangup: no such channel: " + name)
} }
imp.names[name] = nil, false
imp.chans[nc.id] = nil, false
nc.close() nc.close()
return nil return nil
} }
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package netchan package netchan
import ( import (
"net"
"strings" "strings"
"testing" "testing"
"time" "time"
...@@ -94,27 +95,13 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) { ...@@ -94,27 +95,13 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) {
} }
func TestExportSendImportReceive(t *testing.T) { func TestExportSendImportReceive(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
exportSend(exp, count, t, nil) exportSend(exp, count, t, nil)
importReceive(imp, t, nil) importReceive(imp, t, nil)
} }
func TestExportReceiveImportSend(t *testing.T) { func TestExportReceiveImportSend(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
expDone := make(chan bool) expDone := make(chan bool)
done := make(chan bool) done := make(chan bool)
go func() { go func() {
...@@ -127,27 +114,13 @@ func TestExportReceiveImportSend(t *testing.T) { ...@@ -127,27 +114,13 @@ func TestExportReceiveImportSend(t *testing.T) {
} }
func TestClosingExportSendImportReceive(t *testing.T) { func TestClosingExportSendImportReceive(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
exportSend(exp, closeCount, t, nil) exportSend(exp, closeCount, t, nil)
importReceive(imp, t, nil) importReceive(imp, t, nil)
} }
func TestClosingImportSendExportReceive(t *testing.T) { func TestClosingImportSendExportReceive(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
expDone := make(chan bool) expDone := make(chan bool)
done := make(chan bool) done := make(chan bool)
go func() { go func() {
...@@ -160,17 +133,10 @@ func TestClosingImportSendExportReceive(t *testing.T) { ...@@ -160,17 +133,10 @@ func TestClosingImportSendExportReceive(t *testing.T) {
} }
func TestErrorForIllegalChannel(t *testing.T) { func TestErrorForIllegalChannel(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
// Now export a channel. // Now export a channel.
ch := make(chan int, 1) ch := make(chan int, 1)
err = exp.Export("aChannel", ch, Send) err := exp.Export("aChannel", ch, Send)
if err != nil { if err != nil {
t.Fatal("export:", err) t.Fatal("export:", err)
} }
...@@ -200,14 +166,7 @@ func TestErrorForIllegalChannel(t *testing.T) { ...@@ -200,14 +166,7 @@ func TestErrorForIllegalChannel(t *testing.T) {
// Not a great test but it does at least invoke Drain. // Not a great test but it does at least invoke Drain.
func TestExportDrain(t *testing.T) { func TestExportDrain(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
done := make(chan bool) done := make(chan bool)
go func() { go func() {
exportSend(exp, closeCount, t, nil) exportSend(exp, closeCount, t, nil)
...@@ -221,14 +180,7 @@ func TestExportDrain(t *testing.T) { ...@@ -221,14 +180,7 @@ func TestExportDrain(t *testing.T) {
// Not a great test but it does at least invoke Sync. // Not a great test but it does at least invoke Sync.
func TestExportSync(t *testing.T) { func TestExportSync(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
done := make(chan bool) done := make(chan bool)
exportSend(exp, closeCount, t, nil) exportSend(exp, closeCount, t, nil)
go importReceive(imp, t, done) go importReceive(imp, t, done)
...@@ -239,16 +191,9 @@ func TestExportSync(t *testing.T) { ...@@ -239,16 +191,9 @@ func TestExportSync(t *testing.T) {
// Test hanging up the send side of an export. // Test hanging up the send side of an export.
// TODO: test hanging up the receive side of an export. // TODO: test hanging up the receive side of an export.
func TestExportHangup(t *testing.T) { func TestExportHangup(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
ech := make(chan int) ech := make(chan int)
err = exp.Export("exportedSend", ech, Send) err := exp.Export("exportedSend", ech, Send)
if err != nil { if err != nil {
t.Fatal("export:", err) t.Fatal("export:", err)
} }
...@@ -276,16 +221,9 @@ func TestExportHangup(t *testing.T) { ...@@ -276,16 +221,9 @@ func TestExportHangup(t *testing.T) {
// Test hanging up the send side of an import. // Test hanging up the send side of an import.
// TODO: test hanging up the receive side of an import. // TODO: test hanging up the receive side of an import.
func TestImportHangup(t *testing.T) { func TestImportHangup(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
ech := make(chan int) ech := make(chan int)
err = exp.Export("exportedRecv", ech, Recv) err := exp.Export("exportedRecv", ech, Recv)
if err != nil { if err != nil {
t.Fatal("export:", err) t.Fatal("export:", err)
} }
...@@ -343,14 +281,7 @@ func exportLoopback(exp *Exporter, t *testing.T) { ...@@ -343,14 +281,7 @@ func exportLoopback(exp *Exporter, t *testing.T) {
// This test checks that channel operations can proceed // This test checks that channel operations can proceed
// even when other concurrent operations are blocked. // even when other concurrent operations are blocked.
func TestIndependentSends(t *testing.T) { func TestIndependentSends(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
exportLoopback(exp, t) exportLoopback(exp, t)
...@@ -377,23 +308,8 @@ type value struct { ...@@ -377,23 +308,8 @@ type value struct {
} }
func TestCrossConnect(t *testing.T) { func TestCrossConnect(t *testing.T) {
e1, err := NewExporter("tcp", "127.0.0.1:0") e1, i1 := pair(t)
if err != nil { e2, i2 := pair(t)
t.Fatal("new exporter:", err)
}
i1, err := NewImporter("tcp", e1.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
e2, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
}
i2, err := NewImporter("tcp", e2.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
crossExport(e1, e2, t) crossExport(e1, e2, t)
crossImport(i1, i2, t) crossImport(i1, i2, t)
...@@ -452,20 +368,13 @@ const flowCount = 100 ...@@ -452,20 +368,13 @@ const flowCount = 100
// test flow control from exporter to importer. // test flow control from exporter to importer.
func TestExportFlowControl(t *testing.T) { func TestExportFlowControl(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
sendDone := make(chan bool, 1) sendDone := make(chan bool, 1)
exportSend(exp, flowCount, t, sendDone) exportSend(exp, flowCount, t, sendDone)
ch := make(chan int) ch := make(chan int)
err = imp.ImportNValues("exportedSend", ch, Recv, 20, -1) err := imp.ImportNValues("exportedSend", ch, Recv, 20, -1)
if err != nil { if err != nil {
t.Fatal("importReceive:", err) t.Fatal("importReceive:", err)
} }
...@@ -475,17 +384,10 @@ func TestExportFlowControl(t *testing.T) { ...@@ -475,17 +384,10 @@ func TestExportFlowControl(t *testing.T) {
// test flow control from importer to exporter. // test flow control from importer to exporter.
func TestImportFlowControl(t *testing.T) { func TestImportFlowControl(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0") exp, imp := pair(t)
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
ch := make(chan int) ch := make(chan int)
err = exp.Export("exportedRecv", ch, Recv) err := exp.Export("exportedRecv", ch, Recv)
if err != nil { if err != nil {
t.Fatal("importReceive:", err) t.Fatal("importReceive:", err)
} }
...@@ -513,3 +415,11 @@ func testFlow(sendDone chan bool, ch <-chan int, N int, t *testing.T) { ...@@ -513,3 +415,11 @@ func testFlow(sendDone chan bool, ch <-chan int, N int, t *testing.T) {
t.Fatalf("expected %d values; got %d", N, n) t.Fatalf("expected %d values; got %d", N, n)
} }
} }
func pair(t *testing.T) (*Exporter, *Importer) {
c0, c1 := net.Pipe()
exp := NewExporter()
go exp.ServeConn(c0)
imp := NewImporter(c1)
return exp, imp
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment