Commit 8071cdf7 authored by Rob Pike's avatar Rob Pike

handle errors better:

	1) terminate outstanding calls on the client when we see EOF from server
	2) allow data to drain on server before closing the connection

R=rsc
DELTA=41  (23 added, 4 deleted, 14 changed)
OCL=31687
CL=31689
parent aa1e8064
...@@ -30,7 +30,7 @@ type Call struct { ...@@ -30,7 +30,7 @@ type Call struct {
// Client represents an RPC Client. // Client represents an RPC Client.
type Client struct { type Client struct {
sync.Mutex; // protects pending, seq sync.Mutex; // protects pending, seq
closed bool; shutdown os.Error; // non-nil if the client is shut down
sending sync.Mutex; sending sync.Mutex;
seq uint64; seq uint64;
conn io.ReadWriteCloser; conn io.ReadWriteCloser;
...@@ -42,6 +42,12 @@ type Client struct { ...@@ -42,6 +42,12 @@ type Client struct {
func (client *Client) send(c *Call) { func (client *Client) send(c *Call) {
// Register this call. // Register this call.
client.Lock(); client.Lock();
if client.shutdown != nil {
client.Unlock();
c.Error = client.shutdown;
doNotBlock := c.Done <- c;
return;
}
c.seq = client.seq; c.seq = client.seq;
client.seq++; client.seq++;
client.pending[c.seq] = c; client.pending[c.seq] = c;
...@@ -66,10 +72,7 @@ func (client *Client) serve() { ...@@ -66,10 +72,7 @@ func (client *Client) serve() {
response := new(Response); response := new(Response);
err = client.dec.Decode(response); err = client.dec.Decode(response);
if err != nil { if err != nil {
if err == os.EOF { break
break;
}
break;
} }
seq := response.Seq; seq := response.Seq;
client.Lock(); client.Lock();
...@@ -82,7 +85,14 @@ func (client *Client) serve() { ...@@ -82,7 +85,14 @@ func (client *Client) serve() {
// sure the channel has enough buffer space. See comment in Go(). // sure the channel has enough buffer space. See comment in Go().
doNotBlock := c.Done <- c; doNotBlock := c.Done <- c;
} }
client.closed = true; // Terminate pending calls.
client.Lock();
client.shutdown = err;
for seq, call := range client.pending {
call.Error = err;
doNotBlock := call.Done <- call;
}
client.Unlock();
log.Stderr("client protocol error:", err); log.Stderr("client protocol error:", err);
} }
...@@ -144,8 +154,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface ...@@ -144,8 +154,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
// RPCs that will be using that channel. // RPCs that will be using that channel.
} }
c.Done = done; c.Done = done;
if client.closed { if client.shutdown != nil {
c.Error = os.EOF; c.Error = client.shutdown;
doNotBlock := c.Done <- c; doNotBlock := c.Done <- c;
return c; return c;
} }
...@@ -155,8 +165,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface ...@@ -155,8 +165,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
// Call invokes the named function, waits for it to complete, and returns its error status. // Call invokes the named function, waits for it to complete, and returns its error status.
func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error { func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
if client.closed { if client.shutdown != nil {
return os.EOF return client.shutdown
} }
call := <-client.Go(serviceMethod, args, reply, nil).Done; call := <-client.Go(serviceMethod, args, reply, nil).Done;
return call.Error; return call.Error;
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"reflect"; "reflect";
"strings"; "strings";
"sync"; "sync";
"time"; // See TODO in serve()
"unicode"; "unicode";
"utf8"; "utf8";
) )
...@@ -148,13 +149,13 @@ func _new(t *reflect.PtrType) *reflect.PtrValue { ...@@ -148,13 +149,13 @@ func _new(t *reflect.PtrType) *reflect.PtrValue {
return v; return v;
} }
func (s *service) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) { func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) {
resp := new(Response); resp := new(Response);
// Encode the response header // Encode the response header
sending.Lock();
resp.ServiceMethod = req.ServiceMethod; resp.ServiceMethod = req.ServiceMethod;
resp.Error = errmsg; resp.Error = errmsg;
resp.Seq = req.Seq; resp.Seq = req.Seq;
sending.Lock();
enc.Encode(resp); enc.Encode(resp);
// Encode the reply value. // Encode the reply value.
enc.Encode(reply); enc.Encode(reply);
...@@ -170,7 +171,7 @@ func (s *service) call(sending *sync.Mutex, function *reflect.FuncValue, req *Re ...@@ -170,7 +171,7 @@ func (s *service) call(sending *sync.Mutex, function *reflect.FuncValue, req *Re
if errInter != nil { if errInter != nil {
errmsg = errInter.(os.Error).String(); errmsg = errInter.(os.Error).String();
} }
s.sendResponse(sending, req, replyv.Interface(), enc, errmsg); sendResponse(sending, req, replyv.Interface(), enc, errmsg);
} }
func (server *serverType) serve(conn io.ReadWriteCloser) { func (server *serverType) serve(conn io.ReadWriteCloser) {
...@@ -182,25 +183,27 @@ func (server *serverType) serve(conn io.ReadWriteCloser) { ...@@ -182,25 +183,27 @@ func (server *serverType) serve(conn io.ReadWriteCloser) {
req := new(Request); req := new(Request);
err := dec.Decode(req); err := dec.Decode(req);
if err != nil { if err != nil {
log.Stderr("rpc: server cannot decode request:", err); s := "rpc: server cannot decode request: " + err.String();
sendResponse(sending, req, invalidRequest, enc, s);
break; break;
} }
serviceMethod := strings.Split(req.ServiceMethod, ".", 0); serviceMethod := strings.Split(req.ServiceMethod, ".", 0);
if len(serviceMethod) != 2 { if len(serviceMethod) != 2 {
log.Stderr("rpc: service/Method request ill-formed:", req.ServiceMethod); s := "rpc: service/method request ill:formed: " + req.ServiceMethod;
sendResponse(sending, req, invalidRequest, enc, s);
break; break;
} }
// Look up the request. // Look up the request.
service, ok := server.serviceMap[serviceMethod[0]]; service, ok := server.serviceMap[serviceMethod[0]];
if !ok { if !ok {
s := "rpc: can't find service " + req.ServiceMethod; s := "rpc: can't find service " + req.ServiceMethod;
service.sendResponse(sending, req, invalidRequest, enc, s); sendResponse(sending, req, invalidRequest, enc, s);
break; break;
} }
mtype, ok := service.method[serviceMethod[1]]; mtype, ok := service.method[serviceMethod[1]];
if !ok { if !ok {
s := "rpc: can't find method " + req.ServiceMethod; s := "rpc: can't find method " + req.ServiceMethod;
service.sendResponse(sending, req, invalidRequest, enc, s); sendResponse(sending, req, invalidRequest, enc, s);
break; break;
} }
method := mtype.method; method := mtype.method;
...@@ -210,11 +213,17 @@ func (server *serverType) serve(conn io.ReadWriteCloser) { ...@@ -210,11 +213,17 @@ func (server *serverType) serve(conn io.ReadWriteCloser) {
err = dec.Decode(argv.Interface()); err = dec.Decode(argv.Interface());
if err != nil { if err != nil {
log.Stderr("tearing down connection:", err); log.Stderr("tearing down connection:", err);
service.sendResponse(sending, req, replyv.Interface(), enc, err.String()); sendResponse(sending, req, replyv.Interface(), enc, err.String());
break; break;
} }
go service.call(sending, method.Func, req, argv, replyv, enc); go service.call(sending, method.Func, req, argv, replyv, enc);
} }
// TODO(r): Gobs cannot handle unexpected data yet. Once they can, we can
// ignore it and the connection can persist. For now, though, bad data
// ruins the connection and we must shut down. The sleep is necessary to
// guarantee all the data gets out before we close the connection, so the
// client can see the error description.
time.Sleep(2e9);
conn.Close(); conn.Close();
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment