Commit 752b47cf authored by Rob Pike's avatar Rob Pike

netchan: improve closing and shutdown. there's still more to do.

Fixes #805.

R=rsc
CC=golang-dev
https://golang.org/cl/1400041
parent 63f01491
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
use the channels in the usual way. use the channels in the usual way.
Networked channels are not synchronized; they always behave Networked channels are not synchronized; they always behave
as if there is a buffer of at least one element between the as if they are buffered channels of at least one element.
two machines.
*/ */
package netchan package netchan
// BUG: can't use range clause to receive when using ImportNValues with N non-zero.
import ( import (
"log" "log"
"net" "net"
...@@ -143,6 +144,10 @@ func (client *expClient) serveRecv(hdr header, count int) { ...@@ -143,6 +144,10 @@ func (client *expClient) serveRecv(hdr header, count int) {
} }
for { for {
val := ech.ch.Recv() val := ech.ch.Recv()
if ech.ch.Closed() {
client.sendError(&hdr, os.EOF.String())
break
}
if err := client.encode(&hdr, payData, val.Interface()); err != nil { if err := client.encode(&hdr, payData, val.Interface()); err != nil {
log.Stderr("error encoding client response:", err) log.Stderr("error encoding client response:", err)
client.sendError(&hdr, err.String()) client.sendError(&hdr, err.String())
......
...@@ -49,6 +49,17 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) { ...@@ -49,6 +49,17 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
return imp, nil return imp, nil
} }
// shutdown closes all channels for which we are receiving data from the remote side.
func (imp *Importer) shutdown() {
imp.chanLock.Lock()
for _, ich := range imp.chans {
if ich.dir == Recv {
ich.ch.Close()
}
}
imp.chanLock.Unlock()
}
// Handle the data from a single imported data stream, which will // Handle the data from a single imported data stream, which will
// have the form // have the form
// (response, data)* // (response, data)*
...@@ -60,6 +71,7 @@ func (imp *Importer) run() { ...@@ -60,6 +71,7 @@ func (imp *Importer) run() {
for { for {
if e := imp.decode(hdr); e != nil { if e := imp.decode(hdr); e != nil {
log.Stderr("importer header:", e) log.Stderr("importer header:", e)
imp.shutdown()
return return
} }
switch hdr.payloadType { switch hdr.payloadType {
...@@ -72,7 +84,7 @@ func (imp *Importer) run() { ...@@ -72,7 +84,7 @@ func (imp *Importer) run() {
} }
if err.error != "" { if err.error != "" {
log.Stderr("importer response error:", err.error) log.Stderr("importer response error:", err.error)
// TODO: tear down connection imp.shutdown()
return return
} }
default: default:
......
...@@ -11,17 +11,19 @@ type value struct { ...@@ -11,17 +11,19 @@ type value struct {
s string s string
} }
const count = 10 const count = 10 // number of items in most tests
const closeCount = 5 // number of items when sender closes early
func exportSend(exp *Exporter, t *testing.T) { func exportSend(exp *Exporter, n int, t *testing.T) {
ch := make(chan value) ch := make(chan value)
err := exp.Export("exportedSend", ch, Send, new(value)) err := exp.Export("exportedSend", ch, Send, new(value))
if err != nil { if err != nil {
t.Fatal("exportSend:", err) t.Fatal("exportSend:", err)
} }
for i := 0; i < count; i++ { for i := 0; i < n; i++ {
ch <- value{23 + i, "hello"} ch <- value{23 + i, "hello"}
} }
close(ch)
} }
func exportReceive(exp *Exporter, t *testing.T) { func exportReceive(exp *Exporter, t *testing.T) {
...@@ -46,6 +48,12 @@ func importReceive(imp *Importer, t *testing.T) { ...@@ -46,6 +48,12 @@ func importReceive(imp *Importer, t *testing.T) {
} }
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
v := <-ch v := <-ch
if closed(ch) {
if i != closeCount {
t.Errorf("expected close at %d; got one at %d\n", count/2, i)
}
break
}
if v.i != 23+i || v.s != "hello" { if v.i != 23+i || v.s != "hello" {
t.Errorf("importReceive: bad value: expected %d, hello; got %+v", 23+i, v) t.Errorf("importReceive: bad value: expected %d, hello; got %+v", 23+i, v)
} }
...@@ -72,7 +80,7 @@ func TestExportSendImportReceive(t *testing.T) { ...@@ -72,7 +80,7 @@ func TestExportSendImportReceive(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("new importer:", err) t.Fatal("new importer:", err)
} }
go exportSend(exp, t) go exportSend(exp, count, t)
importReceive(imp, t) importReceive(imp, t)
} }
...@@ -88,3 +96,16 @@ func TestExportReceiveImportSend(t *testing.T) { ...@@ -88,3 +96,16 @@ func TestExportReceiveImportSend(t *testing.T) {
go importSend(imp, t) go importSend(imp, t)
exportReceive(exp, t) exportReceive(exp, t)
} }
func TestClosingExportSendImportReceive(t *testing.T) {
exp, err := NewExporter("tcp", ":0")
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
go exportSend(exp, closeCount, t)
importReceive(imp, t)
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment