Commit d54b921c authored by Rob Pike's avatar Rob Pike

netchan: use acknowledgements on export send.

Also add exporter.Drain() to wait for completion.
This makes it possible for an Exporter to fire off a message
and wait (by calling Drain) for the message to be received,
even if a client has yet to call to retrieve it.

Once this design is settled, I'll do the same for import send.

Testing strategies welcome.  I have some working stand-alone
tests.

R=rsc
CC=golang-dev
https://golang.org/cl/2137041
parent 6405ab0f
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"os" "os"
"reflect" "reflect"
"sync" "sync"
"time"
) )
// The direction of a connection from the client's perspective. // The direction of a connection from the client's perspective.
...@@ -25,6 +26,7 @@ const ( ...@@ -25,6 +26,7 @@ const (
payRequest = iota // request structure follows payRequest = iota // request structure follows
payError // error structure follows payError // error structure follows
payData // user payload follows payData // user payload follows
payAck // acknowledgement; no payload
) )
// A header is sent as a prefix to every transmission. It will be followed by // A header is sent as a prefix to every transmission. It will be followed by
...@@ -32,13 +34,14 @@ const ( ...@@ -32,13 +34,14 @@ const (
type header struct { type header struct {
name string name string
payloadType int payloadType int
seqNum int64
} }
// Sent with a header once per channel from importer to exporter to report // Sent with a header once per channel from importer to exporter to report
// that it wants to bind to a channel with the specified direction for count // that it wants to bind to a channel with the specified direction for count
// messages. If count is zero, it means unlimited. // messages. If count is zero, it means unlimited.
type request struct { type request struct {
count int count int64
dir Dir dir Dir
} }
...@@ -47,6 +50,27 @@ type error struct { ...@@ -47,6 +50,27 @@ type error struct {
error string error string
} }
// Used to unify management of acknowledgements for import and export.
type unackedCounter interface {
unackedCount() int64
ack() int64
seq() int64
}
// A channel and its direction.
type chanDir struct {
ch *reflect.ChanValue
dir Dir
}
// clientSet contains the objects and methods needed for tracking
// clients of an exporter and draining outstanding messages.
type clientSet struct {
mu sync.Mutex // protects access to channel and client maps
chans map[string]*chanDir
clients map[unackedCounter]bool
}
// Mutex-protected encoder and decoder pair. // Mutex-protected encoder and decoder pair.
type encDec struct { type encDec struct {
decLock sync.Mutex decLock sync.Mutex
...@@ -79,10 +103,78 @@ func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.E ...@@ -79,10 +103,78 @@ func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.E
hdr.payloadType = payloadType hdr.payloadType = payloadType
err := ed.enc.Encode(hdr) err := ed.enc.Encode(hdr)
if err == nil { if err == nil {
err = ed.enc.Encode(payload) if payload != nil {
} else { err = ed.enc.Encode(payload)
}
}
if err != nil {
// TODO: tear down connection if there is an error? // TODO: tear down connection if there is an error?
} }
ed.encLock.Unlock() ed.encLock.Unlock()
return err return err
} }
// See the comment for Exporter.Drain.
func (cs *clientSet) drain(timeout int64) os.Error {
startTime := time.Nanoseconds()
for {
pending := false
cs.mu.Lock()
// Any messages waiting for a client?
for _, chDir := range cs.chans {
if chDir.ch.Len() > 0 {
pending = true
}
}
// Any unacknowledged messages?
for client := range cs.clients {
n := client.unackedCount()
if n > 0 { // Check for > rather than != just to be safe.
pending = true
break
}
}
cs.mu.Unlock()
if !pending {
break
}
if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
return os.ErrorString("timeout")
}
time.Sleep(100 * 1e6) // 100 milliseconds
}
return nil
}
// See the comment for Exporter.Sync.
func (cs *clientSet) sync(timeout int64) os.Error {
startTime := time.Nanoseconds()
// seq remembers the clients and their seqNum at point of entry.
seq := make(map[unackedCounter]int64)
for client := range cs.clients {
seq[client] = client.seq()
}
for {
pending := false
cs.mu.Lock()
// Any unacknowledged messages? Look only at clients that existed
// when we started and are still in this client set.
for client := range seq {
if _, ok := cs.clients[client]; ok {
if client.ack() < seq[client] {
pending = true
break
}
}
}
cs.mu.Unlock()
if !pending {
break
}
if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
return os.ErrorString("timeout")
}
time.Sleep(100 * 1e6) // 100 milliseconds
}
return nil
}
...@@ -31,59 +31,47 @@ import ( ...@@ -31,59 +31,47 @@ import (
// Export // Export
// A channel and its associated information: a direction plus
// a handy marshaling place for its data.
type exportChan struct {
ch *reflect.ChanValue
dir Dir
}
// An Exporter allows a set of channels to be published on a single // An Exporter allows a set of channels to be published on a single
// network port. A single machine may have multiple Exporters // network port. A single machine may have multiple Exporters
// but they must use different ports. // but they must use different ports.
type Exporter struct { type Exporter struct {
*clientSet
listener net.Listener listener net.Listener
chanLock sync.Mutex // protects access to channel map
chans map[string]*exportChan
} }
type expClient struct { type expClient struct {
*encDec *encDec
exp *Exporter exp *Exporter
mu sync.Mutex // protects remaining fields
errored bool // client has been sent an error
seqNum int64 // sequences messages sent to client; has value of highest sent
ackNum int64 // highest sequence number acknowledged
} }
func newClient(exp *Exporter, conn net.Conn) *expClient { func newClient(exp *Exporter, conn net.Conn) *expClient {
client := new(expClient) client := new(expClient)
client.exp = exp client.exp = exp
client.encDec = newEncDec(conn) client.encDec = newEncDec(conn)
client.seqNum = 0
client.ackNum = 0
return client return client
} }
// Wait for incoming connections, start a new runner for each
func (exp *Exporter) listen() {
for {
conn, err := exp.listener.Accept()
if err != nil {
log.Stderr("exporter.listen:", err)
break
}
client := newClient(exp, conn)
go client.run()
}
}
func (client *expClient) sendError(hdr *header, err string) { func (client *expClient) sendError(hdr *header, err string) {
error := &error{err} error := &error{err}
log.Stderr("export:", error.error) log.Stderr("export:", error.error)
client.encode(hdr, payError, error) // ignore any encode error, hope client gets it client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
client.mu.Lock()
client.errored = true
client.mu.Unlock()
} }
func (client *expClient) getChan(hdr *header, dir Dir) *exportChan { func (client *expClient) getChan(hdr *header, dir Dir) *chanDir {
exp := client.exp exp := client.exp
exp.chanLock.Lock() exp.mu.Lock()
ech, ok := exp.chans[hdr.name] ech, ok := exp.chans[hdr.name]
exp.chanLock.Unlock() exp.mu.Unlock()
if !ok { if !ok {
client.sendError(hdr, "no such channel: "+hdr.name) client.sendError(hdr, "no such channel: "+hdr.name)
return nil return nil
...@@ -95,9 +83,10 @@ func (client *expClient) getChan(hdr *header, dir Dir) *exportChan { ...@@ -95,9 +83,10 @@ func (client *expClient) getChan(hdr *header, dir Dir) *exportChan {
return ech return ech
} }
// Manage sends and receives for a single client. For each (client Recv) request, // The function run manages sends and receives for a single client. For each
// this will launch a serveRecv goroutine to deliver the data for that channel, // (client Recv) request, this will launch a serveRecv goroutine to deliver
// while (client Send) requests are handled as data arrives from the client. // the data for that channel, while (client Send) requests are handled as
// data arrives from the client.
func (client *expClient) run() { func (client *expClient) run() {
hdr := new(header) hdr := new(header)
hdrValue := reflect.NewValue(hdr) hdrValue := reflect.NewValue(hdr)
...@@ -107,15 +96,13 @@ func (client *expClient) run() { ...@@ -107,15 +96,13 @@ func (client *expClient) run() {
for { for {
if err := client.decode(hdrValue); err != nil { if err := client.decode(hdrValue); err != nil {
log.Stderr("error decoding client header:", err) log.Stderr("error decoding client header:", err)
// TODO: tear down connection break
return
} }
switch hdr.payloadType { switch hdr.payloadType {
case payRequest: case payRequest:
if err := client.decode(reqValue); err != nil { if err := client.decode(reqValue); err != nil {
log.Stderr("error decoding client request:", err) log.Stderr("error decoding client request:", err)
// TODO: tear down connection break
return
} }
switch req.dir { switch req.dir {
case Recv: case Recv:
...@@ -132,13 +119,27 @@ func (client *expClient) run() { ...@@ -132,13 +119,27 @@ func (client *expClient) run() {
} }
case payData: case payData:
client.serveSend(*hdr) client.serveSend(*hdr)
case payAck:
client.mu.Lock()
if client.ackNum != hdr.seqNum-1 {
// Since the sequence number is incremented and the message is sent
// in a single instance of locking client.mu, the messages are guaranteed
// to be sent in order. Therefore receipt of acknowledgement N means
// all messages <=N have been seen by the recipient. We check anyway.
log.Stderr("netchan export: sequence out of order:", client.ackNum, hdr.seqNum)
}
if client.ackNum < hdr.seqNum { // If there has been an error, don't back up the count.
client.ackNum = hdr.seqNum
}
client.mu.Unlock()
} }
} }
client.exp.delClient(client)
} }
// Send all the data on a single channel to a client asking for a Recv. // Send all the data on a single channel to a client asking for a Recv.
// The header is passed by value to avoid issues of overwriting. // The header is passed by value to avoid issues of overwriting.
func (client *expClient) serveRecv(hdr header, count int) { func (client *expClient) serveRecv(hdr header, count int64) {
ech := client.getChan(&hdr, Send) ech := client.getChan(&hdr, Send)
if ech == nil { if ech == nil {
return return
...@@ -149,7 +150,16 @@ func (client *expClient) serveRecv(hdr header, count int) { ...@@ -149,7 +150,16 @@ func (client *expClient) serveRecv(hdr header, count int) {
client.sendError(&hdr, os.EOF.String()) client.sendError(&hdr, os.EOF.String())
break break
} }
if err := client.encode(&hdr, payData, val.Interface()); err != nil { // We hold the lock during transmission to guarantee messages are
// sent in sequence number order. Also, we increment first so the
// value of client.seqNum is the value of the highest used sequence
// number, not one beyond.
client.mu.Lock()
client.seqNum++
hdr.seqNum = client.seqNum
err := client.encode(&hdr, payData, val.Interface())
client.mu.Unlock()
if err != nil {
log.Stderr("error encoding client response:", err) log.Stderr("error encoding client response:", err)
client.sendError(&hdr, err.String()) client.sendError(&hdr, err.String())
break break
...@@ -180,6 +190,40 @@ func (client *expClient) serveSend(hdr header) { ...@@ -180,6 +190,40 @@ func (client *expClient) serveSend(hdr header) {
// TODO count // TODO count
} }
func (client *expClient) unackedCount() int64 {
client.mu.Lock()
n := client.seqNum - client.ackNum
client.mu.Unlock()
return n
}
func (client *expClient) seq() int64 {
client.mu.Lock()
n := client.seqNum
client.mu.Unlock()
return n
}
func (client *expClient) ack() int64 {
client.mu.Lock()
n := client.seqNum
client.mu.Unlock()
return n
}
// Wait for incoming connections, start a new runner for each
func (exp *Exporter) listen() {
for {
conn, err := exp.listener.Accept()
if err != nil {
log.Stderr("exporter.listen:", err)
break
}
client := exp.addClient(conn)
go client.run()
}
}
// NewExporter creates a new Exporter to export channels // NewExporter creates a new Exporter to export channels
// on the network and local address defined as in net.Listen. // on the network and local address defined as in net.Listen.
func NewExporter(network, localaddr string) (*Exporter, os.Error) { func NewExporter(network, localaddr string) (*Exporter, os.Error) {
...@@ -189,12 +233,52 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) { ...@@ -189,12 +233,52 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) {
} }
e := &Exporter{ e := &Exporter{
listener: listener, listener: listener,
chans: make(map[string]*exportChan), clientSet: &clientSet{
chans: make(map[string]*chanDir),
clients: make(map[unackedCounter]bool),
},
} }
go e.listen() go e.listen()
return e, nil return e, nil
} }
// addClient creates a new expClient and records its existence
func (exp *Exporter) addClient(conn net.Conn) *expClient {
client := newClient(exp, conn)
exp.clients[client] = true
exp.mu.Unlock()
return client
}
// delClient forgets the client existed
func (exp *Exporter) delClient(client *expClient) {
exp.mu.Lock()
exp.clients[client] = false, false
exp.mu.Unlock()
}
// Drain waits until all messages sent from this exporter/importer, including
// those not yet sent to any client and possibly including those sent while
// Drain was executing, have been received by the importer. In short, it
// waits until all the exporter's messages have been received by a client.
// If the timeout (measured in nanoseconds) is positive and Drain takes
// longer than that to complete, an error is returned.
func (exp *Exporter) Drain(timeout int64) os.Error {
// This wrapper function is here so the method's comment will appear in godoc.
return exp.clientSet.drain(timeout)
}
// Sync waits until all clients of the exporter have received the messages
// that were sent at the time Sync was invoked. Unlike Drain, it does not
// wait for messages sent while it is running or messages that have not been
// dispatched to any client. If the timeout (measured in nanoseconds) is
// positive and Sync takes longer than that to complete, an error is
// returned.
func (exp *Exporter) Sync(timeout int64) os.Error {
// This wrapper function is here so the method's comment will appear in godoc.
return exp.clientSet.sync(timeout)
}
// Addr returns the Exporter's local network address. // Addr returns the Exporter's local network address.
func (exp *Exporter) Addr() net.Addr { return exp.listener.Addr() } func (exp *Exporter) Addr() net.Addr { return exp.listener.Addr() }
...@@ -230,12 +314,12 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { ...@@ -230,12 +314,12 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
if err != nil { if err != nil {
return err return err
} }
exp.chanLock.Lock() exp.mu.Lock()
defer exp.chanLock.Unlock() defer exp.mu.Unlock()
_, present := exp.chans[name] _, present := exp.chans[name]
if present { if present {
return os.ErrorString("channel name already being exported:" + name) return os.ErrorString("channel name already being exported:" + name)
} }
exp.chans[name] = &exportChan{ch, dir} exp.chans[name] = &chanDir{ch, dir}
return nil return nil
} }
...@@ -14,13 +14,6 @@ import ( ...@@ -14,13 +14,6 @@ import (
// Import // Import
// A channel and its associated information: a template value and direction,
// plus a handy marshaling place for its data.
type importChan struct {
ch *reflect.ChanValue
dir Dir
}
// An Importer allows a set of channels to be imported from a single // An Importer allows a set of channels to be imported from a single
// remote machine/network port. A machine may have multiple // remote machine/network port. A machine may have multiple
// importers, even from the same machine/network port. // importers, even from the same machine/network port.
...@@ -28,7 +21,7 @@ type Importer struct { ...@@ -28,7 +21,7 @@ type Importer struct {
*encDec *encDec
conn net.Conn conn net.Conn
chanLock sync.Mutex // protects access to channel map chanLock sync.Mutex // protects access to channel map
chans map[string]*importChan chans map[string]*chanDir
} }
// NewImporter creates a new Importer object to import channels // NewImporter creates a new Importer object to import channels
...@@ -43,7 +36,7 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) { ...@@ -43,7 +36,7 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
imp := new(Importer) imp := new(Importer)
imp.encDec = newEncDec(conn) imp.encDec = newEncDec(conn)
imp.conn = conn imp.conn = conn
imp.chans = make(map[string]*importChan) imp.chans = make(map[string]*chanDir)
go imp.run() go imp.run()
return imp, nil return imp, nil
} }
...@@ -67,6 +60,7 @@ func (imp *Importer) run() { ...@@ -67,6 +60,7 @@ func (imp *Importer) run() {
// Loop on responses; requests are sent by ImportNValues() // Loop on responses; requests are sent by ImportNValues()
hdr := new(header) hdr := new(header)
hdrValue := reflect.NewValue(hdr) hdrValue := reflect.NewValue(hdr)
ackHdr := new(header)
err := new(error) err := new(error)
errValue := reflect.NewValue(err) errValue := reflect.NewValue(err)
for { for {
...@@ -103,6 +97,10 @@ func (imp *Importer) run() { ...@@ -103,6 +97,10 @@ func (imp *Importer) run() {
log.Stderr("cannot happen: receive from non-Recv channel") log.Stderr("cannot happen: receive from non-Recv channel")
return return
} }
// Acknowledge receipt
ackHdr.name = hdr.name
ackHdr.seqNum = hdr.seqNum
imp.encode(ackHdr, payAck, nil)
// Create a new value for each received item. // Create a new value for each received item.
value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem()) value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem())
if e := imp.decode(value); e != nil { if e := imp.decode(value); e != nil {
...@@ -144,14 +142,10 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) ...@@ -144,14 +142,10 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int)
if present { if present {
return os.ErrorString("channel name already being imported:" + name) return os.ErrorString("channel name already being imported:" + name)
} }
imp.chans[name] = &importChan{ch, dir} imp.chans[name] = &chanDir{ch, dir}
// Tell the other side about this channel. // Tell the other side about this channel.
hdr := new(header) hdr := &header{name: name, payloadType: payRequest}
hdr.name = name req := &request{count: int64(n), dir: dir}
hdr.payloadType = payRequest
req := new(request)
req.dir = dir
req.count = n
if err := imp.encode(hdr, payRequest, req); err != nil { if err := imp.encode(hdr, payRequest, req); err != nil {
log.Stderr("importer request encode:", err) log.Stderr("importer request encode:", err)
return err return err
......
...@@ -37,7 +37,7 @@ func exportReceive(exp *Exporter, t *testing.T) { ...@@ -37,7 +37,7 @@ func exportReceive(exp *Exporter, t *testing.T) {
} }
} }
func importReceive(imp *Importer, t *testing.T) { func importReceive(imp *Importer, t *testing.T, done chan bool) {
ch := make(chan int) ch := make(chan int)
err := imp.ImportNValues("exportedSend", ch, Recv, count) err := imp.ImportNValues("exportedSend", ch, Recv, count)
if err != nil { if err != nil {
...@@ -55,6 +55,9 @@ func importReceive(imp *Importer, t *testing.T) { ...@@ -55,6 +55,9 @@ func importReceive(imp *Importer, t *testing.T) {
t.Errorf("importReceive: bad value: expected %d; got %+d", 23+i, v) t.Errorf("importReceive: bad value: expected %d; got %+d", 23+i, v)
} }
} }
if done != nil {
done <- true
}
} }
func importSend(imp *Importer, t *testing.T) { func importSend(imp *Importer, t *testing.T) {
...@@ -78,7 +81,7 @@ func TestExportSendImportReceive(t *testing.T) { ...@@ -78,7 +81,7 @@ func TestExportSendImportReceive(t *testing.T) {
t.Fatal("new importer:", err) t.Fatal("new importer:", err)
} }
exportSend(exp, count, t) exportSend(exp, count, t)
importReceive(imp, t) importReceive(imp, t, nil)
} }
func TestExportReceiveImportSend(t *testing.T) { func TestExportReceiveImportSend(t *testing.T) {
...@@ -104,5 +107,39 @@ func TestClosingExportSendImportReceive(t *testing.T) { ...@@ -104,5 +107,39 @@ func TestClosingExportSendImportReceive(t *testing.T) {
t.Fatal("new importer:", err) t.Fatal("new importer:", err)
} }
exportSend(exp, closeCount, t) exportSend(exp, closeCount, t)
importReceive(imp, t) importReceive(imp, t, nil)
}
// Not a great test but it does at least invoke Drain.
func TestExportDrain(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
done := make(chan bool)
go exportSend(exp, closeCount, t)
go importReceive(imp, t, done)
exp.Drain(0)
<-done
}
// Not a great test but it does at least invoke Sync.
func TestExportSync(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
done := make(chan bool)
go importReceive(imp, t, done)
exportSend(exp, closeCount, t)
exp.Sync(0)
<-done
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment