Commit 649060ab authored by Jacob Vosmaer's avatar Jacob Vosmaer

Downgrade grpc to get Go 1.5 compatibility

parent e1e34707
# gRPC-Go #gRPC-Go
[![Build Status](https://travis-ci.org/grpc/grpc-go.svg)](https://travis-ci.org/grpc/grpc-go) [![GoDoc](https://godoc.org/google.golang.org/grpc?status.svg)](https://godoc.org/google.golang.org/grpc) [![Build Status](https://travis-ci.org/grpc/grpc-go.svg)](https://travis-ci.org/grpc/grpc-go) [![GoDoc](https://godoc.org/google.golang.org/grpc?status.svg)](https://godoc.org/google.golang.org/grpc)
...@@ -16,7 +16,23 @@ $ go get google.golang.org/grpc ...@@ -16,7 +16,23 @@ $ go get google.golang.org/grpc
Prerequisites Prerequisites
------------- -------------
This requires Go 1.6 or later. This requires Go 1.5 or later.
A note on the version used: significant performance improvements in benchmarks
of grpc-go have been seen by upgrading the go version from 1.5 to the latest
1.7.1.
From https://golang.org/doc/install, one way to install the latest version of go is:
```
$ GO_VERSION=1.7.1
$ OS=linux
$ ARCH=amd64
$ curl -O https://storage.googleapis.com/golang/go${GO_VERSION}.${OS}-${ARCH}.tar.gz
$ sudo tar -C /usr/local -xzf go$GO_VERSION.$OS-$ARCH.tar.gz
$ # Put go on the PATH, keep the usual installation dir
$ sudo ln -s /usr/local/go/bin/go /usr/bin/go
$ rm go$GO_VERSION.$OS-$ARCH.tar.gz
```
Constraints Constraints
----------- -----------
......
...@@ -36,14 +36,13 @@ package grpc ...@@ -36,14 +36,13 @@ package grpc
import ( import (
"bytes" "bytes"
"io" "io"
"math"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
...@@ -73,22 +72,19 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran ...@@ -73,22 +72,19 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
} }
} }
for { for {
if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil { if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, inPayload); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
return return
} }
} }
if inPayload != nil && err == io.EOF && stream.Status().Code() == codes.OK { if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK {
// TODO in the current implementation, inTrailer may be handled before inPayload in some cases. // TODO in the current implementation, inTrailer may be handled before inPayload in some cases.
// Fix the order if necessary. // Fix the order if necessary.
dopts.copts.StatsHandler.HandleRPC(ctx, inPayload) dopts.copts.StatsHandler.HandleRPC(ctx, inPayload)
} }
c.trailerMD = stream.Trailer() c.trailerMD = stream.Trailer()
if peer, ok := peer.FromContext(stream.Context()); ok {
c.peer = peer
}
return nil return nil
} }
...@@ -231,7 +227,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli ...@@ -231,7 +227,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
t, put, err = cc.getTransport(ctx, gopts) t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if _, ok := status.FromError(err); ok { if _, ok := err.(*rpcError); ok {
return err return err
} }
if err == errConnClosing || err == errConnUnavailable { if err == errConnClosing || err == errConnUnavailable {
...@@ -285,6 +281,6 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli ...@@ -285,6 +281,6 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
put() put()
put = nil put = nil
} }
return stream.Status().Err() return Errorf(stream.StatusCode(), "%s", stream.StatusDesc())
} }
} }
...@@ -36,8 +36,8 @@ package grpc ...@@ -36,8 +36,8 @@ package grpc
import ( import (
"errors" "errors"
"fmt" "fmt"
"math"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
...@@ -45,7 +45,6 @@ import ( ...@@ -45,7 +45,6 @@ import (
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
...@@ -79,6 +78,7 @@ var ( ...@@ -79,6 +78,7 @@ var (
errConnClosing = errors.New("grpc: the connection is closing") errConnClosing = errors.New("grpc: the connection is closing")
// errConnUnavailable indicates that the connection is unavailable. // errConnUnavailable indicates that the connection is unavailable.
errConnUnavailable = errors.New("grpc: the connection is unavailable") errConnUnavailable = errors.New("grpc: the connection is unavailable")
errNoAddr = errors.New("grpc: there is no address available to dial")
// minimum time to give a connection to complete // minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second minConnectTimeout = 20 * time.Second
) )
...@@ -86,33 +86,23 @@ var ( ...@@ -86,33 +86,23 @@ var (
// dialOptions configure a Dial call. dialOptions are set by the DialOption // dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial. // values passed to Dial.
type dialOptions struct { type dialOptions struct {
unaryInt UnaryClientInterceptor unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor streamInt StreamClientInterceptor
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
bs backoffStrategy bs backoffStrategy
balancer Balancer balancer Balancer
block bool block bool
insecure bool insecure bool
timeout time.Duration timeout time.Duration
scChan <-chan ServiceConfig scChan <-chan ServiceConfig
copts transport.ConnectOptions copts transport.ConnectOptions
maxMsgSize int }
}
const defaultClientMaxMsgSize = math.MaxInt32
// DialOption configures how we set up the connection. // DialOption configures how we set up the connection.
type DialOption func(*dialOptions) type DialOption func(*dialOptions)
// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive.
func WithMaxMsgSize(s int) DialOption {
return func(o *dialOptions) {
o.maxMsgSize = s
}
}
// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling. // WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
func WithCodec(c Codec) DialOption { func WithCodec(c Codec) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
...@@ -259,13 +249,6 @@ func WithUserAgent(s string) DialOption { ...@@ -259,13 +249,6 @@ func WithUserAgent(s string) DialOption {
} }
} }
// WithKeepaliveParams returns a DialOption that specifies keepalive paramaters for the client transport.
func WithKeepaliveParams(kp keepalive.ClientParameters) DialOption {
return func(o *dialOptions) {
o.copts.KeepaliveParams = kp
}
}
// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs. // WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs.
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption { func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
...@@ -280,15 +263,6 @@ func WithStreamInterceptor(f StreamClientInterceptor) DialOption { ...@@ -280,15 +263,6 @@ func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
} }
} }
// WithAuthority returns a DialOption that specifies the value to be used as
// the :authority pseudo-header. This value only works with WithInsecure and
// has no effect if TransportCredentials are present.
func WithAuthority(a string) DialOption {
return func(o *dialOptions) {
o.copts.Authority = a
}
}
// Dial creates a client connection to the given target. // Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) { func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...) return DialContext(context.Background(), target, opts...)
...@@ -305,19 +279,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ...@@ -305,19 +279,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
conns: make(map[Address]*addrConn), conns: make(map[Address]*addrConn),
} }
cc.ctx, cc.cancel = context.WithCancel(context.Background()) cc.ctx, cc.cancel = context.WithCancel(context.Background())
cc.dopts.maxMsgSize = defaultClientMaxMsgSize
for _, opt := range opts { for _, opt := range opts {
opt(&cc.dopts) opt(&cc.dopts)
} }
cc.mkp = cc.dopts.copts.KeepaliveParams
grpcUA := "grpc-go/" + Version
if cc.dopts.copts.UserAgent != "" {
cc.dopts.copts.UserAgent += " " + grpcUA
} else {
cc.dopts.copts.UserAgent = grpcUA
}
if cc.dopts.timeout > 0 { if cc.dopts.timeout > 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout) ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout)
...@@ -357,18 +321,24 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ...@@ -357,18 +321,24 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
creds := cc.dopts.copts.TransportCredentials creds := cc.dopts.copts.TransportCredentials
if creds != nil && creds.Info().ServerName != "" { if creds != nil && creds.Info().ServerName != "" {
cc.authority = creds.Info().ServerName cc.authority = creds.Info().ServerName
} else if cc.dopts.insecure && cc.dopts.copts.Authority != "" {
cc.authority = cc.dopts.copts.Authority
} else { } else {
cc.authority = target colonPos := strings.LastIndex(target, ":")
if colonPos == -1 {
colonPos = len(target)
}
cc.authority = target[:colonPos]
} }
var ok bool
waitC := make(chan error, 1) waitC := make(chan error, 1)
go func() { go func() {
defer close(waitC) var addrs []Address
if cc.dopts.balancer == nil && cc.sc.LB != nil { if cc.dopts.balancer == nil && cc.sc.LB != nil {
cc.dopts.balancer = cc.sc.LB cc.dopts.balancer = cc.sc.LB
} }
if cc.dopts.balancer != nil { if cc.dopts.balancer == nil {
// Connect to target directly if balancer is nil.
addrs = append(addrs, Address{Addr: target})
} else {
var credsClone credentials.TransportCredentials var credsClone credentials.TransportCredentials
if creds != nil { if creds != nil {
credsClone = creds.Clone() credsClone = creds.Clone()
...@@ -381,22 +351,24 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ...@@ -381,22 +351,24 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
return return
} }
ch := cc.dopts.balancer.Notify() ch := cc.dopts.balancer.Notify()
if ch != nil { if ch == nil {
if cc.dopts.block { // There is no name resolver installed.
doneChan := make(chan struct{}) addrs = append(addrs, Address{Addr: target})
go cc.lbWatcher(doneChan) } else {
<-doneChan addrs, ok = <-ch
} else { if !ok || len(addrs) == 0 {
go cc.lbWatcher(nil) waitC <- errNoAddr
return
} }
return
} }
} }
// No balancer, or no resolver within the balancer. Connect directly. for _, a := range addrs {
if err := cc.resetAddrConn(Address{Addr: target}, cc.dopts.block, nil); err != nil { if err := cc.resetAddrConn(a, false, nil); err != nil {
waitC <- err waitC <- err
return return
}
} }
close(waitC)
}() }()
select { select {
case <-ctx.Done(): case <-ctx.Done():
...@@ -407,10 +379,15 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ...@@ -407,10 +379,15 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
} }
} }
// If balancer is nil or balancer.Notify() is nil, ok will be false here.
// The lbWatcher goroutine will not be created.
if ok {
go cc.lbWatcher()
}
if cc.dopts.scChan != nil { if cc.dopts.scChan != nil {
go cc.scWatcher() go cc.scWatcher()
} }
return cc, nil return cc, nil
} }
...@@ -459,14 +436,9 @@ type ClientConn struct { ...@@ -459,14 +436,9 @@ type ClientConn struct {
mu sync.RWMutex mu sync.RWMutex
sc ServiceConfig sc ServiceConfig
conns map[Address]*addrConn conns map[Address]*addrConn
// Keepalive parameter can be udated if a GoAway is received.
mkp keepalive.ClientParameters
} }
// lbWatcher watches the Notify channel of the balancer in cc and manages func (cc *ClientConn) lbWatcher() {
// connections accordingly. If doneChan is not nil, it is closed after the
// first successfull connection is made.
func (cc *ClientConn) lbWatcher(doneChan chan struct{}) {
for addrs := range cc.dopts.balancer.Notify() { for addrs := range cc.dopts.balancer.Notify() {
var ( var (
add []Address // Addresses need to setup connections. add []Address // Addresses need to setup connections.
...@@ -493,15 +465,7 @@ func (cc *ClientConn) lbWatcher(doneChan chan struct{}) { ...@@ -493,15 +465,7 @@ func (cc *ClientConn) lbWatcher(doneChan chan struct{}) {
} }
cc.mu.Unlock() cc.mu.Unlock()
for _, a := range add { for _, a := range add {
if doneChan != nil { cc.resetAddrConn(a, true, nil)
err := cc.resetAddrConn(a, true, nil)
if err == nil {
close(doneChan)
doneChan = nil
}
} else {
cc.resetAddrConn(a, false, nil)
}
} }
for _, c := range del { for _, c := range del {
c.tearDown(errConnDrain) c.tearDown(errConnDrain)
...@@ -530,15 +494,12 @@ func (cc *ClientConn) scWatcher() { ...@@ -530,15 +494,12 @@ func (cc *ClientConn) scWatcher() {
// resetAddrConn creates an addrConn for addr and adds it to cc.conns. // resetAddrConn creates an addrConn for addr and adds it to cc.conns.
// If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason. // If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason.
// If tearDownErr is nil, errConnDrain will be used instead. // If tearDownErr is nil, errConnDrain will be used instead.
func (cc *ClientConn) resetAddrConn(addr Address, block bool, tearDownErr error) error { func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr error) error {
ac := &addrConn{ ac := &addrConn{
cc: cc, cc: cc,
addr: addr, addr: addr,
dopts: cc.dopts, dopts: cc.dopts,
} }
cc.mu.RLock()
ac.dopts.copts.KeepaliveParams = cc.mkp
cc.mu.RUnlock()
ac.ctx, ac.cancel = context.WithCancel(cc.ctx) ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
ac.stateCV = sync.NewCond(&ac.mu) ac.stateCV = sync.NewCond(&ac.mu)
if EnableTracing { if EnableTracing {
...@@ -583,7 +544,8 @@ func (cc *ClientConn) resetAddrConn(addr Address, block bool, tearDownErr error) ...@@ -583,7 +544,8 @@ func (cc *ClientConn) resetAddrConn(addr Address, block bool, tearDownErr error)
stale.tearDown(tearDownErr) stale.tearDown(tearDownErr)
} }
} }
if block { // skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(false); err != nil {
if err != errConnClosing { if err != errConnClosing {
// Tear down ac and delete it from cc.conns. // Tear down ac and delete it from cc.conns.
...@@ -720,20 +682,6 @@ type addrConn struct { ...@@ -720,20 +682,6 @@ type addrConn struct {
tearDownErr error tearDownErr error
} }
// adjustParams updates parameters used to create transports upon
// receiving a GoAway.
func (ac *addrConn) adjustParams(r transport.GoAwayReason) {
switch r {
case transport.TooManyPings:
v := 2 * ac.dopts.copts.KeepaliveParams.Time
ac.cc.mu.Lock()
if v > ac.cc.mkp.Time {
ac.cc.mkp.Time = v
}
ac.cc.mu.Unlock()
}
}
// printf records an event in ac's event log, unless ac has been closed. // printf records an event in ac's event log, unless ac has been closed.
// REQUIRES ac.mu is held. // REQUIRES ac.mu is held.
func (ac *addrConn) printf(format string, a ...interface{}) { func (ac *addrConn) printf(format string, a ...interface{}) {
...@@ -818,8 +766,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { ...@@ -818,8 +766,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
Metadata: ac.addr.Metadata, Metadata: ac.addr.Metadata,
} }
newTransport, err := transport.NewClientTransport(ctx, sinfo, ac.dopts.copts) newTransport, err := transport.NewClientTransport(ctx, sinfo, ac.dopts.copts)
// Don't call cancel in success path due to a race in Go 1.6:
// https://github.com/golang/go/issues/15078.
if err != nil { if err != nil {
cancel() cancel()
...@@ -890,7 +836,6 @@ func (ac *addrConn) transportMonitor() { ...@@ -890,7 +836,6 @@ func (ac *addrConn) transportMonitor() {
} }
return return
case <-t.GoAway(): case <-t.GoAway():
ac.adjustParams(t.GetGoAwayReason())
// If GoAway happens without any network I/O error, ac is closed without shutting down the // If GoAway happens without any network I/O error, ac is closed without shutting down the
// underlying transport (the transport will be closed when all the pending RPCs finished or // underlying transport (the transport will be closed when all the pending RPCs finished or
// failed.). // failed.).
...@@ -899,9 +844,9 @@ func (ac *addrConn) transportMonitor() { ...@@ -899,9 +844,9 @@ func (ac *addrConn) transportMonitor() {
// In both cases, a new ac is created. // In both cases, a new ac is created.
select { select {
case <-t.Error(): case <-t.Error():
ac.cc.resetAddrConn(ac.addr, false, errNetworkIO) ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
default: default:
ac.cc.resetAddrConn(ac.addr, false, errConnDrain) ac.cc.resetAddrConn(ac.addr, true, errConnDrain)
} }
return return
case <-t.Error(): case <-t.Error():
...@@ -910,8 +855,7 @@ func (ac *addrConn) transportMonitor() { ...@@ -910,8 +855,7 @@ func (ac *addrConn) transportMonitor() {
t.Close() t.Close()
return return
case <-t.GoAway(): case <-t.GoAway():
ac.adjustParams(t.GetGoAwayReason()) ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
ac.cc.resetAddrConn(ac.addr, false, errNetworkIO)
return return
default: default:
} }
......
...@@ -40,7 +40,7 @@ import ( ...@@ -40,7 +40,7 @@ import (
// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs. // UnaryInvoker is called by UnaryClientInterceptor to complete RPCs.
type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error
// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. invoker is the handler to complete the RPC // UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC
// and it is the responsibility of the interceptor to call it. // and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API. // This is the EXPERIMENTAL API.
type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
......
...@@ -37,6 +37,7 @@ import ( ...@@ -37,6 +37,7 @@ import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math" "math"
...@@ -47,9 +48,7 @@ import ( ...@@ -47,9 +48,7 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
...@@ -141,7 +140,6 @@ type callInfo struct { ...@@ -141,7 +140,6 @@ type callInfo struct {
failFast bool failFast bool
headerMD metadata.MD headerMD metadata.MD
trailerMD metadata.MD trailerMD metadata.MD
peer *peer.Peer
traceInfo traceInfo // in trace.go traceInfo traceInfo // in trace.go
} }
...@@ -185,22 +183,12 @@ func Trailer(md *metadata.MD) CallOption { ...@@ -185,22 +183,12 @@ func Trailer(md *metadata.MD) CallOption {
}) })
} }
// Peer returns a CallOption that retrieves peer information for a
// unary RPC.
func Peer(peer *peer.Peer) CallOption {
return afterCall(func(c *callInfo) {
if c.peer != nil {
*peer = *c.peer
}
})
}
// FailFast configures the action to take when an RPC is attempted on broken // FailFast configures the action to take when an RPC is attempted on broken
// connections or unreachable servers. If failfast is true, the RPC will fail // connections or unreachable servers. If failfast is true, the RPC will fail
// immediately. Otherwise, the RPC client will block the call until a // immediately. Otherwise, the RPC client will block the call until a
// connection is available (or the call is canceled or times out) and will retry // connection is available (or the call is canceled or times out) and will retry
// the call if it fails due to a transient error. Please refer to // the call if it fails due to a transient error. Please refer to
// https://github.com/grpc/grpc/blob/master/doc/fail_fast.md. Note: failFast is default to true. // https://github.com/grpc/grpc/blob/master/doc/fail_fast.md
func FailFast(failFast bool) CallOption { func FailFast(failFast bool) CallOption {
return beforeCall(func(c *callInfo) error { return beforeCall(func(c *callInfo) error {
c.failFast = failFast c.failFast = failFast
...@@ -372,57 +360,88 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ ...@@ -372,57 +360,88 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
return nil return nil
} }
// rpcError defines the status from an RPC.
type rpcError struct {
code codes.Code
desc string
}
func (e *rpcError) Error() string {
return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc)
}
// Code returns the error code for err if it was produced by the rpc system. // Code returns the error code for err if it was produced by the rpc system.
// Otherwise, it returns codes.Unknown. // Otherwise, it returns codes.Unknown.
//
// Deprecated; use status.FromError and Code method instead.
func Code(err error) codes.Code { func Code(err error) codes.Code {
if s, ok := status.FromError(err); ok { if err == nil {
return s.Code() return codes.OK
}
if e, ok := err.(*rpcError); ok {
return e.code
} }
return codes.Unknown return codes.Unknown
} }
// ErrorDesc returns the error description of err if it was produced by the rpc system. // ErrorDesc returns the error description of err if it was produced by the rpc system.
// Otherwise, it returns err.Error() or empty string when err is nil. // Otherwise, it returns err.Error() or empty string when err is nil.
//
// Deprecated; use status.FromError and Message method instead.
func ErrorDesc(err error) string { func ErrorDesc(err error) string {
if s, ok := status.FromError(err); ok { if err == nil {
return s.Message() return ""
}
if e, ok := err.(*rpcError); ok {
return e.desc
} }
return err.Error() return err.Error()
} }
// Errorf returns an error containing an error code and a description; // Errorf returns an error containing an error code and a description;
// Errorf returns nil if c is OK. // Errorf returns nil if c is OK.
//
// Deprecated; use status.Errorf instead.
func Errorf(c codes.Code, format string, a ...interface{}) error { func Errorf(c codes.Code, format string, a ...interface{}) error {
return status.Errorf(c, format, a...) if c == codes.OK {
return nil
}
return &rpcError{
code: c,
desc: fmt.Sprintf(format, a...),
}
} }
// toRPCErr converts an error into an error from the status package. // toRPCErr converts an error into a rpcError.
func toRPCErr(err error) error { func toRPCErr(err error) error {
if _, ok := status.FromError(err); ok {
return err
}
switch e := err.(type) { switch e := err.(type) {
case *rpcError:
return err
case transport.StreamError: case transport.StreamError:
return status.Error(e.Code, e.Desc) return &rpcError{
code: e.Code,
desc: e.Desc,
}
case transport.ConnectionError: case transport.ConnectionError:
return status.Error(codes.Internal, e.Desc) return &rpcError{
code: codes.Internal,
desc: e.Desc,
}
default: default:
switch err { switch err {
case context.DeadlineExceeded: case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error()) return &rpcError{
code: codes.DeadlineExceeded,
desc: err.Error(),
}
case context.Canceled: case context.Canceled:
return status.Error(codes.Canceled, err.Error()) return &rpcError{
code: codes.Canceled,
desc: err.Error(),
}
case ErrClientConnClosing: case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error()) return &rpcError{
code: codes.FailedPrecondition,
desc: err.Error(),
}
} }
} }
return status.Error(codes.Unknown, err.Error()) return Errorf(codes.Unknown, "%v", err)
} }
// convertCode converts a standard Go error into its canonical code. Note that // convertCode converts a standard Go error into its canonical code. Note that
...@@ -467,17 +486,17 @@ type MethodConfig struct { ...@@ -467,17 +486,17 @@ type MethodConfig struct {
// then the other will be used. If neither is set, then the RPC has no deadline. // then the other will be used. If neither is set, then the RPC has no deadline.
Timeout time.Duration Timeout time.Duration
// MaxReqSize is the maximum allowed payload size for an individual request in a // MaxReqSize is the maximum allowed payload size for an individual request in a
// stream (client->server) in bytes. The size which is measured is the serialized // stream (client->server) in bytes. The size which is measured is the serialized,
// payload after per-message compression (but before stream compression) in bytes. // uncompressed payload in bytes. The actual value used is the minumum of the value
// The actual value used is the minumum of the value specified here and the value set // specified here and the value set by the application via the gRPC client API. If
// by the application via the gRPC client API. If either one is not set, then the other // either one is not set, then the other will be used. If neither is set, then the
// will be used. If neither is set, then the built-in default is used. // built-in default is used.
// TODO: support this. // TODO: support this.
MaxReqSize uint32 MaxReqSize uint64
// MaxRespSize is the maximum allowed payload size for an individual response in a // MaxRespSize is the maximum allowed payload size for an individual response in a
// stream (server->client) in bytes. // stream (server->client) in bytes.
// TODO: support this. // TODO: support this.
MaxRespSize uint32 MaxRespSize uint64
} }
// ServiceConfig is provided by the service provider and contains parameters for how // ServiceConfig is provided by the service provider and contains parameters for how
...@@ -498,6 +517,3 @@ type ServiceConfig struct { ...@@ -498,6 +517,3 @@ type ServiceConfig struct {
// requires a synchronised update of grpc-go and protoc-gen-go. This constant // requires a synchronised update of grpc-go and protoc-gen-go. This constant
// should not be referenced from any other code. // should not be referenced from any other code.
const SupportPackageIsVersion4 = true const SupportPackageIsVersion4 = true
// Version is the current grpc version.
const Version = "1.3.0-dev"
...@@ -53,10 +53,8 @@ import ( ...@@ -53,10 +53,8 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/tap" "google.golang.org/grpc/tap"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
...@@ -118,9 +116,6 @@ type options struct { ...@@ -118,9 +116,6 @@ type options struct {
statsHandler stats.Handler statsHandler stats.Handler
maxConcurrentStreams uint32 maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server useHandlerImpl bool // use http.Handler-based server
unknownStreamDesc *StreamDesc
keepaliveParams keepalive.ServerParameters
keepalivePolicy keepalive.EnforcementPolicy
} }
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
...@@ -128,20 +123,6 @@ var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size l ...@@ -128,20 +123,6 @@ var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size l
// A ServerOption sets options. // A ServerOption sets options.
type ServerOption func(*options) type ServerOption func(*options)
// KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server.
func KeepaliveParams(kp keepalive.ServerParameters) ServerOption {
return func(o *options) {
o.keepaliveParams = kp
}
}
// KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server.
func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
return func(o *options) {
o.keepalivePolicy = kep
}
}
// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
func CustomCodec(codec Codec) ServerOption { func CustomCodec(codec Codec) ServerOption {
return func(o *options) { return func(o *options) {
...@@ -227,24 +208,6 @@ func StatsHandler(h stats.Handler) ServerOption { ...@@ -227,24 +208,6 @@ func StatsHandler(h stats.Handler) ServerOption {
} }
} }
// UnknownServiceHandler returns a ServerOption that allows for adding a custom
// unknown service handler. The provided method is a bidi-streaming RPC service
// handler that will be invoked instead of returning the the "unimplemented" gRPC
// error whenever a request is received for an unregistered service or method.
// The handling function has full access to the Context of the request and the
// stream, and the invocation passes through interceptors.
func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
return func(o *options) {
o.unknownStreamDesc = &StreamDesc{
StreamName: "unknown_service_handler",
Handler: streamHandler,
// We need to assume that the users of the streamHandler will want to use both.
ClientStreams: true,
ServerStreams: true,
}
}
}
// NewServer creates a gRPC server which has no service registered and has not // NewServer creates a gRPC server which has no service registered and has not
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
...@@ -483,12 +446,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) { ...@@ -483,12 +446,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
// transport.NewServerTransport). // transport.NewServerTransport).
func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
config := &transport.ServerConfig{ config := &transport.ServerConfig{
MaxStreams: s.opts.maxConcurrentStreams, MaxStreams: s.opts.maxConcurrentStreams,
AuthInfo: authInfo, AuthInfo: authInfo,
InTapHandle: s.opts.inTapHandle, InTapHandle: s.opts.inTapHandle,
StatsHandler: s.opts.statsHandler, StatsHandler: s.opts.statsHandler,
KeepaliveParams: s.opts.keepaliveParams,
KeepalivePolicy: s.opts.keepalivePolicy,
} }
st, err := transport.NewServerTransport("http2", c, config) st, err := transport.NewServerTransport("http2", c, config)
if err != nil { if err != nil {
...@@ -672,7 +633,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -672,7 +633,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
stream.SetSendCompress(s.opts.cp.Type()) stream.SetSendCompress(s.opts.cp.Type())
} }
p := &parser{r: stream} p := &parser{r: stream}
for { // TODO: delete for {
pf, req, err := p.recvMsg(s.opts.maxMsgSize) pf, req, err := p.recvMsg(s.opts.maxMsgSize)
if err == io.EOF { if err == io.EOF {
// The entire stream is done (for unary RPC only). // The entire stream is done (for unary RPC only).
...@@ -682,37 +643,36 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -682,37 +643,36 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
} }
if err != nil { if err != nil {
if st, ok := status.FromError(err); ok { switch err := err.(type) {
if e := t.WriteStatus(stream, st); e != nil { case *rpcError:
if e := t.WriteStatus(stream, err.code, err.desc); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
} }
} else { case transport.ConnectionError:
switch st := err.(type) { // Nothing to do here.
case transport.ConnectionError: case transport.StreamError:
// Nothing to do here. if e := t.WriteStatus(stream, err.Code, err.Desc); e != nil {
case transport.StreamError: grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st))
} }
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err))
} }
return err return err
} }
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
if st, ok := status.FromError(err); ok { switch err := err.(type) {
if e := t.WriteStatus(stream, st); e != nil { case *rpcError:
if e := t.WriteStatus(stream, err.code, err.desc); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
} }
return err return err
default:
if e := t.WriteStatus(stream, codes.Internal, err.Error()); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
// TODO checkRecvPayload always return RPC error. Add a return here if necessary.
} }
if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
// TODO checkRecvPayload always return RPC error. Add a return here if necessary.
} }
var inPayload *stats.InPayload var inPayload *stats.InPayload
if sh != nil { if sh != nil {
...@@ -720,6 +680,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -720,6 +680,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
RecvTime: time.Now(), RecvTime: time.Now(),
} }
} }
statusCode := codes.OK
statusDesc := ""
df := func(v interface{}) error { df := func(v interface{}) error {
if inPayload != nil { if inPayload != nil {
inPayload.WireLength = len(req) inPayload.WireLength = len(req)
...@@ -728,16 +690,20 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -728,16 +690,20 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
var err error var err error
req, err = s.opts.dc.Do(bytes.NewReader(req)) req, err = s.opts.dc.Do(bytes.NewReader(req))
if err != nil { if err != nil {
if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
return Errorf(codes.Internal, err.Error()) return Errorf(codes.Internal, err.Error())
} }
} }
if len(req) > s.opts.maxMsgSize { if len(req) > s.opts.maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with // TODO: Revisit the error code. Currently keep it consistent with
// java implementation. // java implementation.
return status.Errorf(codes.Internal, "grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) statusCode = codes.Internal
statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
} }
if err := s.opts.codec.Unmarshal(req, v); err != nil { if err := s.opts.codec.Unmarshal(req, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) return err
} }
if inPayload != nil { if inPayload != nil {
inPayload.Payload = v inPayload.Payload = v
...@@ -752,20 +718,21 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -752,20 +718,21 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
if appErr != nil { if appErr != nil {
appStatus, ok := status.FromError(appErr) if err, ok := appErr.(*rpcError); ok {
if !ok { statusCode = err.code
// Convert appErr if it is not a grpc status error. statusDesc = err.desc
appErr = status.Error(convertCode(appErr), appErr.Error()) } else {
appStatus, _ = status.FromError(appErr) statusCode = convertCode(appErr)
statusDesc = appErr.Error()
} }
if trInfo != nil { if trInfo != nil && statusCode != codes.OK {
trInfo.tr.LazyLog(stringer(appStatus.Message()), true) trInfo.tr.LazyLog(stringer(statusDesc), true)
trInfo.tr.SetError() trInfo.tr.SetError()
} }
if e := t.WriteStatus(stream, appStatus); e != nil { if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e) grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
} }
return appErr return Errorf(statusCode, statusDesc)
} }
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(stringer("OK"), false) trInfo.tr.LazyLog(stringer("OK"), false)
...@@ -775,35 +742,26 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -775,35 +742,26 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
Delay: false, Delay: false,
} }
if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
if err == io.EOF { switch err := err.(type) {
// The entire stream is done (for unary RPC only). case transport.ConnectionError:
return err // Nothing to do here.
} case transport.StreamError:
if s, ok := status.FromError(err); ok { statusCode = err.Code
if e := t.WriteStatus(stream, s); e != nil { statusDesc = err.Desc
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e) default:
} statusCode = codes.Unknown
} else { statusDesc = err.Error()
switch st := err.(type) {
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st))
}
} }
return err return err
} }
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
} }
// TODO: Should we be logging if writing status failed here, like above? errWrite := t.WriteStatus(stream, statusCode, statusDesc)
// Should the logging be in WriteStatus? Should we ignore the WriteStatus if statusCode != codes.OK {
// error or allow the stats handler to see it? return Errorf(statusCode, statusDesc)
return t.WriteStatus(stream, status.New(codes.OK, "")) }
return errWrite
} }
} }
...@@ -857,47 +815,43 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ...@@ -857,47 +815,43 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
}() }()
} }
var appErr error var appErr error
var server interface{}
if srv != nil {
server = srv.server
}
if s.opts.streamInt == nil { if s.opts.streamInt == nil {
appErr = sd.Handler(server, ss) appErr = sd.Handler(srv.server, ss)
} else { } else {
info := &StreamServerInfo{ info := &StreamServerInfo{
FullMethod: stream.Method(), FullMethod: stream.Method(),
IsClientStream: sd.ClientStreams, IsClientStream: sd.ClientStreams,
IsServerStream: sd.ServerStreams, IsServerStream: sd.ServerStreams,
} }
appErr = s.opts.streamInt(server, ss, info, sd.Handler) appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
} }
if appErr != nil { if appErr != nil {
appStatus, ok := status.FromError(appErr) if err, ok := appErr.(*rpcError); ok {
if !ok { ss.statusCode = err.code
switch err := appErr.(type) { ss.statusDesc = err.desc
case transport.StreamError: } else if err, ok := appErr.(transport.StreamError); ok {
appStatus = status.New(err.Code, err.Desc) ss.statusCode = err.Code
default: ss.statusDesc = err.Desc
appStatus = status.New(convertCode(appErr), appErr.Error()) } else {
} ss.statusCode = convertCode(appErr)
appErr = appStatus.Err() ss.statusDesc = appErr.Error()
}
if trInfo != nil {
ss.mu.Lock()
ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
ss.trInfo.tr.SetError()
ss.mu.Unlock()
} }
t.WriteStatus(ss.s, appStatus)
// TODO: Should we log an error from WriteStatus here and below?
return appErr
} }
if trInfo != nil { if trInfo != nil {
ss.mu.Lock() ss.mu.Lock()
ss.trInfo.tr.LazyLog(stringer("OK"), false) if ss.statusCode != codes.OK {
ss.trInfo.tr.LazyLog(stringer(ss.statusDesc), true)
ss.trInfo.tr.SetError()
} else {
ss.trInfo.tr.LazyLog(stringer("OK"), false)
}
ss.mu.Unlock() ss.mu.Unlock()
} }
return t.WriteStatus(ss.s, status.New(codes.OK, "")) errWrite := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc)
if ss.statusCode != codes.OK {
return Errorf(ss.statusCode, ss.statusDesc)
}
return errWrite
} }
...@@ -913,7 +867,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str ...@@ -913,7 +867,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
trInfo.tr.SetError() trInfo.tr.SetError()
} }
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
if err := t.WriteStatus(stream, status.New(codes.InvalidArgument, errDesc)); err != nil { if err := t.WriteStatus(stream, codes.InvalidArgument, errDesc); err != nil {
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError() trInfo.tr.SetError()
...@@ -929,16 +883,12 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str ...@@ -929,16 +883,12 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
method := sm[pos+1:] method := sm[pos+1:]
srv, ok := s.m[service] srv, ok := s.m[service]
if !ok { if !ok {
if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
return
}
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true) trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true)
trInfo.tr.SetError() trInfo.tr.SetError()
} }
errDesc := fmt.Sprintf("unknown service %v", service) errDesc := fmt.Sprintf("unknown service %v", service)
if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError() trInfo.tr.SetError()
...@@ -963,12 +913,8 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str ...@@ -963,12 +913,8 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true) trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true)
trInfo.tr.SetError() trInfo.tr.SetError()
} }
if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
return
}
errDesc := fmt.Sprintf("unknown method %v", method) errDesc := fmt.Sprintf("unknown method %v", method)
if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError() trInfo.tr.SetError()
......
...@@ -37,6 +37,7 @@ import ( ...@@ -37,6 +37,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"io" "io"
"math"
"sync" "sync"
"time" "time"
...@@ -45,7 +46,6 @@ import ( ...@@ -45,7 +46,6 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
...@@ -178,7 +178,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -178,7 +178,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
t, put, err = cc.getTransport(ctx, gopts) t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if _, ok := status.FromError(err); ok { if _, ok := err.(*rpcError); ok {
return nil, err return nil, err
} }
if err == errConnClosing || err == errConnUnavailable { if err == errConnClosing || err == errConnUnavailable {
...@@ -208,14 +208,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -208,14 +208,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
break break
} }
cs := &clientStream{ cs := &clientStream{
opts: opts, opts: opts,
c: c, c: c,
desc: desc, desc: desc,
codec: cc.dopts.codec, codec: cc.dopts.codec,
cp: cc.dopts.cp, cp: cc.dopts.cp,
dc: cc.dopts.dc, dc: cc.dopts.dc,
maxMsgSize: cc.dopts.maxMsgSize, cancel: cancel,
cancel: cancel,
put: put, put: put,
t: t, t: t,
...@@ -240,7 +239,11 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -240,7 +239,11 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
case <-s.Done(): case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending // TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution. // I/O, which is probably not the optimal solution.
cs.finish(s.Status().Err()) if s.StatusCode() == codes.OK {
cs.finish(nil)
} else {
cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
}
cs.closeTransportStream(nil) cs.closeTransportStream(nil)
case <-s.GoAway(): case <-s.GoAway():
cs.finish(errConnDrain) cs.finish(errConnDrain)
...@@ -256,18 +259,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -256,18 +259,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// clientStream implements a client side Stream. // clientStream implements a client side Stream.
type clientStream struct { type clientStream struct {
opts []CallOption opts []CallOption
c callInfo c callInfo
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
desc *StreamDesc desc *StreamDesc
codec Codec codec Codec
cp Compressor cp Compressor
cbuf *bytes.Buffer cbuf *bytes.Buffer
dc Decompressor dc Decompressor
maxMsgSize int cancel context.CancelFunc
cancel context.CancelFunc
tracing bool // set to EnableTracing when the clientStream is created. tracing bool // set to EnableTracing when the clientStream is created.
...@@ -380,7 +382,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { ...@@ -380,7 +382,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
Client: true, Client: true,
} }
} }
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, inPayload) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload)
defer func() { defer func() {
// err != nil indicates the termination of the stream. // err != nil indicates the termination of the stream.
if err != nil { if err != nil {
...@@ -403,17 +405,17 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { ...@@ -403,17 +405,17 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
} }
// Special handling for client streaming rpc. // Special handling for client streaming rpc.
// This recv expects EOF or errors, so we don't collect inPayload. // This recv expects EOF or errors, so we don't collect inPayload.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, nil) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil)
cs.closeTransportStream(err) cs.closeTransportStream(err)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
} }
if err == io.EOF { if err == io.EOF {
if se := cs.s.Status().Err(); se != nil { if cs.s.StatusCode() == codes.OK {
return se cs.finish(err)
return nil
} }
cs.finish(err) return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc())
return nil
} }
return toRPCErr(err) return toRPCErr(err)
} }
...@@ -421,11 +423,11 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { ...@@ -421,11 +423,11 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
cs.closeTransportStream(err) cs.closeTransportStream(err)
} }
if err == io.EOF { if err == io.EOF {
if statusErr := cs.s.Status().Err(); statusErr != nil { if cs.s.StatusCode() == codes.OK {
return statusErr // Returns io.EOF to indicate the end of the stream.
return
} }
// Returns io.EOF to indicate the end of the stream. return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc())
return
} }
return toRPCErr(err) return toRPCErr(err)
} }
...@@ -517,6 +519,8 @@ type serverStream struct { ...@@ -517,6 +519,8 @@ type serverStream struct {
dc Decompressor dc Decompressor
cbuf *bytes.Buffer cbuf *bytes.Buffer
maxMsgSize int maxMsgSize int
statusCode codes.Code
statusDesc string
trInfo *traceInfo trInfo *traceInfo
statsHandler stats.Handler statsHandler stats.Handler
......
...@@ -35,9 +35,7 @@ package transport ...@@ -35,9 +35,7 @@ package transport
import ( import (
"fmt" "fmt"
"math"
"sync" "sync"
"time"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
...@@ -46,18 +44,8 @@ const ( ...@@ -46,18 +44,8 @@ const (
// The default value of flow control window size in HTTP2 spec. // The default value of flow control window size in HTTP2 spec.
defaultWindowSize = 65535 defaultWindowSize = 65535
// The initial window size for flow control. // The initial window size for flow control.
initialWindowSize = defaultWindowSize // for an RPC initialWindowSize = defaultWindowSize // for an RPC
initialConnWindowSize = defaultWindowSize * 16 // for a connection initialConnWindowSize = defaultWindowSize * 16 // for a connection
infinity = time.Duration(math.MaxInt64)
defaultClientKeepaliveTime = infinity
defaultClientKeepaliveTimeout = time.Duration(20 * time.Second)
defaultMaxStreamsClient = 100
defaultMaxConnectionIdle = infinity
defaultMaxConnectionAge = infinity
defaultMaxConnectionAgeGrace = infinity
defaultServerKeepaliveTime = time.Duration(2 * time.Hour)
defaultServerKeepaliveTimeout = time.Duration(20 * time.Second)
defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute)
) )
// The following defines various control items which could flow through // The following defines various control items which could flow through
...@@ -85,8 +73,6 @@ type resetStream struct { ...@@ -85,8 +73,6 @@ type resetStream struct {
func (*resetStream) item() {} func (*resetStream) item() {}
type goAway struct { type goAway struct {
code http2.ErrCode
debugData []byte
} }
func (*goAway) item() {} func (*goAway) item() {}
......
...@@ -53,7 +53,6 @@ import ( ...@@ -53,7 +53,6 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
) )
// NewServerHandlerTransport returns a ServerTransport handling gRPC // NewServerHandlerTransport returns a ServerTransport handling gRPC
...@@ -183,7 +182,7 @@ func (ht *serverHandlerTransport) do(fn func()) error { ...@@ -183,7 +182,7 @@ func (ht *serverHandlerTransport) do(fn func()) error {
} }
} }
func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error { func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
err := ht.do(func() { err := ht.do(func() {
ht.writeCommonHeaders(s) ht.writeCommonHeaders(s)
...@@ -193,13 +192,10 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro ...@@ -193,13 +192,10 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
ht.rw.(http.Flusher).Flush() ht.rw.(http.Flusher).Flush()
h := ht.rw.Header() h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", st.Code())) h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if m := st.Message(); m != "" { if statusDesc != "" {
h.Set("Grpc-Message", encodeGrpcMessage(m)) h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
} }
// TODO: Support Grpc-Status-Details-Bin
if md := s.Trailer(); len(md) > 0 { if md := s.Trailer(); len(md) > 0 {
for k, vv := range md { for k, vv := range md {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent. // Clients don't tolerate reading restricted headers after some non restricted ones were sent.
...@@ -238,7 +234,6 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { ...@@ -238,7 +234,6 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
// and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers // and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
h.Add("Trailer", "Grpc-Status") h.Add("Trailer", "Grpc-Status")
h.Add("Trailer", "Grpc-Message") h.Add("Trailer", "Grpc-Message")
// TODO: Support Grpc-Status-Details-Bin
if s.sendCompress != "" { if s.sendCompress != "" {
h.Set("Grpc-Encoding", s.sendCompress) h.Set("Grpc-Encoding", s.sendCompress)
...@@ -319,7 +314,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace ...@@ -319,7 +314,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
if req.TLS != nil { if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS} pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
} }
ctx = metadata.NewIncomingContext(ctx, ht.headerMD) ctx = metadata.NewContext(ctx, ht.headerMD)
ctx = peer.NewContext(ctx, pr) ctx = peer.NewContext(ctx, pr)
s.ctx = newContextWithStream(ctx, s) s.ctx = newContextWithStream(ctx, s)
s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf} s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
......
...@@ -35,12 +35,12 @@ package transport ...@@ -35,12 +35,12 @@ package transport
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"math" "math"
"net" "net"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
...@@ -49,11 +49,9 @@ import ( ...@@ -49,11 +49,9 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
) )
// http2Client implements the ClientTransport interface with HTTP2. // http2Client implements the ClientTransport interface with HTTP2.
...@@ -82,8 +80,6 @@ type http2Client struct { ...@@ -82,8 +80,6 @@ type http2Client struct {
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport. // that the server sent GoAway on this transport.
goAway chan struct{} goAway chan struct{}
// awakenKeepalive is used to wake up keepalive when after it has gone dormant.
awakenKeepalive chan struct{}
framer *framer framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding hBuf *bytes.Buffer // the buffer for HPACK encoding
...@@ -103,11 +99,6 @@ type http2Client struct { ...@@ -103,11 +99,6 @@ type http2Client struct {
creds []credentials.PerRPCCredentials creds []credentials.PerRPCCredentials
// Boolean to keep track of reading activity on transport.
// 1 is true and 0 is false.
activity uint32 // Accessed atomically.
kp keepalive.ClientParameters
statsHandler stats.Handler statsHandler stats.Handler
mu sync.Mutex // guard the following variables mu sync.Mutex // guard the following variables
...@@ -121,9 +112,6 @@ type http2Client struct { ...@@ -121,9 +112,6 @@ type http2Client struct {
goAwayID uint32 goAwayID uint32
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32 prevGoAwayID uint32
// goAwayReason records the http2.ErrCode and debug data received with the
// GoAway frame.
goAwayReason GoAwayReason
} }
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) {
...@@ -190,19 +178,15 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( ...@@ -190,19 +178,15 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
return nil, connectionErrorf(temp, err, "transport: %v", err) return nil, connectionErrorf(temp, err, "transport: %v", err)
} }
} }
kp := opts.KeepaliveParams ua := primaryUA
// Validate keepalive parameters. if opts.UserAgent != "" {
if kp.Time == 0 { ua = opts.UserAgent + " " + ua
kp.Time = defaultClientKeepaliveTime
}
if kp.Timeout == 0 {
kp.Timeout = defaultClientKeepaliveTimeout
} }
var buf bytes.Buffer var buf bytes.Buffer
t := &http2Client{ t := &http2Client{
ctx: ctx, ctx: ctx,
target: addr.Addr, target: addr.Addr,
userAgent: opts.UserAgent, userAgent: ua,
md: addr.Metadata, md: addr.Metadata,
conn: conn, conn: conn,
remoteAddr: conn.RemoteAddr(), remoteAddr: conn.RemoteAddr(),
...@@ -214,7 +198,6 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( ...@@ -214,7 +198,6 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
shutdownChan: make(chan struct{}), shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
goAway: make(chan struct{}), goAway: make(chan struct{}),
awakenKeepalive: make(chan struct{}, 1),
framer: newFramer(conn), framer: newFramer(conn),
hBuf: &buf, hBuf: &buf,
hEnc: hpack.NewEncoder(&buf), hEnc: hpack.NewEncoder(&buf),
...@@ -225,15 +208,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( ...@@ -225,15 +208,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
state: reachable, state: reachable,
activeStreams: make(map[uint32]*Stream), activeStreams: make(map[uint32]*Stream),
creds: opts.PerRPCCredentials, creds: opts.PerRPCCredentials,
maxStreams: defaultMaxStreamsClient, maxStreams: math.MaxInt32,
streamsQuota: newQuotaPool(defaultMaxStreamsClient),
streamSendQuota: defaultWindowSize, streamSendQuota: defaultWindowSize,
kp: kp,
statsHandler: opts.StatsHandler, statsHandler: opts.StatsHandler,
} }
// Make sure awakenKeepalive can't be written upon.
// keepalive routine will make it writable, if need be.
t.awakenKeepalive <- struct{}{}
if t.statsHandler != nil { if t.statsHandler != nil {
t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{ t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{
RemoteAddr: t.remoteAddr, RemoteAddr: t.remoteAddr,
...@@ -278,9 +256,6 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( ...@@ -278,9 +256,6 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
} }
} }
go t.controller() go t.controller()
if t.kp.Time != infinity {
go t.keepalive()
}
t.writableChan <- 0 t.writableChan <- 0
return t, nil return t, nil
} }
...@@ -314,7 +289,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { ...@@ -314,7 +289,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
return s return s
} }
// NewStream creates a stream and registers it into the transport as "active" // NewStream creates a stream and register it into the transport as "active"
// streams. // streams.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
pr := &peer.Peer{ pr := &peer.Peer{
...@@ -362,18 +337,21 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea ...@@ -362,18 +337,21 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
} }
checkStreamsQuota := t.streamsQuota != nil
t.mu.Unlock() t.mu.Unlock()
sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) if checkStreamsQuota {
if err != nil { sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
return nil, err if err != nil {
} return nil, err
// Returns the quota balance back. }
if sq > 1 { // Returns the quota balance back.
t.streamsQuota.add(sq - 1) if sq > 1 {
t.streamsQuota.add(sq - 1)
}
} }
if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
// Return the quota back now because there is no stream returned to the caller. // Return the quota back now because there is no stream returned to the caller.
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok && checkStreamsQuota {
t.streamsQuota.add(1) t.streamsQuota.add(1)
} }
return nil, err return nil, err
...@@ -381,7 +359,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea ...@@ -381,7 +359,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.mu.Lock() t.mu.Lock()
if t.state == draining { if t.state == draining {
t.mu.Unlock() t.mu.Unlock()
t.streamsQuota.add(1) if checkStreamsQuota {
t.streamsQuota.add(1)
}
// Need to make t writable again so that the rpc in flight can still proceed. // Need to make t writable again so that the rpc in flight can still proceed.
t.writableChan <- 0 t.writableChan <- 0
return nil, ErrStreamDrain return nil, ErrStreamDrain
...@@ -393,17 +373,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea ...@@ -393,17 +373,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
s := t.newStream(ctx, callHdr) s := t.newStream(ctx, callHdr)
s.clientStatsCtx = userCtx s.clientStatsCtx = userCtx
t.activeStreams[s.id] = s t.activeStreams[s.id] = s
// If the number of active streams change from 0 to 1, then check if keepalive
// has gone dormant. If so, wake it up.
if len(t.activeStreams) == 1 {
select {
case t.awakenKeepalive <- struct{}{}:
t.framer.writePing(false, false, [8]byte{})
default:
}
}
// This stream is not counted when applySetings(...) initialize t.streamsQuota.
// Reset t.streamsQuota to the right value.
var reset bool
if !checkStreamsQuota && t.streamsQuota != nil {
reset = true
}
t.mu.Unlock() t.mu.Unlock()
if reset {
t.streamsQuota.add(-1)
}
// HPACK encodes various headers. Note that once WriteField(...) is // HPACK encodes various headers. Note that once WriteField(...) is
// called, the corresponding headers/continuation frame has to be sent // called, the corresponding headers/continuation frame has to be sent
...@@ -435,7 +415,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea ...@@ -435,7 +415,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
hasMD bool hasMD bool
endHeaders bool endHeaders bool
) )
if md, ok := metadata.FromOutgoingContext(ctx); ok { if md, ok := metadata.FromContext(ctx); ok {
hasMD = true hasMD = true
for k, v := range md { for k, v := range md {
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
...@@ -511,11 +491,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea ...@@ -511,11 +491,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
// CloseStream clears the footprint of a stream when the stream is not needed any more. // CloseStream clears the footprint of a stream when the stream is not needed any more.
// This must not be executed in reader's goroutine. // This must not be executed in reader's goroutine.
func (t *http2Client) CloseStream(s *Stream, err error) { func (t *http2Client) CloseStream(s *Stream, err error) {
var updateStreams bool
t.mu.Lock() t.mu.Lock()
if t.activeStreams == nil { if t.activeStreams == nil {
t.mu.Unlock() t.mu.Unlock()
return return
} }
if t.streamsQuota != nil {
updateStreams = true
}
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 { if t.state == draining && len(t.activeStreams) == 0 {
// The transport is draining and s is the last live stream on t. // The transport is draining and s is the last live stream on t.
...@@ -524,27 +508,10 @@ func (t *http2Client) CloseStream(s *Stream, err error) { ...@@ -524,27 +508,10 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
return return
} }
t.mu.Unlock() t.mu.Unlock()
// rstStream is true in case the stream is being closed at the client-side if updateStreams {
// and the server needs to be intimated about it by sending a RST_STREAM t.streamsQuota.add(1)
// frame. }
// To make sure this frame is written to the wire before the headers of the
// next stream waiting for streamsQuota, we add to streamsQuota pool only
// after having acquired the writableChan to send RST_STREAM out (look at
// the controller() routine).
var rstStream bool
var rstError http2.ErrCode
defer func() {
// In case, the client doesn't have to send RST_STREAM to server
// we can safely add back to streamsQuota pool now.
if !rstStream {
t.streamsQuota.add(1)
return
}
t.controlBuf.put(&resetStream{s.id, rstError})
}()
s.mu.Lock() s.mu.Lock()
rstStream = s.rstStream
rstError = s.rstError
if q := s.fc.resetPendingData(); q > 0 { if q := s.fc.resetPendingData(); q > 0 {
if n := t.fc.onRead(q); n > 0 { if n := t.fc.onRead(q); n > 0 {
t.controlBuf.put(&windowUpdate{0, n}) t.controlBuf.put(&windowUpdate{0, n})
...@@ -560,9 +527,8 @@ func (t *http2Client) CloseStream(s *Stream, err error) { ...@@ -560,9 +527,8 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
} }
s.state = streamDone s.state = streamDone
s.mu.Unlock() s.mu.Unlock()
if _, ok := err.(StreamError); ok { if se, ok := err.(StreamError); ok && se.Code != codes.DeadlineExceeded {
rstStream = true t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel})
rstError = http2.ErrCodeCancel
} }
} }
...@@ -776,7 +742,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) { ...@@ -776,7 +742,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
} }
func (t *http2Client) handleData(f *http2.DataFrame) { func (t *http2Client) handleData(f *http2.DataFrame) {
size := f.Header().Length size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil { if err := t.fc.onData(uint32(size)); err != nil {
t.notifyError(connectionErrorf(true, err, "%v", err)) t.notifyError(connectionErrorf(true, err, "%v", err))
return return
...@@ -790,11 +756,6 @@ func (t *http2Client) handleData(f *http2.DataFrame) { ...@@ -790,11 +756,6 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
return return
} }
if size > 0 { if size > 0 {
if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := t.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 {
t.controlBuf.put(&windowUpdate{0, w})
}
}
s.mu.Lock() s.mu.Lock()
if s.state == streamDone { if s.state == streamDone {
s.mu.Unlock() s.mu.Unlock()
...@@ -805,27 +766,22 @@ func (t *http2Client) handleData(f *http2.DataFrame) { ...@@ -805,27 +766,22 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
return return
} }
if err := s.fc.onData(uint32(size)); err != nil { if err := s.fc.onData(uint32(size)); err != nil {
s.rstStream = true s.state = streamDone
s.rstError = http2.ErrCodeFlowControl s.statusCode = codes.Internal
s.finish(status.New(codes.Internal, err.Error())) s.statusDesc = err.Error()
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
return return
} }
if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 {
t.controlBuf.put(&windowUpdate{s.id, w})
}
}
s.mu.Unlock() s.mu.Unlock()
// TODO(bradfitz, zhaoq): A copy is required here because there is no // TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
if len(f.Data()) > 0 { data := make([]byte, size)
data := make([]byte, len(f.Data())) copy(data, f.Data())
copy(data, f.Data()) s.write(recvMsg{data: data})
s.write(recvMsg{data: data})
}
} }
// The server has closed the stream without sending trailers. Record that // The server has closed the stream without sending trailers. Record that
// the read direction is closed, and set the status appropriately. // the read direction is closed, and set the status appropriately.
...@@ -835,7 +791,10 @@ func (t *http2Client) handleData(f *http2.DataFrame) { ...@@ -835,7 +791,10 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
s.mu.Unlock() s.mu.Unlock()
return return
} }
s.finish(status.New(codes.Internal, "server closed the stream without sending trailers")) s.state = streamDone
s.statusCode = codes.Internal
s.statusDesc = "server closed the stream without sending trailers"
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
...@@ -851,16 +810,18 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { ...@@ -851,16 +810,18 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
s.mu.Unlock() s.mu.Unlock()
return return
} }
s.state = streamDone
if !s.headerDone { if !s.headerDone {
close(s.headerChan) close(s.headerChan)
s.headerDone = true s.headerDone = true
} }
statusCode, ok := http2ErrConvTab[http2.ErrCode(f.ErrCode)] s.statusCode, ok = http2ErrConvTab[http2.ErrCode(f.ErrCode)]
if !ok { if !ok {
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode) grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
statusCode = codes.Unknown s.statusCode = codes.Unknown
} }
s.finish(status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %d", f.ErrCode)) s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
...@@ -888,9 +849,6 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { ...@@ -888,9 +849,6 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
} }
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
if f.ErrCode == http2.ErrCodeEnhanceYourCalm {
grpclog.Printf("Client received GoAway with http2.ErrCodeEnhanceYourCalm.")
}
t.mu.Lock() t.mu.Lock()
if t.state == reachable || t.state == draining { if t.state == reachable || t.state == draining {
if f.LastStreamID > 0 && f.LastStreamID%2 != 1 { if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
...@@ -912,7 +870,6 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { ...@@ -912,7 +870,6 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
t.mu.Unlock() t.mu.Unlock()
return return
default: default:
t.setGoAwayReason(f)
} }
t.goAwayID = f.LastStreamID t.goAwayID = f.LastStreamID
close(t.goAway) close(t.goAway)
...@@ -920,26 +877,6 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { ...@@ -920,26 +877,6 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
t.mu.Unlock() t.mu.Unlock()
} }
// setGoAwayReason sets the value of t.goAwayReason based
// on the GoAway frame received.
// It expects a lock on transport's mutext to be held by
// the caller.
func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) {
t.goAwayReason = NoReason
switch f.ErrCode {
case http2.ErrCodeEnhanceYourCalm:
if string(f.DebugData()) == "too_many_pings" {
t.goAwayReason = TooManyPings
}
}
}
func (t *http2Client) GetGoAwayReason() GoAwayReason {
t.mu.Lock()
defer t.mu.Unlock()
return t.goAwayReason
}
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
id := f.Header().StreamID id := f.Header().StreamID
incr := f.Increment incr := f.Increment
...@@ -960,17 +897,18 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { ...@@ -960,17 +897,18 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
} }
var state decodeState var state decodeState
for _, hf := range frame.Fields { for _, hf := range frame.Fields {
if err := state.processHeaderField(hf); err != nil { state.processHeaderField(hf)
s.mu.Lock() }
if !s.headerDone { if state.err != nil {
close(s.headerChan) s.mu.Lock()
s.headerDone = true if !s.headerDone {
} close(s.headerChan)
s.mu.Unlock() s.headerDone = true
s.write(recvMsg{err: err})
// Something wrong. Stops reading even when there is remaining.
return
} }
s.mu.Unlock()
s.write(recvMsg{err: state.err})
// Something wrong. Stops reading even when there is remaining.
return
} }
endStream := frame.StreamEnded() endStream := frame.StreamEnded()
...@@ -1013,7 +951,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { ...@@ -1013,7 +951,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if len(state.mdata) > 0 { if len(state.mdata) > 0 {
s.trailer = state.mdata s.trailer = state.mdata
} }
s.finish(state.status()) s.statusCode = state.statusCode
s.statusDesc = state.statusDesc
close(s.done)
s.state = streamDone
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
...@@ -1041,7 +982,6 @@ func (t *http2Client) reader() { ...@@ -1041,7 +982,6 @@ func (t *http2Client) reader() {
t.notifyError(err) t.notifyError(err)
return return
} }
atomic.CompareAndSwapUint32(&t.activity, 0, 1)
sf, ok := frame.(*http2.SettingsFrame) sf, ok := frame.(*http2.SettingsFrame)
if !ok { if !ok {
t.notifyError(err) t.notifyError(err)
...@@ -1052,7 +992,6 @@ func (t *http2Client) reader() { ...@@ -1052,7 +992,6 @@ func (t *http2Client) reader() {
// loop to keep reading incoming messages on this transport. // loop to keep reading incoming messages on this transport.
for { for {
frame, err := t.framer.readFrame() frame, err := t.framer.readFrame()
atomic.CompareAndSwapUint32(&t.activity, 0, 1)
if err != nil { if err != nil {
// Abort an active stream if the http2.Framer returns a // Abort an active stream if the http2.Framer returns a
// http2.StreamError. This can happen only if the server's response // http2.StreamError. This can happen only if the server's response
...@@ -1104,10 +1043,16 @@ func (t *http2Client) applySettings(ss []http2.Setting) { ...@@ -1104,10 +1043,16 @@ func (t *http2Client) applySettings(ss []http2.Setting) {
s.Val = math.MaxInt32 s.Val = math.MaxInt32
} }
t.mu.Lock() t.mu.Lock()
reset := t.streamsQuota != nil
if !reset {
t.streamsQuota = newQuotaPool(int(s.Val) - len(t.activeStreams))
}
ms := t.maxStreams ms := t.maxStreams
t.maxStreams = int(s.Val) t.maxStreams = int(s.Val)
t.mu.Unlock() t.mu.Unlock()
t.streamsQuota.add(int(s.Val) - ms) if reset {
t.streamsQuota.add(int(s.Val) - ms)
}
case http2.SettingInitialWindowSize: case http2.SettingInitialWindowSize:
t.mu.Lock() t.mu.Lock()
for _, stream := range t.activeStreams { for _, stream := range t.activeStreams {
...@@ -1140,12 +1085,6 @@ func (t *http2Client) controller() { ...@@ -1140,12 +1085,6 @@ func (t *http2Client) controller() {
t.framer.writeSettings(true, i.ss...) t.framer.writeSettings(true, i.ss...)
} }
case *resetStream: case *resetStream:
// If the server needs to be to intimated about stream closing,
// then we need to make sure the RST_STREAM frame is written to
// the wire before the headers of the next stream waiting on
// streamQuota. We ensure this by adding to the streamsQuota pool
// only after having acquired the writableChan to send RST_STREAM.
t.streamsQuota.add(1)
t.framer.writeRSTStream(true, i.streamID, i.code) t.framer.writeRSTStream(true, i.streamID, i.code)
case *flushIO: case *flushIO:
t.framer.flushWrite() t.framer.flushWrite()
...@@ -1165,61 +1104,6 @@ func (t *http2Client) controller() { ...@@ -1165,61 +1104,6 @@ func (t *http2Client) controller() {
} }
} }
// keepalive running in a separate goroutune makes sure the connection is alive by sending pings.
func (t *http2Client) keepalive() {
p := &ping{data: [8]byte{}}
timer := time.NewTimer(t.kp.Time)
for {
select {
case <-timer.C:
if atomic.CompareAndSwapUint32(&t.activity, 1, 0) {
timer.Reset(t.kp.Time)
continue
}
// Check if keepalive should go dormant.
t.mu.Lock()
if len(t.activeStreams) < 1 && !t.kp.PermitWithoutStream {
// Make awakenKeepalive writable.
<-t.awakenKeepalive
t.mu.Unlock()
select {
case <-t.awakenKeepalive:
// If the control gets here a ping has been sent
// need to reset the timer with keepalive.Timeout.
case <-t.shutdownChan:
return
}
} else {
t.mu.Unlock()
// Send ping.
t.controlBuf.put(p)
}
// By the time control gets here a ping has been sent one way or the other.
timer.Reset(t.kp.Timeout)
select {
case <-timer.C:
if atomic.CompareAndSwapUint32(&t.activity, 1, 0) {
timer.Reset(t.kp.Time)
continue
}
t.Close()
return
case <-t.shutdownChan:
if !timer.Stop() {
<-timer.C
}
return
}
case <-t.shutdownChan:
if !timer.Stop() {
<-timer.C
}
return
}
}
}
func (t *http2Client) Error() <-chan struct{} { func (t *http2Client) Error() <-chan struct{} {
return t.errorChan return t.errorChan
} }
......
...@@ -38,25 +38,19 @@ import ( ...@@ -38,25 +38,19 @@ import (
"errors" "errors"
"io" "io"
"math" "math"
"math/rand"
"net" "net"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
"time"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/tap" "google.golang.org/grpc/tap"
) )
...@@ -96,33 +90,11 @@ type http2Server struct { ...@@ -96,33 +90,11 @@ type http2Server struct {
stats stats.Handler stats stats.Handler
// Flag to keep track of reading activity on transport.
// 1 is true and 0 is false.
activity uint32 // Accessed atomically.
// Keepalive and max-age parameters for the server.
kp keepalive.ServerParameters
// Keepalive enforcement policy.
kep keepalive.EnforcementPolicy
// The time instance last ping was received.
lastPingAt time.Time
// Number of times the client has violated keepalive ping policy so far.
pingStrikes uint8
// Flag to signify that number of ping strikes should be reset to 0.
// This is set whenever data or header frames are sent.
// 1 means yes.
resetPingStrikes uint32 // Accessed atomically.
mu sync.Mutex // guard the following mu sync.Mutex // guard the following
state transportState state transportState
activeStreams map[uint32]*Stream activeStreams map[uint32]*Stream
// the per-stream outbound flow control window size set by the peer. // the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32 streamSendQuota uint32
// idle is the time instant when the connection went idle.
// This is either the begining of the connection or when the number of
// RPCs go down to 0.
// When the connection is busy, this value is set to 0.
idle time.Time
} }
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
...@@ -156,28 +128,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err ...@@ -156,28 +128,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
return nil, connectionErrorf(true, err, "transport: %v", err) return nil, connectionErrorf(true, err, "transport: %v", err)
} }
} }
kp := config.KeepaliveParams
if kp.MaxConnectionIdle == 0 {
kp.MaxConnectionIdle = defaultMaxConnectionIdle
}
if kp.MaxConnectionAge == 0 {
kp.MaxConnectionAge = defaultMaxConnectionAge
}
// Add a jitter to MaxConnectionAge.
kp.MaxConnectionAge += getJitter(kp.MaxConnectionAge)
if kp.MaxConnectionAgeGrace == 0 {
kp.MaxConnectionAgeGrace = defaultMaxConnectionAgeGrace
}
if kp.Time == 0 {
kp.Time = defaultServerKeepaliveTime
}
if kp.Timeout == 0 {
kp.Timeout = defaultServerKeepaliveTimeout
}
kep := config.KeepalivePolicy
if kep.MinTime == 0 {
kep.MinTime = defaultKeepalivePolicyMinTime
}
var buf bytes.Buffer var buf bytes.Buffer
t := &http2Server{ t := &http2Server{
ctx: context.Background(), ctx: context.Background(),
...@@ -199,9 +149,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err ...@@ -199,9 +149,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
activeStreams: make(map[uint32]*Stream), activeStreams: make(map[uint32]*Stream),
streamSendQuota: defaultWindowSize, streamSendQuota: defaultWindowSize,
stats: config.StatsHandler, stats: config.StatsHandler,
kp: kp,
idle: time.Now(),
kep: kep,
} }
if t.stats != nil { if t.stats != nil {
t.ctx = t.stats.TagConn(t.ctx, &stats.ConnTagInfo{ t.ctx = t.stats.TagConn(t.ctx, &stats.ConnTagInfo{
...@@ -212,7 +159,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err ...@@ -212,7 +159,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
t.stats.HandleConn(t.ctx, connBegin) t.stats.HandleConn(t.ctx, connBegin)
} }
go t.controller() go t.controller()
go t.keepalive()
t.writableChan <- 0 t.writableChan <- 0
return t, nil return t, nil
} }
...@@ -229,12 +175,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( ...@@ -229,12 +175,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
var state decodeState var state decodeState
for _, hf := range frame.Fields { for _, hf := range frame.Fields {
if err := state.processHeaderField(hf); err != nil { state.processHeaderField(hf)
if se, ok := err.(StreamError); ok { }
t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]}) if err := state.err; err != nil {
} if se, ok := err.(StreamError); ok {
return t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]})
} }
return
} }
if frame.StreamEnded() { if frame.StreamEnded() {
...@@ -261,7 +208,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( ...@@ -261,7 +208,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.ctx = newContextWithStream(s.ctx, s) s.ctx = newContextWithStream(s.ctx, s)
// Attach the received metadata to the context. // Attach the received metadata to the context.
if len(state.mdata) > 0 { if len(state.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata) s.ctx = metadata.NewContext(s.ctx, state.mdata)
} }
s.dec = &recvBufferReader{ s.dec = &recvBufferReader{
...@@ -301,9 +248,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( ...@@ -301,9 +248,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.maxStreamID = s.id t.maxStreamID = s.id
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s t.activeStreams[s.id] = s
if len(t.activeStreams) == 1 {
t.idle = time.Time{}
}
t.mu.Unlock() t.mu.Unlock()
s.windowHandler = func(n int) { s.windowHandler = func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
...@@ -351,7 +295,6 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. ...@@ -351,7 +295,6 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
t.Close() t.Close()
return return
} }
atomic.StoreUint32(&t.activity, 1)
sf, ok := frame.(*http2.SettingsFrame) sf, ok := frame.(*http2.SettingsFrame)
if !ok { if !ok {
grpclog.Printf("transport: http2Server.HandleStreams saw invalid preface type %T from client", frame) grpclog.Printf("transport: http2Server.HandleStreams saw invalid preface type %T from client", frame)
...@@ -362,7 +305,6 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. ...@@ -362,7 +305,6 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
for { for {
frame, err := t.framer.readFrame() frame, err := t.framer.readFrame()
atomic.StoreUint32(&t.activity, 1)
if err != nil { if err != nil {
if se, ok := err.(http2.StreamError); ok { if se, ok := err.(http2.StreamError); ok {
t.mu.Lock() t.mu.Lock()
...@@ -439,7 +381,7 @@ func (t *http2Server) updateWindow(s *Stream, n uint32) { ...@@ -439,7 +381,7 @@ func (t *http2Server) updateWindow(s *Stream, n uint32) {
} }
func (t *http2Server) handleData(f *http2.DataFrame) { func (t *http2Server) handleData(f *http2.DataFrame) {
size := f.Header().Length size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil { if err := t.fc.onData(uint32(size)); err != nil {
grpclog.Printf("transport: http2Server %v", err) grpclog.Printf("transport: http2Server %v", err)
t.Close() t.Close()
...@@ -454,11 +396,6 @@ func (t *http2Server) handleData(f *http2.DataFrame) { ...@@ -454,11 +396,6 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
return return
} }
if size > 0 { if size > 0 {
if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := t.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 {
t.controlBuf.put(&windowUpdate{0, w})
}
}
s.mu.Lock() s.mu.Lock()
if s.state == streamDone { if s.state == streamDone {
s.mu.Unlock() s.mu.Unlock()
...@@ -474,20 +411,13 @@ func (t *http2Server) handleData(f *http2.DataFrame) { ...@@ -474,20 +411,13 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
return return
} }
if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 {
t.controlBuf.put(&windowUpdate{s.id, w})
}
}
s.mu.Unlock() s.mu.Unlock()
// TODO(bradfitz, zhaoq): A copy is required here because there is no // TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
if len(f.Data()) > 0 { data := make([]byte, size)
data := make([]byte, len(f.Data())) copy(data, f.Data())
copy(data, f.Data()) s.write(recvMsg{data: data})
s.write(recvMsg{data: data})
}
} }
if f.Header().Flags.Has(http2.FlagDataEndStream) { if f.Header().Flags.Has(http2.FlagDataEndStream) {
// Received the end of stream from the client. // Received the end of stream from the client.
...@@ -521,11 +451,6 @@ func (t *http2Server) handleSettings(f *http2.SettingsFrame) { ...@@ -521,11 +451,6 @@ func (t *http2Server) handleSettings(f *http2.SettingsFrame) {
t.controlBuf.put(&settings{ack: true, ss: ss}) t.controlBuf.put(&settings{ack: true, ss: ss})
} }
const (
maxPingStrikes = 2
defaultPingTimeout = 2 * time.Hour
)
func (t *http2Server) handlePing(f *http2.PingFrame) { func (t *http2Server) handlePing(f *http2.PingFrame) {
if f.IsAck() { // Do nothing. if f.IsAck() { // Do nothing.
return return
...@@ -533,38 +458,6 @@ func (t *http2Server) handlePing(f *http2.PingFrame) { ...@@ -533,38 +458,6 @@ func (t *http2Server) handlePing(f *http2.PingFrame) {
pingAck := &ping{ack: true} pingAck := &ping{ack: true}
copy(pingAck.data[:], f.Data[:]) copy(pingAck.data[:], f.Data[:])
t.controlBuf.put(pingAck) t.controlBuf.put(pingAck)
now := time.Now()
defer func() {
t.lastPingAt = now
}()
// A reset ping strikes means that we don't need to check for policy
// violation for this ping and the pingStrikes counter should be set
// to 0.
if atomic.CompareAndSwapUint32(&t.resetPingStrikes, 1, 0) {
t.pingStrikes = 0
return
}
t.mu.Lock()
ns := len(t.activeStreams)
t.mu.Unlock()
if ns < 1 && !t.kep.PermitWithoutStream {
// Keepalive shouldn't be active thus, this new ping should
// have come after atleast defaultPingTimeout.
if t.lastPingAt.Add(defaultPingTimeout).After(now) {
t.pingStrikes++
}
} else {
// Check if keepalive policy is respected.
if t.lastPingAt.Add(t.kep.MinTime).After(now) {
t.pingStrikes++
}
}
if t.pingStrikes > maxPingStrikes {
// Send goaway and close the connection.
t.controlBuf.put(&goAway{code: http2.ErrCodeEnhanceYourCalm, debugData: []byte("too_many_pings")})
}
} }
func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) {
...@@ -583,13 +476,6 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e ...@@ -583,13 +476,6 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
first := true first := true
endHeaders := false endHeaders := false
var err error var err error
defer func() {
if err == nil {
// Reset ping strikes when seding headers since that might cause the
// peer to send ping.
atomic.StoreUint32(&t.resetPingStrikes, 1)
}
}()
// Sends the headers in a single batch. // Sends the headers in a single batch.
for !endHeaders { for !endHeaders {
size := t.hBuf.Len() size := t.hBuf.Len()
...@@ -671,7 +557,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { ...@@ -671,7 +557,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
// There is no further I/O operations being able to perform on this stream. // There is no further I/O operations being able to perform on this stream.
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted. // OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
var headersSent, hasHeader bool var headersSent, hasHeader bool
s.mu.Lock() s.mu.Lock()
if s.state == streamDone { if s.state == streamDone {
...@@ -702,24 +588,9 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { ...@@ -702,24 +588,9 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
t.hEnc.WriteField( t.hEnc.WriteField(
hpack.HeaderField{ hpack.HeaderField{
Name: "grpc-status", Name: "grpc-status",
Value: strconv.Itoa(int(st.Code())), Value: strconv.Itoa(int(statusCode)),
}) })
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
if p := st.Proto(); p != nil && len(p.Details) > 0 {
stBytes, err := proto.Marshal(p)
if err != nil {
// TODO: return error instead, when callers are able to handle it.
panic(err)
}
for k, v := range metadata.New(map[string]string{"grpc-status-details-bin": (string)(stBytes)}) {
for _, entry := range v {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
}
}
}
// Attach the trailer metadata. // Attach the trailer metadata.
for k, v := range s.trailer { for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent. // Clients don't tolerate reading restricted headers after some non restricted ones were sent.
...@@ -748,7 +619,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { ...@@ -748,7 +619,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error // Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error). // is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) { func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
// TODO(zhaoq): Support multi-writers for a single stream. // TODO(zhaoq): Support multi-writers for a single stream.
var writeHeaderFrame bool var writeHeaderFrame bool
s.mu.Lock() s.mu.Lock()
...@@ -763,13 +634,6 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) { ...@@ -763,13 +634,6 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) {
if writeHeaderFrame { if writeHeaderFrame {
t.WriteHeader(s, nil) t.WriteHeader(s, nil)
} }
defer func() {
if err == nil {
// Reset ping strikes when sending data since this might cause
// the peer to send ping.
atomic.StoreUint32(&t.resetPingStrikes, 1)
}
}()
r := bytes.NewBuffer(data) r := bytes.NewBuffer(data)
for { for {
if r.Len() == 0 { if r.Len() == 0 {
...@@ -859,91 +723,6 @@ func (t *http2Server) applySettings(ss []http2.Setting) { ...@@ -859,91 +723,6 @@ func (t *http2Server) applySettings(ss []http2.Setting) {
} }
} }
// keepalive running in a separate goroutine does the following:
// 1. Gracefully closes an idle connection after a duration of keepalive.MaxConnectionIdle.
// 2. Gracefully closes any connection after a duration of keepalive.MaxConnectionAge.
// 3. Forcibly closes a connection after an additive period of keepalive.MaxConnectionAgeGrace over keepalive.MaxConnectionAge.
// 4. Makes sure a connection is alive by sending pings with a frequency of keepalive.Time and closes a non-resposive connection
// after an additional duration of keepalive.Timeout.
func (t *http2Server) keepalive() {
p := &ping{}
var pingSent bool
maxIdle := time.NewTimer(t.kp.MaxConnectionIdle)
maxAge := time.NewTimer(t.kp.MaxConnectionAge)
keepalive := time.NewTimer(t.kp.Time)
// NOTE: All exit paths of this function should reset their
// respecitve timers. A failure to do so will cause the
// following clean-up to deadlock and eventually leak.
defer func() {
if !maxIdle.Stop() {
<-maxIdle.C
}
if !maxAge.Stop() {
<-maxAge.C
}
if !keepalive.Stop() {
<-keepalive.C
}
}()
for {
select {
case <-maxIdle.C:
t.mu.Lock()
idle := t.idle
if idle.IsZero() { // The connection is non-idle.
t.mu.Unlock()
maxIdle.Reset(t.kp.MaxConnectionIdle)
continue
}
val := t.kp.MaxConnectionIdle - time.Since(idle)
if val <= 0 {
// The connection has been idle for a duration of keepalive.MaxConnectionIdle or more.
// Gracefully close the connection.
t.state = draining
t.mu.Unlock()
t.Drain()
// Reseting the timer so that the clean-up doesn't deadlock.
maxIdle.Reset(infinity)
return
}
t.mu.Unlock()
maxIdle.Reset(val)
case <-maxAge.C:
t.mu.Lock()
t.state = draining
t.mu.Unlock()
t.Drain()
maxAge.Reset(t.kp.MaxConnectionAgeGrace)
select {
case <-maxAge.C:
// Close the connection after grace period.
t.Close()
// Reseting the timer so that the clean-up doesn't deadlock.
maxAge.Reset(infinity)
case <-t.shutdownChan:
}
return
case <-keepalive.C:
if atomic.CompareAndSwapUint32(&t.activity, 1, 0) {
pingSent = false
keepalive.Reset(t.kp.Time)
continue
}
if pingSent {
t.Close()
// Reseting the timer so that the clean-up doesn't deadlock.
keepalive.Reset(infinity)
return
}
pingSent = true
t.controlBuf.put(p)
keepalive.Reset(t.kp.Timeout)
case <-t.shutdownChan:
return
}
}
}
// controller running in a separate goroutine takes charge of sending control // controller running in a separate goroutine takes charge of sending control
// frames (e.g., window update, reset stream, setting, etc.) to the server. // frames (e.g., window update, reset stream, setting, etc.) to the server.
func (t *http2Server) controller() { func (t *http2Server) controller() {
...@@ -975,10 +754,7 @@ func (t *http2Server) controller() { ...@@ -975,10 +754,7 @@ func (t *http2Server) controller() {
sid := t.maxStreamID sid := t.maxStreamID
t.state = draining t.state = draining
t.mu.Unlock() t.mu.Unlock()
t.framer.writeGoAway(true, sid, i.code, i.debugData) t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
if i.code == http2.ErrCodeEnhanceYourCalm {
t.Close()
}
case *flushIO: case *flushIO:
t.framer.flushWrite() t.framer.flushWrite()
case *ping: case *ping:
...@@ -1028,9 +804,6 @@ func (t *http2Server) Close() (err error) { ...@@ -1028,9 +804,6 @@ func (t *http2Server) Close() (err error) {
func (t *http2Server) closeStream(s *Stream) { func (t *http2Server) closeStream(s *Stream) {
t.mu.Lock() t.mu.Lock()
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
if len(t.activeStreams) == 0 {
t.idle = time.Now()
}
if t.state == draining && len(t.activeStreams) == 0 { if t.state == draining && len(t.activeStreams) == 0 {
defer t.Close() defer t.Close()
} }
...@@ -1058,17 +831,5 @@ func (t *http2Server) RemoteAddr() net.Addr { ...@@ -1058,17 +831,5 @@ func (t *http2Server) RemoteAddr() net.Addr {
} }
func (t *http2Server) Drain() { func (t *http2Server) Drain() {
t.controlBuf.put(&goAway{code: http2.ErrCodeNo}) t.controlBuf.put(&goAway{})
}
var rgen = rand.New(rand.NewSource(time.Now().UnixNano()))
func getJitter(v time.Duration) time.Duration {
if v == infinity {
return 0
}
// Generate a jitter between +/- 10% of the value.
r := int64(v / 10)
j := rgen.Int63n(2*r) - r
return time.Duration(j)
} }
...@@ -44,17 +44,16 @@ import ( ...@@ -44,17 +44,16 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/golang/protobuf/proto"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
) )
const ( const (
// The primary user agent
primaryUA = "grpc-go/1.0"
// http2MaxFrameLen specifies the max length of a HTTP2 frame. // http2MaxFrameLen specifies the max length of a HTTP2 frame.
http2MaxFrameLen = 16384 // 16KB frame http2MaxFrameLen = 16384 // 16KB frame
// http://http2.github.io/http2-spec/#SettingValues // http://http2.github.io/http2-spec/#SettingValues
...@@ -93,15 +92,13 @@ var ( ...@@ -93,15 +92,13 @@ var (
// Records the states during HPACK decoding. Must be reset once the // Records the states during HPACK decoding. Must be reset once the
// decoding of the entire headers are finished. // decoding of the entire headers are finished.
type decodeState struct { type decodeState struct {
err error // first error encountered decoding
encoding string encoding string
// statusGen caches the stream status received from the trailer the server // statusCode caches the stream status received from the trailer
// sent. Client side only. Do not access directly. After all trailers are // the server sent. Client side only.
// parsed, use the status method to retrieve the status. statusCode codes.Code
statusGen *status.Status statusDesc string
// rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not
// intended for direct access outside of parsing.
rawStatusCode int32
rawStatusMsg string
// Server side only fields. // Server side only fields.
timeoutSet bool timeoutSet bool
timeout time.Duration timeout time.Duration
...@@ -124,7 +121,6 @@ func isReservedHeader(hdr string) bool { ...@@ -124,7 +121,6 @@ func isReservedHeader(hdr string) bool {
"grpc-message", "grpc-message",
"grpc-status", "grpc-status",
"grpc-timeout", "grpc-timeout",
"grpc-status-details-bin",
"te": "te":
return true return true
default: default:
...@@ -143,6 +139,12 @@ func isWhitelistedPseudoHeader(hdr string) bool { ...@@ -143,6 +139,12 @@ func isWhitelistedPseudoHeader(hdr string) bool {
} }
} }
func (d *decodeState) setErr(err error) {
if d.err == nil {
d.err = err
}
}
func validContentType(t string) bool { func validContentType(t string) bool {
e := "application/grpc" e := "application/grpc"
if !strings.HasPrefix(t, e) { if !strings.HasPrefix(t, e) {
...@@ -156,62 +158,56 @@ func validContentType(t string) bool { ...@@ -156,62 +158,56 @@ func validContentType(t string) bool {
return true return true
} }
func (d *decodeState) status() *status.Status { func (d *decodeState) processHeaderField(f hpack.HeaderField) {
if d.statusGen == nil {
// No status-details were provided; generate status using code/msg.
d.statusGen = status.New(codes.Code(d.rawStatusCode), d.rawStatusMsg)
}
return d.statusGen
}
func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
switch f.Name { switch f.Name {
case "content-type": case "content-type":
if !validContentType(f.Value) { if !validContentType(f.Value) {
return streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value) d.setErr(streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value))
return
} }
case "grpc-encoding": case "grpc-encoding":
d.encoding = f.Value d.encoding = f.Value
case "grpc-status": case "grpc-status":
code, err := strconv.Atoi(f.Value) code, err := strconv.Atoi(f.Value)
if err != nil { if err != nil {
return streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err) d.setErr(streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err))
return
} }
d.rawStatusCode = int32(code) d.statusCode = codes.Code(code)
case "grpc-message": case "grpc-message":
d.rawStatusMsg = decodeGrpcMessage(f.Value) d.statusDesc = decodeGrpcMessage(f.Value)
case "grpc-status-details-bin":
_, v, err := metadata.DecodeKeyValue("grpc-status-details-bin", f.Value)
if err != nil {
return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
}
s := &spb.Status{}
if err := proto.Unmarshal([]byte(v), s); err != nil {
return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
}
d.statusGen = status.FromProto(s)
case "grpc-timeout": case "grpc-timeout":
d.timeoutSet = true d.timeoutSet = true
var err error var err error
if d.timeout, err = decodeTimeout(f.Value); err != nil { d.timeout, err = decodeTimeout(f.Value)
return streamErrorf(codes.Internal, "transport: malformed time-out: %v", err) if err != nil {
d.setErr(streamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
return
} }
case ":path": case ":path":
d.method = f.Value d.method = f.Value
default: default:
if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) { if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) {
if f.Name == "user-agent" {
i := strings.LastIndex(f.Value, " ")
if i == -1 {
// There is no application user agent string being set.
return
}
// Extract the application user agent string.
f.Value = f.Value[:i]
}
if d.mdata == nil { if d.mdata == nil {
d.mdata = make(map[string][]string) d.mdata = make(map[string][]string)
} }
k, v, err := metadata.DecodeKeyValue(f.Name, f.Value) k, v, err := metadata.DecodeKeyValue(f.Name, f.Value)
if err != nil { if err != nil {
grpclog.Printf("Failed to decode (%q, %q): %v", f.Name, f.Value, err) grpclog.Printf("Failed to decode (%q, %q): %v", f.Name, f.Value, err)
return nil return
} }
d.mdata[k] = append(d.mdata[k], v) d.mdata[k] = append(d.mdata[k], v)
} }
} }
return nil
} }
type timeoutUnit uint8 type timeoutUnit uint8
...@@ -383,9 +379,6 @@ func newFramer(conn net.Conn) *framer { ...@@ -383,9 +379,6 @@ func newFramer(conn net.Conn) *framer {
writer: bufio.NewWriterSize(conn, http2IOBufSize), writer: bufio.NewWriterSize(conn, http2IOBufSize),
} }
f.fr = http2.NewFramer(f.writer, f.reader) f.fr = http2.NewFramer(f.writer, f.reader)
// Opt-in to Frame reuse API on framer to reduce garbage.
// Frames aren't safe to read from after a subsequent call to ReadFrame.
f.fr.SetReuseFrames()
f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil) f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
return f return f
} }
......
// +build !go1.6
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
var dialer net.Dialer
if deadline, ok := ctx.Deadline(); ok {
dialer.Timeout = deadline.Sub(time.Now())
}
return dialer.Dial(network, address)
}
...@@ -45,13 +45,10 @@ import ( ...@@ -45,13 +45,10 @@ import (
"sync" "sync"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/http2"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/tap" "google.golang.org/grpc/tap"
) )
...@@ -213,13 +210,9 @@ type Stream struct { ...@@ -213,13 +210,9 @@ type Stream struct {
// true iff headerChan is closed. Used to avoid closing headerChan // true iff headerChan is closed. Used to avoid closing headerChan
// multiple times. // multiple times.
headerDone bool headerDone bool
// the status error received from the server. // the status received from the server.
status *status.Status statusCode codes.Code
// rstStream indicates whether a RST_STREAM frame needs to be sent statusDesc string
// to the server to signify that this stream is closing.
rstStream bool
// rstError is the error that needs to be sent along with the RST_STREAM frame.
rstError http2.ErrCode
} }
// RecvCompress returns the compression algorithm applied to the inbound // RecvCompress returns the compression algorithm applied to the inbound
...@@ -284,9 +277,14 @@ func (s *Stream) Method() string { ...@@ -284,9 +277,14 @@ func (s *Stream) Method() string {
return s.method return s.method
} }
// Status returns the status received from the server. // StatusCode returns statusCode received from the server.
func (s *Stream) Status() *status.Status { func (s *Stream) StatusCode() codes.Code {
return s.status return s.statusCode
}
// StatusDesc returns statusDesc received from the server.
func (s *Stream) StatusDesc() string {
return s.statusDesc
} }
// SetHeader sets the header metadata. This can be called multiple times. // SetHeader sets the header metadata. This can be called multiple times.
...@@ -333,20 +331,6 @@ func (s *Stream) Read(p []byte) (n int, err error) { ...@@ -333,20 +331,6 @@ func (s *Stream) Read(p []byte) (n int, err error) {
return return
} }
// finish sets the stream's state and status, and closes the done channel.
// s.mu must be held by the caller. st must always be non-nil.
func (s *Stream) finish(st *status.Status) {
s.status = st
s.state = streamDone
close(s.done)
}
// GoString is implemented by Stream so context.String() won't
// race when printing %#v.
func (s *Stream) GoString() string {
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
}
// The key to save transport.Stream in the context. // The key to save transport.Stream in the context.
type streamKey struct{} type streamKey struct{}
...@@ -374,12 +358,10 @@ const ( ...@@ -374,12 +358,10 @@ const (
// ServerConfig consists of all the configurations to establish a server transport. // ServerConfig consists of all the configurations to establish a server transport.
type ServerConfig struct { type ServerConfig struct {
MaxStreams uint32 MaxStreams uint32
AuthInfo credentials.AuthInfo AuthInfo credentials.AuthInfo
InTapHandle tap.ServerInHandle InTapHandle tap.ServerInHandle
StatsHandler stats.Handler StatsHandler stats.Handler
KeepaliveParams keepalive.ServerParameters
KeepalivePolicy keepalive.EnforcementPolicy
} }
// NewServerTransport creates a ServerTransport with conn or non-nil error // NewServerTransport creates a ServerTransport with conn or non-nil error
...@@ -392,9 +374,6 @@ func NewServerTransport(protocol string, conn net.Conn, config *ServerConfig) (S ...@@ -392,9 +374,6 @@ func NewServerTransport(protocol string, conn net.Conn, config *ServerConfig) (S
type ConnectOptions struct { type ConnectOptions struct {
// UserAgent is the application user agent. // UserAgent is the application user agent.
UserAgent string UserAgent string
// Authority is the :authority pseudo-header to use. This field has no effect if
// TransportCredentials is set.
Authority string
// Dialer specifies how to dial a network address. // Dialer specifies how to dial a network address.
Dialer func(context.Context, string) (net.Conn, error) Dialer func(context.Context, string) (net.Conn, error)
// FailOnNonTempDialError specifies if gRPC fails on non-temporary dial errors. // FailOnNonTempDialError specifies if gRPC fails on non-temporary dial errors.
...@@ -403,8 +382,6 @@ type ConnectOptions struct { ...@@ -403,8 +382,6 @@ type ConnectOptions struct {
PerRPCCredentials []credentials.PerRPCCredentials PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection. // TransportCredentials stores the Authenticator required to setup a client connection.
TransportCredentials credentials.TransportCredentials TransportCredentials credentials.TransportCredentials
// KeepaliveParams stores the keepalive parameters.
KeepaliveParams keepalive.ClientParameters
// StatsHandler stores the handler for stats. // StatsHandler stores the handler for stats.
StatsHandler stats.Handler StatsHandler stats.Handler
} }
...@@ -493,9 +470,6 @@ type ClientTransport interface { ...@@ -493,9 +470,6 @@ type ClientTransport interface {
// receives the draining signal from the server (e.g., GOAWAY frame in // receives the draining signal from the server (e.g., GOAWAY frame in
// HTTP/2). // HTTP/2).
GoAway() <-chan struct{} GoAway() <-chan struct{}
// GetGoAwayReason returns the reason why GoAway frame was received.
GetGoAwayReason() GoAwayReason
} }
// ServerTransport is the common interface for all gRPC server-side transport // ServerTransport is the common interface for all gRPC server-side transport
...@@ -515,9 +489,10 @@ type ServerTransport interface { ...@@ -515,9 +489,10 @@ type ServerTransport interface {
// Write may not be called on all streams. // Write may not be called on all streams.
Write(s *Stream, data []byte, opts *Options) error Write(s *Stream, data []byte, opts *Options) error
// WriteStatus sends the status of a stream to the client. WriteStatus is // WriteStatus sends the status of a stream to the client.
// the final call made on a stream and always occurs. // WriteStatus is the final call made on a stream and always
WriteStatus(s *Stream, st *status.Status) error // occurs.
WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error
// Close tears down the transport. Once it is called, the transport // Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their // should not be accessed any more. All the pending streams and their
...@@ -583,8 +558,6 @@ var ( ...@@ -583,8 +558,6 @@ var (
ErrStreamDrain = streamErrorf(codes.Unavailable, "the server stops accepting new RPCs") ErrStreamDrain = streamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
) )
// TODO: See if we can replace StreamError with status package errors.
// StreamError is an error that only affects one stream within a connection. // StreamError is an error that only affects one stream within a connection.
type StreamError struct { type StreamError struct {
Code codes.Code Code codes.Code
...@@ -592,7 +565,7 @@ type StreamError struct { ...@@ -592,7 +565,7 @@ type StreamError struct {
} }
func (e StreamError) Error() string { func (e StreamError) Error() string {
return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc) return fmt.Sprintf("stream error: code = %d desc = %q", e.Code, e.Desc)
} }
// ContextErr converts the error from context package into a StreamError. // ContextErr converts the error from context package into a StreamError.
...@@ -633,16 +606,3 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <- ...@@ -633,16 +606,3 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-
return i, nil return i, nil
} }
} }
// GoAwayReason contains the reason for the GoAway frame received.
type GoAwayReason uint8
const (
// Invalid indicates that no GoAway frame is received.
Invalid GoAwayReason = 0
// NoReason is the default value when GoAway frame is received.
NoReason GoAwayReason = 1
// TooManyPings indicates that a GoAway frame with ErrCodeEnhanceYourCalm
// was recieved and that the debug data said "too_many_pings".
TooManyPings GoAwayReason = 2
)
...@@ -132,10 +132,10 @@ ...@@ -132,10 +132,10 @@
"revisionTime": "2017-04-04T13:20:09Z" "revisionTime": "2017-04-04T13:20:09Z"
}, },
{ {
"checksumSHA1": "tidJMmntKTZuU196aiLojkULL+g=", "checksumSHA1": "epHwh7hDQSYzDowPIbw8vnLzPS0=",
"path": "google.golang.org/grpc", "path": "google.golang.org/grpc",
"revision": "6d158dbf32084eac5fc0b9ea6f1feed214290ec6", "revision": "50955793b0183f9de69bd78e2ec251cf20aab121",
"revisionTime": "2017-04-12T06:39:30Z" "revisionTime": "2017-01-11T19:10:52Z"
}, },
{ {
"checksumSHA1": "08icuA15HRkdYCt6H+Cs90RPQsY=", "checksumSHA1": "08icuA15HRkdYCt6H+Cs90RPQsY=",
...@@ -204,10 +204,10 @@ ...@@ -204,10 +204,10 @@
"revisionTime": "2017-04-12T06:39:30Z" "revisionTime": "2017-04-12T06:39:30Z"
}, },
{ {
"checksumSHA1": "WMlN+OrgFM70j2/AoMh6DM6NtK8=", "checksumSHA1": "yHpUeGwKoqqwd3cbEp3lkcnvft0=",
"path": "google.golang.org/grpc/transport", "path": "google.golang.org/grpc/transport",
"revision": "6d158dbf32084eac5fc0b9ea6f1feed214290ec6", "revision": "50955793b0183f9de69bd78e2ec251cf20aab121",
"revisionTime": "2017-04-12T06:39:30Z" "revisionTime": "2017-01-11T19:10:52Z"
}, },
{ {
"checksumSHA1": "fALlQNY1fM99NesfLJ50KguWsio=", "checksumSHA1": "fALlQNY1fM99NesfLJ50KguWsio=",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment