Commit 25d866fa authored by Jacob Vosmaer (GitLab)'s avatar Jacob Vosmaer (GitLab)

Merge branch 'bjk/grpc_vendor' into 'master'

More vendoring updates

See merge request gitlab-org/gitlab-workhorse!280
parents 2fd82d0f 273a9016
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [1.2.0](https://github.com/grpc-ecosystem/go-grpc-prometheus/releases/tag/v1.2.0) - 2018-06-04
### Added
* Provide metrics object as `prometheus.Collector`, for conventional metric registration.
* Support non-default/global Prometheus registry.
* Allow configuring counters with `prometheus.CounterOpts`.
### Changed
* Remove usage of deprecated `grpc.Code()`.
* Remove usage of deprecated `grpc.Errorf` and replace with `status.Errorf`.
---
This changelog was started with version `v1.2.0`, for earlier versions refer to the respective [GitHub releases](https://github.com/grpc-ecosystem/go-grpc-prometheus/releases).
...@@ -49,8 +49,8 @@ import "github.com/grpc-ecosystem/go-grpc-prometheus" ...@@ -49,8 +49,8 @@ import "github.com/grpc-ecosystem/go-grpc-prometheus"
... ...
clientConn, err = grpc.Dial( clientConn, err = grpc.Dial(
address, address,
grpc.WithUnaryInterceptor(UnaryClientInterceptor), grpc.WithUnaryInterceptor(grpc_prometheus.UnaryClientInterceptor),
grpc.WithStreamInterceptor(StreamClientInterceptor) grpc.WithStreamInterceptor(grpc_prometheus.StreamClientInterceptor)
) )
client = pb_testproto.NewTestServiceClient(clientConn) client = pb_testproto.NewTestServiceClient(clientConn)
resp, err := client.PingEmpty(s.ctx, &myservice.Request{Msg: "hello"}) resp, err := client.PingEmpty(s.ctx, &myservice.Request{Msg: "hello"})
...@@ -118,7 +118,7 @@ each of the 20 messages sent back, a counter will be incremented: ...@@ -118,7 +118,7 @@ each of the 20 messages sent back, a counter will be incremented:
grpc_server_msg_sent_total{grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream"} 20 grpc_server_msg_sent_total{grpc_method="PingList",grpc_service="mwitkow.testproto.TestService",grpc_type="server_stream"} 20
``` ```
After the call completes, it's status (`OK` or other [gRPC status code](https://github.com/grpc/grpc-go/blob/master/codes/codes.go)) After the call completes, its status (`OK` or other [gRPC status code](https://github.com/grpc/grpc-go/blob/master/codes/codes.go))
and the relevant call labels increment the `grpc_server_handled_total` counter. and the relevant call labels increment the `grpc_server_handled_total` counter.
```jsoniq ```jsoniq
...@@ -128,8 +128,8 @@ grpc_server_handled_total{grpc_code="OK",grpc_method="PingList",grpc_service="mw ...@@ -128,8 +128,8 @@ grpc_server_handled_total{grpc_code="OK",grpc_method="PingList",grpc_service="mw
## Histograms ## Histograms
[Prometheus histograms](https://prometheus.io/docs/concepts/metric_types/#histogram) are a great way [Prometheus histograms](https://prometheus.io/docs/concepts/metric_types/#histogram) are a great way
to measure latency distributions of your RPCs. However since it is bad practice to have metrics to measure latency distributions of your RPCs. However, since it is bad practice to have metrics
of [high cardinality](https://prometheus.io/docs/practices/instrumentation/#do-not-overuse-labels)) of [high cardinality](https://prometheus.io/docs/practices/instrumentation/#do-not-overuse-labels)
the latency monitoring metrics are disabled by default. To enable them please call the following the latency monitoring metrics are disabled by default. To enable them please call the following
in your server initialization code: in your server initialization code:
...@@ -137,8 +137,8 @@ in your server initialization code: ...@@ -137,8 +137,8 @@ in your server initialization code:
grpc_prometheus.EnableHandlingTimeHistogram() grpc_prometheus.EnableHandlingTimeHistogram()
``` ```
After the call completes, it's handling time will be recorded in a [Prometheus histogram](https://prometheus.io/docs/concepts/metric_types/#histogram) After the call completes, its handling time will be recorded in a [Prometheus histogram](https://prometheus.io/docs/concepts/metric_types/#histogram)
variable `grpc_server_handling_seconds`. It contains three sub-metrics: variable `grpc_server_handling_seconds`. The histogram variable contains three sub-metrics:
* `grpc_server_handling_seconds_count` - the count of all completed RPCs by status and method * `grpc_server_handling_seconds_count` - the count of all completed RPCs by status and method
* `grpc_server_handling_seconds_sum` - cumulative time of RPCs by status and method, useful for * `grpc_server_handling_seconds_sum` - cumulative time of RPCs by status and method, useful for
...@@ -168,7 +168,7 @@ grpc_server_handling_seconds_count{grpc_code="OK",grpc_method="PingList",grpc_se ...@@ -168,7 +168,7 @@ grpc_server_handling_seconds_count{grpc_code="OK",grpc_method="PingList",grpc_se
## Useful query examples ## Useful query examples
Prometheus philosophy is to provide the most detailed metrics possible to the monitoring system, and Prometheus philosophy is to provide raw metrics to the monitoring system, and
let the aggregations be handled there. The verbosity of above metrics make it possible to have that let the aggregations be handled there. The verbosity of above metrics make it possible to have that
flexibility. Here's a couple of useful monitoring queries: flexibility. Here's a couple of useful monitoring queries:
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
) )
// ClientMetrics represents a collection of metrics to be registered on a // ClientMetrics represents a collection of metrics to be registered on a
...@@ -25,31 +26,32 @@ type ClientMetrics struct { ...@@ -25,31 +26,32 @@ type ClientMetrics struct {
// ClientMetrics when not using the default Prometheus metrics registry, for // ClientMetrics when not using the default Prometheus metrics registry, for
// example when wanting to control which metrics are added to a registry as // example when wanting to control which metrics are added to a registry as
// opposed to automatically adding metrics via init functions. // opposed to automatically adding metrics via init functions.
func NewClientMetrics() *ClientMetrics { func NewClientMetrics(counterOpts ...CounterOption) *ClientMetrics {
opts := counterOptions(counterOpts)
return &ClientMetrics{ return &ClientMetrics{
clientStartedCounter: prom.NewCounterVec( clientStartedCounter: prom.NewCounterVec(
prom.CounterOpts{ opts.apply(prom.CounterOpts{
Name: "grpc_client_started_total", Name: "grpc_client_started_total",
Help: "Total number of RPCs started on the client.", Help: "Total number of RPCs started on the client.",
}, []string{"grpc_type", "grpc_service", "grpc_method"}), }), []string{"grpc_type", "grpc_service", "grpc_method"}),
clientHandledCounter: prom.NewCounterVec( clientHandledCounter: prom.NewCounterVec(
prom.CounterOpts{ opts.apply(prom.CounterOpts{
Name: "grpc_client_handled_total", Name: "grpc_client_handled_total",
Help: "Total number of RPCs completed by the client, regardless of success or failure.", Help: "Total number of RPCs completed by the client, regardless of success or failure.",
}, []string{"grpc_type", "grpc_service", "grpc_method", "grpc_code"}), }), []string{"grpc_type", "grpc_service", "grpc_method", "grpc_code"}),
clientStreamMsgReceived: prom.NewCounterVec( clientStreamMsgReceived: prom.NewCounterVec(
prom.CounterOpts{ opts.apply(prom.CounterOpts{
Name: "grpc_client_msg_received_total", Name: "grpc_client_msg_received_total",
Help: "Total number of RPC stream messages received by the client.", Help: "Total number of RPC stream messages received by the client.",
}, []string{"grpc_type", "grpc_service", "grpc_method"}), }), []string{"grpc_type", "grpc_service", "grpc_method"}),
clientStreamMsgSent: prom.NewCounterVec( clientStreamMsgSent: prom.NewCounterVec(
prom.CounterOpts{ opts.apply(prom.CounterOpts{
Name: "grpc_client_msg_sent_total", Name: "grpc_client_msg_sent_total",
Help: "Total number of gRPC stream messages sent by the client.", Help: "Total number of gRPC stream messages sent by the client.",
}, []string{"grpc_type", "grpc_service", "grpc_method"}), }), []string{"grpc_type", "grpc_service", "grpc_method"}),
clientHandledHistogramEnabled: false, clientHandledHistogramEnabled: false,
clientHandledHistogramOpts: prom.HistogramOpts{ clientHandledHistogramOpts: prom.HistogramOpts{
...@@ -111,18 +113,20 @@ func (m *ClientMetrics) UnaryClientInterceptor() func(ctx context.Context, metho ...@@ -111,18 +113,20 @@ func (m *ClientMetrics) UnaryClientInterceptor() func(ctx context.Context, metho
if err != nil { if err != nil {
monitor.ReceivedMessage() monitor.ReceivedMessage()
} }
monitor.Handled(grpc.Code(err)) st, _ := status.FromError(err)
monitor.Handled(st.Code())
return err return err
} }
} }
// StreamServerInterceptor is a gRPC client-side interceptor that provides Prometheus monitoring for Streaming RPCs. // StreamClientInterceptor is a gRPC client-side interceptor that provides Prometheus monitoring for Streaming RPCs.
func (m *ClientMetrics) StreamClientInterceptor() func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { func (m *ClientMetrics) StreamClientInterceptor() func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
monitor := newClientReporter(m, clientStreamType(desc), method) monitor := newClientReporter(m, clientStreamType(desc), method)
clientStream, err := streamer(ctx, desc, cc, method, opts...) clientStream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil { if err != nil {
monitor.Handled(grpc.Code(err)) st, _ := status.FromError(err)
monitor.Handled(st.Code())
return nil, err return nil, err
} }
return &monitoredClientStream{clientStream, monitor}, nil return &monitoredClientStream{clientStream, monitor}, nil
...@@ -159,7 +163,8 @@ func (s *monitoredClientStream) RecvMsg(m interface{}) error { ...@@ -159,7 +163,8 @@ func (s *monitoredClientStream) RecvMsg(m interface{}) error {
} else if err == io.EOF { } else if err == io.EOF {
s.monitor.Handled(codes.OK) s.monitor.Handled(codes.OK)
} else { } else {
s.monitor.Handled(grpc.Code(err)) st, _ := status.FromError(err)
s.monitor.Handled(st.Code())
} }
return err return err
} }
SHELL="/bin/bash"
GOFILES_NOVENDOR = $(shell go list ./... | grep -v /vendor/)
all: vet fmt test
fmt:
go fmt $(GOFILES_NOVENDOR)
vet:
go vet $(GOFILES_NOVENDOR)
test: vet
./scripts/test_all.sh
.PHONY: all vet test
package grpc_prometheus
import (
prom "github.com/prometheus/client_golang/prometheus"
)
// A CounterOption lets you add options to Counter metrics using With* funcs.
type CounterOption func(*prom.CounterOpts)
type counterOptions []CounterOption
func (co counterOptions) apply(o prom.CounterOpts) prom.CounterOpts {
for _, f := range co {
f(&o)
}
return o
}
// WithConstLabels allows you to add ConstLabels to Counter metrics.
func WithConstLabels(labels prom.Labels) CounterOption {
return func(o *prom.CounterOpts) {
o.ConstLabels = labels
}
}
// A HistogramOption lets you add options to Histogram metrics using With*
// funcs.
type HistogramOption func(*prom.HistogramOpts)
// WithHistogramBuckets allows you to specify custom bucket ranges for histograms if EnableHandlingTimeHistogram is on.
func WithHistogramBuckets(buckets []float64) HistogramOption {
return func(o *prom.HistogramOpts) { o.Buckets = buckets }
}
// WithHistogramConstLabels allows you to add custom ConstLabels to
// histograms metrics.
func WithHistogramConstLabels(labels prom.Labels) HistogramOption {
return func(o *prom.HistogramOpts) {
o.ConstLabels = labels
}
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
prom "github.com/prometheus/client_golang/prometheus" prom "github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/status"
) )
// ServerMetrics represents a collection of metrics to be registered on a // ServerMetrics represents a collection of metrics to be registered on a
...@@ -22,28 +23,29 @@ type ServerMetrics struct { ...@@ -22,28 +23,29 @@ type ServerMetrics struct {
// ServerMetrics when not using the default Prometheus metrics registry, for // ServerMetrics when not using the default Prometheus metrics registry, for
// example when wanting to control which metrics are added to a registry as // example when wanting to control which metrics are added to a registry as
// opposed to automatically adding metrics via init functions. // opposed to automatically adding metrics via init functions.
func NewServerMetrics() *ServerMetrics { func NewServerMetrics(counterOpts ...CounterOption) *ServerMetrics {
opts := counterOptions(counterOpts)
return &ServerMetrics{ return &ServerMetrics{
serverStartedCounter: prom.NewCounterVec( serverStartedCounter: prom.NewCounterVec(
prom.CounterOpts{ opts.apply(prom.CounterOpts{
Name: "grpc_server_started_total", Name: "grpc_server_started_total",
Help: "Total number of RPCs started on the server.", Help: "Total number of RPCs started on the server.",
}, []string{"grpc_type", "grpc_service", "grpc_method"}), }), []string{"grpc_type", "grpc_service", "grpc_method"}),
serverHandledCounter: prom.NewCounterVec( serverHandledCounter: prom.NewCounterVec(
prom.CounterOpts{ opts.apply(prom.CounterOpts{
Name: "grpc_server_handled_total", Name: "grpc_server_handled_total",
Help: "Total number of RPCs completed on the server, regardless of success or failure.", Help: "Total number of RPCs completed on the server, regardless of success or failure.",
}, []string{"grpc_type", "grpc_service", "grpc_method", "grpc_code"}), }), []string{"grpc_type", "grpc_service", "grpc_method", "grpc_code"}),
serverStreamMsgReceived: prom.NewCounterVec( serverStreamMsgReceived: prom.NewCounterVec(
prom.CounterOpts{ opts.apply(prom.CounterOpts{
Name: "grpc_server_msg_received_total", Name: "grpc_server_msg_received_total",
Help: "Total number of RPC stream messages received on the server.", Help: "Total number of RPC stream messages received on the server.",
}, []string{"grpc_type", "grpc_service", "grpc_method"}), }), []string{"grpc_type", "grpc_service", "grpc_method"}),
serverStreamMsgSent: prom.NewCounterVec( serverStreamMsgSent: prom.NewCounterVec(
prom.CounterOpts{ opts.apply(prom.CounterOpts{
Name: "grpc_server_msg_sent_total", Name: "grpc_server_msg_sent_total",
Help: "Total number of gRPC stream messages sent by the server.", Help: "Total number of gRPC stream messages sent by the server.",
}, []string{"grpc_type", "grpc_service", "grpc_method"}), }), []string{"grpc_type", "grpc_service", "grpc_method"}),
serverHandledHistogramEnabled: false, serverHandledHistogramEnabled: false,
serverHandledHistogramOpts: prom.HistogramOpts{ serverHandledHistogramOpts: prom.HistogramOpts{
Name: "grpc_server_handling_seconds", Name: "grpc_server_handling_seconds",
...@@ -54,13 +56,6 @@ func NewServerMetrics() *ServerMetrics { ...@@ -54,13 +56,6 @@ func NewServerMetrics() *ServerMetrics {
} }
} }
type HistogramOption func(*prom.HistogramOpts)
// WithHistogramBuckets allows you to specify custom bucket ranges for histograms if EnableHandlingTimeHistogram is on.
func WithHistogramBuckets(buckets []float64) HistogramOption {
return func(o *prom.HistogramOpts) { o.Buckets = buckets }
}
// EnableHandlingTimeHistogram enables histograms being registered when // EnableHandlingTimeHistogram enables histograms being registered when
// registering the ServerMetrics on a Prometheus registry. Histograms can be // registering the ServerMetrics on a Prometheus registry. Histograms can be
// expensive on Prometheus servers. It takes options to configure histogram // expensive on Prometheus servers. It takes options to configure histogram
...@@ -110,7 +105,8 @@ func (m *ServerMetrics) UnaryServerInterceptor() func(ctx context.Context, req i ...@@ -110,7 +105,8 @@ func (m *ServerMetrics) UnaryServerInterceptor() func(ctx context.Context, req i
monitor := newServerReporter(m, Unary, info.FullMethod) monitor := newServerReporter(m, Unary, info.FullMethod)
monitor.ReceivedMessage() monitor.ReceivedMessage()
resp, err := handler(ctx, req) resp, err := handler(ctx, req)
monitor.Handled(grpc.Code(err)) st, _ := status.FromError(err)
monitor.Handled(st.Code())
if err == nil { if err == nil {
monitor.SentMessage() monitor.SentMessage()
} }
...@@ -121,9 +117,10 @@ func (m *ServerMetrics) UnaryServerInterceptor() func(ctx context.Context, req i ...@@ -121,9 +117,10 @@ func (m *ServerMetrics) UnaryServerInterceptor() func(ctx context.Context, req i
// StreamServerInterceptor is a gRPC server-side interceptor that provides Prometheus monitoring for Streaming RPCs. // StreamServerInterceptor is a gRPC server-side interceptor that provides Prometheus monitoring for Streaming RPCs.
func (m *ServerMetrics) StreamServerInterceptor() func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { func (m *ServerMetrics) StreamServerInterceptor() func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
monitor := newServerReporter(m, streamRpcType(info), info.FullMethod) monitor := newServerReporter(m, streamRPCType(info), info.FullMethod)
err := handler(srv, &monitoredServerStream{ss, monitor}) err := handler(srv, &monitoredServerStream{ss, monitor})
monitor.Handled(grpc.Code(err)) st, _ := status.FromError(err)
monitor.Handled(st.Code())
return err return err
} }
} }
...@@ -140,27 +137,7 @@ func (m *ServerMetrics) InitializeMetrics(server *grpc.Server) { ...@@ -140,27 +137,7 @@ func (m *ServerMetrics) InitializeMetrics(server *grpc.Server) {
} }
} }
// Register registers all server metrics in a given metrics registry. Depending func streamRPCType(info *grpc.StreamServerInfo) grpcType {
// on histogram options and whether they are enabled, histogram metrics are
// also registered.
//
// Deprecated: ServerMetrics implements Prometheus Collector interface. You can
// register an instance of ServerMetrics directly by using
// prometheus.Register(m).
func (m *ServerMetrics) Register(r prom.Registerer) error {
return r.Register(m)
}
// MustRegister tries to register all server metrics and panics on an error.
//
// Deprecated: ServerMetrics implements Prometheus Collector interface. You can
// register an instance of ServerMetrics directly by using
// prometheus.MustRegister(m).
func (m *ServerMetrics) MustRegister(r prom.Registerer) {
r.MustRegister(m)
}
func streamRpcType(info *grpc.StreamServerInfo) grpcType {
if info.IsClientStream && !info.IsServerStream { if info.IsClientStream && !info.IsServerStream {
return ClientStream return ClientStream
} else if !info.IsClientStream && info.IsServerStream { } else if !info.IsClientStream && info.IsServerStream {
......
#!/usr/bin/env bash
set -e
echo "" > coverage.txt
for d in $(go list ./... | grep -v vendor); do
echo -e "TESTS FOR: for \033[0;35m${d}\033[0m"
go test -race -v -coverprofile=profile.coverage.out -covermode=atomic $d
if [ -f profile.coverage.out ]; then
cat profile.coverage.out >> coverage.txt
rm profile.coverage.out
fi
echo ""
done
...@@ -37,13 +37,13 @@ func splitMethodName(fullMethodName string) (string, string) { ...@@ -37,13 +37,13 @@ func splitMethodName(fullMethodName string) (string, string) {
} }
func typeFromMethodInfo(mInfo *grpc.MethodInfo) grpcType { func typeFromMethodInfo(mInfo *grpc.MethodInfo) grpcType {
if mInfo.IsClientStream == false && mInfo.IsServerStream == false { if !mInfo.IsClientStream && !mInfo.IsServerStream {
return Unary return Unary
} }
if mInfo.IsClientStream == true && mInfo.IsServerStream == false { if mInfo.IsClientStream && !mInfo.IsServerStream {
return ClientStream return ClientStream
} }
if mInfo.IsClientStream == false && mInfo.IsServerStream == true { if !mInfo.IsClientStream && mInfo.IsServerStream {
return ServerStream return ServerStream
} }
return BidiStream return BidiStream
......
# 1.0.5
* Fix hooks race (#707)
* Fix panic deadlock (#695)
# 1.0.4 # 1.0.4
* Fix race when adding hooks (#612) * Fix race when adding hooks (#612)
......
...@@ -220,7 +220,7 @@ Logrus comes with [built-in hooks](hooks/). Add those, or your custom hook, in ...@@ -220,7 +220,7 @@ Logrus comes with [built-in hooks](hooks/). Add those, or your custom hook, in
```go ```go
import ( import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/gemnasium/logrus-airbrake-hook.v2" // the package is named "aibrake" "gopkg.in/gemnasium/logrus-airbrake-hook.v2" // the package is named "airbrake"
logrus_syslog "github.com/sirupsen/logrus/hooks/syslog" logrus_syslog "github.com/sirupsen/logrus/hooks/syslog"
"log/syslog" "log/syslog"
) )
...@@ -241,58 +241,8 @@ func init() { ...@@ -241,58 +241,8 @@ func init() {
``` ```
Note: Syslog hook also support connecting to local syslog (Ex. "/dev/log" or "/var/run/syslog" or "/var/run/log"). For the detail, please check the [syslog hook README](hooks/syslog/README.md). Note: Syslog hook also support connecting to local syslog (Ex. "/dev/log" or "/var/run/syslog" or "/var/run/log"). For the detail, please check the [syslog hook README](hooks/syslog/README.md).
| Hook | Description | A list of currently known of service hook can be found in this wiki [page](https://github.com/sirupsen/logrus/wiki/Hooks)
| ----- | ----------- |
| [Airbrake "legacy"](https://github.com/gemnasium/logrus-airbrake-legacy-hook) | Send errors to an exception tracking service compatible with the Airbrake API V2. Uses [`airbrake-go`](https://github.com/tobi/airbrake-go) behind the scenes. |
| [Airbrake](https://github.com/gemnasium/logrus-airbrake-hook) | Send errors to the Airbrake API V3. Uses the official [`gobrake`](https://github.com/airbrake/gobrake) behind the scenes. |
| [Amazon Kinesis](https://github.com/evalphobia/logrus_kinesis) | Hook for logging to [Amazon Kinesis](https://aws.amazon.com/kinesis/) |
| [Amqp-Hook](https://github.com/vladoatanasov/logrus_amqp) | Hook for logging to Amqp broker (Like RabbitMQ) |
| [AzureTableHook](https://github.com/kpfaulkner/azuretablehook/) | Hook for logging to Azure Table Storage|
| [Bugsnag](https://github.com/Shopify/logrus-bugsnag/blob/master/bugsnag.go) | Send errors to the Bugsnag exception tracking service. |
| [DeferPanic](https://github.com/deferpanic/dp-logrus) | Hook for logging to DeferPanic |
| [Discordrus](https://github.com/kz/discordrus) | Hook for logging to [Discord](https://discordapp.com/) |
| [ElasticSearch](https://github.com/sohlich/elogrus) | Hook for logging to ElasticSearch|
| [Firehose](https://github.com/beaubrewer/logrus_firehose) | Hook for logging to [Amazon Firehose](https://aws.amazon.com/kinesis/firehose/)
| [Fluentd](https://github.com/evalphobia/logrus_fluent) | Hook for logging to fluentd |
| [Go-Slack](https://github.com/multiplay/go-slack) | Hook for logging to [Slack](https://slack.com) |
| [Graylog](https://github.com/gemnasium/logrus-graylog-hook) | Hook for logging to [Graylog](http://graylog2.org/) |
| [Hiprus](https://github.com/nubo/hiprus) | Send errors to a channel in hipchat. |
| [Honeybadger](https://github.com/agonzalezro/logrus_honeybadger) | Hook for sending exceptions to Honeybadger |
| [InfluxDB](https://github.com/Abramovic/logrus_influxdb) | Hook for logging to influxdb |
| [Influxus](http://github.com/vlad-doru/influxus) | Hook for concurrently logging to [InfluxDB](http://influxdata.com/) |
| [Journalhook](https://github.com/wercker/journalhook) | Hook for logging to `systemd-journald` |
| [KafkaLogrus](https://github.com/tracer0tong/kafkalogrus) | Hook for logging to Kafka |
| [LFShook](https://github.com/rifflock/lfshook) | Hook for logging to the local filesystem |
| [Logbeat](https://github.com/macandmia/logbeat) | Hook for logging to [Opbeat](https://opbeat.com/) |
| [Logentries](https://github.com/jcftang/logentriesrus) | Hook for logging to [Logentries](https://logentries.com/) |
| [Logentrus](https://github.com/puddingfactory/logentrus) | Hook for logging to [Logentries](https://logentries.com/) |
| [Logmatic.io](https://github.com/logmatic/logmatic-go) | Hook for logging to [Logmatic.io](http://logmatic.io/) |
| [Logrusly](https://github.com/sebest/logrusly) | Send logs to [Loggly](https://www.loggly.com/) |
| [Logstash](https://github.com/bshuster-repo/logrus-logstash-hook) | Hook for logging to [Logstash](https://www.elastic.co/products/logstash) |
| [Mail](https://github.com/zbindenren/logrus_mail) | Hook for sending exceptions via mail |
| [Mattermost](https://github.com/shuLhan/mattermost-integration/tree/master/hooks/logrus) | Hook for logging to [Mattermost](https://mattermost.com/) |
| [Mongodb](https://github.com/weekface/mgorus) | Hook for logging to mongodb |
| [NATS-Hook](https://github.com/rybit/nats_logrus_hook) | Hook for logging to [NATS](https://nats.io) |
| [Octokit](https://github.com/dorajistyle/logrus-octokit-hook) | Hook for logging to github via octokit |
| [Papertrail](https://github.com/polds/logrus-papertrail-hook) | Send errors to the [Papertrail](https://papertrailapp.com) hosted logging service via UDP. |
| [PostgreSQL](https://github.com/gemnasium/logrus-postgresql-hook) | Send logs to [PostgreSQL](http://postgresql.org) |
| [Promrus](https://github.com/weaveworks/promrus) | Expose number of log messages as [Prometheus](https://prometheus.io/) metrics |
| [Pushover](https://github.com/toorop/logrus_pushover) | Send error via [Pushover](https://pushover.net) |
| [Raygun](https://github.com/squirkle/logrus-raygun-hook) | Hook for logging to [Raygun.io](http://raygun.io/) |
| [Redis-Hook](https://github.com/rogierlommers/logrus-redis-hook) | Hook for logging to a ELK stack (through Redis) |
| [Rollrus](https://github.com/heroku/rollrus) | Hook for sending errors to rollbar |
| [Scribe](https://github.com/sagar8192/logrus-scribe-hook) | Hook for logging to [Scribe](https://github.com/facebookarchive/scribe)|
| [Sentry](https://github.com/evalphobia/logrus_sentry) | Send errors to the Sentry error logging and aggregation service. |
| [Slackrus](https://github.com/johntdyer/slackrus) | Hook for Slack chat. |
| [Stackdriver](https://github.com/knq/sdhook) | Hook for logging to [Google Stackdriver](https://cloud.google.com/logging/) |
| [Sumorus](https://github.com/doublefree/sumorus) | Hook for logging to [SumoLogic](https://www.sumologic.com/)|
| [Syslog](https://github.com/sirupsen/logrus/blob/master/hooks/syslog/syslog.go) | Send errors to remote syslog server. Uses standard library `log/syslog` behind the scenes. |
| [Syslog TLS](https://github.com/shinji62/logrus-syslog-ng) | Send errors to remote syslog server with TLS support. |
| [Telegram](https://github.com/rossmcdonald/telegram_hook) | Hook for logging errors to [Telegram](https://telegram.org/) |
| [TraceView](https://github.com/evalphobia/logrus_appneta) | Hook for logging to [AppNeta TraceView](https://www.appneta.com/products/traceview/) |
| [Typetalk](https://github.com/dragon3/logrus-typetalk-hook) | Hook for logging to [Typetalk](https://www.typetalk.in/) |
| [logz.io](https://github.com/ripcurld00d/logrus-logzio-hook) | Hook for logging to [logz.io](https://logz.io), a Log as a Service using Logstash |
| [SQS-Hook](https://github.com/tsarpaul/logrus_sqs) | Hook for logging to [Amazon Simple Queue Service (SQS)](https://aws.amazon.com/sqs/) |
#### Level logging #### Level logging
...@@ -370,6 +320,8 @@ The built-in logging formatters are: ...@@ -370,6 +320,8 @@ The built-in logging formatters are:
field to `true`. To force no colored output even if there is a TTY set the field to `true`. To force no colored output even if there is a TTY set the
`DisableColors` field to `true`. For Windows, see `DisableColors` field to `true`. For Windows, see
[github.com/mattn/go-colorable](https://github.com/mattn/go-colorable). [github.com/mattn/go-colorable](https://github.com/mattn/go-colorable).
* When colors are enabled, levels are truncated to 4 characters by default. To disable
truncation set the `DisableLevelTruncation` field to `true`.
* All options are listed in the [generated docs](https://godoc.org/github.com/sirupsen/logrus#TextFormatter). * All options are listed in the [generated docs](https://godoc.org/github.com/sirupsen/logrus#TextFormatter).
* `logrus.JSONFormatter`. Logs fields as JSON. * `logrus.JSONFormatter`. Logs fields as JSON.
* All options are listed in the [generated docs](https://godoc.org/github.com/sirupsen/logrus#JSONFormatter). * All options are listed in the [generated docs](https://godoc.org/github.com/sirupsen/logrus#JSONFormatter).
...@@ -493,7 +445,7 @@ logrus.RegisterExitHandler(handler) ...@@ -493,7 +445,7 @@ logrus.RegisterExitHandler(handler)
#### Thread safety #### Thread safety
By default Logger is protected by mutex for concurrent writes, this mutex is invoked when calling hooks and writing logs. By default, Logger is protected by a mutex for concurrent writes. The mutex is held when calling hooks and writing logs.
If you are sure such locking is not needed, you can call logger.SetNoLock() to disable the locking. If you are sure such locking is not needed, you can call logger.SetNoLock() to disable the locking.
Situation when locking is not needed includes: Situation when locking is not needed includes:
......
version: "{build}" version: "{build}"
platform: x64 platform: x64
clone_folder: c:\gopath\src\github.com\sirupsen\logrus clone_folder: c:\gopath\src\github.com\sirupsen\logrus
environment: environment:
GOPATH: c:\gopath GOPATH: c:\gopath
branches: branches:
only: only:
- master - master
install: install:
- set PATH=%GOPATH%\bin;c:\go\bin;%PATH% - set PATH=%GOPATH%\bin;c:\go\bin;%PATH%
- go version - go version
build_script: build_script:
- go get -t - go get -t
- go test - go test
...@@ -48,7 +48,7 @@ type Entry struct { ...@@ -48,7 +48,7 @@ type Entry struct {
func NewEntry(logger *Logger) *Entry { func NewEntry(logger *Logger) *Entry {
return &Entry{ return &Entry{
Logger: logger, Logger: logger,
// Default is three fields, give a little extra room // Default is five fields, give a little extra room
Data: make(Fields, 5), Data: make(Fields, 5),
} }
} }
...@@ -83,49 +83,70 @@ func (entry *Entry) WithFields(fields Fields) *Entry { ...@@ -83,49 +83,70 @@ func (entry *Entry) WithFields(fields Fields) *Entry {
for k, v := range fields { for k, v := range fields {
data[k] = v data[k] = v
} }
return &Entry{Logger: entry.Logger, Data: data} return &Entry{Logger: entry.Logger, Data: data, Time: entry.Time}
}
// Overrides the time of the Entry.
func (entry *Entry) WithTime(t time.Time) *Entry {
return &Entry{Logger: entry.Logger, Data: entry.Data, Time: t}
} }
// This function is not declared with a pointer value because otherwise // This function is not declared with a pointer value because otherwise
// race conditions will occur when using multiple goroutines // race conditions will occur when using multiple goroutines
func (entry Entry) log(level Level, msg string) { func (entry Entry) log(level Level, msg string) {
var buffer *bytes.Buffer var buffer *bytes.Buffer
entry.Time = time.Now()
// Default to now, but allow users to override if they want.
//
// We don't have to worry about polluting future calls to Entry#log()
// with this assignment because this function is declared with a
// non-pointer receiver.
if entry.Time.IsZero() {
entry.Time = time.Now()
}
entry.Level = level entry.Level = level
entry.Message = msg entry.Message = msg
entry.Logger.mu.Lock() entry.fireHooks()
err := entry.Logger.Hooks.Fire(level, &entry)
entry.Logger.mu.Unlock()
if err != nil {
entry.Logger.mu.Lock()
fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err)
entry.Logger.mu.Unlock()
}
buffer = bufferPool.Get().(*bytes.Buffer) buffer = bufferPool.Get().(*bytes.Buffer)
buffer.Reset() buffer.Reset()
defer bufferPool.Put(buffer) defer bufferPool.Put(buffer)
entry.Buffer = buffer entry.Buffer = buffer
serialized, err := entry.Logger.Formatter.Format(&entry)
entry.write()
entry.Buffer = nil entry.Buffer = nil
// To avoid Entry#log() returning a value that only would make sense for
// panic() to use in Entry#Panic(), we avoid the allocation by checking
// directly here.
if level <= PanicLevel {
panic(&entry)
}
}
func (entry *Entry) fireHooks() {
entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock()
err := entry.Logger.Hooks.Fire(entry.Level, entry)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err)
}
}
func (entry *Entry) write() {
serialized, err := entry.Logger.Formatter.Format(entry)
entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock()
if err != nil { if err != nil {
entry.Logger.mu.Lock()
fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err) fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err)
entry.Logger.mu.Unlock()
} else { } else {
entry.Logger.mu.Lock()
_, err = entry.Logger.Out.Write(serialized) _, err = entry.Logger.Out.Write(serialized)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err) fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
} }
entry.Logger.mu.Unlock()
}
// To avoid Entry#log() returning a value that only would make sense for
// panic() to use in Entry#Panic(), we avoid the allocation by checking
// directly here.
if level <= PanicLevel {
panic(&entry)
} }
} }
......
...@@ -2,6 +2,7 @@ package logrus ...@@ -2,6 +2,7 @@ package logrus
import ( import (
"io" "io"
"time"
) )
var ( var (
...@@ -15,9 +16,7 @@ func StandardLogger() *Logger { ...@@ -15,9 +16,7 @@ func StandardLogger() *Logger {
// SetOutput sets the standard logger output. // SetOutput sets the standard logger output.
func SetOutput(out io.Writer) { func SetOutput(out io.Writer) {
std.mu.Lock() std.SetOutput(out)
defer std.mu.Unlock()
std.Out = out
} }
// SetFormatter sets the standard logger formatter. // SetFormatter sets the standard logger formatter.
...@@ -72,6 +71,15 @@ func WithFields(fields Fields) *Entry { ...@@ -72,6 +71,15 @@ func WithFields(fields Fields) *Entry {
return std.WithFields(fields) return std.WithFields(fields)
} }
// WithTime creats an entry from the standard logger and overrides the time of
// logs generated with it.
//
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
// or Panic on the Entry it returns.
func WithTime(t time.Time) *Entry {
return std.WithTime(t)
}
// Debug logs a message at level Debug on the standard logger. // Debug logs a message at level Debug on the standard logger.
func Debug(args ...interface{}) { func Debug(args ...interface{}) {
std.Debug(args...) std.Debug(args...)
...@@ -107,7 +115,7 @@ func Panic(args ...interface{}) { ...@@ -107,7 +115,7 @@ func Panic(args ...interface{}) {
std.Panic(args...) std.Panic(args...)
} }
// Fatal logs a message at level Fatal on the standard logger. // Fatal logs a message at level Fatal on the standard logger then the process will exit with status set to 1.
func Fatal(args ...interface{}) { func Fatal(args ...interface{}) {
std.Fatal(args...) std.Fatal(args...)
} }
...@@ -147,7 +155,7 @@ func Panicf(format string, args ...interface{}) { ...@@ -147,7 +155,7 @@ func Panicf(format string, args ...interface{}) {
std.Panicf(format, args...) std.Panicf(format, args...)
} }
// Fatalf logs a message at level Fatal on the standard logger. // Fatalf logs a message at level Fatal on the standard logger then the process will exit with status set to 1.
func Fatalf(format string, args ...interface{}) { func Fatalf(format string, args ...interface{}) {
std.Fatalf(format, args...) std.Fatalf(format, args...)
} }
...@@ -187,7 +195,7 @@ func Panicln(args ...interface{}) { ...@@ -187,7 +195,7 @@ func Panicln(args ...interface{}) {
std.Panicln(args...) std.Panicln(args...)
} }
// Fatalln logs a message at level Fatal on the standard logger. // Fatalln logs a message at level Fatal on the standard logger then the process will exit with status set to 1.
func Fatalln(args ...interface{}) { func Fatalln(args ...interface{}) {
std.Fatalln(args...) std.Fatalln(args...)
} }
...@@ -30,16 +30,22 @@ type Formatter interface { ...@@ -30,16 +30,22 @@ type Formatter interface {
// //
// It's not exported because it's still using Data in an opinionated way. It's to // It's not exported because it's still using Data in an opinionated way. It's to
// avoid code duplication between the two default formatters. // avoid code duplication between the two default formatters.
func prefixFieldClashes(data Fields) { func prefixFieldClashes(data Fields, fieldMap FieldMap) {
if t, ok := data["time"]; ok { timeKey := fieldMap.resolve(FieldKeyTime)
data["fields.time"] = t if t, ok := data[timeKey]; ok {
data["fields."+timeKey] = t
delete(data, timeKey)
} }
if m, ok := data["msg"]; ok { msgKey := fieldMap.resolve(FieldKeyMsg)
data["fields.msg"] = m if m, ok := data[msgKey]; ok {
data["fields."+msgKey] = m
delete(data, msgKey)
} }
if l, ok := data["level"]; ok { levelKey := fieldMap.resolve(FieldKeyLevel)
data["fields.level"] = l if l, ok := data[levelKey]; ok {
data["fields."+levelKey] = l
delete(data, levelKey)
} }
} }
...@@ -33,6 +33,9 @@ type JSONFormatter struct { ...@@ -33,6 +33,9 @@ type JSONFormatter struct {
// DisableTimestamp allows disabling automatic timestamps in output // DisableTimestamp allows disabling automatic timestamps in output
DisableTimestamp bool DisableTimestamp bool
// DataKey allows users to put all the log entry parameters into a nested dictionary at a given key.
DataKey string
// FieldMap allows users to customize the names of keys for default fields. // FieldMap allows users to customize the names of keys for default fields.
// As an example: // As an example:
// formatter := &JSONFormatter{ // formatter := &JSONFormatter{
...@@ -58,7 +61,14 @@ func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) { ...@@ -58,7 +61,14 @@ func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) {
data[k] = v data[k] = v
} }
} }
prefixFieldClashes(data)
if f.DataKey != "" {
newData := make(Fields, 4)
newData[f.DataKey] = data
data = newData
}
prefixFieldClashes(data, f.FieldMap)
timestampFormat := f.TimestampFormat timestampFormat := f.TimestampFormat
if timestampFormat == "" { if timestampFormat == "" {
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"os" "os"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
) )
type Logger struct { type Logger struct {
...@@ -88,7 +89,7 @@ func (logger *Logger) releaseEntry(entry *Entry) { ...@@ -88,7 +89,7 @@ func (logger *Logger) releaseEntry(entry *Entry) {
} }
// Adds a field to the log entry, note that it doesn't log until you call // Adds a field to the log entry, note that it doesn't log until you call
// Debug, Print, Info, Warn, Fatal or Panic. It only creates a log entry. // Debug, Print, Info, Warn, Error, Fatal or Panic. It only creates a log entry.
// If you want multiple fields, use `WithFields`. // If you want multiple fields, use `WithFields`.
func (logger *Logger) WithField(key string, value interface{}) *Entry { func (logger *Logger) WithField(key string, value interface{}) *Entry {
entry := logger.newEntry() entry := logger.newEntry()
...@@ -112,6 +113,13 @@ func (logger *Logger) WithError(err error) *Entry { ...@@ -112,6 +113,13 @@ func (logger *Logger) WithError(err error) *Entry {
return entry.WithError(err) return entry.WithError(err)
} }
// Overrides the time of the log entry.
func (logger *Logger) WithTime(t time.Time) *Entry {
entry := logger.newEntry()
defer logger.releaseEntry(entry)
return entry.WithTime(t)
}
func (logger *Logger) Debugf(format string, args ...interface{}) { func (logger *Logger) Debugf(format string, args ...interface{}) {
if logger.level() >= DebugLevel { if logger.level() >= DebugLevel {
entry := logger.newEntry() entry := logger.newEntry()
...@@ -316,6 +324,12 @@ func (logger *Logger) SetLevel(level Level) { ...@@ -316,6 +324,12 @@ func (logger *Logger) SetLevel(level Level) {
atomic.StoreUint32((*uint32)(&logger.Level), uint32(level)) atomic.StoreUint32((*uint32)(&logger.Level), uint32(level))
} }
func (logger *Logger) SetOutput(out io.Writer) {
logger.mu.Lock()
defer logger.mu.Unlock()
logger.Out = out
}
func (logger *Logger) AddHook(hook Hook) { func (logger *Logger) AddHook(hook Hook) {
logger.mu.Lock() logger.mu.Lock()
defer logger.mu.Unlock() defer logger.mu.Unlock()
......
// +build darwin freebsd openbsd netbsd dragonfly // +build darwin freebsd openbsd netbsd dragonfly
// +build !appengine // +build !appengine,!gopherjs
package logrus package logrus
......
// +build appengine // +build appengine gopherjs
package logrus package logrus
......
// +build !appengine // +build !appengine,!gopherjs
package logrus package logrus
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build !appengine // +build !appengine,!gopherjs
package logrus package logrus
......
...@@ -20,6 +20,7 @@ const ( ...@@ -20,6 +20,7 @@ const (
var ( var (
baseTimestamp time.Time baseTimestamp time.Time
emptyFieldMap FieldMap
) )
func init() { func init() {
...@@ -50,12 +51,24 @@ type TextFormatter struct { ...@@ -50,12 +51,24 @@ type TextFormatter struct {
// be desired. // be desired.
DisableSorting bool DisableSorting bool
// Disables the truncation of the level text to 4 characters.
DisableLevelTruncation bool
// QuoteEmptyFields will wrap empty fields in quotes if true // QuoteEmptyFields will wrap empty fields in quotes if true
QuoteEmptyFields bool QuoteEmptyFields bool
// Whether the logger's out is to a terminal // Whether the logger's out is to a terminal
isTerminal bool isTerminal bool
// FieldMap allows users to customize the names of keys for default fields.
// As an example:
// formatter := &TextFormatter{
// FieldMap: FieldMap{
// FieldKeyTime: "@timestamp",
// FieldKeyLevel: "@level",
// FieldKeyMsg: "@message"}}
FieldMap FieldMap
sync.Once sync.Once
} }
...@@ -67,7 +80,8 @@ func (f *TextFormatter) init(entry *Entry) { ...@@ -67,7 +80,8 @@ func (f *TextFormatter) init(entry *Entry) {
// Format renders a single log entry // Format renders a single log entry
func (f *TextFormatter) Format(entry *Entry) ([]byte, error) { func (f *TextFormatter) Format(entry *Entry) ([]byte, error) {
var b *bytes.Buffer prefixFieldClashes(entry.Data, f.FieldMap)
keys := make([]string, 0, len(entry.Data)) keys := make([]string, 0, len(entry.Data))
for k := range entry.Data { for k := range entry.Data {
keys = append(keys, k) keys = append(keys, k)
...@@ -76,14 +90,14 @@ func (f *TextFormatter) Format(entry *Entry) ([]byte, error) { ...@@ -76,14 +90,14 @@ func (f *TextFormatter) Format(entry *Entry) ([]byte, error) {
if !f.DisableSorting { if !f.DisableSorting {
sort.Strings(keys) sort.Strings(keys)
} }
var b *bytes.Buffer
if entry.Buffer != nil { if entry.Buffer != nil {
b = entry.Buffer b = entry.Buffer
} else { } else {
b = &bytes.Buffer{} b = &bytes.Buffer{}
} }
prefixFieldClashes(entry.Data)
f.Do(func() { f.init(entry) }) f.Do(func() { f.init(entry) })
isColored := (f.ForceColors || f.isTerminal) && !f.DisableColors isColored := (f.ForceColors || f.isTerminal) && !f.DisableColors
...@@ -96,11 +110,11 @@ func (f *TextFormatter) Format(entry *Entry) ([]byte, error) { ...@@ -96,11 +110,11 @@ func (f *TextFormatter) Format(entry *Entry) ([]byte, error) {
f.printColored(b, entry, keys, timestampFormat) f.printColored(b, entry, keys, timestampFormat)
} else { } else {
if !f.DisableTimestamp { if !f.DisableTimestamp {
f.appendKeyValue(b, "time", entry.Time.Format(timestampFormat)) f.appendKeyValue(b, f.FieldMap.resolve(FieldKeyTime), entry.Time.Format(timestampFormat))
} }
f.appendKeyValue(b, "level", entry.Level.String()) f.appendKeyValue(b, f.FieldMap.resolve(FieldKeyLevel), entry.Level.String())
if entry.Message != "" { if entry.Message != "" {
f.appendKeyValue(b, "msg", entry.Message) f.appendKeyValue(b, f.FieldMap.resolve(FieldKeyMsg), entry.Message)
} }
for _, key := range keys { for _, key := range keys {
f.appendKeyValue(b, key, entry.Data[key]) f.appendKeyValue(b, key, entry.Data[key])
...@@ -124,7 +138,10 @@ func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []strin ...@@ -124,7 +138,10 @@ func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []strin
levelColor = blue levelColor = blue
} }
levelText := strings.ToUpper(entry.Level.String())[0:4] levelText := strings.ToUpper(entry.Level.String())
if !f.DisableLevelTruncation {
levelText = levelText[0:4]
}
if f.DisableTimestamp { if f.DisableTimestamp {
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m %-44s ", levelColor, levelText, entry.Message) fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m %-44s ", levelColor, levelText, entry.Message)
......
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.7
// Package ctxhttp provides helper functions for performing context-aware HTTP requests.
package ctxhttp // import "golang.org/x/net/context/ctxhttp"
import (
"io"
"net/http"
"net/url"
"strings"
"golang.org/x/net/context"
)
// Do sends an HTTP request with the provided http.Client and returns
// an HTTP response.
//
// If the client is nil, http.DefaultClient is used.
//
// The provided ctx must be non-nil. If it is canceled or times out,
// ctx.Err() will be returned.
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req.WithContext(ctx))
// If we got an error, and the context has been canceled,
// the context's error is probably more useful.
if err != nil {
select {
case <-ctx.Done():
err = ctx.Err()
default:
}
}
return resp, err
}
// Get issues a GET request via the Do function.
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
return Do(ctx, client, req)
}
// Head issues a HEAD request via the Do function.
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
return Do(ctx, client, req)
}
// Post issues a POST request via the Do function.
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", bodyType)
return Do(ctx, client, req)
}
// PostForm issues a POST request via the Do function.
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.7
package ctxhttp // import "golang.org/x/net/context/ctxhttp"
import (
"io"
"net/http"
"net/url"
"strings"
"golang.org/x/net/context"
)
func nop() {}
var (
testHookContextDoneBeforeHeaders = nop
testHookDoReturned = nop
testHookDidBodyClose = nop
)
// Do sends an HTTP request with the provided http.Client and returns an HTTP response.
// If the client is nil, http.DefaultClient is used.
// If the context is canceled or times out, ctx.Err() will be returned.
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
if client == nil {
client = http.DefaultClient
}
// TODO(djd): Respect any existing value of req.Cancel.
cancel := make(chan struct{})
req.Cancel = cancel
type responseAndError struct {
resp *http.Response
err error
}
result := make(chan responseAndError, 1)
// Make local copies of test hooks closed over by goroutines below.
// Prevents data races in tests.
testHookDoReturned := testHookDoReturned
testHookDidBodyClose := testHookDidBodyClose
go func() {
resp, err := client.Do(req)
testHookDoReturned()
result <- responseAndError{resp, err}
}()
var resp *http.Response
select {
case <-ctx.Done():
testHookContextDoneBeforeHeaders()
close(cancel)
// Clean up after the goroutine calling client.Do:
go func() {
if r := <-result; r.resp != nil {
testHookDidBodyClose()
r.resp.Body.Close()
}
}()
return nil, ctx.Err()
case r := <-result:
var err error
resp, err = r.resp, r.err
if err != nil {
return resp, err
}
}
c := make(chan struct{})
go func() {
select {
case <-ctx.Done():
close(cancel)
case <-c:
// The response's Body is closed.
}
}()
resp.Body = &notifyingReader{resp.Body, c}
return resp, nil
}
// Get issues a GET request via the Do function.
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
return Do(ctx, client, req)
}
// Head issues a HEAD request via the Do function.
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
return Do(ctx, client, req)
}
// Post issues a POST request via the Do function.
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", bodyType)
return Do(ctx, client, req)
}
// PostForm issues a POST request via the Do function.
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
// notifyingReader is an io.ReadCloser that closes the notify channel after
// Close is called or a Read fails on the underlying ReadCloser.
type notifyingReader struct {
io.ReadCloser
notify chan<- struct{}
}
func (r *notifyingReader) Read(p []byte) (int, error) {
n, err := r.ReadCloser.Read(p)
if err != nil && r.notify != nil {
close(r.notify)
r.notify = nil
}
return n, err
}
func (r *notifyingReader) Close() error {
err := r.ReadCloser.Close()
if r.notify != nil {
close(r.notify)
r.notify = nil
}
return err
}
...@@ -27,6 +27,10 @@ How to get your contributions merged smoothly and quickly. ...@@ -27,6 +27,10 @@ How to get your contributions merged smoothly and quickly.
- Keep your PR up to date with upstream/master (if there are merge conflicts, we can't really merge your change). - Keep your PR up to date with upstream/master (if there are merge conflicts, we can't really merge your change).
- **All tests need to be passing** before your change can be merged. We recommend you **run tests locally** before creating your PR to catch breakages early on. - **All tests need to be passing** before your change can be merged. We recommend you **run tests locally** before creating your PR to catch breakages early on.
- `make all` to test everything, OR
- `make vet` to catch vet errors
- `make test` to run the tests
- `make testrace` to run tests in race mode
- Exceptions to the rules can be made if there's a compelling reason for doing so. - Exceptions to the rules can be made if there's a compelling reason for doing so.
all: test testrace all: vet test testrace
deps: deps:
go get -d -v google.golang.org/grpc/... go get -d -v google.golang.org/grpc/...
...@@ -22,6 +22,9 @@ proto: ...@@ -22,6 +22,9 @@ proto:
fi fi
go generate google.golang.org/grpc/... go generate google.golang.org/grpc/...
vet:
./vet.sh
test: testdeps test: testdeps
go test -cpu 1,4 -timeout 5m google.golang.org/grpc/... go test -cpu 1,4 -timeout 5m google.golang.org/grpc/...
...@@ -39,7 +42,7 @@ clean: ...@@ -39,7 +42,7 @@ clean:
updatetestdeps \ updatetestdeps \
build \ build \
proto \ proto \
vet \
test \ test \
testrace \ testrace \
clean \ clean
coverage
...@@ -16,8 +16,7 @@ $ go get -u google.golang.org/grpc ...@@ -16,8 +16,7 @@ $ go get -u google.golang.org/grpc
Prerequisites Prerequisites
------------- -------------
This requires Go 1.6 or later. Go 1.7 will be required as of the next gRPC-Go This requires Go 1.6 or later. Go 1.7 will be required soon.
release (1.8).
Constraints Constraints
----------- -----------
......
...@@ -16,81 +16,23 @@ ...@@ -16,81 +16,23 @@
* *
*/ */
// See internal/backoff package for the backoff implementation. This file is
// kept for the exported types and API backward compatility.
package grpc package grpc
import ( import (
"math/rand"
"time" "time"
) )
// DefaultBackoffConfig uses values specified for backoff in // DefaultBackoffConfig uses values specified for backoff in
// https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md. // https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
var DefaultBackoffConfig = BackoffConfig{ var DefaultBackoffConfig = BackoffConfig{
MaxDelay: 120 * time.Second, MaxDelay: 120 * time.Second,
baseDelay: 1.0 * time.Second,
factor: 1.6,
jitter: 0.2,
}
// backoffStrategy defines the methodology for backing off after a grpc
// connection failure.
//
// This is unexported until the gRPC project decides whether or not to allow
// alternative backoff strategies. Once a decision is made, this type and its
// method may be exported.
type backoffStrategy interface {
// backoff returns the amount of time to wait before the next retry given
// the number of consecutive failures.
backoff(retries int) time.Duration
} }
// BackoffConfig defines the parameters for the default gRPC backoff strategy. // BackoffConfig defines the parameters for the default gRPC backoff strategy.
type BackoffConfig struct { type BackoffConfig struct {
// MaxDelay is the upper bound of backoff delay. // MaxDelay is the upper bound of backoff delay.
MaxDelay time.Duration MaxDelay time.Duration
// TODO(stevvooe): The following fields are not exported, as allowing
// changes would violate the current gRPC specification for backoff. If
// gRPC decides to allow more interesting backoff strategies, these fields
// may be opened up in the future.
// baseDelay is the amount of time to wait before retrying after the first
// failure.
baseDelay time.Duration
// factor is applied to the backoff after each retry.
factor float64
// jitter provides a range to randomize backoff delays.
jitter float64
}
func setDefaults(bc *BackoffConfig) {
md := bc.MaxDelay
*bc = DefaultBackoffConfig
if md > 0 {
bc.MaxDelay = md
}
}
func (bc BackoffConfig) backoff(retries int) time.Duration {
if retries == 0 {
return bc.baseDelay
}
backoff, max := float64(bc.baseDelay), float64(bc.MaxDelay)
for backoff < max && retries > 0 {
backoff *= bc.factor
retries--
}
if backoff > max {
backoff = max
}
// Randomize backoff delays so that if a cluster of requests start at
// the same time, they won't operate in lockstep.
backoff *= 1 + bc.jitter*(rand.Float64()*2-1)
if backoff < 0 {
return 0
}
return time.Duration(backoff)
} }
...@@ -32,7 +32,8 @@ import ( ...@@ -32,7 +32,8 @@ import (
) )
// Address represents a server the client connects to. // Address represents a server the client connects to.
// This is the EXPERIMENTAL API and may be changed or extended in the future. //
// Deprecated: please use package balancer.
type Address struct { type Address struct {
// Addr is the server address on which a connection will be established. // Addr is the server address on which a connection will be established.
Addr string Addr string
...@@ -42,6 +43,8 @@ type Address struct { ...@@ -42,6 +43,8 @@ type Address struct {
} }
// BalancerConfig specifies the configurations for Balancer. // BalancerConfig specifies the configurations for Balancer.
//
// Deprecated: please use package balancer.
type BalancerConfig struct { type BalancerConfig struct {
// DialCreds is the transport credential the Balancer implementation can // DialCreds is the transport credential the Balancer implementation can
// use to dial to a remote load balancer server. The Balancer implementations // use to dial to a remote load balancer server. The Balancer implementations
...@@ -54,7 +57,8 @@ type BalancerConfig struct { ...@@ -54,7 +57,8 @@ type BalancerConfig struct {
} }
// BalancerGetOptions configures a Get call. // BalancerGetOptions configures a Get call.
// This is the EXPERIMENTAL API and may be changed or extended in the future. //
// Deprecated: please use package balancer.
type BalancerGetOptions struct { type BalancerGetOptions struct {
// BlockingWait specifies whether Get should block when there is no // BlockingWait specifies whether Get should block when there is no
// connected address. // connected address.
...@@ -62,7 +66,8 @@ type BalancerGetOptions struct { ...@@ -62,7 +66,8 @@ type BalancerGetOptions struct {
} }
// Balancer chooses network addresses for RPCs. // Balancer chooses network addresses for RPCs.
// This is the EXPERIMENTAL API and may be changed or extended in the future. //
// Deprecated: please use package balancer.
type Balancer interface { type Balancer interface {
// Start does the initialization work to bootstrap a Balancer. For example, // Start does the initialization work to bootstrap a Balancer. For example,
// this function may start the name resolution and watch the updates. It will // this function may start the name resolution and watch the updates. It will
...@@ -135,6 +140,8 @@ func downErrorf(timeout, temporary bool, format string, a ...interface{}) downEr ...@@ -135,6 +140,8 @@ func downErrorf(timeout, temporary bool, format string, a ...interface{}) downEr
// RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch // RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch
// the name resolution updates and updates the addresses available correspondingly. // the name resolution updates and updates the addresses available correspondingly.
//
// Deprecated: please use package balancer/roundrobin.
func RoundRobin(r naming.Resolver) Balancer { func RoundRobin(r naming.Resolver) Balancer {
return &roundRobin{r: r} return &roundRobin{r: r}
} }
......
...@@ -36,9 +36,12 @@ var ( ...@@ -36,9 +36,12 @@ var (
m = make(map[string]Builder) m = make(map[string]Builder)
) )
// Register registers the balancer builder to the balancer map. // Register registers the balancer builder to the balancer map. b.Name
// b.Name (lowercased) will be used as the name registered with // (lowercased) will be used as the name registered with this builder.
// this builder. //
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Balancers are
// registered with the same name, the one registered last will take effect.
func Register(b Builder) { func Register(b Builder) {
m[strings.ToLower(b.Name())] = b m[strings.ToLower(b.Name())] = b
} }
...@@ -126,6 +129,8 @@ type BuildOptions struct { ...@@ -126,6 +129,8 @@ type BuildOptions struct {
// to a remote load balancer server. The Balancer implementations // to a remote load balancer server. The Balancer implementations
// can ignore this if it doesn't need to talk to remote balancer. // can ignore this if it doesn't need to talk to remote balancer.
Dialer func(context.Context, string) (net.Conn, error) Dialer func(context.Context, string) (net.Conn, error)
// ChannelzParentID is the entity parent's channelz unique identification number.
ChannelzParentID int64
} }
// Builder creates a balancer. // Builder creates a balancer.
...@@ -160,7 +165,7 @@ var ( ...@@ -160,7 +165,7 @@ var (
) )
// Picker is used by gRPC to pick a SubConn to send an RPC. // Picker is used by gRPC to pick a SubConn to send an RPC.
// Balancer is expected to generate a new picker from its snapshot everytime its // Balancer is expected to generate a new picker from its snapshot every time its
// internal state has changed. // internal state has changed.
// //
// The pickers used by gRPC can be updated by ClientConn.UpdateBalancerState(). // The pickers used by gRPC can be updated by ClientConn.UpdateBalancerState().
...@@ -221,3 +226,45 @@ type Balancer interface { ...@@ -221,3 +226,45 @@ type Balancer interface {
// ClientConn.RemoveSubConn for its existing SubConns. // ClientConn.RemoveSubConn for its existing SubConns.
Close() Close()
} }
// ConnectivityStateEvaluator takes the connectivity states of multiple SubConns
// and returns one aggregated connectivity state.
//
// It's not thread safe.
type ConnectivityStateEvaluator struct {
numReady uint64 // Number of addrConns in ready state.
numConnecting uint64 // Number of addrConns in connecting state.
numTransientFailure uint64 // Number of addrConns in transientFailure.
}
// RecordTransition records state change happening in subConn and based on that
// it evaluates what aggregated state should be.
//
// - If at least one SubConn in Ready, the aggregated state is Ready;
// - Else if at least one SubConn in Connecting, the aggregated state is Connecting;
// - Else the aggregated state is TransientFailure.
//
// Idle and Shutdown are not considered.
func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState connectivity.State) connectivity.State {
// Update counters.
for idx, state := range []connectivity.State{oldState, newState} {
updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new.
switch state {
case connectivity.Ready:
cse.numReady += updateVal
case connectivity.Connecting:
cse.numConnecting += updateVal
case connectivity.TransientFailure:
cse.numTransientFailure += updateVal
}
}
// Evaluate.
if cse.numReady > 0 {
return connectivity.Ready
}
if cse.numConnecting > 0 {
return connectivity.Connecting
}
return connectivity.TransientFailure
}
...@@ -146,7 +146,6 @@ func (b *baseBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectiv ...@@ -146,7 +146,6 @@ func (b *baseBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectiv
} }
b.cc.UpdateBalancerState(b.state, b.picker) b.cc.UpdateBalancerState(b.state, b.picker)
return
} }
// Close is a nop because base balancer doesn't have internal state to clean up, // Close is a nop because base balancer doesn't have internal state to clean up,
......
...@@ -115,7 +115,7 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui ...@@ -115,7 +115,7 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui
return ccb return ccb
} }
// watcher balancer functions sequencially, so the balancer can be implemeneted // watcher balancer functions sequentially, so the balancer can be implemented
// lock-free. // lock-free.
func (ccb *ccBalancerWrapper) watcher() { func (ccb *ccBalancerWrapper) watcher() {
for { for {
......
...@@ -55,7 +55,7 @@ func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.B ...@@ -55,7 +55,7 @@ func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.B
startCh: make(chan struct{}), startCh: make(chan struct{}),
conns: make(map[resolver.Address]balancer.SubConn), conns: make(map[resolver.Address]balancer.SubConn),
connSt: make(map[balancer.SubConn]*scState), connSt: make(map[balancer.SubConn]*scState),
csEvltr: &connectivityStateEvaluator{}, csEvltr: &balancer.ConnectivityStateEvaluator{},
state: connectivity.Idle, state: connectivity.Idle,
} }
cc.UpdateBalancerState(connectivity.Idle, bw) cc.UpdateBalancerState(connectivity.Idle, bw)
...@@ -80,10 +80,6 @@ type balancerWrapper struct { ...@@ -80,10 +80,6 @@ type balancerWrapper struct {
cc balancer.ClientConn cc balancer.ClientConn
targetAddr string // Target without the scheme. targetAddr string // Target without the scheme.
// To aggregate the connectivity state.
csEvltr *connectivityStateEvaluator
state connectivity.State
mu sync.Mutex mu sync.Mutex
conns map[resolver.Address]balancer.SubConn conns map[resolver.Address]balancer.SubConn
connSt map[balancer.SubConn]*scState connSt map[balancer.SubConn]*scState
...@@ -92,6 +88,10 @@ type balancerWrapper struct { ...@@ -92,6 +88,10 @@ type balancerWrapper struct {
// - NewSubConn is created, cc wants to notify balancer of state changes; // - NewSubConn is created, cc wants to notify balancer of state changes;
// - Build hasn't return, cc doesn't have access to balancer. // - Build hasn't return, cc doesn't have access to balancer.
startCh chan struct{} startCh chan struct{}
// To aggregate the connectivity state.
csEvltr *balancer.ConnectivityStateEvaluator
state connectivity.State
} }
// lbWatcher watches the Notify channel of the balancer and manages // lbWatcher watches the Notify channel of the balancer and manages
...@@ -248,7 +248,7 @@ func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s conne ...@@ -248,7 +248,7 @@ func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s conne
scSt.down(errConnClosing) scSt.down(errConnClosing)
} }
} }
sa := bw.csEvltr.recordTransition(oldS, s) sa := bw.csEvltr.RecordTransition(oldS, s)
if bw.state != sa { if bw.state != sa {
bw.state = sa bw.state = sa
} }
...@@ -257,7 +257,6 @@ func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s conne ...@@ -257,7 +257,6 @@ func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s conne
// Remove state for this sc. // Remove state for this sc.
delete(bw.connSt, sc) delete(bw.connSt, sc)
} }
return
} }
func (bw *balancerWrapper) HandleResolvedAddrs([]resolver.Address, error) { func (bw *balancerWrapper) HandleResolvedAddrs([]resolver.Address, error) {
...@@ -270,7 +269,6 @@ func (bw *balancerWrapper) HandleResolvedAddrs([]resolver.Address, error) { ...@@ -270,7 +269,6 @@ func (bw *balancerWrapper) HandleResolvedAddrs([]resolver.Address, error) {
} }
// There should be a resolver inside the balancer. // There should be a resolver inside the balancer.
// All updates here, if any, are ignored. // All updates here, if any, are ignored.
return
} }
func (bw *balancerWrapper) Close() { func (bw *balancerWrapper) Close() {
...@@ -282,7 +280,6 @@ func (bw *balancerWrapper) Close() { ...@@ -282,7 +280,6 @@ func (bw *balancerWrapper) Close() {
close(bw.startCh) close(bw.startCh)
} }
bw.balancer.Close() bw.balancer.Close()
return
} }
// The picker is the balancerWrapper itself. // The picker is the balancerWrapper itself.
...@@ -329,47 +326,3 @@ func (bw *balancerWrapper) Pick(ctx context.Context, opts balancer.PickOptions) ...@@ -329,47 +326,3 @@ func (bw *balancerWrapper) Pick(ctx context.Context, opts balancer.PickOptions)
return sc, done, nil return sc, done, nil
} }
// connectivityStateEvaluator gets updated by addrConns when their
// states transition, based on which it evaluates the state of
// ClientConn.
type connectivityStateEvaluator struct {
mu sync.Mutex
numReady uint64 // Number of addrConns in ready state.
numConnecting uint64 // Number of addrConns in connecting state.
numTransientFailure uint64 // Number of addrConns in transientFailure.
}
// recordTransition records state change happening in every subConn and based on
// that it evaluates what aggregated state should be.
// It can only transition between Ready, Connecting and TransientFailure. Other states,
// Idle and Shutdown are transitioned into by ClientConn; in the beginning of the connection
// before any subConn is created ClientConn is in idle state. In the end when ClientConn
// closes it is in Shutdown state.
// TODO Note that in later releases, a ClientConn with no activity will be put into an Idle state.
func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State {
cse.mu.Lock()
defer cse.mu.Unlock()
// Update counters.
for idx, state := range []connectivity.State{oldState, newState} {
updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new.
switch state {
case connectivity.Ready:
cse.numReady += updateVal
case connectivity.Connecting:
cse.numConnecting += updateVal
case connectivity.TransientFailure:
cse.numTransientFailure += updateVal
}
}
// Evaluate.
if cse.numReady > 0 {
return connectivity.Ready
}
if cse.numConnecting > 0 {
return connectivity.Connecting
}
return connectivity.TransientFailure
}
...@@ -19,138 +19,39 @@ ...@@ -19,138 +19,39 @@
package grpc package grpc
import ( import (
"io"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
) )
// recvResponse receives and parses an RPC response.
// On error, it returns the error and indicates whether the call should be retried.
//
// TODO(zhaoq): Check whether the received message sequence is valid.
// TODO ctx is used for stats collection and processing. It is the context passed from the application.
func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
// Try to acquire header metadata from the server if there is any.
defer func() {
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
}
}()
c.headerMD, err = stream.Header()
if err != nil {
return
}
p := &parser{r: stream}
var inPayload *stats.InPayload
if dopts.copts.StatsHandler != nil {
inPayload = &stats.InPayload{
Client: true,
}
}
for {
if c.maxReceiveMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
// Set dc if it exists and matches the message compression type used,
// otherwise set comp if a registered compressor exists for it.
var comp encoding.Compressor
var dc Decompressor
if rc := stream.RecvCompress(); dopts.dc != nil && dopts.dc.Type() == rc {
dc = dopts.dc
} else if rc != "" && rc != encoding.Identity {
comp = encoding.GetCompressor(rc)
}
if err = recv(p, dopts.codec, stream, dc, reply, *c.maxReceiveMessageSize, inPayload, comp); err != nil {
if err == io.EOF {
break
}
return
}
}
if inPayload != nil && err == io.EOF && stream.Status().Code() == codes.OK {
// TODO in the current implementation, inTrailer may be handled before inPayload in some cases.
// Fix the order if necessary.
dopts.copts.StatsHandler.HandleRPC(ctx, inPayload)
}
c.trailerMD = stream.Trailer()
return nil
}
// sendRequest writes out various information of an RPC such as Context and Message.
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, c *callInfo, callHdr *transport.CallHdr, stream *transport.Stream, t transport.ClientTransport, args interface{}, opts *transport.Options) (err error) {
defer func() {
if err != nil {
// If err is connection error, t will be closed, no need to close stream here.
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
}
}()
var (
outPayload *stats.OutPayload
)
if dopts.copts.StatsHandler != nil {
outPayload = &stats.OutPayload{
Client: true,
}
}
// Set comp and clear compressor if a registered compressor matches the type
// specified via UseCompressor. (And error if a matching compressor is not
// registered.)
var comp encoding.Compressor
if ct := c.compressorType; ct != "" && ct != encoding.Identity {
compressor = nil // Disable the legacy compressor.
comp = encoding.GetCompressor(ct)
if comp == nil {
return status.Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ct)
}
}
hdr, data, err := encode(dopts.codec, args, compressor, outPayload, comp)
if err != nil {
return err
}
if c.maxSendMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
if len(data) > *c.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize)
}
err = t.Write(stream, hdr, data, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
dopts.copts.StatsHandler.HandleRPC(ctx, outPayload)
}
// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
// recvResponse to get the final status.
if err != nil && err != io.EOF {
return err
}
// Sent successfully.
return nil
}
// Invoke sends the RPC request on the wire and returns after response is // Invoke sends the RPC request on the wire and returns after response is
// received. This is typically called by generated code. // received. This is typically called by generated code.
//
// All errors returned by Invoke are compatible with the status package.
func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error { func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error {
// allow interceptor to see all applicable call options, which means those
// configured as defaults from dial option as well as per-call options
opts = combine(cc.dopts.callOptions, opts)
if cc.dopts.unaryInt != nil { if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...) return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
} }
return invoke(ctx, method, args, reply, cc, opts...) return invoke(ctx, method, args, reply, cc, opts...)
} }
func combine(o1 []CallOption, o2 []CallOption) []CallOption {
// we don't use append because o1 could have extra capacity whose
// elements would be overwritten, which could cause inadvertent
// sharing (and race connditions) between concurrent calls
if len(o1) == 0 {
return o2
} else if len(o2) == 0 {
return o1
}
ret := make([]CallOption, len(o1)+len(o2))
copy(ret, o1)
copy(ret[len(o1):], o2)
return ret
}
// Invoke sends the RPC request on the wire and returns after response is // Invoke sends the RPC request on the wire and returns after response is
// received. This is typically called by generated code. // received. This is typically called by generated code.
// //
...@@ -159,188 +60,34 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli ...@@ -159,188 +60,34 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
return cc.Invoke(ctx, method, args, reply, opts...) return cc.Invoke(ctx, method, args, reply, opts...)
} }
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { var unaryStreamDesc = &StreamDesc{ServerStreams: false, ClientStreams: false}
c := defaultCallInfo()
mc := cc.GetMethodConfig(method)
if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady
}
if mc.Timeout != nil && *mc.Timeout >= 0 { func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
var cancel context.CancelFunc // TODO: implement retries in clientStream and make this simply
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) // newClientStream, SendMsg, RecvMsg.
defer cancel()
}
opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts {
if err := o.before(c); err != nil {
return toRPCErr(err)
}
}
defer func() {
for _, o := range opts {
o.after(c)
}
}()
c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize)
c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
if EnableTracing {
c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
defer c.traceInfo.tr.Finish()
c.traceInfo.firstLine.client = true
if deadline, ok := ctx.Deadline(); ok {
c.traceInfo.firstLine.deadline = deadline.Sub(time.Now())
}
c.traceInfo.tr.LazyLog(&c.traceInfo.firstLine, false)
// TODO(dsymonds): Arrange for c.traceInfo.firstLine.remoteAddr to be set.
defer func() {
if e != nil {
c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{e}}, true)
c.traceInfo.tr.SetError()
}
}()
}
ctx = newContextWithRPCInfo(ctx, c.failFast)
sh := cc.dopts.copts.StatsHandler
if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
begin := &stats.Begin{
Client: true,
BeginTime: time.Now(),
FailFast: c.failFast,
}
sh.HandleRPC(ctx, begin)
defer func() {
end := &stats.End{
Client: true,
EndTime: time.Now(),
Error: e,
}
sh.HandleRPC(ctx, end)
}()
}
topts := &transport.Options{
Last: true,
Delay: false,
}
callHdr := &transport.CallHdr{
Host: cc.authority,
Method: method,
}
if c.creds != nil {
callHdr.Creds = c.creds
}
if c.compressorType != "" {
callHdr.SendCompress = c.compressorType
} else if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
firstAttempt := true firstAttempt := true
for { for {
// Check to make sure the context has expired. This will prevent us from csInt, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...)
// looping forever if an error occurs for wait-for-ready RPCs where no data
// is sent on the wire.
select {
case <-ctx.Done():
return toRPCErr(ctx.Err())
default:
}
// Record the done handler from Balancer.Get(...). It is called once the
// RPC has completed or failed.
t, done, err := cc.getTransport(ctx, c.failFast)
if err != nil { if err != nil {
return err return err
} }
stream, err := t.NewStream(ctx, callHdr) cs := csInt.(*clientStream)
if err != nil { if err := cs.SendMsg(req); err != nil {
if done != nil { if !cs.c.failFast && cs.attempt.s.Unprocessed() && firstAttempt {
done(balancer.DoneInfo{Err: err}) // TODO: Add a field to header for grpc-transparent-retry-attempts
} firstAttempt = false
// In the event of any error from NewStream, we never attempted to write
// anything to the wire, so we can retry indefinitely for non-fail-fast
// RPCs.
if !c.failFast {
continue continue
} }
return toRPCErr(err) return err
}
if peer, ok := peer.FromContext(stream.Context()); ok {
c.peer = peer
}
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts)
if err != nil {
if done != nil {
done(balancer.DoneInfo{
Err: err,
BytesSent: true,
BytesReceived: stream.BytesReceived(),
})
}
// Retry a non-failfast RPC when
// i) the server started to drain before this RPC was initiated.
// ii) the server refused the stream.
if !c.failFast && stream.Unprocessed() {
// In this case, the server did not receive the data, but we still
// created wire traffic, so we should not retry indefinitely.
if firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false
continue
}
// Otherwise, give up and return an error anyway.
}
return toRPCErr(err)
}
err = recvResponse(ctx, cc.dopts, t, c, stream, reply)
if err != nil {
if done != nil {
done(balancer.DoneInfo{
Err: err,
BytesSent: true,
BytesReceived: stream.BytesReceived(),
})
}
if !c.failFast && stream.Unprocessed() {
// In these cases, the server did not receive the data, but we still
// created wire traffic, so we should not retry indefinitely.
if firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false
continue
}
// Otherwise, give up and return an error anyway.
}
return toRPCErr(err)
}
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
}
t.CloseStream(stream, nil)
err = stream.Status().Err()
if done != nil {
done(balancer.DoneInfo{
Err: err,
BytesSent: true,
BytesReceived: stream.BytesReceived(),
})
} }
if !c.failFast && stream.Unprocessed() { if err := cs.RecvMsg(reply); err != nil {
// In these cases, the server did not receive the data, but we still if !cs.c.failFast && cs.attempt.s.Unprocessed() && firstAttempt {
// created wire traffic, so we should not retry indefinitely.
if firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts // TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false firstAttempt = false
continue continue
} }
return err
} }
return err return nil
} }
} }
...@@ -32,25 +32,36 @@ import ( ...@@ -32,25 +32,36 @@ import (
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
_ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin. _ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin.
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
_ "google.golang.org/grpc/resolver/dns" // To register dns resolver. _ "google.golang.org/grpc/resolver/dns" // To register dns resolver.
_ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver. _ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver.
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
const (
// minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second
// must match grpclbName in grpclb/grpclb.go
grpclbName = "grpclb"
)
var ( var (
// ErrClientConnClosing indicates that the operation is illegal because // ErrClientConnClosing indicates that the operation is illegal because
// the ClientConn is closing. // the ClientConn is closing.
ErrClientConnClosing = errors.New("grpc: the client connection is closing") //
// ErrClientConnTimeout indicates that the ClientConn cannot establish the // Deprecated: this error should not be relied upon by users; use the status
// underlying connections within the specified timeout. // code of Canceled instead.
// DEPRECATED: Please use context.DeadlineExceeded instead. ErrClientConnClosing = status.Error(codes.Canceled, "grpc: the client connection is closing")
ErrClientConnTimeout = errors.New("grpc: timed out when dialing")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained") errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing. // errConnClosing indicates that the connection is closing.
...@@ -59,8 +70,11 @@ var ( ...@@ -59,8 +70,11 @@ var (
errConnUnavailable = errors.New("grpc: the connection is unavailable") errConnUnavailable = errors.New("grpc: the connection is unavailable")
// errBalancerClosed indicates that the balancer is closed. // errBalancerClosed indicates that the balancer is closed.
errBalancerClosed = errors.New("grpc: balancer is closed") errBalancerClosed = errors.New("grpc: balancer is closed")
// minimum time to give a connection to complete // We use an accessor so that minConnectTimeout can be
minConnectTimeout = 20 * time.Second // atomically read and updated while testing.
getMinConnectTimeout = func() time.Duration {
return minConnectTimeout
}
) )
// The following errors are returned from Dial and DialContext // The following errors are returned from Dial and DialContext
...@@ -85,10 +99,9 @@ var ( ...@@ -85,10 +99,9 @@ var (
type dialOptions struct { type dialOptions struct {
unaryInt UnaryClientInterceptor unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor streamInt StreamClientInterceptor
codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
bs backoffStrategy bs backoff.Strategy
block bool block bool
insecure bool insecure bool
timeout time.Duration timeout time.Duration
...@@ -99,10 +112,10 @@ type dialOptions struct { ...@@ -99,10 +112,10 @@ type dialOptions struct {
// balancer, and also by WithBalancerName dial option. // balancer, and also by WithBalancerName dial option.
balancerBuilder balancer.Builder balancerBuilder balancer.Builder
// This is to support grpclb. // This is to support grpclb.
resolverBuilder resolver.Builder resolverBuilder resolver.Builder
// Custom user options for resolver.Build. waitForHandshake bool
resolverBuildUserOptions interface{} channelzParentID int64
waitForHandshake bool disableServiceConfig bool
} }
const ( const (
...@@ -110,6 +123,12 @@ const ( ...@@ -110,6 +123,12 @@ const (
defaultClientMaxSendMessageSize = math.MaxInt32 defaultClientMaxSendMessageSize = math.MaxInt32
) )
// RegisterChannelz turns on channelz service.
// This is an EXPERIMENTAL API.
func RegisterChannelz() {
channelz.TurnOn()
}
// DialOption configures how we set up the connection. // DialOption configures how we set up the connection.
type DialOption func(*dialOptions) type DialOption func(*dialOptions)
...@@ -154,7 +173,9 @@ func WithInitialConnWindowSize(s int32) DialOption { ...@@ -154,7 +173,9 @@ func WithInitialConnWindowSize(s int32) DialOption {
} }
} }
// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive. Deprecated: use WithDefaultCallOptions(MaxCallRecvMsgSize(s)) instead. // WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive.
//
// Deprecated: use WithDefaultCallOptions(MaxCallRecvMsgSize(s)) instead.
func WithMaxMsgSize(s int) DialOption { func WithMaxMsgSize(s int) DialOption {
return WithDefaultCallOptions(MaxCallRecvMsgSize(s)) return WithDefaultCallOptions(MaxCallRecvMsgSize(s))
} }
...@@ -167,10 +188,10 @@ func WithDefaultCallOptions(cos ...CallOption) DialOption { ...@@ -167,10 +188,10 @@ func WithDefaultCallOptions(cos ...CallOption) DialOption {
} }
// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling. // WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
//
// Deprecated: use WithDefaultCallOptions(CallCustomCodec(c)) instead.
func WithCodec(c Codec) DialOption { func WithCodec(c Codec) DialOption {
return func(o *dialOptions) { return WithDefaultCallOptions(CallCustomCodec(c))
o.codec = c
}
} }
// WithCompressor returns a DialOption which sets a Compressor to use for // WithCompressor returns a DialOption which sets a Compressor to use for
...@@ -236,16 +257,9 @@ func withResolverBuilder(b resolver.Builder) DialOption { ...@@ -236,16 +257,9 @@ func withResolverBuilder(b resolver.Builder) DialOption {
} }
} }
// WithResolverUserOptions returns a DialOption which sets the UserOptions
// field of resolver's BuildOption.
func WithResolverUserOptions(userOpt interface{}) DialOption {
return func(o *dialOptions) {
o.resolverBuildUserOptions = userOpt
}
}
// WithServiceConfig returns a DialOption which has a channel to read the service configuration. // WithServiceConfig returns a DialOption which has a channel to read the service configuration.
// DEPRECATED: service config should be received through name resolver, as specified here. //
// Deprecated: service config should be received through name resolver, as specified here.
// https://github.com/grpc/grpc/blob/master/doc/service_config.md // https://github.com/grpc/grpc/blob/master/doc/service_config.md
func WithServiceConfig(c <-chan ServiceConfig) DialOption { func WithServiceConfig(c <-chan ServiceConfig) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
...@@ -265,17 +279,17 @@ func WithBackoffMaxDelay(md time.Duration) DialOption { ...@@ -265,17 +279,17 @@ func WithBackoffMaxDelay(md time.Duration) DialOption {
// Use WithBackoffMaxDelay until more parameters on BackoffConfig are opened up // Use WithBackoffMaxDelay until more parameters on BackoffConfig are opened up
// for use. // for use.
func WithBackoffConfig(b BackoffConfig) DialOption { func WithBackoffConfig(b BackoffConfig) DialOption {
// Set defaults to ensure that provided BackoffConfig is valid and
// unexported fields get default values. return withBackoff(backoff.Exponential{
setDefaults(&b) MaxDelay: b.MaxDelay,
return withBackoff(b) })
} }
// withBackoff sets the backoff strategy used for connectRetryNum after a // withBackoff sets the backoff strategy used for connectRetryNum after a
// failed connection attempt. // failed connection attempt.
// //
// This can be exported if arbitrary backoff strategies are allowed by gRPC. // This can be exported if arbitrary backoff strategies are allowed by gRPC.
func withBackoff(bs backoffStrategy) DialOption { func withBackoff(bs backoff.Strategy) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.bs = bs o.bs = bs
} }
...@@ -316,6 +330,7 @@ func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption { ...@@ -316,6 +330,7 @@ func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption {
// WithTimeout returns a DialOption that configures a timeout for dialing a ClientConn // WithTimeout returns a DialOption that configures a timeout for dialing a ClientConn
// initially. This is valid if and only if WithBlock() is present. // initially. This is valid if and only if WithBlock() is present.
//
// Deprecated: use DialContext and context.WithTimeout instead. // Deprecated: use DialContext and context.WithTimeout instead.
func WithTimeout(d time.Duration) DialOption { func WithTimeout(d time.Duration) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
...@@ -329,6 +344,11 @@ func withContextDialer(f func(context.Context, string) (net.Conn, error)) DialOp ...@@ -329,6 +344,11 @@ func withContextDialer(f func(context.Context, string) (net.Conn, error)) DialOp
} }
} }
func init() {
internal.WithContextDialer = withContextDialer
internal.WithResolverBuilder = withResolverBuilder
}
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses. // WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
// If FailOnNonTempDialError() is set to true, and an error is returned by f, gRPC checks the error's // If FailOnNonTempDialError() is set to true, and an error is returned by f, gRPC checks the error's
// Temporary() method to decide if it should try to reconnect to the network address. // Temporary() method to decide if it should try to reconnect to the network address.
...@@ -398,15 +418,44 @@ func WithAuthority(a string) DialOption { ...@@ -398,15 +418,44 @@ func WithAuthority(a string) DialOption {
} }
} }
// WithChannelzParentID returns a DialOption that specifies the channelz ID of current ClientConn's
// parent. This function is used in nested channel creation (e.g. grpclb dial).
func WithChannelzParentID(id int64) DialOption {
return func(o *dialOptions) {
o.channelzParentID = id
}
}
// WithDisableServiceConfig returns a DialOption that causes grpc to ignore any
// service config provided by the resolver and provides a hint to the resolver
// to not fetch service configs.
func WithDisableServiceConfig() DialOption {
return func(o *dialOptions) {
o.disableServiceConfig = true
}
}
// Dial creates a client connection to the given target. // Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) { func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...) return DialContext(context.Background(), target, opts...)
} }
// DialContext creates a client connection to the given target. ctx can be used to // DialContext creates a client connection to the given target. By default, it's
// cancel or expire the pending connection. Once this function returns, the // a non-blocking dial (the function won't wait for connections to be
// cancellation and expiration of ctx will be noop. Users should call ClientConn.Close // established, and connecting happens in the background). To make it a blocking
// to terminate all the pending operations after this function returns. // dial, use WithBlock() dial option.
//
// In the non-blocking case, the ctx does not act against the connection. It
// only controls the setup steps.
//
// In the blocking case, ctx can be used to cancel or expire the pending
// connection. Once this function returns, the cancellation and expiration of
// ctx will be noop. Users should call ClientConn.Close to terminate all the
// pending operations after this function returns.
//
// The target name syntax is defined in
// https://github.com/grpc/grpc/blob/master/doc/naming.md.
// e.g. to use dns resolver, a "dns:///" prefix should be applied to the target.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
...@@ -421,6 +470,14 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ...@@ -421,6 +470,14 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
opt(&cc.dopts) opt(&cc.dopts)
} }
if channelz.IsOn() {
if cc.dopts.channelzParentID != 0 {
cc.channelzID = channelz.RegisterChannel(cc, cc.dopts.channelzParentID, target)
} else {
cc.channelzID = channelz.RegisterChannel(cc, 0, target)
}
}
if !cc.dopts.insecure { if !cc.dopts.insecure {
if cc.dopts.copts.TransportCredentials == nil { if cc.dopts.copts.TransportCredentials == nil {
return nil, errNoTransportSecurity return nil, errNoTransportSecurity
...@@ -441,7 +498,8 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ...@@ -441,7 +498,8 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
if cc.dopts.copts.Dialer == nil { if cc.dopts.copts.Dialer == nil {
cc.dopts.copts.Dialer = newProxyDialer( cc.dopts.copts.Dialer = newProxyDialer(
func(ctx context.Context, addr string) (net.Conn, error) { func(ctx context.Context, addr string) (net.Conn, error) {
return dialContext(ctx, "tcp", addr) network, addr := parseDialTarget(addr)
return dialContext(ctx, network, addr)
}, },
) )
} }
...@@ -482,14 +540,30 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ...@@ -482,14 +540,30 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
default: default:
} }
} }
// Set defaults.
if cc.dopts.codec == nil {
cc.dopts.codec = protoCodec{}
}
if cc.dopts.bs == nil { if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig cc.dopts.bs = backoff.Exponential{
MaxDelay: DefaultBackoffConfig.MaxDelay,
}
}
if cc.dopts.resolverBuilder == nil {
// Only try to parse target when resolver builder is not already set.
cc.parsedTarget = parseTarget(cc.target)
grpclog.Infof("parsed scheme: %q", cc.parsedTarget.Scheme)
cc.dopts.resolverBuilder = resolver.Get(cc.parsedTarget.Scheme)
if cc.dopts.resolverBuilder == nil {
// If resolver builder is still nil, the parse target's scheme is
// not registered. Fallback to default resolver and set Endpoint to
// the original unparsed target.
grpclog.Infof("scheme %q not registered, fallback to default scheme", cc.parsedTarget.Scheme)
cc.parsedTarget = resolver.Target{
Scheme: resolver.GetDefaultScheme(),
Endpoint: target,
}
cc.dopts.resolverBuilder = resolver.Get(cc.parsedTarget.Scheme)
}
} else {
cc.parsedTarget = resolver.Target{Endpoint: target}
} }
cc.parsedTarget = parseTarget(cc.target)
creds := cc.dopts.copts.TransportCredentials creds := cc.dopts.copts.TransportCredentials
if creds != nil && creds.Info().ServerName != "" { if creds != nil && creds.Info().ServerName != "" {
cc.authority = creds.Info().ServerName cc.authority = creds.Info().ServerName
...@@ -521,8 +595,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ...@@ -521,8 +595,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
credsClone = creds.Clone() credsClone = creds.Clone()
} }
cc.balancerBuildOpts = balancer.BuildOptions{ cc.balancerBuildOpts = balancer.BuildOptions{
DialCreds: credsClone, DialCreds: credsClone,
Dialer: cc.dopts.copts.Dialer, Dialer: cc.dopts.copts.Dialer,
ChannelzParentID: cc.channelzID,
} }
// Build the resolver. // Build the resolver.
...@@ -624,6 +699,13 @@ type ClientConn struct { ...@@ -624,6 +699,13 @@ type ClientConn struct {
preBalancerName string // previous balancer name. preBalancerName string // previous balancer name.
curAddresses []resolver.Address curAddresses []resolver.Address
balancerWrapper *ccBalancerWrapper balancerWrapper *ccBalancerWrapper
channelzID int64 // channelz unique identification number
czmu sync.RWMutex
callsStarted int64
callsSucceeded int64
callsFailed int64
lastCallStartedTime time.Time
} }
// WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or // WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or
...@@ -748,6 +830,8 @@ func (cc *ClientConn) switchBalancer(name string) { ...@@ -748,6 +830,8 @@ func (cc *ClientConn) switchBalancer(name string) {
if cc.balancerWrapper != nil { if cc.balancerWrapper != nil {
cc.balancerWrapper.close() cc.balancerWrapper.close()
} }
// Clear all stickiness state.
cc.blockingpicker.clearStickinessState()
builder := balancer.Get(name) builder := balancer.Get(name)
if builder == nil { if builder == nil {
...@@ -787,6 +871,9 @@ func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) { ...@@ -787,6 +871,9 @@ func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) {
cc.mu.Unlock() cc.mu.Unlock()
return nil, ErrClientConnClosing return nil, ErrClientConnClosing
} }
if channelz.IsOn() {
ac.channelzID = channelz.RegisterSubChannel(ac, cc.channelzID, "")
}
cc.conns[ac] = struct{}{} cc.conns[ac] = struct{}{}
cc.mu.Unlock() cc.mu.Unlock()
return ac, nil return ac, nil
...@@ -805,6 +892,42 @@ func (cc *ClientConn) removeAddrConn(ac *addrConn, err error) { ...@@ -805,6 +892,42 @@ func (cc *ClientConn) removeAddrConn(ac *addrConn, err error) {
ac.tearDown(err) ac.tearDown(err)
} }
// ChannelzMetric returns ChannelInternalMetric of current ClientConn.
// This is an EXPERIMENTAL API.
func (cc *ClientConn) ChannelzMetric() *channelz.ChannelInternalMetric {
state := cc.GetState()
cc.czmu.RLock()
defer cc.czmu.RUnlock()
return &channelz.ChannelInternalMetric{
State: state,
Target: cc.target,
CallsStarted: cc.callsStarted,
CallsSucceeded: cc.callsSucceeded,
CallsFailed: cc.callsFailed,
LastCallStartedTimestamp: cc.lastCallStartedTime,
}
}
func (cc *ClientConn) incrCallsStarted() {
cc.czmu.Lock()
cc.callsStarted++
// TODO(yuxuanli): will make this a time.Time pointer improve performance?
cc.lastCallStartedTime = time.Now()
cc.czmu.Unlock()
}
func (cc *ClientConn) incrCallsSucceeded() {
cc.czmu.Lock()
cc.callsSucceeded++
cc.czmu.Unlock()
}
func (cc *ClientConn) incrCallsFailed() {
cc.czmu.Lock()
cc.callsFailed++
cc.czmu.Unlock()
}
// connect starts to creating transport and also starts the transport monitor // connect starts to creating transport and also starts the transport monitor
// goroutine for this ac. // goroutine for this ac.
// It does nothing if the ac is not IDLE. // It does nothing if the ac is not IDLE.
...@@ -875,7 +998,7 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool { ...@@ -875,7 +998,7 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool {
// the corresponding MethodConfig. // the corresponding MethodConfig.
// If there isn't an exact match for the input method, we look for the default config // If there isn't an exact match for the input method, we look for the default config
// under the service (i.e /service/). If there is a default MethodConfig for // under the service (i.e /service/). If there is a default MethodConfig for
// the serivce, we return it. // the service, we return it.
// Otherwise, we return an empty MethodConfig. // Otherwise, we return an empty MethodConfig.
func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { func (cc *ClientConn) GetMethodConfig(method string) MethodConfig {
// TODO: Avoid the locking here. // TODO: Avoid the locking here.
...@@ -884,7 +1007,7 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { ...@@ -884,7 +1007,7 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig {
m, ok := cc.sc.Methods[method] m, ok := cc.sc.Methods[method]
if !ok { if !ok {
i := strings.LastIndex(method, "/") i := strings.LastIndex(method, "/")
m, _ = cc.sc.Methods[method[:i+1]] m = cc.sc.Methods[method[:i+1]]
} }
return m return m
} }
...@@ -900,6 +1023,9 @@ func (cc *ClientConn) getTransport(ctx context.Context, failfast bool) (transpor ...@@ -900,6 +1023,9 @@ func (cc *ClientConn) getTransport(ctx context.Context, failfast bool) (transpor
// handleServiceConfig parses the service config string in JSON format to Go native // handleServiceConfig parses the service config string in JSON format to Go native
// struct ServiceConfig, and store both the struct and the JSON string in ClientConn. // struct ServiceConfig, and store both the struct and the JSON string in ClientConn.
func (cc *ClientConn) handleServiceConfig(js string) error { func (cc *ClientConn) handleServiceConfig(js string) error {
if cc.dopts.disableServiceConfig {
return nil
}
sc, err := parseServiceConfig(js) sc, err := parseServiceConfig(js)
if err != nil { if err != nil {
return err return err
...@@ -920,14 +1046,26 @@ func (cc *ClientConn) handleServiceConfig(js string) error { ...@@ -920,14 +1046,26 @@ func (cc *ClientConn) handleServiceConfig(js string) error {
cc.balancerWrapper.handleResolvedAddrs(cc.curAddresses, nil) cc.balancerWrapper.handleResolvedAddrs(cc.curAddresses, nil)
} }
} }
if envConfigStickinessOn {
var newStickinessMDKey string
if sc.stickinessMetadataKey != nil && *sc.stickinessMetadataKey != "" {
newStickinessMDKey = *sc.stickinessMetadataKey
}
// newStickinessMDKey is "" if one of the following happens:
// - stickinessMetadataKey is set to ""
// - stickinessMetadataKey field doesn't exist in service config
cc.blockingpicker.updateStickinessMDKey(strings.ToLower(newStickinessMDKey))
}
cc.mu.Unlock() cc.mu.Unlock()
return nil return nil
} }
func (cc *ClientConn) resolveNow(o resolver.ResolveNowOption) { func (cc *ClientConn) resolveNow(o resolver.ResolveNowOption) {
cc.mu.Lock() cc.mu.RLock()
r := cc.resolverWrapper r := cc.resolverWrapper
cc.mu.Unlock() cc.mu.RUnlock()
if r == nil { if r == nil {
return return
} }
...@@ -936,7 +1074,7 @@ func (cc *ClientConn) resolveNow(o resolver.ResolveNowOption) { ...@@ -936,7 +1074,7 @@ func (cc *ClientConn) resolveNow(o resolver.ResolveNowOption) {
// Close tears down the ClientConn and all underlying connections. // Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() error { func (cc *ClientConn) Close() error {
cc.cancel() defer cc.cancel()
cc.mu.Lock() cc.mu.Lock()
if cc.conns == nil { if cc.conns == nil {
...@@ -952,16 +1090,22 @@ func (cc *ClientConn) Close() error { ...@@ -952,16 +1090,22 @@ func (cc *ClientConn) Close() error {
bWrapper := cc.balancerWrapper bWrapper := cc.balancerWrapper
cc.balancerWrapper = nil cc.balancerWrapper = nil
cc.mu.Unlock() cc.mu.Unlock()
cc.blockingpicker.close() cc.blockingpicker.close()
if rWrapper != nil { if rWrapper != nil {
rWrapper.close() rWrapper.close()
} }
if bWrapper != nil { if bWrapper != nil {
bWrapper.close() bWrapper.close()
} }
for ac := range conns { for ac := range conns {
ac.tearDown(ErrClientConnClosing) ac.tearDown(ErrClientConnClosing)
} }
if channelz.IsOn() {
channelz.RemoveEntry(cc.channelzID)
}
return nil return nil
} }
...@@ -995,6 +1139,13 @@ type addrConn struct { ...@@ -995,6 +1139,13 @@ type addrConn struct {
// connectDeadline is the time by which all connection // connectDeadline is the time by which all connection
// negotiations must complete. // negotiations must complete.
connectDeadline time.Time connectDeadline time.Time
channelzID int64 // channelz unique identification number
czmu sync.RWMutex
callsStarted int64
callsSucceeded int64
callsFailed int64
lastCallStartedTime time.Time
} }
// adjustParams updates parameters used to create transports upon // adjustParams updates parameters used to create transports upon
...@@ -1030,7 +1181,7 @@ func (ac *addrConn) errorf(format string, a ...interface{}) { ...@@ -1030,7 +1181,7 @@ func (ac *addrConn) errorf(format string, a ...interface{}) {
// resetTransport recreates a transport to the address for ac. The old // resetTransport recreates a transport to the address for ac. The old
// transport will close itself on error or when the clientconn is closed. // transport will close itself on error or when the clientconn is closed.
// The created transport must receive initial settings frame from the server. // The created transport must receive initial settings frame from the server.
// In case that doesnt happen, transportMonitor will kill the newly created // In case that doesn't happen, transportMonitor will kill the newly created
// transport after connectDeadline has expired. // transport after connectDeadline has expired.
// In case there was an error on the transport before the settings frame was // In case there was an error on the transport before the settings frame was
// received, resetTransport resumes connecting to backends after the one that // received, resetTransport resumes connecting to backends after the one that
...@@ -1063,9 +1214,9 @@ func (ac *addrConn) resetTransport() error { ...@@ -1063,9 +1214,9 @@ func (ac *addrConn) resetTransport() error {
// This means either a successful HTTP2 connection was established // This means either a successful HTTP2 connection was established
// or this is the first time this addrConn is trying to establish a // or this is the first time this addrConn is trying to establish a
// connection. // connection.
backoffFor := ac.dopts.bs.backoff(connectRetryNum) // time.Duration. backoffFor := ac.dopts.bs.Backoff(connectRetryNum) // time.Duration.
// This will be the duration that dial gets to finish. // This will be the duration that dial gets to finish.
dialDuration := minConnectTimeout dialDuration := getMinConnectTimeout()
if backoffFor > dialDuration { if backoffFor > dialDuration {
// Give dial more time as we keep failing to connect. // Give dial more time as we keep failing to connect.
dialDuration = backoffFor dialDuration = backoffFor
...@@ -1075,7 +1226,7 @@ func (ac *addrConn) resetTransport() error { ...@@ -1075,7 +1226,7 @@ func (ac *addrConn) resetTransport() error {
connectDeadline = start.Add(dialDuration) connectDeadline = start.Add(dialDuration)
ridx = 0 // Start connecting from the beginning. ridx = 0 // Start connecting from the beginning.
} else { } else {
// Continue trying to conect with the same deadlines. // Continue trying to connect with the same deadlines.
connectRetryNum = ac.connectRetryNum connectRetryNum = ac.connectRetryNum
backoffDeadline = ac.backoffDeadline backoffDeadline = ac.backoffDeadline
connectDeadline = ac.connectDeadline connectDeadline = ac.connectDeadline
...@@ -1136,18 +1287,13 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, ...@@ -1136,18 +1287,13 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline,
// Do not cancel in the success path because of // Do not cancel in the success path because of
// this issue in Go1.6: https://github.com/golang/go/issues/15078. // this issue in Go1.6: https://github.com/golang/go/issues/15078.
connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline) connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline)
if channelz.IsOn() {
copts.ChannelzParentID = ac.channelzID
}
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, target, copts, onPrefaceReceipt) newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, target, copts, onPrefaceReceipt)
if err != nil { if err != nil {
cancel() cancel()
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { ac.cc.blockingpicker.updateConnectionError(err)
ac.mu.Lock()
if ac.state != connectivity.Shutdown {
ac.state = connectivity.TransientFailure
ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
}
ac.mu.Unlock()
return false, err
}
ac.mu.Lock() ac.mu.Lock()
if ac.state == connectivity.Shutdown { if ac.state == connectivity.Shutdown {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
...@@ -1199,6 +1345,10 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, ...@@ -1199,6 +1345,10 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline,
return true, nil return true, nil
} }
ac.mu.Lock() ac.mu.Lock()
if ac.state == connectivity.Shutdown {
ac.mu.Unlock()
return false, errConnClosing
}
ac.state = connectivity.TransientFailure ac.state = connectivity.TransientFailure
ac.cc.handleSubConnStateChange(ac.acbw, ac.state) ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
ac.cc.resolveNow(resolver.ResolveNowOption{}) ac.cc.resolveNow(resolver.ResolveNowOption{})
...@@ -1233,7 +1383,20 @@ func (ac *addrConn) transportMonitor() { ...@@ -1233,7 +1383,20 @@ func (ac *addrConn) transportMonitor() {
// Block until we receive a goaway or an error occurs. // Block until we receive a goaway or an error occurs.
select { select {
case <-t.GoAway(): case <-t.GoAway():
done := t.Error()
cleanup := t.Close
// Since this transport will be orphaned (won't have a transportMonitor)
// we need to launch a goroutine to keep track of clientConn.Close()
// happening since it might not be noticed by any other goroutine for a while.
go func() {
<-done
cleanup()
}()
case <-t.Error(): case <-t.Error():
// In case this is triggered because clientConn.Close()
// was called, we want to immeditately close the transport
// since no other goroutine might notice it for a while.
t.Close()
case <-cdeadline: case <-cdeadline:
ac.mu.Lock() ac.mu.Lock()
// This implies that client received server preface. // This implies that client received server preface.
...@@ -1377,7 +1540,9 @@ func (ac *addrConn) tearDown(err error) { ...@@ -1377,7 +1540,9 @@ func (ac *addrConn) tearDown(err error) {
close(ac.ready) close(ac.ready)
ac.ready = nil ac.ready = nil
} }
return if channelz.IsOn() {
channelz.RemoveEntry(ac.channelzID)
}
} }
func (ac *addrConn) getState() connectivity.State { func (ac *addrConn) getState() connectivity.State {
...@@ -1385,3 +1550,53 @@ func (ac *addrConn) getState() connectivity.State { ...@@ -1385,3 +1550,53 @@ func (ac *addrConn) getState() connectivity.State {
defer ac.mu.Unlock() defer ac.mu.Unlock()
return ac.state return ac.state
} }
func (ac *addrConn) getCurAddr() (ret resolver.Address) {
ac.mu.Lock()
ret = ac.curAddr
ac.mu.Unlock()
return
}
func (ac *addrConn) ChannelzMetric() *channelz.ChannelInternalMetric {
ac.mu.Lock()
addr := ac.curAddr.Addr
ac.mu.Unlock()
state := ac.getState()
ac.czmu.RLock()
defer ac.czmu.RUnlock()
return &channelz.ChannelInternalMetric{
State: state,
Target: addr,
CallsStarted: ac.callsStarted,
CallsSucceeded: ac.callsSucceeded,
CallsFailed: ac.callsFailed,
LastCallStartedTimestamp: ac.lastCallStartedTime,
}
}
func (ac *addrConn) incrCallsStarted() {
ac.czmu.Lock()
ac.callsStarted++
ac.lastCallStartedTime = time.Now()
ac.czmu.Unlock()
}
func (ac *addrConn) incrCallsSucceeded() {
ac.czmu.Lock()
ac.callsSucceeded++
ac.czmu.Unlock()
}
func (ac *addrConn) incrCallsFailed() {
ac.czmu.Lock()
ac.callsFailed++
ac.czmu.Unlock()
}
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
// underlying connections within the specified timeout.
//
// Deprecated: This error is never returned by grpc and should not be
// referenced by users.
var ErrClientConnTimeout = errors.New("grpc: timed out when dialing")
...@@ -19,96 +19,32 @@ ...@@ -19,96 +19,32 @@
package grpc package grpc
import ( import (
"math" "google.golang.org/grpc/encoding"
"sync" _ "google.golang.org/grpc/encoding/proto" // to register the Codec for "proto"
"github.com/golang/protobuf/proto"
) )
// baseCodec contains the functionality of both Codec and encoding.Codec, but
// omits the name/string, which vary between the two and are not needed for
// anything besides the registry in the encoding package.
type baseCodec interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
}
var _ baseCodec = Codec(nil)
var _ baseCodec = encoding.Codec(nil)
// Codec defines the interface gRPC uses to encode and decode messages. // Codec defines the interface gRPC uses to encode and decode messages.
// Note that implementations of this interface must be thread safe; // Note that implementations of this interface must be thread safe;
// a Codec's methods can be called from concurrent goroutines. // a Codec's methods can be called from concurrent goroutines.
//
// Deprecated: use encoding.Codec instead.
type Codec interface { type Codec interface {
// Marshal returns the wire format of v. // Marshal returns the wire format of v.
Marshal(v interface{}) ([]byte, error) Marshal(v interface{}) ([]byte, error)
// Unmarshal parses the wire format into v. // Unmarshal parses the wire format into v.
Unmarshal(data []byte, v interface{}) error Unmarshal(data []byte, v interface{}) error
// String returns the name of the Codec implementation. The returned // String returns the name of the Codec implementation. This is unused by
// string will be used as part of content type in transmission. // gRPC.
String() string String() string
} }
// protoCodec is a Codec implementation with protobuf. It is the default codec for gRPC.
type protoCodec struct {
}
type cachedProtoBuffer struct {
lastMarshaledSize uint32
proto.Buffer
}
func capToMaxInt32(val int) uint32 {
if val > math.MaxInt32 {
return uint32(math.MaxInt32)
}
return uint32(val)
}
func (p protoCodec) marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error) {
protoMsg := v.(proto.Message)
newSlice := make([]byte, 0, cb.lastMarshaledSize)
cb.SetBuf(newSlice)
cb.Reset()
if err := cb.Marshal(protoMsg); err != nil {
return nil, err
}
out := cb.Bytes()
cb.lastMarshaledSize = capToMaxInt32(len(out))
return out, nil
}
func (p protoCodec) Marshal(v interface{}) ([]byte, error) {
if pm, ok := v.(proto.Marshaler); ok {
// object can marshal itself, no need for buffer
return pm.Marshal()
}
cb := protoBufferPool.Get().(*cachedProtoBuffer)
out, err := p.marshal(v, cb)
// put back buffer and lose the ref to the slice
cb.SetBuf(nil)
protoBufferPool.Put(cb)
return out, err
}
func (p protoCodec) Unmarshal(data []byte, v interface{}) error {
protoMsg := v.(proto.Message)
protoMsg.Reset()
if pu, ok := protoMsg.(proto.Unmarshaler); ok {
// object can unmarshal itself, no need for buffer
return pu.Unmarshal(data)
}
cb := protoBufferPool.Get().(*cachedProtoBuffer)
cb.SetBuf(data)
err := cb.Unmarshal(protoMsg)
cb.SetBuf(nil)
protoBufferPool.Put(cb)
return err
}
func (protoCodec) String() string {
return "proto"
}
var protoBufferPool = &sync.Pool{
New: func() interface{} {
return &cachedProtoBuffer{
Buffer: proto.Buffer{},
lastMarshaledSize: 16,
}
},
}
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
// Package codes defines the canonical error codes used by gRPC. It is // Package codes defines the canonical error codes used by gRPC. It is
// consistent across various languages. // consistent across various languages.
package codes // import "google.golang.org/grpc/codes" package codes // import "google.golang.org/grpc/codes"
import ( import (
"fmt" "fmt"
"strconv"
) )
// A Code is an unsigned 32-bit error code as defined in the gRPC spec. // A Code is an unsigned 32-bit error code as defined in the gRPC spec.
...@@ -33,9 +35,9 @@ const ( ...@@ -33,9 +35,9 @@ const (
// Canceled indicates the operation was canceled (typically by the caller). // Canceled indicates the operation was canceled (typically by the caller).
Canceled Code = 1 Canceled Code = 1
// Unknown error. An example of where this error may be returned is // Unknown error. An example of where this error may be returned is
// if a Status value received from another address space belongs to // if a Status value received from another address space belongs to
// an error-space that is not known in this address space. Also // an error-space that is not known in this address space. Also
// errors raised by APIs that do not return enough error information // errors raised by APIs that do not return enough error information
// may be converted to this error. // may be converted to this error.
Unknown Code = 2 Unknown Code = 2
...@@ -64,15 +66,11 @@ const ( ...@@ -64,15 +66,11 @@ const (
// PermissionDenied indicates the caller does not have permission to // PermissionDenied indicates the caller does not have permission to
// execute the specified operation. It must not be used for rejections // execute the specified operation. It must not be used for rejections
// caused by exhausting some resource (use ResourceExhausted // caused by exhausting some resource (use ResourceExhausted
// instead for those errors). It must not be // instead for those errors). It must not be
// used if the caller cannot be identified (use Unauthenticated // used if the caller cannot be identified (use Unauthenticated
// instead for those errors). // instead for those errors).
PermissionDenied Code = 7 PermissionDenied Code = 7
// Unauthenticated indicates the request does not have valid
// authentication credentials for the operation.
Unauthenticated Code = 16
// ResourceExhausted indicates some resource has been exhausted, perhaps // ResourceExhausted indicates some resource has been exhausted, perhaps
// a per-user quota, or perhaps the entire file system is out of space. // a per-user quota, or perhaps the entire file system is out of space.
ResourceExhausted Code = 8 ResourceExhausted Code = 8
...@@ -88,7 +86,7 @@ const ( ...@@ -88,7 +86,7 @@ const (
// (b) Use Aborted if the client should retry at a higher-level // (b) Use Aborted if the client should retry at a higher-level
// (e.g., restarting a read-modify-write sequence). // (e.g., restarting a read-modify-write sequence).
// (c) Use FailedPrecondition if the client should not retry until // (c) Use FailedPrecondition if the client should not retry until
// the system state has been explicitly fixed. E.g., if an "rmdir" // the system state has been explicitly fixed. E.g., if an "rmdir"
// fails because the directory is non-empty, FailedPrecondition // fails because the directory is non-empty, FailedPrecondition
// should be returned since the client should not retry unless // should be returned since the client should not retry unless
// they have first fixed up the directory by deleting files from it. // they have first fixed up the directory by deleting files from it.
...@@ -117,7 +115,7 @@ const ( ...@@ -117,7 +115,7 @@ const (
// file size. // file size.
// //
// There is a fair bit of overlap between FailedPrecondition and // There is a fair bit of overlap between FailedPrecondition and
// OutOfRange. We recommend using OutOfRange (the more specific // OutOfRange. We recommend using OutOfRange (the more specific
// error) when it applies so that callers who are iterating through // error) when it applies so that callers who are iterating through
// a space can easily look for an OutOfRange error to detect when // a space can easily look for an OutOfRange error to detect when
// they are done. // they are done.
...@@ -127,8 +125,8 @@ const ( ...@@ -127,8 +125,8 @@ const (
// supported/enabled in this service. // supported/enabled in this service.
Unimplemented Code = 12 Unimplemented Code = 12
// Internal errors. Means some invariants expected by underlying // Internal errors. Means some invariants expected by underlying
// system has been broken. If you see one of these errors, // system has been broken. If you see one of these errors,
// something is very broken. // something is very broken.
Internal Code = 13 Internal Code = 13
...@@ -142,6 +140,12 @@ const ( ...@@ -142,6 +140,12 @@ const (
// DataLoss indicates unrecoverable data loss or corruption. // DataLoss indicates unrecoverable data loss or corruption.
DataLoss Code = 15 DataLoss Code = 15
// Unauthenticated indicates the request does not have valid
// authentication credentials for the operation.
Unauthenticated Code = 16
_maxCode = 17
) )
var strToCode = map[string]Code{ var strToCode = map[string]Code{
...@@ -175,6 +179,16 @@ func (c *Code) UnmarshalJSON(b []byte) error { ...@@ -175,6 +179,16 @@ func (c *Code) UnmarshalJSON(b []byte) error {
if c == nil { if c == nil {
return fmt.Errorf("nil receiver passed to UnmarshalJSON") return fmt.Errorf("nil receiver passed to UnmarshalJSON")
} }
if ci, err := strconv.ParseUint(string(b), 10, 32); err == nil {
if ci >= _maxCode {
return fmt.Errorf("invalid code: %q", ci)
}
*c = Code(ci)
return nil
}
if jc, ok := strToCode[string(b)]; ok { if jc, ok := strToCode[string(b)]; ok {
*c = jc *c = jc
return nil return nil
......
...@@ -43,8 +43,9 @@ type PerRPCCredentials interface { ...@@ -43,8 +43,9 @@ type PerRPCCredentials interface {
// GetRequestMetadata gets the current request metadata, refreshing // GetRequestMetadata gets the current request metadata, refreshing
// tokens if required. This should be called by the transport layer on // tokens if required. This should be called by the transport layer on
// each request, and the data should be populated in headers or other // each request, and the data should be populated in headers or other
// context. uri is the URI of the entry point for the request. When // context. If a status code is returned, it will be used as the status
// supported by the underlying implementation, ctx can be used for // for the RPC. uri is the URI of the entry point for the request.
// When supported by the underlying implementation, ctx can be used for
// timeout and cancellation. // timeout and cancellation.
// TODO(zhaoq): Define the set of the qualified keys instead of leaving // TODO(zhaoq): Define the set of the qualified keys instead of leaving
// it as an arbitrary string. // it as an arbitrary string.
......
...@@ -16,46 +16,103 @@ ...@@ -16,46 +16,103 @@
* *
*/ */
// Package encoding defines the interface for the compressor and the functions // Package encoding defines the interface for the compressor and codec, and
// to register and get the compossor. // functions to register and retrieve compressors and codecs.
//
// This package is EXPERIMENTAL. // This package is EXPERIMENTAL.
package encoding package encoding
import ( import (
"io" "io"
"strings"
) )
var registerCompressor = make(map[string]Compressor) // Identity specifies the optional encoding for uncompressed streams.
// It is intended for grpc internal use only.
const Identity = "identity"
// Compressor is used for compressing and decompressing when sending or receiving messages. // Compressor is used for compressing and decompressing when sending or
// receiving messages.
type Compressor interface { type Compressor interface {
// Compress writes the data written to wc to w after compressing it. If an error // Compress writes the data written to wc to w after compressing it. If an
// occurs while initializing the compressor, that error is returned instead. // error occurs while initializing the compressor, that error is returned
// instead.
Compress(w io.Writer) (io.WriteCloser, error) Compress(w io.Writer) (io.WriteCloser, error)
// Decompress reads data from r, decompresses it, and provides the uncompressed data // Decompress reads data from r, decompresses it, and provides the
// via the returned io.Reader. If an error occurs while initializing the decompressor, that error // uncompressed data via the returned io.Reader. If an error occurs while
// is returned instead. // initializing the decompressor, that error is returned instead.
Decompress(r io.Reader) (io.Reader, error) Decompress(r io.Reader) (io.Reader, error)
// Name is the name of the compression codec and is used to set the content coding header. // Name is the name of the compression codec and is used to set the content
// coding header. The result must be static; the result cannot change
// between calls.
Name() string Name() string
} }
// RegisterCompressor registers the compressor with gRPC by its name. It can be activated when var registeredCompressor = make(map[string]Compressor)
// sending an RPC via grpc.UseCompressor(). It will be automatically accessed when receiving a
// message based on the content coding header. Servers also use it to send a response with the // RegisterCompressor registers the compressor with gRPC by its name. It can
// same encoding as the request. // be activated when sending an RPC via grpc.UseCompressor(). It will be
// automatically accessed when receiving a message based on the content coding
// header. Servers also use it to send a response with the same encoding as
// the request.
// //
// NOTE: this function must only be called during initialization time (i.e. in an init() function). If // NOTE: this function must only be called during initialization time (i.e. in
// multiple Compressors are registered with the same name, the one registered last will take effect. // an init() function), and is not thread-safe. If multiple Compressors are
// registered with the same name, the one registered last will take effect.
func RegisterCompressor(c Compressor) { func RegisterCompressor(c Compressor) {
registerCompressor[c.Name()] = c registeredCompressor[c.Name()] = c
} }
// GetCompressor returns Compressor for the given compressor name. // GetCompressor returns Compressor for the given compressor name.
func GetCompressor(name string) Compressor { func GetCompressor(name string) Compressor {
return registerCompressor[name] return registeredCompressor[name]
} }
// Identity specifies the optional encoding for uncompressed streams. // Codec defines the interface gRPC uses to encode and decode messages. Note
// It is intended for grpc internal use only. // that implementations of this interface must be thread safe; a Codec's
const Identity = "identity" // methods can be called from concurrent goroutines.
type Codec interface {
// Marshal returns the wire format of v.
Marshal(v interface{}) ([]byte, error)
// Unmarshal parses the wire format into v.
Unmarshal(data []byte, v interface{}) error
// Name returns the name of the Codec implementation. The returned string
// will be used as part of content type in transmission. The result must be
// static; the result cannot change between calls.
Name() string
}
var registeredCodecs = make(map[string]Codec)
// RegisterCodec registers the provided Codec for use with all gRPC clients and
// servers.
//
// The Codec will be stored and looked up by result of its Name() method, which
// should match the content-subtype of the encoding handled by the Codec. This
// is case-insensitive, and is stored and looked up as lowercase. If the
// result of calling Name() is an empty string, RegisterCodec will panic. See
// Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Compressors are
// registered with the same name, the one registered last will take effect.
func RegisterCodec(codec Codec) {
if codec == nil {
panic("cannot register a nil Codec")
}
contentSubtype := strings.ToLower(codec.Name())
if contentSubtype == "" {
panic("cannot register Codec with empty string result for String()")
}
registeredCodecs[contentSubtype] = codec
}
// GetCodec gets a registered Codec by content-subtype, or nil if no Codec is
// registered for the content-subtype.
//
// The content-subtype is expected to be lowercase.
func GetCodec(contentSubtype string) Codec {
return registeredCodecs[contentSubtype]
}
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package proto defines the protobuf codec. Importing this package will
// register the codec.
package proto
import (
"math"
"sync"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/encoding"
)
// Name is the name registered for the proto compressor.
const Name = "proto"
func init() {
encoding.RegisterCodec(codec{})
}
// codec is a Codec implementation with protobuf. It is the default codec for gRPC.
type codec struct{}
type cachedProtoBuffer struct {
lastMarshaledSize uint32
proto.Buffer
}
func capToMaxInt32(val int) uint32 {
if val > math.MaxInt32 {
return uint32(math.MaxInt32)
}
return uint32(val)
}
func marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error) {
protoMsg := v.(proto.Message)
newSlice := make([]byte, 0, cb.lastMarshaledSize)
cb.SetBuf(newSlice)
cb.Reset()
if err := cb.Marshal(protoMsg); err != nil {
return nil, err
}
out := cb.Bytes()
cb.lastMarshaledSize = capToMaxInt32(len(out))
return out, nil
}
func (codec) Marshal(v interface{}) ([]byte, error) {
if pm, ok := v.(proto.Marshaler); ok {
// object can marshal itself, no need for buffer
return pm.Marshal()
}
cb := protoBufferPool.Get().(*cachedProtoBuffer)
out, err := marshal(v, cb)
// put back buffer and lose the ref to the slice
cb.SetBuf(nil)
protoBufferPool.Put(cb)
return out, err
}
func (codec) Unmarshal(data []byte, v interface{}) error {
protoMsg := v.(proto.Message)
protoMsg.Reset()
if pu, ok := protoMsg.(proto.Unmarshaler); ok {
// object can unmarshal itself, no need for buffer
return pu.Unmarshal(data)
}
cb := protoBufferPool.Get().(*cachedProtoBuffer)
cb.SetBuf(data)
err := cb.Unmarshal(protoMsg)
cb.SetBuf(nil)
protoBufferPool.Put(cb)
return err
}
func (codec) Name() string {
return Name
}
var protoBufferPool = &sync.Pool{
New: func() interface{} {
return &cachedProtoBuffer{
Buffer: proto.Buffer{},
lastMarshaledSize: 16,
}
},
}
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"os"
"strings"
)
const (
envConfigPrefix = "GRPC_GO_"
envConfigStickinessStr = envConfigPrefix + "STICKINESS"
)
var (
envConfigStickinessOn bool
)
func init() {
envConfigStickinessOn = strings.EqualFold(os.Getenv(envConfigStickinessStr), "on")
}
...@@ -25,7 +25,6 @@ import ( ...@@ -25,7 +25,6 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"os"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
...@@ -48,6 +47,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro ...@@ -48,6 +47,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro
// toRPCErr converts an error into an error from the status package. // toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error { func toRPCErr(err error) error {
if err == nil || err == io.EOF {
return err
}
if _, ok := status.FromError(err); ok { if _, ok := status.FromError(err); ok {
return err return err
} }
...@@ -62,37 +64,7 @@ func toRPCErr(err error) error { ...@@ -62,37 +64,7 @@ func toRPCErr(err error) error {
return status.Error(codes.DeadlineExceeded, err.Error()) return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled: case context.Canceled:
return status.Error(codes.Canceled, err.Error()) return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
} }
} }
return status.Error(codes.Unknown, err.Error()) return status.Error(codes.Unknown, err.Error())
} }
// convertCode converts a standard Go error into its canonical code. Note that
// this is only used to translate the error returned by the server applications.
func convertCode(err error) codes.Code {
switch err {
case nil:
return codes.OK
case io.EOF:
return codes.OutOfRange
case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF:
return codes.FailedPrecondition
case os.ErrInvalid:
return codes.InvalidArgument
case context.Canceled:
return codes.Canceled
case context.DeadlineExceeded:
return codes.DeadlineExceeded
}
switch {
case os.IsExist(err):
return codes.AlreadyExists
case os.IsNotExist(err):
return codes.NotFound
case os.IsPermission(err):
return codes.PermissionDenied
}
return codes.Unknown
}
...@@ -26,7 +26,6 @@ import ( ...@@ -26,7 +26,6 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"os"
netctx "golang.org/x/net/context" netctx "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
...@@ -49,6 +48,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro ...@@ -49,6 +48,9 @@ func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) erro
// toRPCErr converts an error into an error from the status package. // toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error { func toRPCErr(err error) error {
if err == nil || err == io.EOF {
return err
}
if _, ok := status.FromError(err); ok { if _, ok := status.FromError(err); ok {
return err return err
} }
...@@ -63,37 +65,7 @@ func toRPCErr(err error) error { ...@@ -63,37 +65,7 @@ func toRPCErr(err error) error {
return status.Error(codes.DeadlineExceeded, err.Error()) return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled, netctx.Canceled: case context.Canceled, netctx.Canceled:
return status.Error(codes.Canceled, err.Error()) return status.Error(codes.Canceled, err.Error())
case ErrClientConnClosing:
return status.Error(codes.FailedPrecondition, err.Error())
} }
} }
return status.Error(codes.Unknown, err.Error()) return status.Error(codes.Unknown, err.Error())
} }
// convertCode converts a standard Go error into its canonical code. Note that
// this is only used to translate the error returned by the server applications.
func convertCode(err error) codes.Code {
switch err {
case nil:
return codes.OK
case io.EOF:
return codes.OutOfRange
case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF:
return codes.FailedPrecondition
case os.ErrInvalid:
return codes.InvalidArgument
case context.Canceled, netctx.Canceled:
return codes.Canceled
case context.DeadlineExceeded, netctx.DeadlineExceeded:
return codes.DeadlineExceeded
}
switch {
case os.IsExist(err):
return codes.AlreadyExists
case os.IsNotExist(err):
return codes.NotFound
case os.IsPermission(err):
return codes.PermissionDenied
}
return codes.Unknown
}
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"strconv"
"strings"
"sync"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/resolver"
)
const (
lbTokeyKey = "lb-token"
defaultFallbackTimeout = 10 * time.Second
grpclbName = "grpclb"
)
func convertDuration(d *lbpb.Duration) time.Duration {
if d == nil {
return 0
}
return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond
}
// Client API for LoadBalancer service.
// Mostly copied from generated pb.go file.
// To avoid circular dependency.
type loadBalancerClient struct {
cc *ClientConn
}
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...CallOption) (*balanceLoadClientStream, error) {
desc := &StreamDesc{
StreamName: "BalanceLoad",
ServerStreams: true,
ClientStreams: true,
}
stream, err := NewClientStream(ctx, desc, c.cc, "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
if err != nil {
return nil, err
}
x := &balanceLoadClientStream{stream}
return x, nil
}
type balanceLoadClientStream struct {
ClientStream
}
func (x *balanceLoadClientStream) Send(m *lbpb.LoadBalanceRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *balanceLoadClientStream) Recv() (*lbpb.LoadBalanceResponse, error) {
m := new(lbpb.LoadBalanceResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func init() {
balancer.Register(newLBBuilder())
}
// newLBBuilder creates a builder for grpclb.
func newLBBuilder() balancer.Builder {
return NewLBBuilderWithFallbackTimeout(defaultFallbackTimeout)
}
// NewLBBuilderWithFallbackTimeout creates a grpclb builder with the given
// fallbackTimeout. If no response is received from the remote balancer within
// fallbackTimeout, the backend addresses from the resolved address list will be
// used.
//
// Only call this function when a non-default fallback timeout is needed.
func NewLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Builder {
return &lbBuilder{
fallbackTimeout: fallbackTimeout,
}
}
type lbBuilder struct {
fallbackTimeout time.Duration
}
func (b *lbBuilder) Name() string {
return grpclbName
}
func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
// This generates a manual resolver builder with a random scheme. This
// scheme will be used to dial to remote LB, so we can send filtered address
// updates to remote LB ClientConn using this manual resolver.
scheme := "grpclb_internal_" + strconv.FormatInt(time.Now().UnixNano(), 36)
r := &lbManualResolver{scheme: scheme, ccb: cc}
var target string
targetSplitted := strings.Split(cc.Target(), ":///")
if len(targetSplitted) < 2 {
target = cc.Target()
} else {
target = targetSplitted[1]
}
lb := &lbBalancer{
cc: cc,
target: target,
opt: opt,
fallbackTimeout: b.fallbackTimeout,
doneCh: make(chan struct{}),
manualResolver: r,
csEvltr: &connectivityStateEvaluator{},
subConns: make(map[resolver.Address]balancer.SubConn),
scStates: make(map[balancer.SubConn]connectivity.State),
picker: &errPicker{err: balancer.ErrNoSubConnAvailable},
clientStats: &rpcStats{},
}
return lb
}
type lbBalancer struct {
cc balancer.ClientConn
target string
opt balancer.BuildOptions
fallbackTimeout time.Duration
doneCh chan struct{}
// manualResolver is used in the remote LB ClientConn inside grpclb. When
// resolved address updates are received by grpclb, filtered updates will be
// send to remote LB ClientConn through this resolver.
manualResolver *lbManualResolver
// The ClientConn to talk to the remote balancer.
ccRemoteLB *ClientConn
// Support client side load reporting. Each picker gets a reference to this,
// and will update its content.
clientStats *rpcStats
mu sync.Mutex // guards everything following.
// The full server list including drops, used to check if the newly received
// serverList contains anything new. Each generate picker will also have
// reference to this list to do the first layer pick.
fullServerList []*lbpb.Server
// All backends addresses, with metadata set to nil. This list contains all
// backend addresses in the same order and with the same duplicates as in
// serverlist. When generating picker, a SubConn slice with the same order
// but with only READY SCs will be gerenated.
backendAddrs []resolver.Address
// Roundrobin functionalities.
csEvltr *connectivityStateEvaluator
state connectivity.State
subConns map[resolver.Address]balancer.SubConn // Used to new/remove SubConn.
scStates map[balancer.SubConn]connectivity.State // Used to filter READY SubConns.
picker balancer.Picker
// Support fallback to resolved backend addresses if there's no response
// from remote balancer within fallbackTimeout.
fallbackTimerExpired bool
serverListReceived bool
// resolvedBackendAddrs is resolvedAddrs minus remote balancers. It's set
// when resolved address updates are received, and read in the goroutine
// handling fallback.
resolvedBackendAddrs []resolver.Address
}
// regeneratePicker takes a snapshot of the balancer, and generates a picker from
// it. The picker
// - always returns ErrTransientFailure if the balancer is in TransientFailure,
// - does two layer roundrobin pick otherwise.
// Caller must hold lb.mu.
func (lb *lbBalancer) regeneratePicker() {
if lb.state == connectivity.TransientFailure {
lb.picker = &errPicker{err: balancer.ErrTransientFailure}
return
}
var readySCs []balancer.SubConn
for _, a := range lb.backendAddrs {
if sc, ok := lb.subConns[a]; ok {
if st, ok := lb.scStates[sc]; ok && st == connectivity.Ready {
readySCs = append(readySCs, sc)
}
}
}
if len(lb.fullServerList) <= 0 {
if len(readySCs) <= 0 {
lb.picker = &errPicker{err: balancer.ErrNoSubConnAvailable}
return
}
lb.picker = &rrPicker{subConns: readySCs}
return
}
lb.picker = &lbPicker{
serverList: lb.fullServerList,
subConns: readySCs,
stats: lb.clientStats,
}
return
}
func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
grpclog.Infof("lbBalancer: handle SubConn state change: %p, %v", sc, s)
lb.mu.Lock()
defer lb.mu.Unlock()
oldS, ok := lb.scStates[sc]
if !ok {
grpclog.Infof("lbBalancer: got state changes for an unknown SubConn: %p, %v", sc, s)
return
}
lb.scStates[sc] = s
switch s {
case connectivity.Idle:
sc.Connect()
case connectivity.Shutdown:
// When an address was removed by resolver, b called RemoveSubConn but
// kept the sc's state in scStates. Remove state for this sc here.
delete(lb.scStates, sc)
}
oldAggrState := lb.state
lb.state = lb.csEvltr.recordTransition(oldS, s)
// Regenerate picker when one of the following happens:
// - this sc became ready from not-ready
// - this sc became not-ready from ready
// - the aggregated state of balancer became TransientFailure from non-TransientFailure
// - the aggregated state of balancer became non-TransientFailure from TransientFailure
if (oldS == connectivity.Ready) != (s == connectivity.Ready) ||
(lb.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) {
lb.regeneratePicker()
}
lb.cc.UpdateBalancerState(lb.state, lb.picker)
return
}
// fallbackToBackendsAfter blocks for fallbackTimeout and falls back to use
// resolved backends (backends received from resolver, not from remote balancer)
// if no connection to remote balancers was successful.
func (lb *lbBalancer) fallbackToBackendsAfter(fallbackTimeout time.Duration) {
timer := time.NewTimer(fallbackTimeout)
defer timer.Stop()
select {
case <-timer.C:
case <-lb.doneCh:
return
}
lb.mu.Lock()
if lb.serverListReceived {
lb.mu.Unlock()
return
}
lb.fallbackTimerExpired = true
lb.refreshSubConns(lb.resolvedBackendAddrs)
lb.mu.Unlock()
}
// HandleResolvedAddrs sends the updated remoteLB addresses to remoteLB
// clientConn. The remoteLB clientConn will handle creating/removing remoteLB
// connections.
func (lb *lbBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) {
grpclog.Infof("lbBalancer: handleResolvedResult: %+v", addrs)
if len(addrs) <= 0 {
return
}
var remoteBalancerAddrs, backendAddrs []resolver.Address
for _, a := range addrs {
if a.Type == resolver.GRPCLB {
remoteBalancerAddrs = append(remoteBalancerAddrs, a)
} else {
backendAddrs = append(backendAddrs, a)
}
}
if lb.ccRemoteLB == nil {
if len(remoteBalancerAddrs) <= 0 {
grpclog.Errorf("grpclb: no remote balancer address is available, should never happen")
return
}
// First time receiving resolved addresses, create a cc to remote
// balancers.
lb.dialRemoteLB(remoteBalancerAddrs[0].ServerName)
// Start the fallback goroutine.
go lb.fallbackToBackendsAfter(lb.fallbackTimeout)
}
// cc to remote balancers uses lb.manualResolver. Send the updated remote
// balancer addresses to it through manualResolver.
lb.manualResolver.NewAddress(remoteBalancerAddrs)
lb.mu.Lock()
lb.resolvedBackendAddrs = backendAddrs
// If serverListReceived is true, connection to remote balancer was
// successful and there's no need to do fallback anymore.
// If fallbackTimerExpired is false, fallback hasn't happened yet.
if !lb.serverListReceived && lb.fallbackTimerExpired {
// This means we received a new list of resolved backends, and we are
// still in fallback mode. Need to update the list of backends we are
// using to the new list of backends.
lb.refreshSubConns(lb.resolvedBackendAddrs)
}
lb.mu.Unlock()
}
func (lb *lbBalancer) Close() {
select {
case <-lb.doneCh:
return
default:
}
close(lb.doneCh)
if lb.ccRemoteLB != nil {
lb.ccRemoteLB.Close()
}
}
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: grpc_lb_v1/messages/messages.proto
/*
Package messages is a generated protocol buffer package.
It is generated from these files:
grpc_lb_v1/messages/messages.proto
It has these top-level messages:
Duration
Timestamp
LoadBalanceRequest
InitialLoadBalanceRequest
ClientStats
LoadBalanceResponse
InitialLoadBalanceResponse
ServerList
Server
*/
package messages
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type Duration struct {
// Signed seconds of the span of time. Must be from -315,576,000,000
// to +315,576,000,000 inclusive.
Seconds int64 `protobuf:"varint,1,opt,name=seconds" json:"seconds,omitempty"`
// Signed fractions of a second at nanosecond resolution of the span
// of time. Durations less than one second are represented with a 0
// `seconds` field and a positive or negative `nanos` field. For durations
// of one second or more, a non-zero value for the `nanos` field must be
// of the same sign as the `seconds` field. Must be from -999,999,999
// to +999,999,999 inclusive.
Nanos int32 `protobuf:"varint,2,opt,name=nanos" json:"nanos,omitempty"`
}
func (m *Duration) Reset() { *m = Duration{} }
func (m *Duration) String() string { return proto.CompactTextString(m) }
func (*Duration) ProtoMessage() {}
func (*Duration) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *Duration) GetSeconds() int64 {
if m != nil {
return m.Seconds
}
return 0
}
func (m *Duration) GetNanos() int32 {
if m != nil {
return m.Nanos
}
return 0
}
type Timestamp struct {
// Represents seconds of UTC time since Unix epoch
// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to
// 9999-12-31T23:59:59Z inclusive.
Seconds int64 `protobuf:"varint,1,opt,name=seconds" json:"seconds,omitempty"`
// Non-negative fractions of a second at nanosecond resolution. Negative
// second values with fractions must still have non-negative nanos values
// that count forward in time. Must be from 0 to 999,999,999
// inclusive.
Nanos int32 `protobuf:"varint,2,opt,name=nanos" json:"nanos,omitempty"`
}
func (m *Timestamp) Reset() { *m = Timestamp{} }
func (m *Timestamp) String() string { return proto.CompactTextString(m) }
func (*Timestamp) ProtoMessage() {}
func (*Timestamp) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *Timestamp) GetSeconds() int64 {
if m != nil {
return m.Seconds
}
return 0
}
func (m *Timestamp) GetNanos() int32 {
if m != nil {
return m.Nanos
}
return 0
}
type LoadBalanceRequest struct {
// Types that are valid to be assigned to LoadBalanceRequestType:
// *LoadBalanceRequest_InitialRequest
// *LoadBalanceRequest_ClientStats
LoadBalanceRequestType isLoadBalanceRequest_LoadBalanceRequestType `protobuf_oneof:"load_balance_request_type"`
}
func (m *LoadBalanceRequest) Reset() { *m = LoadBalanceRequest{} }
func (m *LoadBalanceRequest) String() string { return proto.CompactTextString(m) }
func (*LoadBalanceRequest) ProtoMessage() {}
func (*LoadBalanceRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
type isLoadBalanceRequest_LoadBalanceRequestType interface {
isLoadBalanceRequest_LoadBalanceRequestType()
}
type LoadBalanceRequest_InitialRequest struct {
InitialRequest *InitialLoadBalanceRequest `protobuf:"bytes,1,opt,name=initial_request,json=initialRequest,oneof"`
}
type LoadBalanceRequest_ClientStats struct {
ClientStats *ClientStats `protobuf:"bytes,2,opt,name=client_stats,json=clientStats,oneof"`
}
func (*LoadBalanceRequest_InitialRequest) isLoadBalanceRequest_LoadBalanceRequestType() {}
func (*LoadBalanceRequest_ClientStats) isLoadBalanceRequest_LoadBalanceRequestType() {}
func (m *LoadBalanceRequest) GetLoadBalanceRequestType() isLoadBalanceRequest_LoadBalanceRequestType {
if m != nil {
return m.LoadBalanceRequestType
}
return nil
}
func (m *LoadBalanceRequest) GetInitialRequest() *InitialLoadBalanceRequest {
if x, ok := m.GetLoadBalanceRequestType().(*LoadBalanceRequest_InitialRequest); ok {
return x.InitialRequest
}
return nil
}
func (m *LoadBalanceRequest) GetClientStats() *ClientStats {
if x, ok := m.GetLoadBalanceRequestType().(*LoadBalanceRequest_ClientStats); ok {
return x.ClientStats
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*LoadBalanceRequest) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _LoadBalanceRequest_OneofMarshaler, _LoadBalanceRequest_OneofUnmarshaler, _LoadBalanceRequest_OneofSizer, []interface{}{
(*LoadBalanceRequest_InitialRequest)(nil),
(*LoadBalanceRequest_ClientStats)(nil),
}
}
func _LoadBalanceRequest_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*LoadBalanceRequest)
// load_balance_request_type
switch x := m.LoadBalanceRequestType.(type) {
case *LoadBalanceRequest_InitialRequest:
b.EncodeVarint(1<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.InitialRequest); err != nil {
return err
}
case *LoadBalanceRequest_ClientStats:
b.EncodeVarint(2<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ClientStats); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("LoadBalanceRequest.LoadBalanceRequestType has unexpected type %T", x)
}
return nil
}
func _LoadBalanceRequest_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*LoadBalanceRequest)
switch tag {
case 1: // load_balance_request_type.initial_request
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(InitialLoadBalanceRequest)
err := b.DecodeMessage(msg)
m.LoadBalanceRequestType = &LoadBalanceRequest_InitialRequest{msg}
return true, err
case 2: // load_balance_request_type.client_stats
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ClientStats)
err := b.DecodeMessage(msg)
m.LoadBalanceRequestType = &LoadBalanceRequest_ClientStats{msg}
return true, err
default:
return false, nil
}
}
func _LoadBalanceRequest_OneofSizer(msg proto.Message) (n int) {
m := msg.(*LoadBalanceRequest)
// load_balance_request_type
switch x := m.LoadBalanceRequestType.(type) {
case *LoadBalanceRequest_InitialRequest:
s := proto.Size(x.InitialRequest)
n += proto.SizeVarint(1<<3 | proto.WireBytes)
n += proto.SizeVarint(uint64(s))
n += s
case *LoadBalanceRequest_ClientStats:
s := proto.Size(x.ClientStats)
n += proto.SizeVarint(2<<3 | proto.WireBytes)
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
type InitialLoadBalanceRequest struct {
// Name of load balanced service (IE, balancer.service.com)
// length should be less than 256 bytes.
Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"`
}
func (m *InitialLoadBalanceRequest) Reset() { *m = InitialLoadBalanceRequest{} }
func (m *InitialLoadBalanceRequest) String() string { return proto.CompactTextString(m) }
func (*InitialLoadBalanceRequest) ProtoMessage() {}
func (*InitialLoadBalanceRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *InitialLoadBalanceRequest) GetName() string {
if m != nil {
return m.Name
}
return ""
}
// Contains client level statistics that are useful to load balancing. Each
// count except the timestamp should be reset to zero after reporting the stats.
type ClientStats struct {
// The timestamp of generating the report.
Timestamp *Timestamp `protobuf:"bytes,1,opt,name=timestamp" json:"timestamp,omitempty"`
// The total number of RPCs that started.
NumCallsStarted int64 `protobuf:"varint,2,opt,name=num_calls_started,json=numCallsStarted" json:"num_calls_started,omitempty"`
// The total number of RPCs that finished.
NumCallsFinished int64 `protobuf:"varint,3,opt,name=num_calls_finished,json=numCallsFinished" json:"num_calls_finished,omitempty"`
// The total number of RPCs that were dropped by the client because of rate
// limiting.
NumCallsFinishedWithDropForRateLimiting int64 `protobuf:"varint,4,opt,name=num_calls_finished_with_drop_for_rate_limiting,json=numCallsFinishedWithDropForRateLimiting" json:"num_calls_finished_with_drop_for_rate_limiting,omitempty"`
// The total number of RPCs that were dropped by the client because of load
// balancing.
NumCallsFinishedWithDropForLoadBalancing int64 `protobuf:"varint,5,opt,name=num_calls_finished_with_drop_for_load_balancing,json=numCallsFinishedWithDropForLoadBalancing" json:"num_calls_finished_with_drop_for_load_balancing,omitempty"`
// The total number of RPCs that failed to reach a server except dropped RPCs.
NumCallsFinishedWithClientFailedToSend int64 `protobuf:"varint,6,opt,name=num_calls_finished_with_client_failed_to_send,json=numCallsFinishedWithClientFailedToSend" json:"num_calls_finished_with_client_failed_to_send,omitempty"`
// The total number of RPCs that finished and are known to have been received
// by a server.
NumCallsFinishedKnownReceived int64 `protobuf:"varint,7,opt,name=num_calls_finished_known_received,json=numCallsFinishedKnownReceived" json:"num_calls_finished_known_received,omitempty"`
}
func (m *ClientStats) Reset() { *m = ClientStats{} }
func (m *ClientStats) String() string { return proto.CompactTextString(m) }
func (*ClientStats) ProtoMessage() {}
func (*ClientStats) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
func (m *ClientStats) GetTimestamp() *Timestamp {
if m != nil {
return m.Timestamp
}
return nil
}
func (m *ClientStats) GetNumCallsStarted() int64 {
if m != nil {
return m.NumCallsStarted
}
return 0
}
func (m *ClientStats) GetNumCallsFinished() int64 {
if m != nil {
return m.NumCallsFinished
}
return 0
}
func (m *ClientStats) GetNumCallsFinishedWithDropForRateLimiting() int64 {
if m != nil {
return m.NumCallsFinishedWithDropForRateLimiting
}
return 0
}
func (m *ClientStats) GetNumCallsFinishedWithDropForLoadBalancing() int64 {
if m != nil {
return m.NumCallsFinishedWithDropForLoadBalancing
}
return 0
}
func (m *ClientStats) GetNumCallsFinishedWithClientFailedToSend() int64 {
if m != nil {
return m.NumCallsFinishedWithClientFailedToSend
}
return 0
}
func (m *ClientStats) GetNumCallsFinishedKnownReceived() int64 {
if m != nil {
return m.NumCallsFinishedKnownReceived
}
return 0
}
type LoadBalanceResponse struct {
// Types that are valid to be assigned to LoadBalanceResponseType:
// *LoadBalanceResponse_InitialResponse
// *LoadBalanceResponse_ServerList
LoadBalanceResponseType isLoadBalanceResponse_LoadBalanceResponseType `protobuf_oneof:"load_balance_response_type"`
}
func (m *LoadBalanceResponse) Reset() { *m = LoadBalanceResponse{} }
func (m *LoadBalanceResponse) String() string { return proto.CompactTextString(m) }
func (*LoadBalanceResponse) ProtoMessage() {}
func (*LoadBalanceResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
type isLoadBalanceResponse_LoadBalanceResponseType interface {
isLoadBalanceResponse_LoadBalanceResponseType()
}
type LoadBalanceResponse_InitialResponse struct {
InitialResponse *InitialLoadBalanceResponse `protobuf:"bytes,1,opt,name=initial_response,json=initialResponse,oneof"`
}
type LoadBalanceResponse_ServerList struct {
ServerList *ServerList `protobuf:"bytes,2,opt,name=server_list,json=serverList,oneof"`
}
func (*LoadBalanceResponse_InitialResponse) isLoadBalanceResponse_LoadBalanceResponseType() {}
func (*LoadBalanceResponse_ServerList) isLoadBalanceResponse_LoadBalanceResponseType() {}
func (m *LoadBalanceResponse) GetLoadBalanceResponseType() isLoadBalanceResponse_LoadBalanceResponseType {
if m != nil {
return m.LoadBalanceResponseType
}
return nil
}
func (m *LoadBalanceResponse) GetInitialResponse() *InitialLoadBalanceResponse {
if x, ok := m.GetLoadBalanceResponseType().(*LoadBalanceResponse_InitialResponse); ok {
return x.InitialResponse
}
return nil
}
func (m *LoadBalanceResponse) GetServerList() *ServerList {
if x, ok := m.GetLoadBalanceResponseType().(*LoadBalanceResponse_ServerList); ok {
return x.ServerList
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*LoadBalanceResponse) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _LoadBalanceResponse_OneofMarshaler, _LoadBalanceResponse_OneofUnmarshaler, _LoadBalanceResponse_OneofSizer, []interface{}{
(*LoadBalanceResponse_InitialResponse)(nil),
(*LoadBalanceResponse_ServerList)(nil),
}
}
func _LoadBalanceResponse_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*LoadBalanceResponse)
// load_balance_response_type
switch x := m.LoadBalanceResponseType.(type) {
case *LoadBalanceResponse_InitialResponse:
b.EncodeVarint(1<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.InitialResponse); err != nil {
return err
}
case *LoadBalanceResponse_ServerList:
b.EncodeVarint(2<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ServerList); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("LoadBalanceResponse.LoadBalanceResponseType has unexpected type %T", x)
}
return nil
}
func _LoadBalanceResponse_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*LoadBalanceResponse)
switch tag {
case 1: // load_balance_response_type.initial_response
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(InitialLoadBalanceResponse)
err := b.DecodeMessage(msg)
m.LoadBalanceResponseType = &LoadBalanceResponse_InitialResponse{msg}
return true, err
case 2: // load_balance_response_type.server_list
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ServerList)
err := b.DecodeMessage(msg)
m.LoadBalanceResponseType = &LoadBalanceResponse_ServerList{msg}
return true, err
default:
return false, nil
}
}
func _LoadBalanceResponse_OneofSizer(msg proto.Message) (n int) {
m := msg.(*LoadBalanceResponse)
// load_balance_response_type
switch x := m.LoadBalanceResponseType.(type) {
case *LoadBalanceResponse_InitialResponse:
s := proto.Size(x.InitialResponse)
n += proto.SizeVarint(1<<3 | proto.WireBytes)
n += proto.SizeVarint(uint64(s))
n += s
case *LoadBalanceResponse_ServerList:
s := proto.Size(x.ServerList)
n += proto.SizeVarint(2<<3 | proto.WireBytes)
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
type InitialLoadBalanceResponse struct {
// This is an application layer redirect that indicates the client should use
// the specified server for load balancing. When this field is non-empty in
// the response, the client should open a separate connection to the
// load_balancer_delegate and call the BalanceLoad method. Its length should
// be less than 64 bytes.
LoadBalancerDelegate string `protobuf:"bytes,1,opt,name=load_balancer_delegate,json=loadBalancerDelegate" json:"load_balancer_delegate,omitempty"`
// This interval defines how often the client should send the client stats
// to the load balancer. Stats should only be reported when the duration is
// positive.
ClientStatsReportInterval *Duration `protobuf:"bytes,2,opt,name=client_stats_report_interval,json=clientStatsReportInterval" json:"client_stats_report_interval,omitempty"`
}
func (m *InitialLoadBalanceResponse) Reset() { *m = InitialLoadBalanceResponse{} }
func (m *InitialLoadBalanceResponse) String() string { return proto.CompactTextString(m) }
func (*InitialLoadBalanceResponse) ProtoMessage() {}
func (*InitialLoadBalanceResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
func (m *InitialLoadBalanceResponse) GetLoadBalancerDelegate() string {
if m != nil {
return m.LoadBalancerDelegate
}
return ""
}
func (m *InitialLoadBalanceResponse) GetClientStatsReportInterval() *Duration {
if m != nil {
return m.ClientStatsReportInterval
}
return nil
}
type ServerList struct {
// Contains a list of servers selected by the load balancer. The list will
// be updated when server resolutions change or as needed to balance load
// across more servers. The client should consume the server list in order
// unless instructed otherwise via the client_config.
Servers []*Server `protobuf:"bytes,1,rep,name=servers" json:"servers,omitempty"`
}
func (m *ServerList) Reset() { *m = ServerList{} }
func (m *ServerList) String() string { return proto.CompactTextString(m) }
func (*ServerList) ProtoMessage() {}
func (*ServerList) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
func (m *ServerList) GetServers() []*Server {
if m != nil {
return m.Servers
}
return nil
}
// Contains server information. When none of the [drop_for_*] fields are true,
// use the other fields. When drop_for_rate_limiting is true, ignore all other
// fields. Use drop_for_load_balancing only when it is true and
// drop_for_rate_limiting is false.
type Server struct {
// A resolved address for the server, serialized in network-byte-order. It may
// either be an IPv4 or IPv6 address.
IpAddress []byte `protobuf:"bytes,1,opt,name=ip_address,json=ipAddress,proto3" json:"ip_address,omitempty"`
// A resolved port number for the server.
Port int32 `protobuf:"varint,2,opt,name=port" json:"port,omitempty"`
// An opaque but printable token given to the frontend for each pick. All
// frontend requests for that pick must include the token in its initial
// metadata. The token is used by the backend to verify the request and to
// allow the backend to report load to the gRPC LB system.
//
// Its length is variable but less than 50 bytes.
LoadBalanceToken string `protobuf:"bytes,3,opt,name=load_balance_token,json=loadBalanceToken" json:"load_balance_token,omitempty"`
// Indicates whether this particular request should be dropped by the client
// for rate limiting.
DropForRateLimiting bool `protobuf:"varint,4,opt,name=drop_for_rate_limiting,json=dropForRateLimiting" json:"drop_for_rate_limiting,omitempty"`
// Indicates whether this particular request should be dropped by the client
// for load balancing.
DropForLoadBalancing bool `protobuf:"varint,5,opt,name=drop_for_load_balancing,json=dropForLoadBalancing" json:"drop_for_load_balancing,omitempty"`
}
func (m *Server) Reset() { *m = Server{} }
func (m *Server) String() string { return proto.CompactTextString(m) }
func (*Server) ProtoMessage() {}
func (*Server) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
func (m *Server) GetIpAddress() []byte {
if m != nil {
return m.IpAddress
}
return nil
}
func (m *Server) GetPort() int32 {
if m != nil {
return m.Port
}
return 0
}
func (m *Server) GetLoadBalanceToken() string {
if m != nil {
return m.LoadBalanceToken
}
return ""
}
func (m *Server) GetDropForRateLimiting() bool {
if m != nil {
return m.DropForRateLimiting
}
return false
}
func (m *Server) GetDropForLoadBalancing() bool {
if m != nil {
return m.DropForLoadBalancing
}
return false
}
func init() {
proto.RegisterType((*Duration)(nil), "grpc.lb.v1.Duration")
proto.RegisterType((*Timestamp)(nil), "grpc.lb.v1.Timestamp")
proto.RegisterType((*LoadBalanceRequest)(nil), "grpc.lb.v1.LoadBalanceRequest")
proto.RegisterType((*InitialLoadBalanceRequest)(nil), "grpc.lb.v1.InitialLoadBalanceRequest")
proto.RegisterType((*ClientStats)(nil), "grpc.lb.v1.ClientStats")
proto.RegisterType((*LoadBalanceResponse)(nil), "grpc.lb.v1.LoadBalanceResponse")
proto.RegisterType((*InitialLoadBalanceResponse)(nil), "grpc.lb.v1.InitialLoadBalanceResponse")
proto.RegisterType((*ServerList)(nil), "grpc.lb.v1.ServerList")
proto.RegisterType((*Server)(nil), "grpc.lb.v1.Server")
}
func init() { proto.RegisterFile("grpc_lb_v1/messages/messages.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 709 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x55, 0xdd, 0x4e, 0x1b, 0x3b,
0x10, 0x26, 0x27, 0x01, 0x92, 0x09, 0x3a, 0xe4, 0x98, 0x1c, 0x08, 0x14, 0x24, 0xba, 0x52, 0x69,
0x54, 0xd1, 0x20, 0xa0, 0xbd, 0xe8, 0xcf, 0x45, 0x1b, 0x10, 0x0a, 0x2d, 0x17, 0x95, 0x43, 0x55,
0xa9, 0x52, 0x65, 0x39, 0xd9, 0x21, 0x58, 0x6c, 0xec, 0xad, 0xed, 0x04, 0xf5, 0x11, 0xfa, 0x28,
0x7d, 0x8c, 0xaa, 0xcf, 0xd0, 0xf7, 0xa9, 0xd6, 0xbb, 0x9b, 0x5d, 0x20, 0x80, 0x7a, 0x67, 0x8f,
0xbf, 0xf9, 0xbe, 0xf1, 0xac, 0xbf, 0x59, 0xf0, 0x06, 0x3a, 0xec, 0xb3, 0xa0, 0xc7, 0xc6, 0xbb,
0x3b, 0x43, 0x34, 0x86, 0x0f, 0xd0, 0x4c, 0x16, 0xad, 0x50, 0x2b, 0xab, 0x08, 0x44, 0x98, 0x56,
0xd0, 0x6b, 0x8d, 0x77, 0xbd, 0x97, 0x50, 0x3e, 0x1c, 0x69, 0x6e, 0x85, 0x92, 0xa4, 0x01, 0xf3,
0x06, 0xfb, 0x4a, 0xfa, 0xa6, 0x51, 0xd8, 0x2c, 0x34, 0x8b, 0x34, 0xdd, 0x92, 0x3a, 0xcc, 0x4a,
0x2e, 0x95, 0x69, 0xfc, 0xb3, 0x59, 0x68, 0xce, 0xd2, 0x78, 0xe3, 0xbd, 0x82, 0xca, 0xa9, 0x18,
0xa2, 0xb1, 0x7c, 0x18, 0xfe, 0x75, 0xf2, 0xcf, 0x02, 0x90, 0x13, 0xc5, 0xfd, 0x36, 0x0f, 0xb8,
0xec, 0x23, 0xc5, 0xaf, 0x23, 0x34, 0x96, 0x7c, 0x80, 0x45, 0x21, 0x85, 0x15, 0x3c, 0x60, 0x3a,
0x0e, 0x39, 0xba, 0xea, 0xde, 0xa3, 0x56, 0x56, 0x75, 0xeb, 0x38, 0x86, 0xdc, 0xcc, 0xef, 0xcc,
0xd0, 0x7f, 0x93, 0xfc, 0x94, 0xf1, 0x35, 0x2c, 0xf4, 0x03, 0x81, 0xd2, 0x32, 0x63, 0xb9, 0x8d,
0xab, 0xa8, 0xee, 0xad, 0xe4, 0xe9, 0x0e, 0xdc, 0x79, 0x37, 0x3a, 0xee, 0xcc, 0xd0, 0x6a, 0x3f,
0xdb, 0xb6, 0x1f, 0xc0, 0x6a, 0xa0, 0xb8, 0xcf, 0x7a, 0xb1, 0x4c, 0x5a, 0x14, 0xb3, 0xdf, 0x42,
0xf4, 0x76, 0x60, 0xf5, 0xd6, 0x4a, 0x08, 0x81, 0x92, 0xe4, 0x43, 0x74, 0xe5, 0x57, 0xa8, 0x5b,
0x7b, 0xdf, 0x4b, 0x50, 0xcd, 0x89, 0x91, 0x7d, 0xa8, 0xd8, 0xb4, 0x83, 0xc9, 0x3d, 0xff, 0xcf,
0x17, 0x36, 0x69, 0x2f, 0xcd, 0x70, 0xe4, 0x09, 0xfc, 0x27, 0x47, 0x43, 0xd6, 0xe7, 0x41, 0x60,
0xa2, 0x3b, 0x69, 0x8b, 0xbe, 0xbb, 0x55, 0x91, 0x2e, 0xca, 0xd1, 0xf0, 0x20, 0x8a, 0x77, 0xe3,
0x30, 0xd9, 0x06, 0x92, 0x61, 0xcf, 0x84, 0x14, 0xe6, 0x1c, 0xfd, 0x46, 0xd1, 0x81, 0x6b, 0x29,
0xf8, 0x28, 0x89, 0x13, 0x06, 0xad, 0x9b, 0x68, 0x76, 0x29, 0xec, 0x39, 0xf3, 0xb5, 0x0a, 0xd9,
0x99, 0xd2, 0x4c, 0x73, 0x8b, 0x2c, 0x10, 0x43, 0x61, 0x85, 0x1c, 0x34, 0x4a, 0x8e, 0xe9, 0xf1,
0x75, 0xa6, 0x4f, 0xc2, 0x9e, 0x1f, 0x6a, 0x15, 0x1e, 0x29, 0x4d, 0xb9, 0xc5, 0x93, 0x04, 0x4e,
0x38, 0xec, 0xdc, 0x2b, 0x90, 0x6b, 0x77, 0xa4, 0x30, 0xeb, 0x14, 0x9a, 0x77, 0x28, 0x64, 0xbd,
0x8f, 0x24, 0xbe, 0xc0, 0xd3, 0xdb, 0x24, 0x92, 0x67, 0x70, 0xc6, 0x45, 0x80, 0x3e, 0xb3, 0x8a,
0x19, 0x94, 0x7e, 0x63, 0xce, 0x09, 0x6c, 0x4d, 0x13, 0x88, 0x3f, 0xd5, 0x91, 0xc3, 0x9f, 0xaa,
0x2e, 0x4a, 0x9f, 0x74, 0xe0, 0xe1, 0x14, 0xfa, 0x0b, 0xa9, 0x2e, 0x25, 0xd3, 0xd8, 0x47, 0x31,
0x46, 0xbf, 0x31, 0xef, 0x28, 0x37, 0xae, 0x53, 0xbe, 0x8f, 0x50, 0x34, 0x01, 0x79, 0xbf, 0x0a,
0xb0, 0x74, 0xe5, 0xd9, 0x98, 0x50, 0x49, 0x83, 0xa4, 0x0b, 0xb5, 0xcc, 0x01, 0x71, 0x2c, 0x79,
0x1a, 0x5b, 0xf7, 0x59, 0x20, 0x46, 0x77, 0x66, 0xe8, 0xe2, 0xc4, 0x03, 0x09, 0xe9, 0x0b, 0xa8,
0x1a, 0xd4, 0x63, 0xd4, 0x2c, 0x10, 0xc6, 0x26, 0x1e, 0x58, 0xce, 0xf3, 0x75, 0xdd, 0xf1, 0x89,
0x70, 0x1e, 0x02, 0x33, 0xd9, 0xb5, 0xd7, 0x61, 0xed, 0x9a, 0x03, 0x62, 0xce, 0xd8, 0x02, 0x3f,
0x0a, 0xb0, 0x76, 0x7b, 0x29, 0xe4, 0x19, 0x2c, 0xe7, 0x93, 0x35, 0xf3, 0x31, 0xc0, 0x01, 0xb7,
0xa9, 0x2d, 0xea, 0x41, 0x96, 0xa4, 0x0f, 0x93, 0x33, 0xf2, 0x11, 0xd6, 0xf3, 0x96, 0x65, 0x1a,
0x43, 0xa5, 0x2d, 0x13, 0xd2, 0xa2, 0x1e, 0xf3, 0x20, 0x29, 0xbf, 0x9e, 0x2f, 0x3f, 0x1d, 0x62,
0x74, 0x35, 0xe7, 0x5e, 0xea, 0xf2, 0x8e, 0x93, 0x34, 0xef, 0x0d, 0x40, 0x76, 0x4b, 0xb2, 0x1d,
0x0d, 0xac, 0x68, 0x17, 0x0d, 0xac, 0x62, 0xb3, 0xba, 0x47, 0x6e, 0xb6, 0x83, 0xa6, 0x90, 0x77,
0xa5, 0x72, 0xb1, 0x56, 0xf2, 0x7e, 0x17, 0x60, 0x2e, 0x3e, 0x21, 0x1b, 0x00, 0x22, 0x64, 0xdc,
0xf7, 0x35, 0x9a, 0x78, 0xe4, 0x2d, 0xd0, 0x8a, 0x08, 0xdf, 0xc6, 0x81, 0xc8, 0xfd, 0x91, 0x76,
0x32, 0xf3, 0xdc, 0x3a, 0x32, 0xe3, 0x95, 0x4e, 0x5a, 0x75, 0x81, 0xd2, 0x99, 0xb1, 0x42, 0x6b,
0xb9, 0x46, 0x9c, 0x46, 0x71, 0xb2, 0x0f, 0xcb, 0x77, 0x98, 0xae, 0x4c, 0x97, 0xfc, 0x29, 0x06,
0x7b, 0x0e, 0x2b, 0x77, 0x19, 0xa9, 0x4c, 0xeb, 0xfe, 0x14, 0xd3, 0xb4, 0xe1, 0x73, 0x39, 0xfd,
0x47, 0xf4, 0xe6, 0xdc, 0x4f, 0x62, 0xff, 0x4f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xa3, 0x36, 0x86,
0xa6, 0x4a, 0x06, 0x00, 0x00,
}
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package grpc.lb.v1;
option go_package = "google.golang.org/grpc/grpclb/grpc_lb_v1/messages";
message Duration {
// Signed seconds of the span of time. Must be from -315,576,000,000
// to +315,576,000,000 inclusive.
int64 seconds = 1;
// Signed fractions of a second at nanosecond resolution of the span
// of time. Durations less than one second are represented with a 0
// `seconds` field and a positive or negative `nanos` field. For durations
// of one second or more, a non-zero value for the `nanos` field must be
// of the same sign as the `seconds` field. Must be from -999,999,999
// to +999,999,999 inclusive.
int32 nanos = 2;
}
message Timestamp {
// Represents seconds of UTC time since Unix epoch
// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to
// 9999-12-31T23:59:59Z inclusive.
int64 seconds = 1;
// Non-negative fractions of a second at nanosecond resolution. Negative
// second values with fractions must still have non-negative nanos values
// that count forward in time. Must be from 0 to 999,999,999
// inclusive.
int32 nanos = 2;
}
message LoadBalanceRequest {
oneof load_balance_request_type {
// This message should be sent on the first request to the load balancer.
InitialLoadBalanceRequest initial_request = 1;
// The client stats should be periodically reported to the load balancer
// based on the duration defined in the InitialLoadBalanceResponse.
ClientStats client_stats = 2;
}
}
message InitialLoadBalanceRequest {
// Name of load balanced service (IE, balancer.service.com)
// length should be less than 256 bytes.
string name = 1;
}
// Contains client level statistics that are useful to load balancing. Each
// count except the timestamp should be reset to zero after reporting the stats.
message ClientStats {
// The timestamp of generating the report.
Timestamp timestamp = 1;
// The total number of RPCs that started.
int64 num_calls_started = 2;
// The total number of RPCs that finished.
int64 num_calls_finished = 3;
// The total number of RPCs that were dropped by the client because of rate
// limiting.
int64 num_calls_finished_with_drop_for_rate_limiting = 4;
// The total number of RPCs that were dropped by the client because of load
// balancing.
int64 num_calls_finished_with_drop_for_load_balancing = 5;
// The total number of RPCs that failed to reach a server except dropped RPCs.
int64 num_calls_finished_with_client_failed_to_send = 6;
// The total number of RPCs that finished and are known to have been received
// by a server.
int64 num_calls_finished_known_received = 7;
}
message LoadBalanceResponse {
oneof load_balance_response_type {
// This message should be sent on the first response to the client.
InitialLoadBalanceResponse initial_response = 1;
// Contains the list of servers selected by the load balancer. The client
// should send requests to these servers in the specified order.
ServerList server_list = 2;
}
}
message InitialLoadBalanceResponse {
// This is an application layer redirect that indicates the client should use
// the specified server for load balancing. When this field is non-empty in
// the response, the client should open a separate connection to the
// load_balancer_delegate and call the BalanceLoad method. Its length should
// be less than 64 bytes.
string load_balancer_delegate = 1;
// This interval defines how often the client should send the client stats
// to the load balancer. Stats should only be reported when the duration is
// positive.
Duration client_stats_report_interval = 2;
}
message ServerList {
// Contains a list of servers selected by the load balancer. The list will
// be updated when server resolutions change or as needed to balance load
// across more servers. The client should consume the server list in order
// unless instructed otherwise via the client_config.
repeated Server servers = 1;
// Was google.protobuf.Duration expiration_interval.
reserved 3;
}
// Contains server information. When none of the [drop_for_*] fields are true,
// use the other fields. When drop_for_rate_limiting is true, ignore all other
// fields. Use drop_for_load_balancing only when it is true and
// drop_for_rate_limiting is false.
message Server {
// A resolved address for the server, serialized in network-byte-order. It may
// either be an IPv4 or IPv6 address.
bytes ip_address = 1;
// A resolved port number for the server.
int32 port = 2;
// An opaque but printable token given to the frontend for each pick. All
// frontend requests for that pick must include the token in its initial
// metadata. The token is used by the backend to verify the request and to
// allow the backend to report load to the gRPC LB system.
//
// Its length is variable but less than 50 bytes.
string load_balance_token = 3;
// Indicates whether this particular request should be dropped by the client
// for rate limiting.
bool drop_for_rate_limiting = 4;
// Indicates whether this particular request should be dropped by the client
// for load balancing.
bool drop_for_load_balancing = 5;
}
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"sync"
"sync/atomic"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/status"
)
type rpcStats struct {
NumCallsStarted int64
NumCallsFinished int64
NumCallsFinishedWithDropForRateLimiting int64
NumCallsFinishedWithDropForLoadBalancing int64
NumCallsFinishedWithClientFailedToSend int64
NumCallsFinishedKnownReceived int64
}
// toClientStats converts rpcStats to lbpb.ClientStats, and clears rpcStats.
func (s *rpcStats) toClientStats() *lbpb.ClientStats {
stats := &lbpb.ClientStats{
NumCallsStarted: atomic.SwapInt64(&s.NumCallsStarted, 0),
NumCallsFinished: atomic.SwapInt64(&s.NumCallsFinished, 0),
NumCallsFinishedWithDropForRateLimiting: atomic.SwapInt64(&s.NumCallsFinishedWithDropForRateLimiting, 0),
NumCallsFinishedWithDropForLoadBalancing: atomic.SwapInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 0),
NumCallsFinishedWithClientFailedToSend: atomic.SwapInt64(&s.NumCallsFinishedWithClientFailedToSend, 0),
NumCallsFinishedKnownReceived: atomic.SwapInt64(&s.NumCallsFinishedKnownReceived, 0),
}
return stats
}
func (s *rpcStats) dropForRateLimiting() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithDropForRateLimiting, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
}
func (s *rpcStats) dropForLoadBalancing() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
}
func (s *rpcStats) failedToSend() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithClientFailedToSend, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
}
func (s *rpcStats) knownReceived() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedKnownReceived, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
}
type errPicker struct {
// Pick always returns this err.
err error
}
func (p *errPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
return nil, nil, p.err
}
// rrPicker does roundrobin on subConns. It's typically used when there's no
// response from remote balancer, and grpclb falls back to the resolved
// backends.
//
// It guaranteed that len(subConns) > 0.
type rrPicker struct {
mu sync.Mutex
subConns []balancer.SubConn // The subConns that were READY when taking the snapshot.
subConnsNext int
}
func (p *rrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
p.mu.Lock()
defer p.mu.Unlock()
sc := p.subConns[p.subConnsNext]
p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns)
return sc, nil, nil
}
// lbPicker does two layers of picks:
//
// First layer: roundrobin on all servers in serverList, including drops and backends.
// - If it picks a drop, the RPC will fail as being dropped.
// - If it picks a backend, do a second layer pick to pick the real backend.
//
// Second layer: roundrobin on all READY backends.
//
// It's guaranteed that len(serverList) > 0.
type lbPicker struct {
mu sync.Mutex
serverList []*lbpb.Server
serverListNext int
subConns []balancer.SubConn // The subConns that were READY when taking the snapshot.
subConnsNext int
stats *rpcStats
}
func (p *lbPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
p.mu.Lock()
defer p.mu.Unlock()
// Layer one roundrobin on serverList.
s := p.serverList[p.serverListNext]
p.serverListNext = (p.serverListNext + 1) % len(p.serverList)
// If it's a drop, return an error and fail the RPC.
if s.DropForRateLimiting {
p.stats.dropForRateLimiting()
return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb")
}
if s.DropForLoadBalancing {
p.stats.dropForLoadBalancing()
return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb")
}
// If not a drop but there's no ready subConns.
if len(p.subConns) <= 0 {
return nil, nil, balancer.ErrNoSubConnAvailable
}
// Return the next ready subConn in the list, also collect rpc stats.
sc := p.subConns[p.subConnsNext]
p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns)
done := func(info balancer.DoneInfo) {
if !info.BytesSent {
p.stats.failedToSend()
} else if info.BytesReceived {
p.stats.knownReceived()
}
}
return sc, done, nil
}
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"fmt"
"net"
"reflect"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
)
// processServerList updates balaner's internal state, create/remove SubConns
// and regenerates picker using the received serverList.
func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
grpclog.Infof("lbBalancer: processing server list: %+v", l)
lb.mu.Lock()
defer lb.mu.Unlock()
// Set serverListReceived to true so fallback will not take effect if it has
// not hit timeout.
lb.serverListReceived = true
// If the new server list == old server list, do nothing.
if reflect.DeepEqual(lb.fullServerList, l.Servers) {
grpclog.Infof("lbBalancer: new serverlist same as the previous one, ignoring")
return
}
lb.fullServerList = l.Servers
var backendAddrs []resolver.Address
for _, s := range l.Servers {
if s.DropForLoadBalancing || s.DropForRateLimiting {
continue
}
md := metadata.Pairs(lbTokeyKey, s.LoadBalanceToken)
ip := net.IP(s.IpAddress)
ipStr := ip.String()
if ip.To4() == nil {
// Add square brackets to ipv6 addresses, otherwise net.Dial() and
// net.SplitHostPort() will return too many colons error.
ipStr = fmt.Sprintf("[%s]", ipStr)
}
addr := resolver.Address{
Addr: fmt.Sprintf("%s:%d", ipStr, s.Port),
Metadata: &md,
}
backendAddrs = append(backendAddrs, addr)
}
// Call refreshSubConns to create/remove SubConns.
backendsUpdated := lb.refreshSubConns(backendAddrs)
// If no backend was updated, no SubConn will be newed/removed. But since
// the full serverList was different, there might be updates in drops or
// pick weights(different number of duplicates). We need to update picker
// with the fulllist.
if !backendsUpdated {
lb.regeneratePicker()
lb.cc.UpdateBalancerState(lb.state, lb.picker)
}
}
// refreshSubConns creates/removes SubConns with backendAddrs. It returns a bool
// indicating whether the backendAddrs are different from the cached
// backendAddrs (whether any SubConn was newed/removed).
// Caller must hold lb.mu.
func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address) bool {
lb.backendAddrs = nil
var backendsUpdated bool
// addrsSet is the set converted from backendAddrs, it's used to quick
// lookup for an address.
addrsSet := make(map[resolver.Address]struct{})
// Create new SubConns.
for _, addr := range backendAddrs {
addrWithoutMD := addr
addrWithoutMD.Metadata = nil
addrsSet[addrWithoutMD] = struct{}{}
lb.backendAddrs = append(lb.backendAddrs, addrWithoutMD)
if _, ok := lb.subConns[addrWithoutMD]; !ok {
backendsUpdated = true
// Use addrWithMD to create the SubConn.
sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{})
if err != nil {
grpclog.Warningf("roundrobinBalancer: failed to create new SubConn: %v", err)
continue
}
lb.subConns[addrWithoutMD] = sc // Use the addr without MD as key for the map.
lb.scStates[sc] = connectivity.Idle
sc.Connect()
}
}
for a, sc := range lb.subConns {
// a was removed by resolver.
if _, ok := addrsSet[a]; !ok {
backendsUpdated = true
lb.cc.RemoveSubConn(sc)
delete(lb.subConns, a)
// Keep the state of this sc in b.scStates until sc's state becomes Shutdown.
// The entry will be deleted in HandleSubConnStateChange.
}
}
return backendsUpdated
}
func (lb *lbBalancer) readServerList(s *balanceLoadClientStream) error {
for {
reply, err := s.Recv()
if err != nil {
return fmt.Errorf("grpclb: failed to recv server list: %v", err)
}
if serverList := reply.GetServerList(); serverList != nil {
lb.processServerList(serverList)
}
}
}
func (lb *lbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
case <-s.Context().Done():
return
}
stats := lb.clientStats.toClientStats()
t := time.Now()
stats.Timestamp = &lbpb.Timestamp{
Seconds: t.Unix(),
Nanos: int32(t.Nanosecond()),
}
if err := s.Send(&lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{
ClientStats: stats,
},
}); err != nil {
return
}
}
}
func (lb *lbBalancer) callRemoteBalancer() error {
lbClient := &loadBalancerClient{cc: lb.ccRemoteLB}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbClient.BalanceLoad(ctx, FailFast(false))
if err != nil {
return fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err)
}
// grpclb handshake on the stream.
initReq := &lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
InitialRequest: &lbpb.InitialLoadBalanceRequest{
Name: lb.target,
},
},
}
if err := stream.Send(initReq); err != nil {
return fmt.Errorf("grpclb: failed to send init request: %v", err)
}
reply, err := stream.Recv()
if err != nil {
return fmt.Errorf("grpclb: failed to recv init response: %v", err)
}
initResp := reply.GetInitialResponse()
if initResp == nil {
return fmt.Errorf("grpclb: reply from remote balancer did not include initial response")
}
if initResp.LoadBalancerDelegate != "" {
return fmt.Errorf("grpclb: Delegation is not supported")
}
go func() {
if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 {
lb.sendLoadReport(stream, d)
}
}()
return lb.readServerList(stream)
}
func (lb *lbBalancer) watchRemoteBalancer() {
for {
err := lb.callRemoteBalancer()
select {
case <-lb.doneCh:
return
default:
if err != nil {
grpclog.Error(err)
}
}
}
}
func (lb *lbBalancer) dialRemoteLB(remoteLBName string) {
var dopts []DialOption
if creds := lb.opt.DialCreds; creds != nil {
if err := creds.OverrideServerName(remoteLBName); err == nil {
dopts = append(dopts, WithTransportCredentials(creds))
} else {
grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v, using Insecure", err)
dopts = append(dopts, WithInsecure())
}
} else {
dopts = append(dopts, WithInsecure())
}
if lb.opt.Dialer != nil {
// WithDialer takes a different type of function, so we instead use a
// special DialOption here.
dopts = append(dopts, withContextDialer(lb.opt.Dialer))
}
// Explicitly set pickfirst as the balancer.
dopts = append(dopts, WithBalancerName(PickFirstBalancerName))
dopts = append(dopts, withResolverBuilder(lb.manualResolver))
// Dial using manualResolver.Scheme, which is a random scheme generated
// when init grpclb. The target name is not important.
cc, err := Dial("grpclb:///grpclb.server", dopts...)
if err != nil {
grpclog.Fatalf("failed to dial: %v", err)
}
lb.ccRemoteLB = cc
go lb.watchRemoteBalancer()
}
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/resolver"
)
// The parent ClientConn should re-resolve when grpclb loses connection to the
// remote balancer. When the ClientConn inside grpclb gets a TransientFailure,
// it calls lbManualResolver.ResolveNow(), which calls parent ClientConn's
// ResolveNow, and eventually results in re-resolve happening in parent
// ClientConn's resolver (DNS for example).
//
// parent
// ClientConn
// +-----------------------------------------------------------------+
// | parent +---------------------------------+ |
// | DNS ClientConn | grpclb | |
// | resolver balancerWrapper | | |
// | + + | grpclb grpclb | |
// | | | | ManualResolver ClientConn | |
// | | | | + + | |
// | | | | | | Transient | |
// | | | | | | Failure | |
// | | | | | <--------- | | |
// | | | <--------------- | ResolveNow | | |
// | | <--------- | ResolveNow | | | | |
// | | ResolveNow | | | | | |
// | | | | | | | |
// | + + | + + | |
// | +---------------------------------+ |
// +-----------------------------------------------------------------+
// lbManualResolver is used by the ClientConn inside grpclb. It's a manual
// resolver with a special ResolveNow() function.
//
// When ResolveNow() is called, it calls ResolveNow() on the parent ClientConn,
// so when grpclb client lose contact with remote balancers, the parent
// ClientConn's resolver will re-resolve.
type lbManualResolver struct {
scheme string
ccr resolver.ClientConn
ccb balancer.ClientConn
}
func (r *lbManualResolver) Build(_ resolver.Target, cc resolver.ClientConn, _ resolver.BuildOption) (resolver.Resolver, error) {
r.ccr = cc
return r, nil
}
func (r *lbManualResolver) Scheme() string {
return r.scheme
}
// ResolveNow calls resolveNow on the parent ClientConn.
func (r *lbManualResolver) ResolveNow(o resolver.ResolveNowOption) {
r.ccb.ResolveNow(o)
}
// Close is a noop for Resolver.
func (*lbManualResolver) Close() {}
// NewAddress calls cc.NewAddress.
func (r *lbManualResolver) NewAddress(addrs []resolver.Address) {
r.ccr.NewAddress(addrs)
}
// NewServiceConfig calls cc.NewServiceConfig.
func (r *lbManualResolver) NewServiceConfig(sc string) {
r.ccr.NewServiceConfig(sc)
}
...@@ -105,18 +105,21 @@ func Fatalln(args ...interface{}) { ...@@ -105,18 +105,21 @@ func Fatalln(args ...interface{}) {
} }
// Print prints to the logger. Arguments are handled in the manner of fmt.Print. // Print prints to the logger. Arguments are handled in the manner of fmt.Print.
//
// Deprecated: use Info. // Deprecated: use Info.
func Print(args ...interface{}) { func Print(args ...interface{}) {
logger.Info(args...) logger.Info(args...)
} }
// Printf prints to the logger. Arguments are handled in the manner of fmt.Printf. // Printf prints to the logger. Arguments are handled in the manner of fmt.Printf.
//
// Deprecated: use Infof. // Deprecated: use Infof.
func Printf(format string, args ...interface{}) { func Printf(format string, args ...interface{}) {
logger.Infof(format, args...) logger.Infof(format, args...)
} }
// Println prints to the logger. Arguments are handled in the manner of fmt.Println. // Println prints to the logger. Arguments are handled in the manner of fmt.Println.
//
// Deprecated: use Infoln. // Deprecated: use Infoln.
func Println(args ...interface{}) { func Println(args ...interface{}) {
logger.Infoln(args...) logger.Infoln(args...)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
package grpclog package grpclog
// Logger mimics golang's standard Logger as an interface. // Logger mimics golang's standard Logger as an interface.
//
// Deprecated: use LoggerV2. // Deprecated: use LoggerV2.
type Logger interface { type Logger interface {
Fatal(args ...interface{}) Fatal(args ...interface{})
...@@ -31,6 +32,7 @@ type Logger interface { ...@@ -31,6 +32,7 @@ type Logger interface {
// SetLogger sets the logger that is used in grpc. Call only from // SetLogger sets the logger that is used in grpc. Call only from
// init() functions. // init() functions.
//
// Deprecated: use SetLoggerV2. // Deprecated: use SetLoggerV2.
func SetLogger(l Logger) { func SetLogger(l Logger) {
logger = &loggerWrapper{Logger: l} logger = &loggerWrapper{Logger: l}
......
...@@ -48,7 +48,9 @@ type UnaryServerInfo struct { ...@@ -48,7 +48,9 @@ type UnaryServerInfo struct {
} }
// UnaryHandler defines the handler invoked by UnaryServerInterceptor to complete the normal // UnaryHandler defines the handler invoked by UnaryServerInterceptor to complete the normal
// execution of a unary RPC. // execution of a unary RPC. If a UnaryHandler returns an error, it should be produced by the
// status package, or else gRPC will use codes.Unknown as the status code and err.Error() as
// the status message of the RPC.
type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error) type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)
// UnaryServerInterceptor provides a hook to intercept the execution of a unary RPC on the server. info // UnaryServerInterceptor provides a hook to intercept the execution of a unary RPC on the server. info
......
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package backoff implement the backoff strategy for gRPC.
//
// This is kept in internal until the gRPC project decides whether or not to
// allow alternative backoff strategies.
package backoff
import (
"time"
"google.golang.org/grpc/internal/grpcrand"
)
// Strategy defines the methodology for backing off after a grpc connection
// failure.
//
type Strategy interface {
// Backoff returns the amount of time to wait before the next retry given
// the number of consecutive failures.
Backoff(retries int) time.Duration
}
const (
// baseDelay is the amount of time to wait before retrying after the first
// failure.
baseDelay = 1.0 * time.Second
// factor is applied to the backoff after each retry.
factor = 1.6
// jitter provides a range to randomize backoff delays.
jitter = 0.2
)
// Exponential implements exponential backoff algorithm as defined in
// https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
type Exponential struct {
// MaxDelay is the upper bound of backoff delay.
MaxDelay time.Duration
}
// Backoff returns the amount of time to wait before the next retry given the
// number of retries.
func (bc Exponential) Backoff(retries int) time.Duration {
if retries == 0 {
return baseDelay
}
backoff, max := float64(baseDelay), float64(bc.MaxDelay)
for backoff < max && retries > 0 {
backoff *= factor
retries--
}
if backoff > max {
backoff = max
}
// Randomize backoff delays so that if a cluster of requests start at
// the same time, they won't operate in lockstep.
backoff *= 1 + jitter*(grpcrand.Float64()*2-1)
if backoff < 0 {
return 0
}
return time.Duration(backoff)
}
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package channelz defines APIs for enabling channelz service, entry
// registration/deletion, and accessing channelz data. It also defines channelz
// metric struct formats.
//
// All APIs in this package are experimental.
package channelz
import (
"sort"
"sync"
"sync/atomic"
"google.golang.org/grpc/grpclog"
)
var (
db dbWrapper
idGen idGenerator
// EntryPerPage defines the number of channelz entries to be shown on a web page.
EntryPerPage = 50
curState int32
)
// TurnOn turns on channelz data collection.
func TurnOn() {
if !IsOn() {
NewChannelzStorage()
atomic.StoreInt32(&curState, 1)
}
}
// IsOn returns whether channelz data collection is on.
func IsOn() bool {
return atomic.CompareAndSwapInt32(&curState, 1, 1)
}
// dbWarpper wraps around a reference to internal channelz data storage, and
// provide synchronized functionality to set and get the reference.
type dbWrapper struct {
mu sync.RWMutex
DB *channelMap
}
func (d *dbWrapper) set(db *channelMap) {
d.mu.Lock()
d.DB = db
d.mu.Unlock()
}
func (d *dbWrapper) get() *channelMap {
d.mu.RLock()
defer d.mu.RUnlock()
return d.DB
}
// NewChannelzStorage initializes channelz data storage and id generator.
//
// Note: This function is exported for testing purpose only. User should not call
// it in most cases.
func NewChannelzStorage() {
db.set(&channelMap{
topLevelChannels: make(map[int64]struct{}),
channels: make(map[int64]*channel),
listenSockets: make(map[int64]*listenSocket),
normalSockets: make(map[int64]*normalSocket),
servers: make(map[int64]*server),
subChannels: make(map[int64]*subChannel),
})
idGen.reset()
}
// GetTopChannels returns a slice of top channel's ChannelMetric, along with a
// boolean indicating whether there's more top channels to be queried for.
//
// The arg id specifies that only top channel with id at or above it will be included
// in the result. The returned slice is up to a length of EntryPerPage, and is
// sorted in ascending id order.
func GetTopChannels(id int64) ([]*ChannelMetric, bool) {
return db.get().GetTopChannels(id)
}
// GetServers returns a slice of server's ServerMetric, along with a
// boolean indicating whether there's more servers to be queried for.
//
// The arg id specifies that only server with id at or above it will be included
// in the result. The returned slice is up to a length of EntryPerPage, and is
// sorted in ascending id order.
func GetServers(id int64) ([]*ServerMetric, bool) {
return db.get().GetServers(id)
}
// GetServerSockets returns a slice of server's (identified by id) normal socket's
// SocketMetric, along with a boolean indicating whether there's more sockets to
// be queried for.
//
// The arg startID specifies that only sockets with id at or above it will be
// included in the result. The returned slice is up to a length of EntryPerPage,
// and is sorted in ascending id order.
func GetServerSockets(id int64, startID int64) ([]*SocketMetric, bool) {
return db.get().GetServerSockets(id, startID)
}
// GetChannel returns the ChannelMetric for the channel (identified by id).
func GetChannel(id int64) *ChannelMetric {
return db.get().GetChannel(id)
}
// GetSubChannel returns the SubChannelMetric for the subchannel (identified by id).
func GetSubChannel(id int64) *SubChannelMetric {
return db.get().GetSubChannel(id)
}
// GetSocket returns the SocketInternalMetric for the socket (identified by id).
func GetSocket(id int64) *SocketMetric {
return db.get().GetSocket(id)
}
// RegisterChannel registers the given channel c in channelz database with ref
// as its reference name, and add it to the child list of its parent (identified
// by pid). pid = 0 means no parent. It returns the unique channelz tracking id
// assigned to this channel.
func RegisterChannel(c Channel, pid int64, ref string) int64 {
id := idGen.genID()
cn := &channel{
refName: ref,
c: c,
subChans: make(map[int64]string),
nestedChans: make(map[int64]string),
id: id,
pid: pid,
}
if pid == 0 {
db.get().addChannel(id, cn, true, pid, ref)
} else {
db.get().addChannel(id, cn, false, pid, ref)
}
return id
}
// RegisterSubChannel registers the given channel c in channelz database with ref
// as its reference name, and add it to the child list of its parent (identified
// by pid). It returns the unique channelz tracking id assigned to this subchannel.
func RegisterSubChannel(c Channel, pid int64, ref string) int64 {
if pid == 0 {
grpclog.Error("a SubChannel's parent id cannot be 0")
return 0
}
id := idGen.genID()
sc := &subChannel{
refName: ref,
c: c,
sockets: make(map[int64]string),
id: id,
pid: pid,
}
db.get().addSubChannel(id, sc, pid, ref)
return id
}
// RegisterServer registers the given server s in channelz database. It returns
// the unique channelz tracking id assigned to this server.
func RegisterServer(s Server, ref string) int64 {
id := idGen.genID()
svr := &server{
refName: ref,
s: s,
sockets: make(map[int64]string),
listenSockets: make(map[int64]string),
id: id,
}
db.get().addServer(id, svr)
return id
}
// RegisterListenSocket registers the given listen socket s in channelz database
// with ref as its reference name, and add it to the child list of its parent
// (identified by pid). It returns the unique channelz tracking id assigned to
// this listen socket.
func RegisterListenSocket(s Socket, pid int64, ref string) int64 {
if pid == 0 {
grpclog.Error("a ListenSocket's parent id cannot be 0")
return 0
}
id := idGen.genID()
ls := &listenSocket{refName: ref, s: s, id: id, pid: pid}
db.get().addListenSocket(id, ls, pid, ref)
return id
}
// RegisterNormalSocket registers the given normal socket s in channelz database
// with ref as its reference name, and add it to the child list of its parent
// (identified by pid). It returns the unique channelz tracking id assigned to
// this normal socket.
func RegisterNormalSocket(s Socket, pid int64, ref string) int64 {
if pid == 0 {
grpclog.Error("a NormalSocket's parent id cannot be 0")
return 0
}
id := idGen.genID()
ns := &normalSocket{refName: ref, s: s, id: id, pid: pid}
db.get().addNormalSocket(id, ns, pid, ref)
return id
}
// RemoveEntry removes an entry with unique channelz trakcing id to be id from
// channelz database.
func RemoveEntry(id int64) {
db.get().removeEntry(id)
}
// channelMap is the storage data structure for channelz.
// Methods of channelMap can be divided in two two categories with respect to locking.
// 1. Methods acquire the global lock.
// 2. Methods that can only be called when global lock is held.
// A second type of method need always to be called inside a first type of method.
type channelMap struct {
mu sync.RWMutex
topLevelChannels map[int64]struct{}
servers map[int64]*server
channels map[int64]*channel
subChannels map[int64]*subChannel
listenSockets map[int64]*listenSocket
normalSockets map[int64]*normalSocket
}
func (c *channelMap) addServer(id int64, s *server) {
c.mu.Lock()
s.cm = c
c.servers[id] = s
c.mu.Unlock()
}
func (c *channelMap) addChannel(id int64, cn *channel, isTopChannel bool, pid int64, ref string) {
c.mu.Lock()
cn.cm = c
c.channels[id] = cn
if isTopChannel {
c.topLevelChannels[id] = struct{}{}
} else {
c.findEntry(pid).addChild(id, cn)
}
c.mu.Unlock()
}
func (c *channelMap) addSubChannel(id int64, sc *subChannel, pid int64, ref string) {
c.mu.Lock()
sc.cm = c
c.subChannels[id] = sc
c.findEntry(pid).addChild(id, sc)
c.mu.Unlock()
}
func (c *channelMap) addListenSocket(id int64, ls *listenSocket, pid int64, ref string) {
c.mu.Lock()
ls.cm = c
c.listenSockets[id] = ls
c.findEntry(pid).addChild(id, ls)
c.mu.Unlock()
}
func (c *channelMap) addNormalSocket(id int64, ns *normalSocket, pid int64, ref string) {
c.mu.Lock()
ns.cm = c
c.normalSockets[id] = ns
c.findEntry(pid).addChild(id, ns)
c.mu.Unlock()
}
// removeEntry triggers the removal of an entry, which may not indeed delete the
// entry, if it has to wait on the deletion of its children, or may lead to a chain
// of entry deletion. For example, deleting the last socket of a gracefully shutting
// down server will lead to the server being also deleted.
func (c *channelMap) removeEntry(id int64) {
c.mu.Lock()
c.findEntry(id).triggerDelete()
c.mu.Unlock()
}
// c.mu must be held by the caller.
func (c *channelMap) findEntry(id int64) entry {
var v entry
var ok bool
if v, ok = c.channels[id]; ok {
return v
}
if v, ok = c.subChannels[id]; ok {
return v
}
if v, ok = c.servers[id]; ok {
return v
}
if v, ok = c.listenSockets[id]; ok {
return v
}
if v, ok = c.normalSockets[id]; ok {
return v
}
return &dummyEntry{idNotFound: id}
}
// c.mu must be held by the caller
// deleteEntry simply deletes an entry from the channelMap. Before calling this
// method, caller must check this entry is ready to be deleted, i.e removeEntry()
// has been called on it, and no children still exist.
// Conditionals are ordered by the expected frequency of deletion of each entity
// type, in order to optimize performance.
func (c *channelMap) deleteEntry(id int64) {
var ok bool
if _, ok = c.normalSockets[id]; ok {
delete(c.normalSockets, id)
return
}
if _, ok = c.subChannels[id]; ok {
delete(c.subChannels, id)
return
}
if _, ok = c.channels[id]; ok {
delete(c.channels, id)
delete(c.topLevelChannels, id)
return
}
if _, ok = c.listenSockets[id]; ok {
delete(c.listenSockets, id)
return
}
if _, ok = c.servers[id]; ok {
delete(c.servers, id)
return
}
}
type int64Slice []int64
func (s int64Slice) Len() int { return len(s) }
func (s int64Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s int64Slice) Less(i, j int) bool { return s[i] < s[j] }
func copyMap(m map[int64]string) map[int64]string {
n := make(map[int64]string)
for k, v := range m {
n[k] = v
}
return n
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func (c *channelMap) GetTopChannels(id int64) ([]*ChannelMetric, bool) {
c.mu.RLock()
l := len(c.topLevelChannels)
ids := make([]int64, 0, l)
cns := make([]*channel, 0, min(l, EntryPerPage))
for k := range c.topLevelChannels {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
count := 0
var end bool
var t []*ChannelMetric
for i, v := range ids[idx:] {
if count == EntryPerPage {
break
}
if cn, ok := c.channels[v]; ok {
cns = append(cns, cn)
t = append(t, &ChannelMetric{
NestedChans: copyMap(cn.nestedChans),
SubChans: copyMap(cn.subChans),
})
count++
}
if i == len(ids[idx:])-1 {
end = true
break
}
}
c.mu.RUnlock()
if count == 0 {
end = true
}
for i, cn := range cns {
t[i].ChannelData = cn.c.ChannelzMetric()
t[i].ID = cn.id
t[i].RefName = cn.refName
}
return t, end
}
func (c *channelMap) GetServers(id int64) ([]*ServerMetric, bool) {
c.mu.RLock()
l := len(c.servers)
ids := make([]int64, 0, l)
ss := make([]*server, 0, min(l, EntryPerPage))
for k := range c.servers {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
count := 0
var end bool
var s []*ServerMetric
for i, v := range ids[idx:] {
if count == EntryPerPage {
break
}
if svr, ok := c.servers[v]; ok {
ss = append(ss, svr)
s = append(s, &ServerMetric{
ListenSockets: copyMap(svr.listenSockets),
})
count++
}
if i == len(ids[idx:])-1 {
end = true
break
}
}
c.mu.RUnlock()
if count == 0 {
end = true
}
for i, svr := range ss {
s[i].ServerData = svr.s.ChannelzMetric()
s[i].ID = svr.id
s[i].RefName = svr.refName
}
return s, end
}
func (c *channelMap) GetServerSockets(id int64, startID int64) ([]*SocketMetric, bool) {
var svr *server
var ok bool
c.mu.RLock()
if svr, ok = c.servers[id]; !ok {
// server with id doesn't exist.
c.mu.RUnlock()
return nil, true
}
svrskts := svr.sockets
l := len(svrskts)
ids := make([]int64, 0, l)
sks := make([]*normalSocket, 0, min(l, EntryPerPage))
for k := range svrskts {
ids = append(ids, k)
}
sort.Sort((int64Slice(ids)))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
count := 0
var end bool
for i, v := range ids[idx:] {
if count == EntryPerPage {
break
}
if ns, ok := c.normalSockets[v]; ok {
sks = append(sks, ns)
count++
}
if i == len(ids[idx:])-1 {
end = true
break
}
}
c.mu.RUnlock()
if count == 0 {
end = true
}
var s []*SocketMetric
for _, ns := range sks {
sm := &SocketMetric{}
sm.SocketData = ns.s.ChannelzMetric()
sm.ID = ns.id
sm.RefName = ns.refName
s = append(s, sm)
}
return s, end
}
func (c *channelMap) GetChannel(id int64) *ChannelMetric {
cm := &ChannelMetric{}
var cn *channel
var ok bool
c.mu.RLock()
if cn, ok = c.channels[id]; !ok {
// channel with id doesn't exist.
c.mu.RUnlock()
return nil
}
cm.NestedChans = copyMap(cn.nestedChans)
cm.SubChans = copyMap(cn.subChans)
c.mu.RUnlock()
cm.ChannelData = cn.c.ChannelzMetric()
cm.ID = cn.id
cm.RefName = cn.refName
return cm
}
func (c *channelMap) GetSubChannel(id int64) *SubChannelMetric {
cm := &SubChannelMetric{}
var sc *subChannel
var ok bool
c.mu.RLock()
if sc, ok = c.subChannels[id]; !ok {
// subchannel with id doesn't exist.
c.mu.RUnlock()
return nil
}
cm.Sockets = copyMap(sc.sockets)
c.mu.RUnlock()
cm.ChannelData = sc.c.ChannelzMetric()
cm.ID = sc.id
cm.RefName = sc.refName
return cm
}
func (c *channelMap) GetSocket(id int64) *SocketMetric {
sm := &SocketMetric{}
c.mu.RLock()
if ls, ok := c.listenSockets[id]; ok {
c.mu.RUnlock()
sm.SocketData = ls.s.ChannelzMetric()
sm.ID = ls.id
sm.RefName = ls.refName
return sm
}
if ns, ok := c.normalSockets[id]; ok {
c.mu.RUnlock()
sm.SocketData = ns.s.ChannelzMetric()
sm.ID = ns.id
sm.RefName = ns.refName
return sm
}
c.mu.RUnlock()
return nil
}
type idGenerator struct {
id int64
}
func (i *idGenerator) reset() {
atomic.StoreInt64(&i.id, 0)
}
func (i *idGenerator) genID() int64 {
return atomic.AddInt64(&i.id, 1)
}
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"net"
"time"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
)
// entry represents a node in the channelz database.
type entry interface {
// addChild adds a child e, whose channelz id is id to child list
addChild(id int64, e entry)
// deleteChild deletes a child with channelz id to be id from child list
deleteChild(id int64)
// triggerDelete tries to delete self from channelz database. However, if child
// list is not empty, then deletion from the database is on hold until the last
// child is deleted from database.
triggerDelete()
// deleteSelfIfReady check whether triggerDelete() has been called before, and whether child
// list is now empty. If both conditions are met, then delete self from database.
deleteSelfIfReady()
}
// dummyEntry is a fake entry to handle entry not found case.
type dummyEntry struct {
idNotFound int64
}
func (d *dummyEntry) addChild(id int64, e entry) {
// Note: It is possible for a normal program to reach here under race condition.
// For example, there could be a race between ClientConn.Close() info being propagated
// to addrConn and http2Client. ClientConn.Close() cancel the context and result
// in http2Client to error. The error info is then caught by transport monitor
// and before addrConn.tearDown() is called in side ClientConn.Close(). Therefore,
// the addrConn will create a new transport. And when registering the new transport in
// channelz, its parent addrConn could have already been torn down and deleted
// from channelz tracking, and thus reach the code here.
grpclog.Infof("attempt to add child of type %T with id %d to a parent (id=%d) that doesn't currently exist", e, id, d.idNotFound)
}
func (d *dummyEntry) deleteChild(id int64) {
// It is possible for a normal program to reach here under race condition.
// Refer to the example described in addChild().
grpclog.Infof("attempt to delete child with id %d from a parent (id=%d) that doesn't currently exist", id, d.idNotFound)
}
func (d *dummyEntry) triggerDelete() {
grpclog.Warningf("attempt to delete an entry (id=%d) that doesn't currently exist", d.idNotFound)
}
func (*dummyEntry) deleteSelfIfReady() {
// code should not reach here. deleteSelfIfReady is always called on an existing entry.
}
// ChannelMetric defines the info channelz provides for a specific Channel, which
// includes ChannelInternalMetric and channelz-specific data, such as channelz id,
// child list, etc.
type ChannelMetric struct {
// ID is the channelz id of this channel.
ID int64
// RefName is the human readable reference string of this channel.
RefName string
// ChannelData contains channel internal metric reported by the channel through
// ChannelzMetric().
ChannelData *ChannelInternalMetric
// NestedChans tracks the nested channel type children of this channel in the format of
// a map from nested channel channelz id to corresponding reference string.
NestedChans map[int64]string
// SubChans tracks the subchannel type children of this channel in the format of a
// map from subchannel channelz id to corresponding reference string.
SubChans map[int64]string
// Sockets tracks the socket type children of this channel in the format of a map
// from socket channelz id to corresponding reference string.
// Note current grpc implementation doesn't allow channel having sockets directly,
// therefore, this is field is unused.
Sockets map[int64]string
}
// SubChannelMetric defines the info channelz provides for a specific SubChannel,
// which includes ChannelInternalMetric and channelz-specific data, such as
// channelz id, child list, etc.
type SubChannelMetric struct {
// ID is the channelz id of this subchannel.
ID int64
// RefName is the human readable reference string of this subchannel.
RefName string
// ChannelData contains subchannel internal metric reported by the subchannel
// through ChannelzMetric().
ChannelData *ChannelInternalMetric
// NestedChans tracks the nested channel type children of this subchannel in the format of
// a map from nested channel channelz id to corresponding reference string.
// Note current grpc implementation doesn't allow subchannel to have nested channels
// as children, therefore, this field is unused.
NestedChans map[int64]string
// SubChans tracks the subchannel type children of this subchannel in the format of a
// map from subchannel channelz id to corresponding reference string.
// Note current grpc implementation doesn't allow subchannel to have subchannels
// as children, therefore, this field is unused.
SubChans map[int64]string
// Sockets tracks the socket type children of this subchannel in the format of a map
// from socket channelz id to corresponding reference string.
Sockets map[int64]string
}
// ChannelInternalMetric defines the struct that the implementor of Channel interface
// should return from ChannelzMetric().
type ChannelInternalMetric struct {
// current connectivity state of the channel.
State connectivity.State
// The target this channel originally tried to connect to. May be absent
Target string
// The number of calls started on the channel.
CallsStarted int64
// The number of calls that have completed with an OK status.
CallsSucceeded int64
// The number of calls that have a completed with a non-OK status.
CallsFailed int64
// The last time a call was started on the channel.
LastCallStartedTimestamp time.Time
//TODO: trace
}
// Channel is the interface that should be satisfied in order to be tracked by
// channelz as Channel or SubChannel.
type Channel interface {
ChannelzMetric() *ChannelInternalMetric
}
type channel struct {
refName string
c Channel
closeCalled bool
nestedChans map[int64]string
subChans map[int64]string
id int64
pid int64
cm *channelMap
}
func (c *channel) addChild(id int64, e entry) {
switch v := e.(type) {
case *subChannel:
c.subChans[id] = v.refName
case *channel:
c.nestedChans[id] = v.refName
default:
grpclog.Errorf("cannot add a child (id = %d) of type %T to a channel", id, e)
}
}
func (c *channel) deleteChild(id int64) {
delete(c.subChans, id)
delete(c.nestedChans, id)
c.deleteSelfIfReady()
}
func (c *channel) triggerDelete() {
c.closeCalled = true
c.deleteSelfIfReady()
}
func (c *channel) deleteSelfIfReady() {
if !c.closeCalled || len(c.subChans)+len(c.nestedChans) != 0 {
return
}
c.cm.deleteEntry(c.id)
// not top channel
if c.pid != 0 {
c.cm.findEntry(c.pid).deleteChild(c.id)
}
}
type subChannel struct {
refName string
c Channel
closeCalled bool
sockets map[int64]string
id int64
pid int64
cm *channelMap
}
func (sc *subChannel) addChild(id int64, e entry) {
if v, ok := e.(*normalSocket); ok {
sc.sockets[id] = v.refName
} else {
grpclog.Errorf("cannot add a child (id = %d) of type %T to a subChannel", id, e)
}
}
func (sc *subChannel) deleteChild(id int64) {
delete(sc.sockets, id)
sc.deleteSelfIfReady()
}
func (sc *subChannel) triggerDelete() {
sc.closeCalled = true
sc.deleteSelfIfReady()
}
func (sc *subChannel) deleteSelfIfReady() {
if !sc.closeCalled || len(sc.sockets) != 0 {
return
}
sc.cm.deleteEntry(sc.id)
sc.cm.findEntry(sc.pid).deleteChild(sc.id)
}
// SocketMetric defines the info channelz provides for a specific Socket, which
// includes SocketInternalMetric and channelz-specific data, such as channelz id, etc.
type SocketMetric struct {
// ID is the channelz id of this socket.
ID int64
// RefName is the human readable reference string of this socket.
RefName string
// SocketData contains socket internal metric reported by the socket through
// ChannelzMetric().
SocketData *SocketInternalMetric
}
// SocketInternalMetric defines the struct that the implementor of Socket interface
// should return from ChannelzMetric().
type SocketInternalMetric struct {
// The number of streams that have been started.
StreamsStarted int64
// The number of streams that have ended successfully:
// On client side, receiving frame with eos bit set.
// On server side, sending frame with eos bit set.
StreamsSucceeded int64
// The number of streams that have ended unsuccessfully:
// On client side, termination without receiving frame with eos bit set.
// On server side, termination without sending frame with eos bit set.
StreamsFailed int64
// The number of messages successfully sent on this socket.
MessagesSent int64
MessagesReceived int64
// The number of keep alives sent. This is typically implemented with HTTP/2
// ping messages.
KeepAlivesSent int64
// The last time a stream was created by this endpoint. Usually unset for
// servers.
LastLocalStreamCreatedTimestamp time.Time
// The last time a stream was created by the remote endpoint. Usually unset
// for clients.
LastRemoteStreamCreatedTimestamp time.Time
// The last time a message was sent by this endpoint.
LastMessageSentTimestamp time.Time
// The last time a message was received by this endpoint.
LastMessageReceivedTimestamp time.Time
// The amount of window, granted to the local endpoint by the remote endpoint.
// This may be slightly out of date due to network latency. This does NOT
// include stream level or TCP level flow control info.
LocalFlowControlWindow int64
// The amount of window, granted to the remote endpoint by the local endpoint.
// This may be slightly out of date due to network latency. This does NOT
// include stream level or TCP level flow control info.
RemoteFlowControlWindow int64
// The locally bound address.
LocalAddr net.Addr
// The remote bound address. May be absent.
RemoteAddr net.Addr
// Optional, represents the name of the remote endpoint, if different than
// the original target name.
RemoteName string
//TODO: socket options
//TODO: Security
}
// Socket is the interface that should be satisfied in order to be tracked by
// channelz as Socket.
type Socket interface {
ChannelzMetric() *SocketInternalMetric
}
type listenSocket struct {
refName string
s Socket
id int64
pid int64
cm *channelMap
}
func (ls *listenSocket) addChild(id int64, e entry) {
grpclog.Errorf("cannot add a child (id = %d) of type %T to a listen socket", id, e)
}
func (ls *listenSocket) deleteChild(id int64) {
grpclog.Errorf("cannot delete a child (id = %d) from a listen socket", id)
}
func (ls *listenSocket) triggerDelete() {
ls.cm.deleteEntry(ls.id)
ls.cm.findEntry(ls.pid).deleteChild(ls.id)
}
func (ls *listenSocket) deleteSelfIfReady() {
grpclog.Errorf("cannot call deleteSelfIfReady on a listen socket")
}
type normalSocket struct {
refName string
s Socket
id int64
pid int64
cm *channelMap
}
func (ns *normalSocket) addChild(id int64, e entry) {
grpclog.Errorf("cannot add a child (id = %d) of type %T to a normal socket", id, e)
}
func (ns *normalSocket) deleteChild(id int64) {
grpclog.Errorf("cannot delete a child (id = %d) from a normal socket", id)
}
func (ns *normalSocket) triggerDelete() {
ns.cm.deleteEntry(ns.id)
ns.cm.findEntry(ns.pid).deleteChild(ns.id)
}
func (ns *normalSocket) deleteSelfIfReady() {
grpclog.Errorf("cannot call deleteSelfIfReady on a normal socket")
}
// ServerMetric defines the info channelz provides for a specific Server, which
// includes ServerInternalMetric and channelz-specific data, such as channelz id,
// child list, etc.
type ServerMetric struct {
// ID is the channelz id of this server.
ID int64
// RefName is the human readable reference string of this server.
RefName string
// ServerData contains server internal metric reported by the server through
// ChannelzMetric().
ServerData *ServerInternalMetric
// ListenSockets tracks the listener socket type children of this server in the
// format of a map from socket channelz id to corresponding reference string.
ListenSockets map[int64]string
}
// ServerInternalMetric defines the struct that the implementor of Server interface
// should return from ChannelzMetric().
type ServerInternalMetric struct {
// The number of incoming calls started on the server.
CallsStarted int64
// The number of incoming calls that have completed with an OK status.
CallsSucceeded int64
// The number of incoming calls that have a completed with a non-OK status.
CallsFailed int64
// The last time a call was started on the server.
LastCallStartedTimestamp time.Time
//TODO: trace
}
// Server is the interface to be satisfied in order to be tracked by channelz as
// Server.
type Server interface {
ChannelzMetric() *ServerInternalMetric
}
type server struct {
refName string
s Server
closeCalled bool
sockets map[int64]string
listenSockets map[int64]string
id int64
cm *channelMap
}
func (s *server) addChild(id int64, e entry) {
switch v := e.(type) {
case *normalSocket:
s.sockets[id] = v.refName
case *listenSocket:
s.listenSockets[id] = v.refName
default:
grpclog.Errorf("cannot add a child (id = %d) of type %T to a server", id, e)
}
}
func (s *server) deleteChild(id int64) {
delete(s.sockets, id)
delete(s.listenSockets, id)
s.deleteSelfIfReady()
}
func (s *server) triggerDelete() {
s.closeCalled = true
s.deleteSelfIfReady()
}
func (s *server) deleteSelfIfReady() {
if !s.closeCalled || len(s.sockets)+len(s.listenSockets) != 0 {
return
}
s.cm.deleteEntry(s.id)
}
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package grpcrand implements math/rand functions in a concurrent-safe way
// with a global random source, independent of math/rand's global source.
package grpcrand
import (
"math/rand"
"sync"
"time"
)
var (
r = rand.New(rand.NewSource(time.Now().UnixNano()))
mu sync.Mutex
)
// Int63n implements rand.Int63n on the grpcrand global source.
func Int63n(n int64) int64 {
mu.Lock()
res := r.Int63n(n)
mu.Unlock()
return res
}
// Intn implements rand.Intn on the grpcrand global source.
func Intn(n int) int {
mu.Lock()
res := r.Intn(n)
mu.Unlock()
return res
}
// Float64 implements rand.Float64 on the grpcrand global source.
func Float64() float64 {
mu.Lock()
res := r.Float64()
mu.Unlock()
return res
}
...@@ -15,13 +15,22 @@ ...@@ -15,13 +15,22 @@
* *
*/ */
// Package internal contains gRPC-internal code for testing, to avoid polluting // Package internal contains gRPC-internal code, to avoid polluting
// the godoc of the top-level grpc package. // the godoc of the top-level grpc package. It must not import any grpc
// symbols to avoid circular dependencies.
package internal package internal
// TestingUseHandlerImpl enables the http.Handler-based server implementation. var (
// It must be called before Serve and requires TLS credentials.
// // TestingUseHandlerImpl enables the http.Handler-based server implementation.
// The provided grpcServer must be of type *grpc.Server. It is untyped // It must be called before Serve and requires TLS credentials.
// for circular dependency reasons. //
var TestingUseHandlerImpl func(grpcServer interface{}) // The provided grpcServer must be of type *grpc.Server. It is untyped
// for circular dependency reasons.
TestingUseHandlerImpl func(grpcServer interface{})
// WithContextDialer is exported by clientconn.go
WithContextDialer interface{} // func(context.Context, string) (net.Conn, error) grpc.DialOption
// WithResolverBuilder is exported by clientconn.go
WithResolverBuilder interface{} // func (resolver.Builder) grpc.DialOption
)
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
*/ */
// Package metadata define the structure of the metadata supported by gRPC library. // Package metadata define the structure of the metadata supported by gRPC library.
// Please refer to https://grpc.io/docs/guides/wire.html for more information about custom-metadata. // Please refer to https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
// for more information about custom-metadata.
package metadata // import "google.golang.org/grpc/metadata" package metadata // import "google.golang.org/grpc/metadata"
import ( import (
...@@ -27,7 +28,9 @@ import ( ...@@ -27,7 +28,9 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
) )
// DecodeKeyValue returns k, v, nil. It is deprecated and should not be used. // DecodeKeyValue returns k, v, nil.
//
// Deprecated: use k and v directly instead.
func DecodeKeyValue(k, v string) (string, string, error) { func DecodeKeyValue(k, v string) (string, string, error) {
return k, v, nil return k, v, nil
} }
...@@ -94,6 +97,30 @@ func (md MD) Copy() MD { ...@@ -94,6 +97,30 @@ func (md MD) Copy() MD {
return Join(md) return Join(md)
} }
// Get obtains the values for a given key.
func (md MD) Get(k string) []string {
k = strings.ToLower(k)
return md[k]
}
// Set sets the value of a given key with a slice of values.
func (md MD) Set(k string, vals ...string) {
if len(vals) == 0 {
return
}
k = strings.ToLower(k)
md[k] = vals
}
// Append adds the values to key k, not overwriting what was already stored at that key.
func (md MD) Append(k string, vals ...string) {
if len(vals) == 0 {
return
}
k = strings.ToLower(k)
md[k] = append(md[k], vals...)
}
// Join joins any number of mds into a single MD. // Join joins any number of mds into a single MD.
// The order of values for each key is determined by the order in which // The order of values for each key is determined by the order in which
// the mds containing those values are presented to Join. // the mds containing those values are presented to Join.
...@@ -115,9 +142,26 @@ func NewIncomingContext(ctx context.Context, md MD) context.Context { ...@@ -115,9 +142,26 @@ func NewIncomingContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdIncomingKey{}, md) return context.WithValue(ctx, mdIncomingKey{}, md)
} }
// NewOutgoingContext creates a new context with outgoing md attached. // NewOutgoingContext creates a new context with outgoing md attached. If used
// in conjunction with AppendToOutgoingContext, NewOutgoingContext will
// overwrite any previously-appended metadata.
func NewOutgoingContext(ctx context.Context, md MD) context.Context { func NewOutgoingContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdOutgoingKey{}, md) return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md})
}
// AppendToOutgoingContext returns a new context with the provided kv merged
// with any existing metadata in the context. Please refer to the
// documentation of Pairs for a description of kv.
func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context {
if len(kv)%2 == 1 {
panic(fmt.Sprintf("metadata: AppendToOutgoingContext got an odd number of input pairs for metadata: %d", len(kv)))
}
md, _ := ctx.Value(mdOutgoingKey{}).(rawMD)
added := make([][]string, len(md.added)+1)
copy(added, md.added)
added[len(added)-1] = make([]string, len(kv))
copy(added[len(added)-1], kv)
return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md.md, added: added})
} }
// FromIncomingContext returns the incoming metadata in ctx if it exists. The // FromIncomingContext returns the incoming metadata in ctx if it exists. The
...@@ -128,10 +172,39 @@ func FromIncomingContext(ctx context.Context) (md MD, ok bool) { ...@@ -128,10 +172,39 @@ func FromIncomingContext(ctx context.Context) (md MD, ok bool) {
return return
} }
// FromOutgoingContextRaw returns the un-merged, intermediary contents
// of rawMD. Remember to perform strings.ToLower on the keys. The returned
// MD should not be modified. Writing to it may cause races. Modification
// should be made to copies of the returned MD.
//
// This is intended for gRPC-internal use ONLY.
func FromOutgoingContextRaw(ctx context.Context) (MD, [][]string, bool) {
raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD)
if !ok {
return nil, nil, false
}
return raw.md, raw.added, true
}
// FromOutgoingContext returns the outgoing metadata in ctx if it exists. The // FromOutgoingContext returns the outgoing metadata in ctx if it exists. The
// returned MD should not be modified. Writing to it may cause races. // returned MD should not be modified. Writing to it may cause races.
// Modification should be made to the copies of the returned MD. // Modification should be made to copies of the returned MD.
func FromOutgoingContext(ctx context.Context) (md MD, ok bool) { func FromOutgoingContext(ctx context.Context) (MD, bool) {
md, ok = ctx.Value(mdOutgoingKey{}).(MD) raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD)
return if !ok {
return nil, false
}
mds := make([]MD, 0, len(raw.added)+1)
mds = append(mds, raw.md)
for _, vv := range raw.added {
mds = append(mds, Pairs(vv...))
}
return Join(mds...), ok
}
type rawMD struct {
md MD
added [][]string
} }
...@@ -153,10 +153,10 @@ type ipWatcher struct { ...@@ -153,10 +153,10 @@ type ipWatcher struct {
updateChan chan *Update updateChan chan *Update
} }
// Next returns the adrress resolution Update for the target. For IP address, // Next returns the address resolution Update for the target. For IP address,
// the resolution is itself, thus polling name server is unncessary. Therefore, // the resolution is itself, thus polling name server is unnecessary. Therefore,
// Next() will return an Update the first time it is called, and will be blocked // Next() will return an Update the first time it is called, and will be blocked
// for all following calls as no Update exisits until watcher is closed. // for all following calls as no Update exists until watcher is closed.
func (i *ipWatcher) Next() ([]*Update, error) { func (i *ipWatcher) Next() ([]*Update, error) {
u, ok := <-i.updateChan u, ok := <-i.updateChan
if !ok { if !ok {
......
...@@ -18,20 +18,26 @@ ...@@ -18,20 +18,26 @@
// Package naming defines the naming API and related data structures for gRPC. // Package naming defines the naming API and related data structures for gRPC.
// The interface is EXPERIMENTAL and may be suject to change. // The interface is EXPERIMENTAL and may be suject to change.
//
// Deprecated: please use package resolver.
package naming package naming
// Operation defines the corresponding operations for a name resolution change. // Operation defines the corresponding operations for a name resolution change.
//
// Deprecated: please use package resolver.
type Operation uint8 type Operation uint8
const ( const (
// Add indicates a new address is added. // Add indicates a new address is added.
Add Operation = iota Add Operation = iota
// Delete indicates an exisiting address is deleted. // Delete indicates an existing address is deleted.
Delete Delete
) )
// Update defines a name resolution update. Notice that it is not valid having both // Update defines a name resolution update. Notice that it is not valid having both
// empty string Addr and nil Metadata in an Update. // empty string Addr and nil Metadata in an Update.
//
// Deprecated: please use package resolver.
type Update struct { type Update struct {
// Op indicates the operation of the update. // Op indicates the operation of the update.
Op Operation Op Operation
...@@ -43,12 +49,16 @@ type Update struct { ...@@ -43,12 +49,16 @@ type Update struct {
} }
// Resolver creates a Watcher for a target to track its resolution changes. // Resolver creates a Watcher for a target to track its resolution changes.
//
// Deprecated: please use package resolver.
type Resolver interface { type Resolver interface {
// Resolve creates a Watcher for target. // Resolve creates a Watcher for target.
Resolve(target string) (Watcher, error) Resolve(target string) (Watcher, error)
} }
// Watcher watches for the updates on the specified target. // Watcher watches for the updates on the specified target.
//
// Deprecated: please use package resolver.
type Watcher interface { type Watcher interface {
// Next blocks until an update or error happens. It may return one or more // Next blocks until an update or error happens. It may return one or more
// updates. The first call should get the full set of the results. It should // updates. The first call should get the full set of the results. It should
......
...@@ -19,12 +19,17 @@ ...@@ -19,12 +19,17 @@
package grpc package grpc
import ( import (
"io"
"sync" "sync"
"sync/atomic"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
...@@ -36,13 +41,57 @@ type pickerWrapper struct { ...@@ -36,13 +41,57 @@ type pickerWrapper struct {
done bool done bool
blockingCh chan struct{} blockingCh chan struct{}
picker balancer.Picker picker balancer.Picker
// The latest connection happened.
connErrMu sync.Mutex
connErr error
stickinessMDKey atomic.Value
stickiness *stickyStore
} }
func newPickerWrapper() *pickerWrapper { func newPickerWrapper() *pickerWrapper {
bp := &pickerWrapper{blockingCh: make(chan struct{})} bp := &pickerWrapper{
blockingCh: make(chan struct{}),
stickiness: newStickyStore(),
}
return bp return bp
} }
func (bp *pickerWrapper) updateConnectionError(err error) {
bp.connErrMu.Lock()
bp.connErr = err
bp.connErrMu.Unlock()
}
func (bp *pickerWrapper) connectionError() error {
bp.connErrMu.Lock()
err := bp.connErr
bp.connErrMu.Unlock()
return err
}
func (bp *pickerWrapper) updateStickinessMDKey(newKey string) {
// No need to check ok because mdKey == "" if ok == false.
if oldKey, _ := bp.stickinessMDKey.Load().(string); oldKey != newKey {
bp.stickinessMDKey.Store(newKey)
bp.stickiness.reset(newKey)
}
}
func (bp *pickerWrapper) getStickinessMDKey() string {
// No need to check ok because mdKey == "" if ok == false.
mdKey, _ := bp.stickinessMDKey.Load().(string)
return mdKey
}
func (bp *pickerWrapper) clearStickinessState() {
if oldKey := bp.getStickinessMDKey(); oldKey != "" {
// There's no need to reset store if mdKey was "".
bp.stickiness.reset(oldKey)
}
}
// updatePicker is called by UpdateBalancerState. It unblocks all blocked pick. // updatePicker is called by UpdateBalancerState. It unblocks all blocked pick.
func (bp *pickerWrapper) updatePicker(p balancer.Picker) { func (bp *pickerWrapper) updatePicker(p balancer.Picker) {
bp.mu.Lock() bp.mu.Lock()
...@@ -57,6 +106,23 @@ func (bp *pickerWrapper) updatePicker(p balancer.Picker) { ...@@ -57,6 +106,23 @@ func (bp *pickerWrapper) updatePicker(p balancer.Picker) {
bp.mu.Unlock() bp.mu.Unlock()
} }
func doneChannelzWrapper(acw *acBalancerWrapper, done func(balancer.DoneInfo)) func(balancer.DoneInfo) {
acw.mu.Lock()
ac := acw.ac
acw.mu.Unlock()
ac.incrCallsStarted()
return func(b balancer.DoneInfo) {
if b.Err != nil && b.Err != io.EOF {
ac.incrCallsFailed()
} else {
ac.incrCallsSucceeded()
}
if done != nil {
done(b)
}
}
}
// pick returns the transport that will be used for the RPC. // pick returns the transport that will be used for the RPC.
// It may block in the following cases: // It may block in the following cases:
// - there's no picker // - there's no picker
...@@ -65,6 +131,27 @@ func (bp *pickerWrapper) updatePicker(p balancer.Picker) { ...@@ -65,6 +131,27 @@ func (bp *pickerWrapper) updatePicker(p balancer.Picker) {
// - the subConn returned by the current picker is not READY // - the subConn returned by the current picker is not READY
// When one of these situations happens, pick blocks until the picker gets updated. // When one of these situations happens, pick blocks until the picker gets updated.
func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.PickOptions) (transport.ClientTransport, func(balancer.DoneInfo), error) { func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.PickOptions) (transport.ClientTransport, func(balancer.DoneInfo), error) {
mdKey := bp.getStickinessMDKey()
stickyKey, isSticky := stickyKeyFromContext(ctx, mdKey)
// Potential race here: if stickinessMDKey is updated after the above two
// lines, and this pick is a sticky pick, the following put could add an
// entry to sticky store with an outdated sticky key.
//
// The solution: keep the current md key in sticky store, and at the
// beginning of each get/put, check the mdkey against store.curMDKey.
// - Cons: one more string comparing for each get/put.
// - Pros: the string matching happens inside get/put, so the overhead for
// non-sticky RPCs will be minimal.
if isSticky {
if t, ok := bp.stickiness.get(mdKey, stickyKey); ok {
// Done function returned is always nil.
return t, nil, nil
}
}
var ( var (
p balancer.Picker p balancer.Picker
ch chan struct{} ch chan struct{}
...@@ -107,7 +194,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. ...@@ -107,7 +194,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.
if !failfast { if !failfast {
continue continue
} }
return nil, nil, status.Errorf(codes.Unavailable, "%v", err) return nil, nil, status.Errorf(codes.Unavailable, "%v, latest connection error: %v", err, bp.connectionError())
default: default:
// err is some other error. // err is some other error.
return nil, nil, toRPCErr(err) return nil, nil, toRPCErr(err)
...@@ -120,6 +207,12 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. ...@@ -120,6 +207,12 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.
continue continue
} }
if t, ok := acw.getAddrConn().getReadyTransport(); ok { if t, ok := acw.getAddrConn().getReadyTransport(); ok {
if isSticky {
bp.stickiness.put(mdKey, stickyKey, acw)
}
if channelz.IsOn() {
return t, doneChannelzWrapper(acw, done), nil
}
return t, done, nil return t, done, nil
} }
grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick") grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick")
...@@ -139,3 +232,105 @@ func (bp *pickerWrapper) close() { ...@@ -139,3 +232,105 @@ func (bp *pickerWrapper) close() {
bp.done = true bp.done = true
close(bp.blockingCh) close(bp.blockingCh)
} }
const stickinessKeyCountLimit = 1000
type stickyStoreEntry struct {
acw *acBalancerWrapper
addr resolver.Address
}
type stickyStore struct {
mu sync.Mutex
// curMDKey is check before every get/put to avoid races. The operation will
// abort immediately when the given mdKey is different from the curMDKey.
curMDKey string
store *linkedMap
}
func newStickyStore() *stickyStore {
return &stickyStore{
store: newLinkedMap(),
}
}
// reset clears the map in stickyStore, and set the currentMDKey to newMDKey.
func (ss *stickyStore) reset(newMDKey string) {
ss.mu.Lock()
ss.curMDKey = newMDKey
ss.store.clear()
ss.mu.Unlock()
}
// stickyKey is the key to look up in store. mdKey will be checked against
// curMDKey to avoid races.
func (ss *stickyStore) put(mdKey, stickyKey string, acw *acBalancerWrapper) {
ss.mu.Lock()
defer ss.mu.Unlock()
if mdKey != ss.curMDKey {
return
}
// TODO(stickiness): limit the total number of entries.
ss.store.put(stickyKey, &stickyStoreEntry{
acw: acw,
addr: acw.getAddrConn().getCurAddr(),
})
if ss.store.len() > stickinessKeyCountLimit {
ss.store.removeOldest()
}
}
// stickyKey is the key to look up in store. mdKey will be checked against
// curMDKey to avoid races.
func (ss *stickyStore) get(mdKey, stickyKey string) (transport.ClientTransport, bool) {
ss.mu.Lock()
defer ss.mu.Unlock()
if mdKey != ss.curMDKey {
return nil, false
}
entry, ok := ss.store.get(stickyKey)
if !ok {
return nil, false
}
ac := entry.acw.getAddrConn()
if ac.getCurAddr() != entry.addr {
ss.store.remove(stickyKey)
return nil, false
}
t, ok := ac.getReadyTransport()
if !ok {
ss.store.remove(stickyKey)
return nil, false
}
return t, true
}
// Get one value from metadata in ctx with key stickinessMDKey.
//
// It returns "", false if stickinessMDKey is an empty string.
func stickyKeyFromContext(ctx context.Context, stickinessMDKey string) (string, bool) {
if stickinessMDKey == "" {
return "", false
}
md, added, ok := metadata.FromOutgoingContextRaw(ctx)
if !ok {
return "", false
}
if vv, ok := md[stickinessMDKey]; ok {
if len(vv) > 0 {
return vv[0], true
}
}
for _, ss := range added {
for i := 0; i < len(ss)-1; i += 2 {
if ss[i] == stickinessMDKey {
return ss[i+1], true
}
}
}
return "", false
}
...@@ -24,7 +24,6 @@ import ( ...@@ -24,7 +24,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"net" "net"
"os" "os"
"strconv" "strconv"
...@@ -34,6 +33,7 @@ import ( ...@@ -34,6 +33,7 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
) )
...@@ -50,7 +50,9 @@ const ( ...@@ -50,7 +50,9 @@ const (
txtAttribute = "grpc_config=" txtAttribute = "grpc_config="
) )
var errMissingAddr = errors.New("missing address") var (
errMissingAddr = errors.New("missing address")
)
// NewBuilder creates a dnsBuilder which is used to factory DNS resolvers. // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
func NewBuilder() resolver.Builder { func NewBuilder() resolver.Builder {
...@@ -64,6 +66,9 @@ type dnsBuilder struct { ...@@ -64,6 +66,9 @@ type dnsBuilder struct {
// Build creates and starts a DNS resolver that watches the name resolution of the target. // Build creates and starts a DNS resolver that watches the name resolution of the target.
func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
if target.Authority != "" {
return nil, fmt.Errorf("Default DNS resolver does not support custom DNS server")
}
host, port, err := parseTarget(target.Endpoint) host, port, err := parseTarget(target.Endpoint)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -87,14 +92,15 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts ...@@ -87,14 +92,15 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
// DNS address (non-IP). // DNS address (non-IP).
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
d := &dnsResolver{ d := &dnsResolver{
freq: b.freq, freq: b.freq,
host: host, host: host,
port: port, port: port,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
cc: cc, cc: cc,
t: time.NewTimer(0), t: time.NewTimer(0),
rn: make(chan struct{}, 1), rn: make(chan struct{}, 1),
disableServiceConfig: opts.DisableServiceConfig,
} }
d.wg.Add(1) d.wg.Add(1)
...@@ -157,7 +163,8 @@ type dnsResolver struct { ...@@ -157,7 +163,8 @@ type dnsResolver struct {
// If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes // If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes
// will warns lookup (READ the lookup function pointers) inside watcher() goroutine // will warns lookup (READ the lookup function pointers) inside watcher() goroutine
// has data race with replaceNetFunc (WRITE the lookup function pointers). // has data race with replaceNetFunc (WRITE the lookup function pointers).
wg sync.WaitGroup wg sync.WaitGroup
disableServiceConfig bool
} }
// ResolveNow invoke an immediate resolution of the target that this dnsResolver watches. // ResolveNow invoke an immediate resolution of the target that this dnsResolver watches.
...@@ -187,7 +194,7 @@ func (d *dnsResolver) watcher() { ...@@ -187,7 +194,7 @@ func (d *dnsResolver) watcher() {
result, sc := d.lookup() result, sc := d.lookup()
// Next lookup should happen after an interval defined by d.freq. // Next lookup should happen after an interval defined by d.freq.
d.t.Reset(d.freq) d.t.Reset(d.freq)
d.cc.NewServiceConfig(string(sc)) d.cc.NewServiceConfig(sc)
d.cc.NewAddress(result) d.cc.NewAddress(result)
} }
} }
...@@ -202,7 +209,7 @@ func (d *dnsResolver) lookupSRV() []resolver.Address { ...@@ -202,7 +209,7 @@ func (d *dnsResolver) lookupSRV() []resolver.Address {
for _, s := range srvs { for _, s := range srvs {
lbAddrs, err := lookupHost(d.ctx, s.Target) lbAddrs, err := lookupHost(d.ctx, s.Target)
if err != nil { if err != nil {
grpclog.Warningf("grpc: failed load banlacer address dns lookup due to %v.\n", err) grpclog.Infof("grpc: failed load balancer address dns lookup due to %v.\n", err)
continue continue
} }
for _, a := range lbAddrs { for _, a := range lbAddrs {
...@@ -221,7 +228,7 @@ func (d *dnsResolver) lookupSRV() []resolver.Address { ...@@ -221,7 +228,7 @@ func (d *dnsResolver) lookupSRV() []resolver.Address {
func (d *dnsResolver) lookupTXT() string { func (d *dnsResolver) lookupTXT() string {
ss, err := lookupTXT(d.ctx, d.host) ss, err := lookupTXT(d.ctx, d.host)
if err != nil { if err != nil {
grpclog.Warningf("grpc: failed dns TXT record lookup due to %v.\n", err) grpclog.Infof("grpc: failed dns TXT record lookup due to %v.\n", err)
return "" return ""
} }
var res string var res string
...@@ -257,10 +264,12 @@ func (d *dnsResolver) lookupHost() []resolver.Address { ...@@ -257,10 +264,12 @@ func (d *dnsResolver) lookupHost() []resolver.Address {
} }
func (d *dnsResolver) lookup() ([]resolver.Address, string) { func (d *dnsResolver) lookup() ([]resolver.Address, string) {
var newAddrs []resolver.Address newAddrs := d.lookupSRV()
newAddrs = d.lookupSRV()
// Support fallback to non-balancer address. // Support fallback to non-balancer address.
newAddrs = append(newAddrs, d.lookupHost()...) newAddrs = append(newAddrs, d.lookupHost()...)
if d.disableServiceConfig {
return newAddrs, ""
}
sc := d.lookupTXT() sc := d.lookupTXT()
return newAddrs, canaryingSC(sc) return newAddrs, canaryingSC(sc)
} }
...@@ -339,12 +348,7 @@ func chosenByPercentage(a *int) bool { ...@@ -339,12 +348,7 @@ func chosenByPercentage(a *int) bool {
if a == nil { if a == nil {
return true return true
} }
s := rand.NewSource(time.Now().UnixNano()) return grpcrand.Intn(100)+1 <= *a
r := rand.New(s)
if r.Intn(100)+1 > *a {
return false
}
return true
} }
func canaryingSC(js string) string { func canaryingSC(js string) string {
......
...@@ -29,28 +29,23 @@ var ( ...@@ -29,28 +29,23 @@ var (
// TODO(bar) install dns resolver in init(){}. // TODO(bar) install dns resolver in init(){}.
// Register registers the resolver builder to the resolver map. // Register registers the resolver builder to the resolver map. b.Scheme will be
// b.Scheme will be used as the scheme registered with this builder. // used as the scheme registered with this builder.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Resolvers are
// registered with the same name, the one registered last will take effect.
func Register(b Builder) { func Register(b Builder) {
m[b.Scheme()] = b m[b.Scheme()] = b
} }
// Get returns the resolver builder registered with the given scheme. // Get returns the resolver builder registered with the given scheme.
// If no builder is register with the scheme, the default scheme will //
// be used. // If no builder is register with the scheme, nil will be returned.
// If the default scheme is not modified, "passthrough" will be the default
// scheme, and the preinstalled dns resolver will be used.
// If the default scheme is modified, and a resolver is registered with
// the scheme, that resolver will be returned.
// If the default scheme is modified, and no resolver is registered with
// the scheme, nil will be returned.
func Get(scheme string) Builder { func Get(scheme string) Builder {
if b, ok := m[scheme]; ok { if b, ok := m[scheme]; ok {
return b return b
} }
if b, ok := m[defaultScheme]; ok {
return b
}
return nil return nil
} }
...@@ -60,6 +55,11 @@ func SetDefaultScheme(scheme string) { ...@@ -60,6 +55,11 @@ func SetDefaultScheme(scheme string) {
defaultScheme = scheme defaultScheme = scheme
} }
// GetDefaultScheme gets the default scheme that will be used.
func GetDefaultScheme() string {
return defaultScheme
}
// AddressType indicates the address type returned by name resolution. // AddressType indicates the address type returned by name resolution.
type AddressType uint8 type AddressType uint8
...@@ -90,9 +90,8 @@ type Address struct { ...@@ -90,9 +90,8 @@ type Address struct {
// BuildOption includes additional information for the builder to create // BuildOption includes additional information for the builder to create
// the resolver. // the resolver.
type BuildOption struct { type BuildOption struct {
// UserOptions can be used to pass configuration between DialOptions and the // DisableServiceConfig indicates whether resolver should fetch service config data.
// resolver. DisableServiceConfig bool
UserOptions interface{}
} }
// ClientConn contains the callbacks for resolver to notify any updates // ClientConn contains the callbacks for resolver to notify any updates
......
...@@ -48,31 +48,32 @@ func split2(s, sep string) (string, string, bool) { ...@@ -48,31 +48,32 @@ func split2(s, sep string) (string, string, bool) {
// parseTarget splits target into a struct containing scheme, authority and // parseTarget splits target into a struct containing scheme, authority and
// endpoint. // endpoint.
//
// If target is not a valid scheme://authority/endpoint, it returns {Endpoint:
// target}.
func parseTarget(target string) (ret resolver.Target) { func parseTarget(target string) (ret resolver.Target) {
var ok bool var ok bool
ret.Scheme, ret.Endpoint, ok = split2(target, "://") ret.Scheme, ret.Endpoint, ok = split2(target, "://")
if !ok { if !ok {
return resolver.Target{Endpoint: target} return resolver.Target{Endpoint: target}
} }
ret.Authority, ret.Endpoint, _ = split2(ret.Endpoint, "/") ret.Authority, ret.Endpoint, ok = split2(ret.Endpoint, "/")
if !ok {
return resolver.Target{Endpoint: target}
}
return ret return ret
} }
// newCCResolverWrapper parses cc.target for scheme and gets the resolver // newCCResolverWrapper parses cc.target for scheme and gets the resolver
// builder for this scheme. It then builds the resolver and starts the // builder for this scheme and builds the resolver. The monitoring goroutine
// monitoring goroutine for it. // for it is not started yet and can be created by calling start().
// //
// If withResolverBuilder dial option is set, the specified resolver will be // If withResolverBuilder dial option is set, the specified resolver will be
// used instead. // used instead.
func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) {
grpclog.Infof("dialing to target with scheme: %q", cc.parsedTarget.Scheme)
rb := cc.dopts.resolverBuilder rb := cc.dopts.resolverBuilder
if rb == nil { if rb == nil {
rb = resolver.Get(cc.parsedTarget.Scheme) return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme)
if rb == nil {
return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme)
}
} }
ccr := &ccResolverWrapper{ ccr := &ccResolverWrapper{
...@@ -83,9 +84,7 @@ func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { ...@@ -83,9 +84,7 @@ func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) {
} }
var err error var err error
ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{ ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{DisableServiceConfig: cc.dopts.disableServiceConfig})
UserOptions: cc.dopts.resolverBuildUserOptions,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -96,7 +95,7 @@ func (ccr *ccResolverWrapper) start() { ...@@ -96,7 +95,7 @@ func (ccr *ccResolverWrapper) start() {
go ccr.watcher() go ccr.watcher()
} }
// watcher processes address updates and service config updates sequencially. // watcher processes address updates and service config updates sequentially.
// Otherwise, we need to resolve possible races between address and service // Otherwise, we need to resolve possible races between address and service
// config (e.g. they specify different balancer types). // config (e.g. they specify different balancer types).
func (ccr *ccResolverWrapper) watcher() { func (ccr *ccResolverWrapper) watcher() {
...@@ -149,7 +148,7 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { ...@@ -149,7 +148,7 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) {
} }
// NewServiceConfig is called by the resolver implemenetion to send service // NewServiceConfig is called by the resolver implemenetion to send service
// configs to gPRC. // configs to gRPC.
func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { func (ccr *ccResolverWrapper) NewServiceConfig(sc string) {
select { select {
case <-ccr.scCh: case <-ccr.scCh:
......
...@@ -22,9 +22,12 @@ import ( ...@@ -22,9 +22,12 @@ import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math" "math"
"net/url"
"strings"
"sync" "sync"
"time" "time"
...@@ -32,6 +35,7 @@ import ( ...@@ -32,6 +35,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
...@@ -40,6 +44,8 @@ import ( ...@@ -40,6 +44,8 @@ import (
) )
// Compressor defines the interface gRPC uses to compress a message. // Compressor defines the interface gRPC uses to compress a message.
//
// Deprecated: use package encoding.
type Compressor interface { type Compressor interface {
// Do compresses p into w. // Do compresses p into w.
Do(w io.Writer, p []byte) error Do(w io.Writer, p []byte) error
...@@ -52,14 +58,34 @@ type gzipCompressor struct { ...@@ -52,14 +58,34 @@ type gzipCompressor struct {
} }
// NewGZIPCompressor creates a Compressor based on GZIP. // NewGZIPCompressor creates a Compressor based on GZIP.
//
// Deprecated: use package encoding/gzip.
func NewGZIPCompressor() Compressor { func NewGZIPCompressor() Compressor {
c, _ := NewGZIPCompressorWithLevel(gzip.DefaultCompression)
return c
}
// NewGZIPCompressorWithLevel is like NewGZIPCompressor but specifies the gzip compression level instead
// of assuming DefaultCompression.
//
// The error returned will be nil if the level is valid.
//
// Deprecated: use package encoding/gzip.
func NewGZIPCompressorWithLevel(level int) (Compressor, error) {
if level < gzip.DefaultCompression || level > gzip.BestCompression {
return nil, fmt.Errorf("grpc: invalid compression level: %d", level)
}
return &gzipCompressor{ return &gzipCompressor{
pool: sync.Pool{ pool: sync.Pool{
New: func() interface{} { New: func() interface{} {
return gzip.NewWriter(ioutil.Discard) w, err := gzip.NewWriterLevel(ioutil.Discard, level)
if err != nil {
panic(err)
}
return w
}, },
}, },
} }, nil
} }
func (c *gzipCompressor) Do(w io.Writer, p []byte) error { func (c *gzipCompressor) Do(w io.Writer, p []byte) error {
...@@ -77,6 +103,8 @@ func (c *gzipCompressor) Type() string { ...@@ -77,6 +103,8 @@ func (c *gzipCompressor) Type() string {
} }
// Decompressor defines the interface gRPC uses to decompress a message. // Decompressor defines the interface gRPC uses to decompress a message.
//
// Deprecated: use package encoding.
type Decompressor interface { type Decompressor interface {
// Do reads the data from r and uncompress them. // Do reads the data from r and uncompress them.
Do(r io.Reader) ([]byte, error) Do(r io.Reader) ([]byte, error)
...@@ -89,6 +117,8 @@ type gzipDecompressor struct { ...@@ -89,6 +117,8 @@ type gzipDecompressor struct {
} }
// NewGZIPDecompressor creates a Decompressor based on GZIP. // NewGZIPDecompressor creates a Decompressor based on GZIP.
//
// Deprecated: use package encoding/gzip.
func NewGZIPDecompressor() Decompressor { func NewGZIPDecompressor() Decompressor {
return &gzipDecompressor{} return &gzipDecompressor{}
} }
...@@ -125,13 +155,13 @@ func (d *gzipDecompressor) Type() string { ...@@ -125,13 +155,13 @@ func (d *gzipDecompressor) Type() string {
type callInfo struct { type callInfo struct {
compressorType string compressorType string
failFast bool failFast bool
headerMD metadata.MD stream *clientStream
trailerMD metadata.MD
peer *peer.Peer
traceInfo traceInfo // in trace.go traceInfo traceInfo // in trace.go
maxReceiveMessageSize *int maxReceiveMessageSize *int
maxSendMessageSize *int maxSendMessageSize *int
creds credentials.PerRPCCredentials creds credentials.PerRPCCredentials
contentSubtype string
codec baseCodec
} }
func defaultCallInfo() *callInfo { func defaultCallInfo() *callInfo {
...@@ -158,40 +188,66 @@ type EmptyCallOption struct{} ...@@ -158,40 +188,66 @@ type EmptyCallOption struct{}
func (EmptyCallOption) before(*callInfo) error { return nil } func (EmptyCallOption) before(*callInfo) error { return nil }
func (EmptyCallOption) after(*callInfo) {} func (EmptyCallOption) after(*callInfo) {}
type beforeCall func(c *callInfo) error
func (o beforeCall) before(c *callInfo) error { return o(c) }
func (o beforeCall) after(c *callInfo) {}
type afterCall func(c *callInfo)
func (o afterCall) before(c *callInfo) error { return nil }
func (o afterCall) after(c *callInfo) { o(c) }
// Header returns a CallOptions that retrieves the header metadata // Header returns a CallOptions that retrieves the header metadata
// for a unary RPC. // for a unary RPC.
func Header(md *metadata.MD) CallOption { func Header(md *metadata.MD) CallOption {
return afterCall(func(c *callInfo) { return HeaderCallOption{HeaderAddr: md}
*md = c.headerMD }
})
// HeaderCallOption is a CallOption for collecting response header metadata.
// The metadata field will be populated *after* the RPC completes.
// This is an EXPERIMENTAL API.
type HeaderCallOption struct {
HeaderAddr *metadata.MD
}
func (o HeaderCallOption) before(c *callInfo) error { return nil }
func (o HeaderCallOption) after(c *callInfo) {
if c.stream != nil {
*o.HeaderAddr, _ = c.stream.Header()
}
} }
// Trailer returns a CallOptions that retrieves the trailer metadata // Trailer returns a CallOptions that retrieves the trailer metadata
// for a unary RPC. // for a unary RPC.
func Trailer(md *metadata.MD) CallOption { func Trailer(md *metadata.MD) CallOption {
return afterCall(func(c *callInfo) { return TrailerCallOption{TrailerAddr: md}
*md = c.trailerMD
})
} }
// Peer returns a CallOption that retrieves peer information for a // TrailerCallOption is a CallOption for collecting response trailer metadata.
// unary RPC. // The metadata field will be populated *after* the RPC completes.
func Peer(peer *peer.Peer) CallOption { // This is an EXPERIMENTAL API.
return afterCall(func(c *callInfo) { type TrailerCallOption struct {
if c.peer != nil { TrailerAddr *metadata.MD
*peer = *c.peer }
func (o TrailerCallOption) before(c *callInfo) error { return nil }
func (o TrailerCallOption) after(c *callInfo) {
if c.stream != nil {
*o.TrailerAddr = c.stream.Trailer()
}
}
// Peer returns a CallOption that retrieves peer information for a unary RPC.
// The peer field will be populated *after* the RPC completes.
func Peer(p *peer.Peer) CallOption {
return PeerCallOption{PeerAddr: p}
}
// PeerCallOption is a CallOption for collecting the identity of the remote
// peer. The peer field will be populated *after* the RPC completes.
// This is an EXPERIMENTAL API.
type PeerCallOption struct {
PeerAddr *peer.Peer
}
func (o PeerCallOption) before(c *callInfo) error { return nil }
func (o PeerCallOption) after(c *callInfo) {
if c.stream != nil {
if x, ok := peer.FromContext(c.stream.Context()); ok {
*o.PeerAddr = *x
} }
}) }
} }
// FailFast configures the action to take when an RPC is attempted on broken // FailFast configures the action to take when an RPC is attempted on broken
...@@ -205,55 +261,166 @@ func Peer(peer *peer.Peer) CallOption { ...@@ -205,55 +261,166 @@ func Peer(peer *peer.Peer) CallOption {
// //
// By default, RPCs are "Fail Fast". // By default, RPCs are "Fail Fast".
func FailFast(failFast bool) CallOption { func FailFast(failFast bool) CallOption {
return beforeCall(func(c *callInfo) error { return FailFastCallOption{FailFast: failFast}
c.failFast = failFast
return nil
})
} }
// FailFastCallOption is a CallOption for indicating whether an RPC should fail
// fast or not.
// This is an EXPERIMENTAL API.
type FailFastCallOption struct {
FailFast bool
}
func (o FailFastCallOption) before(c *callInfo) error {
c.failFast = o.FailFast
return nil
}
func (o FailFastCallOption) after(c *callInfo) {}
// MaxCallRecvMsgSize returns a CallOption which sets the maximum message size the client can receive. // MaxCallRecvMsgSize returns a CallOption which sets the maximum message size the client can receive.
func MaxCallRecvMsgSize(s int) CallOption { func MaxCallRecvMsgSize(s int) CallOption {
return beforeCall(func(o *callInfo) error { return MaxRecvMsgSizeCallOption{MaxRecvMsgSize: s}
o.maxReceiveMessageSize = &s }
return nil
}) // MaxRecvMsgSizeCallOption is a CallOption that indicates the maximum message
// size the client can receive.
// This is an EXPERIMENTAL API.
type MaxRecvMsgSizeCallOption struct {
MaxRecvMsgSize int
} }
func (o MaxRecvMsgSizeCallOption) before(c *callInfo) error {
c.maxReceiveMessageSize = &o.MaxRecvMsgSize
return nil
}
func (o MaxRecvMsgSizeCallOption) after(c *callInfo) {}
// MaxCallSendMsgSize returns a CallOption which sets the maximum message size the client can send. // MaxCallSendMsgSize returns a CallOption which sets the maximum message size the client can send.
func MaxCallSendMsgSize(s int) CallOption { func MaxCallSendMsgSize(s int) CallOption {
return beforeCall(func(o *callInfo) error { return MaxSendMsgSizeCallOption{MaxSendMsgSize: s}
o.maxSendMessageSize = &s
return nil
})
} }
// MaxSendMsgSizeCallOption is a CallOption that indicates the maximum message
// size the client can send.
// This is an EXPERIMENTAL API.
type MaxSendMsgSizeCallOption struct {
MaxSendMsgSize int
}
func (o MaxSendMsgSizeCallOption) before(c *callInfo) error {
c.maxSendMessageSize = &o.MaxSendMsgSize
return nil
}
func (o MaxSendMsgSizeCallOption) after(c *callInfo) {}
// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials // PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
// for a call. // for a call.
func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption { func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption {
return beforeCall(func(c *callInfo) error { return PerRPCCredsCallOption{Creds: creds}
c.creds = creds }
return nil
}) // PerRPCCredsCallOption is a CallOption that indicates the per-RPC
// credentials to use for the call.
// This is an EXPERIMENTAL API.
type PerRPCCredsCallOption struct {
Creds credentials.PerRPCCredentials
} }
func (o PerRPCCredsCallOption) before(c *callInfo) error {
c.creds = o.Creds
return nil
}
func (o PerRPCCredsCallOption) after(c *callInfo) {}
// UseCompressor returns a CallOption which sets the compressor used when // UseCompressor returns a CallOption which sets the compressor used when
// sending the request. If WithCompressor is also set, UseCompressor has // sending the request. If WithCompressor is also set, UseCompressor has
// higher priority. // higher priority.
// //
// This API is EXPERIMENTAL. // This API is EXPERIMENTAL.
func UseCompressor(name string) CallOption { func UseCompressor(name string) CallOption {
return beforeCall(func(c *callInfo) error { return CompressorCallOption{CompressorType: name}
c.compressorType = name }
return nil
}) // CompressorCallOption is a CallOption that indicates the compressor to use.
// This is an EXPERIMENTAL API.
type CompressorCallOption struct {
CompressorType string
}
func (o CompressorCallOption) before(c *callInfo) error {
c.compressorType = o.CompressorType
return nil
}
func (o CompressorCallOption) after(c *callInfo) {}
// CallContentSubtype returns a CallOption that will set the content-subtype
// for a call. For example, if content-subtype is "json", the Content-Type over
// the wire will be "application/grpc+json". The content-subtype is converted
// to lowercase before being included in Content-Type. See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// If CallCustomCodec is not also used, the content-subtype will be used to
// look up the Codec to use in the registry controlled by RegisterCodec. See
// the documentation on RegisterCodec for details on registration. The lookup
// of content-subtype is case-insensitive. If no such Codec is found, the call
// will result in an error with code codes.Internal.
//
// If CallCustomCodec is also used, that Codec will be used for all request and
// response messages, with the content-subtype set to the given contentSubtype
// here for requests.
func CallContentSubtype(contentSubtype string) CallOption {
return ContentSubtypeCallOption{ContentSubtype: strings.ToLower(contentSubtype)}
}
// ContentSubtypeCallOption is a CallOption that indicates the content-subtype
// used for marshaling messages.
// This is an EXPERIMENTAL API.
type ContentSubtypeCallOption struct {
ContentSubtype string
}
func (o ContentSubtypeCallOption) before(c *callInfo) error {
c.contentSubtype = o.ContentSubtype
return nil
}
func (o ContentSubtypeCallOption) after(c *callInfo) {}
// CallCustomCodec returns a CallOption that will set the given Codec to be
// used for all request and response messages for a call. The result of calling
// String() will be used as the content-subtype in a case-insensitive manner.
//
// See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details. Also see the documentation on RegisterCodec and
// CallContentSubtype for more details on the interaction between Codec and
// content-subtype.
//
// This function is provided for advanced users; prefer to use only
// CallContentSubtype to select a registered codec instead.
func CallCustomCodec(codec Codec) CallOption {
return CustomCodecCallOption{Codec: codec}
}
// CustomCodecCallOption is a CallOption that indicates the codec used for
// marshaling messages.
// This is an EXPERIMENTAL API.
type CustomCodecCallOption struct {
Codec Codec
}
func (o CustomCodecCallOption) before(c *callInfo) error {
c.codec = o.Codec
return nil
} }
func (o CustomCodecCallOption) after(c *callInfo) {}
// The format of the payload: compressed or not? // The format of the payload: compressed or not?
type payloadFormat uint8 type payloadFormat uint8
const ( const (
compressionNone payloadFormat = iota // no compression compressionNone payloadFormat = 0 // no compression
compressionMade compressionMade payloadFormat = 1 // compressed
) )
// parser reads complete gRPC messages from the underlying reader. // parser reads complete gRPC messages from the underlying reader.
...@@ -263,8 +430,8 @@ type parser struct { ...@@ -263,8 +430,8 @@ type parser struct {
// error types. // error types.
r io.Reader r io.Reader
// The header of a gRPC message. Find more detail // The header of a gRPC message. Find more detail at
// at https://grpc.io/docs/guides/wire.html. // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
header [5]byte header [5]byte
} }
...@@ -310,65 +477,82 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt ...@@ -310,65 +477,82 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
return pf, msg, nil return pf, msg, nil
} }
// encode serializes msg and returns a buffer of message header and a buffer of msg. // encode serializes msg and returns a buffer containing the message, or an
// If msg is nil, it generates the message header and an empty msg buffer. // error if it is too large to be transmitted by grpc. If msg is nil, it
// TODO(ddyihai): eliminate extra Compressor parameter. // generates an empty message.
func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) { func encode(c baseCodec, msg interface{}) ([]byte, error) {
var ( if msg == nil { // NOTE: typed nils will not be caught by this check
b []byte return nil, nil
cbuf *bytes.Buffer }
) b, err := c.Marshal(msg)
const ( if err != nil {
payloadLen = 1 return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
sizeLen = 4 }
) if uint(len(b)) > math.MaxUint32 {
if msg != nil { return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
var err error }
b, err = c.Marshal(msg) return b, nil
if err != nil { }
return nil, nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
// compress returns the input bytes compressed by compressor or cp. If both
// compressors are nil, returns nil.
//
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) {
if compressor == nil && cp == nil {
return nil, nil
}
wrapErr := func(err error) error {
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
cbuf := &bytes.Buffer{}
if compressor != nil {
z, _ := compressor.Compress(cbuf)
if _, err := z.Write(in); err != nil {
return nil, wrapErr(err)
} }
if outPayload != nil { if err := z.Close(); err != nil {
outPayload.Payload = msg return nil, wrapErr(err)
// TODO truncate large payload.
outPayload.Data = b
outPayload.Length = len(b)
} }
if compressor != nil || cp != nil { } else {
cbuf = new(bytes.Buffer) if err := cp.Do(cbuf, in); err != nil {
// Has compressor, check Compressor is set by UseCompressor first. return nil, wrapErr(err)
if compressor != nil {
z, _ := compressor.Compress(cbuf)
if _, err := z.Write(b); err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
z.Close()
} else {
// If Compressor is not set by UseCompressor, use default Compressor
if err := cp.Do(cbuf, b); err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
}
b = cbuf.Bytes()
} }
} }
if uint(len(b)) > math.MaxUint32 { return cbuf.Bytes(), nil
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) }
}
const (
payloadLen = 1
sizeLen = 4
headerLen = payloadLen + sizeLen
)
bufHeader := make([]byte, payloadLen+sizeLen) // msgHeader returns a 5-byte header for the message being transmitted and the
if compressor != nil || cp != nil { // payload, which is compData if non-nil or data otherwise.
bufHeader[0] = byte(compressionMade) func msgHeader(data, compData []byte) (hdr []byte, payload []byte) {
hdr = make([]byte, headerLen)
if compData != nil {
hdr[0] = byte(compressionMade)
data = compData
} else { } else {
bufHeader[0] = byte(compressionNone) hdr[0] = byte(compressionNone)
} }
// Write length of b into buf // Write length of payload into buf
binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b))) binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data)))
if outPayload != nil { return hdr, data
outPayload.WireLength = payloadLen + sizeLen + len(b) }
func outPayload(client bool, msg interface{}, data, payload []byte, t time.Time) *stats.OutPayload {
return &stats.OutPayload{
Client: client,
Payload: msg,
Data: data,
Length: len(data),
WireLength: len(payload) + headerLen,
SentTime: t,
} }
return bufHeader, b, nil
} }
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status { func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
...@@ -390,7 +574,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool ...@@ -390,7 +574,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
// For the two compressor parameters, both should not be set, but if they are, // For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor. // dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API? // TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error { func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error {
pf, d, err := p.recvMsg(maxReceiveMessageSize) pf, d, err := p.recvMsg(maxReceiveMessageSize)
if err != nil { if err != nil {
return err return err
...@@ -485,6 +669,61 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { ...@@ -485,6 +669,61 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {
return status.Errorf(c, format, a...) return status.Errorf(c, format, a...)
} }
// setCallInfoCodec should only be called after CallOptions have been applied.
func setCallInfoCodec(c *callInfo) error {
if c.codec != nil {
// codec was already set by a CallOption; use it.
return nil
}
if c.contentSubtype == "" {
// No codec specified in CallOptions; use proto by default.
c.codec = encoding.GetCodec(proto.Name)
return nil
}
// c.contentSubtype is already lowercased in CallContentSubtype
c.codec = encoding.GetCodec(c.contentSubtype)
if c.codec == nil {
return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype)
}
return nil
}
// parseDialTarget returns the network and address to pass to dialer
func parseDialTarget(target string) (net string, addr string) {
net = "tcp"
m1 := strings.Index(target, ":")
m2 := strings.Index(target, ":/")
// handle unix:addr which will fail with url.Parse
if m1 >= 0 && m2 < 0 {
if n := target[0:m1]; n == "unix" {
net = n
addr = target[m1+1:]
return net, addr
}
}
if m2 >= 0 {
t, err := url.Parse(target)
if err != nil {
return net, target
}
scheme := t.Scheme
addr = t.Path
if scheme == "unix" {
net = scheme
if addr == "" {
addr = t.Host
}
return net, addr
}
}
return net, target
}
// The SupportPackageIsVersion variables are referenced from generated protocol // The SupportPackageIsVersion variables are referenced from generated protocol
// buffer files to ensure compatibility with the gRPC version used. The latest // buffer files to ensure compatibility with the gRPC version used. The latest
// support package version is 5. // support package version is 5.
...@@ -499,7 +738,4 @@ const ( ...@@ -499,7 +738,4 @@ const (
SupportPackageIsVersion5 = true SupportPackageIsVersion5 = true
) )
// Version is the current grpc version.
const Version = "1.9.1"
const grpcUA = "grpc-go/" + Version const grpcUA = "grpc-go/" + Version
...@@ -37,11 +37,14 @@ import ( ...@@ -37,11 +37,14 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
...@@ -96,16 +99,24 @@ type Server struct { ...@@ -96,16 +99,24 @@ type Server struct {
m map[string]*service // service name -> service info m map[string]*service // service name -> service info
events trace.EventLog events trace.EventLog
quit chan struct{} quit chan struct{}
done chan struct{} done chan struct{}
quitOnce sync.Once quitOnce sync.Once
doneOnce sync.Once doneOnce sync.Once
serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop channelzRemoveOnce sync.Once
serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop
channelzID int64 // channelz unique identification number
czmu sync.RWMutex
callsStarted int64
callsFailed int64
callsSucceeded int64
lastCallStartedTime time.Time
} }
type options struct { type options struct {
creds credentials.TransportCredentials creds credentials.TransportCredentials
codec Codec codec baseCodec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
unaryInt UnaryServerInterceptor unaryInt UnaryServerInterceptor
...@@ -182,6 +193,8 @@ func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { ...@@ -182,6 +193,8 @@ func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
} }
// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
//
// This will override any lookups by content-subtype for Codecs registered with RegisterCodec.
func CustomCodec(codec Codec) ServerOption { func CustomCodec(codec Codec) ServerOption {
return func(o *options) { return func(o *options) {
o.codec = codec o.codec = codec
...@@ -213,7 +226,9 @@ func RPCDecompressor(dc Decompressor) ServerOption { ...@@ -213,7 +226,9 @@ func RPCDecompressor(dc Decompressor) ServerOption {
} }
// MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive. // MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
// If this is not set, gRPC uses the default limit. Deprecated: use MaxRecvMsgSize instead. // If this is not set, gRPC uses the default limit.
//
// Deprecated: use MaxRecvMsgSize instead.
func MaxMsgSize(m int) ServerOption { func MaxMsgSize(m int) ServerOption {
return MaxRecvMsgSize(m) return MaxRecvMsgSize(m)
} }
...@@ -327,10 +342,6 @@ func NewServer(opt ...ServerOption) *Server { ...@@ -327,10 +342,6 @@ func NewServer(opt ...ServerOption) *Server {
for _, o := range opt { for _, o := range opt {
o(&opts) o(&opts)
} }
if opts.codec == nil {
// Set the default codec.
opts.codec = protoCodec{}
}
s := &Server{ s := &Server{
lis: make(map[net.Listener]bool), lis: make(map[net.Listener]bool),
opts: opts, opts: opts,
...@@ -344,6 +355,10 @@ func NewServer(opt ...ServerOption) *Server { ...@@ -344,6 +355,10 @@ func NewServer(opt ...ServerOption) *Server {
_, file, line, _ := runtime.Caller(1) _, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
} }
if channelz.IsOn() {
s.channelzID = channelz.RegisterServer(s, "")
}
return s return s
} }
...@@ -459,6 +474,25 @@ func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credenti ...@@ -459,6 +474,25 @@ func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credenti
return s.opts.creds.ServerHandshake(rawConn) return s.opts.creds.ServerHandshake(rawConn)
} }
type listenSocket struct {
net.Listener
channelzID int64
}
func (l *listenSocket) ChannelzMetric() *channelz.SocketInternalMetric {
return &channelz.SocketInternalMetric{
LocalAddr: l.Listener.Addr(),
}
}
func (l *listenSocket) Close() error {
err := l.Listener.Close()
if channelz.IsOn() {
channelz.RemoveEntry(l.channelzID)
}
return err
}
// Serve accepts incoming connections on the listener lis, creating a new // Serve accepts incoming connections on the listener lis, creating a new
// ServerTransport and service goroutine for each. The service goroutines // ServerTransport and service goroutine for each. The service goroutines
// read gRPC requests and then call the registered handlers to reply to them. // read gRPC requests and then call the registered handlers to reply to them.
...@@ -487,13 +521,19 @@ func (s *Server) Serve(lis net.Listener) error { ...@@ -487,13 +521,19 @@ func (s *Server) Serve(lis net.Listener) error {
} }
}() }()
s.lis[lis] = true ls := &listenSocket{Listener: lis}
s.lis[ls] = true
if channelz.IsOn() {
ls.channelzID = channelz.RegisterListenSocket(ls, s.channelzID, "")
}
s.mu.Unlock() s.mu.Unlock()
defer func() { defer func() {
s.mu.Lock() s.mu.Lock()
if s.lis != nil && s.lis[lis] { if s.lis != nil && s.lis[ls] {
lis.Close() ls.Close()
delete(s.lis, lis) delete(s.lis, ls)
} }
s.mu.Unlock() s.mu.Unlock()
}() }()
...@@ -615,6 +655,7 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr ...@@ -615,6 +655,7 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr
InitialConnWindowSize: s.opts.initialConnWindowSize, InitialConnWindowSize: s.opts.initialConnWindowSize,
WriteBufferSize: s.opts.writeBufferSize, WriteBufferSize: s.opts.writeBufferSize,
ReadBufferSize: s.opts.readBufferSize, ReadBufferSize: s.opts.readBufferSize,
ChannelzParentID: s.channelzID,
} }
st, err := transport.NewServerTransport("http2", c, config) st, err := transport.NewServerTransport("http2", c, config)
if err != nil { if err != nil {
...@@ -625,6 +666,7 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr ...@@ -625,6 +666,7 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr
grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err) grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err)
return nil return nil
} }
return st return st
} }
...@@ -695,7 +737,7 @@ func (s *Server) serveUsingHandler(conn net.Conn) { ...@@ -695,7 +737,7 @@ func (s *Server) serveUsingHandler(conn net.Conn) {
// available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL // available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL
// and subject to change. // and subject to change.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r) st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
...@@ -752,39 +794,83 @@ func (s *Server) removeConn(c io.Closer) { ...@@ -752,39 +794,83 @@ func (s *Server) removeConn(c io.Closer) {
} }
} }
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { // ChannelzMetric returns ServerInternalMetric of current server.
var ( // This is an EXPERIMENTAL API.
outPayload *stats.OutPayload func (s *Server) ChannelzMetric() *channelz.ServerInternalMetric {
) s.czmu.RLock()
if s.opts.statsHandler != nil { defer s.czmu.RUnlock()
outPayload = &stats.OutPayload{} return &channelz.ServerInternalMetric{
CallsStarted: s.callsStarted,
CallsSucceeded: s.callsSucceeded,
CallsFailed: s.callsFailed,
LastCallStartedTimestamp: s.lastCallStartedTime,
} }
hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, comp) }
func (s *Server) incrCallsStarted() {
s.czmu.Lock()
s.callsStarted++
s.lastCallStartedTime = time.Now()
s.czmu.Unlock()
}
func (s *Server) incrCallsSucceeded() {
s.czmu.Lock()
s.callsSucceeded++
s.czmu.Unlock()
}
func (s *Server) incrCallsFailed() {
s.czmu.Lock()
s.callsFailed++
s.czmu.Unlock()
}
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
if err != nil { if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err) grpclog.Errorln("grpc: server failed to encode response: ", err)
return err return err
} }
if len(data) > s.opts.maxSendMessageSize { compData, err := compress(data, cp, comp)
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) if err != nil {
grpclog.Errorln("grpc: server failed to compress response: ", err)
return err
}
hdr, payload := msgHeader(data, compData)
// TODO(dfawley): should we be checking len(data) instead?
if len(payload) > s.opts.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize)
} }
err = t.Write(stream, hdr, data, opts) err = t.Write(stream, hdr, payload, opts)
if err == nil && outPayload != nil { if err == nil && s.opts.statsHandler != nil {
outPayload.SentTime = time.Now() s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now()))
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
} }
return err return err
} }
func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
if channelz.IsOn() {
s.incrCallsStarted()
defer func() {
if err != nil && err != io.EOF {
s.incrCallsFailed()
} else {
s.incrCallsSucceeded()
}
}()
}
sh := s.opts.statsHandler sh := s.opts.statsHandler
if sh != nil { if sh != nil {
beginTime := time.Now()
begin := &stats.Begin{ begin := &stats.Begin{
BeginTime: time.Now(), BeginTime: beginTime,
} }
sh.HandleRPC(stream.Context(), begin) sh.HandleRPC(stream.Context(), begin)
defer func() { defer func() {
end := &stats.End{ end := &stats.End{
EndTime: time.Now(), BeginTime: beginTime,
EndTime: time.Now(),
} }
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
end.Error = toRPCErr(err) end.Error = toRPCErr(err)
...@@ -868,6 +954,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -868,6 +954,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
return err return err
} }
if channelz.IsOn() {
t.IncrMsgRecv()
}
if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil { if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil {
if e := t.WriteStatus(stream, st); e != nil { if e := t.WriteStatus(stream, st); e != nil {
grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
...@@ -904,7 +993,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -904,7 +993,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// java implementation. // java implementation.
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize) return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize)
} }
if err := s.opts.codec.Unmarshal(req, v); err != nil { if err := s.getCodec(stream.ContentSubtype()).Unmarshal(req, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
} }
if inPayload != nil { if inPayload != nil {
...@@ -918,12 +1007,13 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -918,12 +1007,13 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
return nil return nil
} }
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) ctx := NewContextWithServerTransportStream(stream.Context(), stream)
reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)
if appErr != nil { if appErr != nil {
appStatus, ok := status.FromError(appErr) appStatus, ok := status.FromError(appErr)
if !ok { if !ok {
// Convert appErr if it is not a grpc status error. // Convert appErr if it is not a grpc status error.
appErr = status.Error(convertCode(appErr), appErr.Error()) appErr = status.Error(codes.Unknown, appErr.Error())
appStatus, _ = status.FromError(appErr) appStatus, _ = status.FromError(appErr)
} }
if trInfo != nil { if trInfo != nil {
...@@ -966,6 +1056,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -966,6 +1056,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
return err return err
} }
if channelz.IsOn() {
t.IncrMsgSent()
}
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
} }
...@@ -976,15 +1069,27 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. ...@@ -976,15 +1069,27 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
if channelz.IsOn() {
s.incrCallsStarted()
defer func() {
if err != nil && err != io.EOF {
s.incrCallsFailed()
} else {
s.incrCallsSucceeded()
}
}()
}
sh := s.opts.statsHandler sh := s.opts.statsHandler
if sh != nil { if sh != nil {
beginTime := time.Now()
begin := &stats.Begin{ begin := &stats.Begin{
BeginTime: time.Now(), BeginTime: beginTime,
} }
sh.HandleRPC(stream.Context(), begin) sh.HandleRPC(stream.Context(), begin)
defer func() { defer func() {
end := &stats.End{ end := &stats.End{
EndTime: time.Now(), BeginTime: beginTime,
EndTime: time.Now(),
} }
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
end.Error = toRPCErr(err) end.Error = toRPCErr(err)
...@@ -992,11 +1097,13 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ...@@ -992,11 +1097,13 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
sh.HandleRPC(stream.Context(), end) sh.HandleRPC(stream.Context(), end)
}() }()
} }
ctx := NewContextWithServerTransportStream(stream.Context(), stream)
ss := &serverStream{ ss := &serverStream{
ctx: ctx,
t: t, t: t,
s: stream, s: stream,
p: &parser{r: stream}, p: &parser{r: stream},
codec: s.opts.codec, codec: s.getCodec(stream.ContentSubtype()),
maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize,
trInfo: trInfo, trInfo: trInfo,
...@@ -1066,7 +1173,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ...@@ -1066,7 +1173,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
case transport.StreamError: case transport.StreamError:
appStatus = status.New(err.Code, err.Desc) appStatus = status.New(err.Code, err.Desc)
default: default:
appStatus = status.New(convertCode(appErr), appErr.Error()) appStatus = status.New(codes.Unknown, appErr.Error())
} }
appErr = appStatus.Err() appErr = appStatus.Err()
} }
...@@ -1086,7 +1193,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ...@@ -1086,7 +1193,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
ss.mu.Unlock() ss.mu.Unlock()
} }
return t.WriteStatus(ss.s, status.New(codes.OK, "")) return t.WriteStatus(ss.s, status.New(codes.OK, ""))
} }
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) { func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
...@@ -1168,6 +1274,42 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str ...@@ -1168,6 +1274,42 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
} }
} }
// The key to save ServerTransportStream in the context.
type streamKey struct{}
// NewContextWithServerTransportStream creates a new context from ctx and
// attaches stream to it.
//
// This API is EXPERIMENTAL.
func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context {
return context.WithValue(ctx, streamKey{}, stream)
}
// ServerTransportStream is a minimal interface that a transport stream must
// implement. This can be used to mock an actual transport stream for tests of
// handler code that use, for example, grpc.SetHeader (which requires some
// stream to be in context).
//
// See also NewContextWithServerTransportStream.
//
// This API is EXPERIMENTAL.
type ServerTransportStream interface {
Method() string
SetHeader(md metadata.MD) error
SendHeader(md metadata.MD) error
SetTrailer(md metadata.MD) error
}
// ServerTransportStreamFromContext returns the ServerTransportStream saved in
// ctx. Returns nil if the given context has no stream associated with it
// (which implies it is not an RPC invocation context).
//
// This API is EXPERIMENTAL.
func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream {
s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
return s
}
// Stop stops the gRPC server. It immediately closes all open // Stop stops the gRPC server. It immediately closes all open
// connections and listeners. // connections and listeners.
// It cancels all active RPCs on the server side and the corresponding // It cancels all active RPCs on the server side and the corresponding
...@@ -1185,6 +1327,12 @@ func (s *Server) Stop() { ...@@ -1185,6 +1327,12 @@ func (s *Server) Stop() {
}) })
}() }()
s.channelzRemoveOnce.Do(func() {
if channelz.IsOn() {
channelz.RemoveEntry(s.channelzID)
}
})
s.mu.Lock() s.mu.Lock()
listeners := s.lis listeners := s.lis
s.lis = nil s.lis = nil
...@@ -1223,11 +1371,17 @@ func (s *Server) GracefulStop() { ...@@ -1223,11 +1371,17 @@ func (s *Server) GracefulStop() {
}) })
}() }()
s.channelzRemoveOnce.Do(func() {
if channelz.IsOn() {
channelz.RemoveEntry(s.channelzID)
}
})
s.mu.Lock() s.mu.Lock()
if s.conns == nil { if s.conns == nil {
s.mu.Unlock() s.mu.Unlock()
return return
} }
for lis := range s.lis { for lis := range s.lis {
lis.Close() lis.Close()
} }
...@@ -1262,6 +1416,22 @@ func init() { ...@@ -1262,6 +1416,22 @@ func init() {
} }
} }
// contentSubtype must be lowercase
// cannot return nil
func (s *Server) getCodec(contentSubtype string) baseCodec {
if s.opts.codec != nil {
return s.opts.codec
}
if contentSubtype == "" {
return encoding.GetCodec(proto.Name)
}
codec := encoding.GetCodec(contentSubtype)
if codec == nil {
return encoding.GetCodec(proto.Name)
}
return codec
}
// SetHeader sets the header metadata. // SetHeader sets the header metadata.
// When called multiple times, all the provided metadata will be merged. // When called multiple times, all the provided metadata will be merged.
// All the metadata will be sent out when one of the following happens: // All the metadata will be sent out when one of the following happens:
...@@ -1272,8 +1442,8 @@ func SetHeader(ctx context.Context, md metadata.MD) error { ...@@ -1272,8 +1442,8 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
if md.Len() == 0 { if md.Len() == 0 {
return nil return nil
} }
stream, ok := transport.StreamFromContext(ctx) stream := ServerTransportStreamFromContext(ctx)
if !ok { if stream == nil {
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
} }
return stream.SetHeader(md) return stream.SetHeader(md)
...@@ -1282,15 +1452,11 @@ func SetHeader(ctx context.Context, md metadata.MD) error { ...@@ -1282,15 +1452,11 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
// SendHeader sends header metadata. It may be called at most once. // SendHeader sends header metadata. It may be called at most once.
// The provided md and headers set by SetHeader() will be sent. // The provided md and headers set by SetHeader() will be sent.
func SendHeader(ctx context.Context, md metadata.MD) error { func SendHeader(ctx context.Context, md metadata.MD) error {
stream, ok := transport.StreamFromContext(ctx) stream := ServerTransportStreamFromContext(ctx)
if !ok { if stream == nil {
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
} }
t := stream.ServerTransport() if err := stream.SendHeader(md); err != nil {
if t == nil {
grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream)
}
if err := t.WriteHeader(stream, md); err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
return nil return nil
...@@ -1302,9 +1468,19 @@ func SetTrailer(ctx context.Context, md metadata.MD) error { ...@@ -1302,9 +1468,19 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
if md.Len() == 0 { if md.Len() == 0 {
return nil return nil
} }
stream, ok := transport.StreamFromContext(ctx) stream := ServerTransportStreamFromContext(ctx)
if !ok { if stream == nil {
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
} }
return stream.SetTrailer(md) return stream.SetTrailer(md)
} }
// Method returns the method string for the server context. The returned
// string is in the format of "/service/method".
func Method(ctx context.Context) (string, bool) {
s := ServerTransportStreamFromContext(ctx)
if s == nil {
return "", false
}
return s.Method(), true
}
...@@ -32,7 +32,8 @@ const maxInt = int(^uint(0) >> 1) ...@@ -32,7 +32,8 @@ const maxInt = int(^uint(0) >> 1)
// MethodConfig defines the configuration recommended by the service providers for a // MethodConfig defines the configuration recommended by the service providers for a
// particular method. // particular method.
// DEPRECATED: Users should not use this struct. Service config should be received //
// Deprecated: Users should not use this struct. Service config should be received
// through name resolver, as specified here // through name resolver, as specified here
// https://github.com/grpc/grpc/blob/master/doc/service_config.md // https://github.com/grpc/grpc/blob/master/doc/service_config.md
type MethodConfig struct { type MethodConfig struct {
...@@ -59,7 +60,8 @@ type MethodConfig struct { ...@@ -59,7 +60,8 @@ type MethodConfig struct {
// ServiceConfig is provided by the service provider and contains parameters for how // ServiceConfig is provided by the service provider and contains parameters for how
// clients that connect to the service should behave. // clients that connect to the service should behave.
// DEPRECATED: Users should not use this struct. Service config should be received //
// Deprecated: Users should not use this struct. Service config should be received
// through name resolver, as specified here // through name resolver, as specified here
// https://github.com/grpc/grpc/blob/master/doc/service_config.md // https://github.com/grpc/grpc/blob/master/doc/service_config.md
type ServiceConfig struct { type ServiceConfig struct {
...@@ -71,6 +73,8 @@ type ServiceConfig struct { ...@@ -71,6 +73,8 @@ type ServiceConfig struct {
// If there's no exact match, look for the default config for the service (/service/) and use the corresponding MethodConfig if it exists. // If there's no exact match, look for the default config for the service (/service/) and use the corresponding MethodConfig if it exists.
// Otherwise, the method has no MethodConfig to use. // Otherwise, the method has no MethodConfig to use.
Methods map[string]MethodConfig Methods map[string]MethodConfig
stickinessMetadataKey *string
} }
func parseDuration(s *string) (*time.Duration, error) { func parseDuration(s *string) (*time.Duration, error) {
...@@ -144,8 +148,9 @@ type jsonMC struct { ...@@ -144,8 +148,9 @@ type jsonMC struct {
// TODO(lyuxuan): delete this struct after cleaning up old service config implementation. // TODO(lyuxuan): delete this struct after cleaning up old service config implementation.
type jsonSC struct { type jsonSC struct {
LoadBalancingPolicy *string LoadBalancingPolicy *string
MethodConfig *[]jsonMC StickinessMetadataKey *string
MethodConfig *[]jsonMC
} }
func parseServiceConfig(js string) (ServiceConfig, error) { func parseServiceConfig(js string) (ServiceConfig, error) {
...@@ -158,6 +163,8 @@ func parseServiceConfig(js string) (ServiceConfig, error) { ...@@ -158,6 +163,8 @@ func parseServiceConfig(js string) (ServiceConfig, error) {
sc := ServiceConfig{ sc := ServiceConfig{
LB: rsc.LoadBalancingPolicy, LB: rsc.LoadBalancingPolicy,
Methods: make(map[string]MethodConfig), Methods: make(map[string]MethodConfig),
stickinessMetadataKey: rsc.StickinessMetadataKey,
} }
if rsc.MethodConfig == nil { if rsc.MethodConfig == nil {
return sc, nil return sc, nil
......
...@@ -169,6 +169,8 @@ func (s *OutTrailer) isRPCStats() {} ...@@ -169,6 +169,8 @@ func (s *OutTrailer) isRPCStats() {}
type End struct { type End struct {
// Client is true if this End is from client side. // Client is true if this End is from client side.
Client bool Client bool
// BeginTime is the time when the RPC began.
BeginTime time.Time
// EndTime is the time when the RPC ends. // EndTime is the time when the RPC ends.
EndTime time.Time EndTime time.Time
// Error is the error the RPC ended with. It is an error generated from // Error is the error the RPC ended with. It is an error generated from
......
// +build go1.6,!go1.7
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package status
import (
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
)
// FromContextError converts a context error into a Status. It returns a
// Status with codes.OK if err is nil, or a Status with codes.Unknown if err is
// non-nil and not a context error.
func FromContextError(err error) *Status {
switch err {
case nil:
return New(codes.OK, "")
case context.DeadlineExceeded:
return New(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return New(codes.Canceled, err.Error())
default:
return New(codes.Unknown, err.Error())
}
}
// +build go1.7
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package status
import (
"context"
netctx "golang.org/x/net/context"
"google.golang.org/grpc/codes"
)
// FromContextError converts a context error into a Status. It returns a
// Status with codes.OK if err is nil, or a Status with codes.Unknown if err is
// non-nil and not a context error.
func FromContextError(err error) *Status {
switch err {
case nil:
return New(codes.OK, "")
case context.DeadlineExceeded, netctx.DeadlineExceeded:
return New(codes.DeadlineExceeded, err.Error())
case context.Canceled, netctx.Canceled:
return New(codes.Canceled, err.Error())
default:
return New(codes.Unknown, err.Error())
}
}
...@@ -46,7 +46,7 @@ func (se *statusError) Error() string { ...@@ -46,7 +46,7 @@ func (se *statusError) Error() string {
return fmt.Sprintf("rpc error: code = %s desc = %s", codes.Code(p.GetCode()), p.GetMessage()) return fmt.Sprintf("rpc error: code = %s desc = %s", codes.Code(p.GetCode()), p.GetMessage())
} }
func (se *statusError) status() *Status { func (se *statusError) GRPCStatus() *Status {
return &Status{s: (*spb.Status)(se)} return &Status{s: (*spb.Status)(se)}
} }
...@@ -120,15 +120,23 @@ func FromProto(s *spb.Status) *Status { ...@@ -120,15 +120,23 @@ func FromProto(s *spb.Status) *Status {
} }
// FromError returns a Status representing err if it was produced from this // FromError returns a Status representing err if it was produced from this
// package, otherwise it returns nil, false. // package or has a method `GRPCStatus() *Status`. Otherwise, ok is false and a
// Status is returned with codes.Unknown and the original error message.
func FromError(err error) (s *Status, ok bool) { func FromError(err error) (s *Status, ok bool) {
if err == nil { if err == nil {
return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true
} }
if se, ok := err.(*statusError); ok { if se, ok := err.(interface{ GRPCStatus() *Status }); ok {
return se.status(), true return se.GRPCStatus(), true
} }
return nil, false return New(codes.Unknown, err.Error()), false
}
// Convert is a convenience function which removes the need to handle the
// boolean return value from FromError.
func Convert(err error) *Status {
s, _ := FromError(err)
return s
} }
// WithDetails returns a new status with the provided details messages appended to the status. // WithDetails returns a new status with the provided details messages appended to the status.
...@@ -174,8 +182,8 @@ func Code(err error) codes.Code { ...@@ -174,8 +182,8 @@ func Code(err error) codes.Code {
if err == nil { if err == nil {
return codes.OK return codes.OK
} }
if se, ok := err.(*statusError); ok { if se, ok := err.(interface{ GRPCStatus() *Status }); ok {
return se.status().Code() return se.GRPCStatus().Code()
} }
return codes.Unknown return codes.Unknown
} }
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"container/list"
)
type linkedMapKVPair struct {
key string
value *stickyStoreEntry
}
// linkedMap is an implementation of a map that supports removing the oldest
// entry.
//
// linkedMap is NOT thread safe.
//
// It's for use of stickiness only!
type linkedMap struct {
m map[string]*list.Element
l *list.List // Head of the list is the oldest element.
}
// newLinkedMap returns a new LinkedMap.
func newLinkedMap() *linkedMap {
return &linkedMap{
m: make(map[string]*list.Element),
l: list.New(),
}
}
// put adds entry (key, value) to the map. Existing key will be overridden.
func (m *linkedMap) put(key string, value *stickyStoreEntry) {
if oldE, ok := m.m[key]; ok {
// Remove existing entry.
m.l.Remove(oldE)
}
e := m.l.PushBack(&linkedMapKVPair{key: key, value: value})
m.m[key] = e
}
// get returns the value of the given key.
func (m *linkedMap) get(key string) (*stickyStoreEntry, bool) {
e, ok := m.m[key]
if !ok {
return nil, false
}
m.l.MoveToBack(e)
return e.Value.(*linkedMapKVPair).value, true
}
// remove removes key from the map, and returns the value. The map is not
// modified if key is not in the map.
func (m *linkedMap) remove(key string) (*stickyStoreEntry, bool) {
e, ok := m.m[key]
if !ok {
return nil, false
}
delete(m.m, key)
m.l.Remove(e)
return e.Value.(*linkedMapKVPair).value, true
}
// len returns the len of the map.
func (m *linkedMap) len() int {
return len(m.m)
}
// clear removes all elements from the map.
func (m *linkedMap) clear() {
m.m = make(map[string]*list.Element)
m.l = list.New()
}
// removeOldest removes the oldest key from the map.
func (m *linkedMap) removeOldest() {
e := m.l.Front()
m.l.Remove(e)
delete(m.m, e.Value.(*linkedMapKVPair).key)
}
...@@ -29,15 +29,18 @@ import ( ...@@ -29,15 +29,18 @@ import (
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
// StreamHandler defines the handler called by gRPC server to complete the // StreamHandler defines the handler called by gRPC server to complete the
// execution of a streaming RPC. // execution of a streaming RPC. If a StreamHandler returns an error, it
// should be produced by the status package, or else gRPC will use
// codes.Unknown as the status code and err.Error() as the status message
// of the RPC.
type StreamHandler func(srv interface{}, stream ServerStream) error type StreamHandler func(srv interface{}, stream ServerStream) error
// StreamDesc represents a streaming RPC service's method specification. // StreamDesc represents a streaming RPC service's method specification.
...@@ -51,6 +54,8 @@ type StreamDesc struct { ...@@ -51,6 +54,8 @@ type StreamDesc struct {
} }
// Stream defines the common interface a client or server stream has to satisfy. // Stream defines the common interface a client or server stream has to satisfy.
//
// All errors returned from Stream are compatible with the status package.
type Stream interface { type Stream interface {
// Context returns the context for this stream. // Context returns the context for this stream.
Context() context.Context Context() context.Context
...@@ -89,22 +94,40 @@ type ClientStream interface { ...@@ -89,22 +94,40 @@ type ClientStream interface {
// Stream.SendMsg() may return a non-nil error when something wrong happens sending // Stream.SendMsg() may return a non-nil error when something wrong happens sending
// the request. The returned error indicates the status of this sending, not the final // the request. The returned error indicates the status of this sending, not the final
// status of the RPC. // status of the RPC.
// Always call Stream.RecvMsg() to get the final status if you care about the status of //
// the RPC. // Always call Stream.RecvMsg() to drain the stream and get the final
// status, otherwise there could be leaked resources.
Stream Stream
} }
// NewStream creates a new Stream for the client side. This is typically // NewStream creates a new Stream for the client side. This is typically
// called by generated code. // called by generated code. ctx is used for the lifetime of the stream.
//
// To ensure resources are not leaked due to the stream returned, one of the following
// actions must be performed:
//
// 1. Call Close on the ClientConn.
// 2. Cancel the context provided.
// 3. Call RecvMsg until a non-nil error is returned. A protobuf-generated
// client-streaming RPC, for instance, might use the helper function
// CloseAndRecv (note that CloseSend does not Recv, therefore is not
// guaranteed to release all resources).
// 4. Receive a non-nil, non-io.EOF error from Header or SendMsg.
//
// If none of the above happen, a goroutine and a context will be leaked, and grpc
// will not call the optionally-configured stats handler with a stats.End message.
func (cc *ClientConn) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) { func (cc *ClientConn) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
// allow interceptor to see all applicable call options, which means those
// configured as defaults from dial option as well as per-call options
opts = combine(cc.dopts.callOptions, opts)
if cc.dopts.streamInt != nil { if cc.dopts.streamInt != nil {
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...) return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
} }
return newClientStream(ctx, desc, cc, method, opts...) return newClientStream(ctx, desc, cc, method, opts...)
} }
// NewClientStream creates a new Stream for the client side. This is typically // NewClientStream is a wrapper for ClientConn.NewStream.
// called by generated code.
// //
// DEPRECATED: Use ClientConn.NewStream instead. // DEPRECATED: Use ClientConn.NewStream instead.
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
...@@ -112,28 +135,37 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -112,28 +135,37 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var ( if channelz.IsOn() {
t transport.ClientTransport cc.incrCallsStarted()
s *transport.Stream defer func() {
done func(balancer.DoneInfo) if err != nil {
cancel context.CancelFunc cc.incrCallsFailed()
) }
}()
}
c := defaultCallInfo() c := defaultCallInfo()
mc := cc.GetMethodConfig(method) mc := cc.GetMethodConfig(method)
if mc.WaitForReady != nil { if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady c.failFast = !*mc.WaitForReady
} }
// Possible context leak:
// The cancel function for the child context we create will only be called
// when RecvMsg returns a non-nil error, if the ClientConn is closed, or if
// an error is generated by SendMsg.
// https://github.com/grpc/grpc-go/issues/1818.
var cancel context.CancelFunc
if mc.Timeout != nil && *mc.Timeout >= 0 { if mc.Timeout != nil && *mc.Timeout >= 0 {
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
defer func() { } else {
if err != nil { ctx, cancel = context.WithCancel(ctx)
cancel()
}
}()
} }
defer func() {
if err != nil {
cancel()
}
}()
opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts { for _, o := range opts {
if err := o.before(c); err != nil { if err := o.before(c); err != nil {
return nil, toRPCErr(err) return nil, toRPCErr(err)
...@@ -141,6 +173,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -141,6 +173,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize) c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize)
c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize) c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
if err := setCallInfoCodec(c); err != nil {
return nil, err
}
callHdr := &transport.CallHdr{ callHdr := &transport.CallHdr{
Host: cc.authority, Host: cc.authority,
...@@ -149,7 +184,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -149,7 +184,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// so we don't flush the header. // so we don't flush the header.
// If it's client streaming, the user may never send a request or send it any // If it's client streaming, the user may never send a request or send it any
// time soon, so we ask the transport to flush the header. // time soon, so we ask the transport to flush the header.
Flush: desc.ClientStreams, Flush: desc.ClientStreams,
ContentSubtype: c.contentSubtype,
} }
// Set our outgoing compression according to the UseCompressor CallOption, if // Set our outgoing compression according to the UseCompressor CallOption, if
...@@ -194,11 +230,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -194,11 +230,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
ctx = newContextWithRPCInfo(ctx, c.failFast) ctx = newContextWithRPCInfo(ctx, c.failFast)
sh := cc.dopts.copts.StatsHandler sh := cc.dopts.copts.StatsHandler
var beginTime time.Time
if sh != nil { if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
beginTime = time.Now()
begin := &stats.Begin{ begin := &stats.Begin{
Client: true, Client: true,
BeginTime: time.Now(), BeginTime: beginTime,
FailFast: c.failFast, FailFast: c.failFast,
} }
sh.HandleRPC(ctx, begin) sh.HandleRPC(ctx, begin)
...@@ -206,14 +244,21 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -206,14 +244,21 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if err != nil { if err != nil {
// Only handle end stats if err != nil. // Only handle end stats if err != nil.
end := &stats.End{ end := &stats.End{
Client: true, Client: true,
Error: err, Error: err,
BeginTime: beginTime,
EndTime: time.Now(),
} }
sh.HandleRPC(ctx, end) sh.HandleRPC(ctx, end)
} }
}() }()
} }
var (
t transport.ClientTransport
s *transport.Stream
done func(balancer.DoneInfo)
)
for { for {
// Check to make sure the context has expired. This will prevent us from // Check to make sure the context has expired. This will prevent us from
// looping forever if an error occurs for wait-for-ready RPCs where no data // looping forever if an error occurs for wait-for-ready RPCs where no data
...@@ -232,14 +277,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -232,14 +277,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
s, err = t.NewStream(ctx, callHdr) s, err = t.NewStream(ctx, callHdr)
if err != nil { if err != nil {
if done != nil { if done != nil {
doneInfo := balancer.DoneInfo{Err: err} done(balancer.DoneInfo{Err: err})
if _, ok := err.(transport.ConnectionError); ok {
// If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent.
doneInfo.BytesSent = true
}
done(doneInfo)
done = nil done = nil
} }
// In the event of any error from NewStream, we never attempted to write // In the event of any error from NewStream, we never attempted to write
...@@ -253,54 +291,44 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -253,54 +291,44 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
break break
} }
// Set callInfo.peer object from stream's context.
if peer, ok := peer.FromContext(s.Context()); ok {
c.peer = peer
}
cs := &clientStream{ cs := &clientStream{
opts: opts, opts: opts,
c: c, c: c,
cc: cc,
desc: desc, desc: desc,
codec: cc.dopts.codec, codec: c.codec,
cp: cp, cp: cp,
dc: cc.dopts.dc,
comp: comp, comp: comp,
cancel: cancel, cancel: cancel,
attempt: &csAttempt{
done: done, t: t,
t: t, s: s,
s: s, p: &parser{r: s},
p: &parser{r: s}, done: done,
dc: cc.dopts.dc,
tracing: EnableTracing, ctx: ctx,
trInfo: trInfo, trInfo: trInfo,
statsHandler: sh,
statsCtx: ctx, beginTime: beginTime,
statsHandler: cc.dopts.copts.StatsHandler, },
}
cs.c.stream = cs
cs.attempt.cs = cs
if desc != unaryStreamDesc {
// Listen on cc and stream contexts to cleanup when the user closes the
// ClientConn or cancels the stream context. In all other cases, an error
// should already be injected into the recv buffer by the transport, which
// the client will eventually receive, and then we will cancel the stream's
// context in clientStream.finish.
go func() {
select {
case <-cc.ctx.Done():
cs.finish(ErrClientConnClosing)
case <-ctx.Done():
cs.finish(toRPCErr(ctx.Err()))
}
}()
} }
// Listen on s.Context().Done() to detect cancellation and s.Done() to detect
// normal termination when there is no pending I/O operations on this stream.
go func() {
select {
case <-t.Error():
// Incur transport error, simply exit.
case <-cc.ctx.Done():
cs.finish(ErrClientConnClosing)
cs.closeTransportStream(ErrClientConnClosing)
case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution.
cs.finish(s.Status().Err())
cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
cs.closeTransportStream(errConnDrain)
case <-s.Context().Done():
err := s.Context().Err()
cs.finish(err)
cs.closeTransportStream(transport.ContextErr(err))
}
}()
return cs, nil return cs, nil
} }
...@@ -308,265 +336,292 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth ...@@ -308,265 +336,292 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
type clientStream struct { type clientStream struct {
opts []CallOption opts []CallOption
c *callInfo c *callInfo
cc *ClientConn
desc *StreamDesc
codec baseCodec
cp Compressor
comp encoding.Compressor
cancel context.CancelFunc // cancels all attempts
sentLast bool // sent an end stream
mu sync.Mutex // guards finished
finished bool // TODO: replace with atomic cmpxchg or sync.Once?
attempt *csAttempt // the active client stream attempt
// TODO(hedging): hedging will have multiple attempts simultaneously.
}
// csAttempt implements a single transport stream attempt within a
// clientStream.
type csAttempt struct {
cs *clientStream
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
desc *StreamDesc done func(balancer.DoneInfo)
codec Codec
cp Compressor
dc Decompressor dc Decompressor
comp encoding.Compressor
decomp encoding.Compressor decomp encoding.Compressor
decompSet bool decompSet bool
cancel context.CancelFunc ctx context.Context // the application's context, wrapped by stats/tracing
tracing bool // set to EnableTracing when the clientStream is created.
mu sync.Mutex mu sync.Mutex // guards trInfo.tr
done func(balancer.DoneInfo) // trInfo.tr is set when created (if EnableTracing is true),
closed bool // and cleared when the finish method is called.
finished bool
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called.
trInfo traceInfo trInfo traceInfo
// statsCtx keeps the user context for stats handling.
// All stats collection should use the statsCtx (instead of the stream context)
// so that all the generated stats for a particular RPC can be associated in the processing phase.
statsCtx context.Context
statsHandler stats.Handler statsHandler stats.Handler
beginTime time.Time
} }
func (cs *clientStream) Context() context.Context { func (cs *clientStream) Context() context.Context {
return cs.s.Context() // TODO(retry): commit the current attempt (the context has peer-aware data).
return cs.attempt.context()
} }
func (cs *clientStream) Header() (metadata.MD, error) { func (cs *clientStream) Header() (metadata.MD, error) {
m, err := cs.s.Header() m, err := cs.attempt.header()
if err != nil { if err != nil {
if _, ok := err.(transport.ConnectionError); !ok { // TODO(retry): maybe retry on error or commit attempt on success.
cs.closeTransportStream(err) err = toRPCErr(err)
} cs.finish(err)
} }
return m, err return m, err
} }
func (cs *clientStream) Trailer() metadata.MD { func (cs *clientStream) Trailer() metadata.MD {
return cs.s.Trailer() // TODO(retry): on error, maybe retry (trailers-only).
return cs.attempt.trailer()
} }
func (cs *clientStream) SendMsg(m interface{}) (err error) { func (cs *clientStream) SendMsg(m interface{}) (err error) {
if cs.tracing { // TODO(retry): buffer message for replaying if not committed.
cs.mu.Lock() return cs.attempt.sendMsg(m)
if cs.trInfo.tr != nil { }
cs.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
} func (cs *clientStream) RecvMsg(m interface{}) (err error) {
// TODO(retry): maybe retry on error or commit attempt on success.
return cs.attempt.recvMsg(m)
}
func (cs *clientStream) CloseSend() error {
cs.attempt.closeSend()
return nil
}
func (cs *clientStream) finish(err error) {
if err == io.EOF {
// Ending a stream with EOF indicates a success.
err = nil
}
cs.mu.Lock()
if cs.finished {
cs.mu.Unlock() cs.mu.Unlock()
return
}
cs.finished = true
cs.mu.Unlock()
if channelz.IsOn() {
if err != nil {
cs.cc.incrCallsFailed()
} else {
cs.cc.incrCallsSucceeded()
}
} }
// TODO(retry): commit current attempt if necessary.
cs.attempt.finish(err)
for _, o := range cs.opts {
o.after(cs.c)
}
cs.cancel()
}
func (a *csAttempt) context() context.Context {
return a.s.Context()
}
func (a *csAttempt) header() (metadata.MD, error) {
return a.s.Header()
}
func (a *csAttempt) trailer() metadata.MD {
return a.s.Trailer()
}
func (a *csAttempt) sendMsg(m interface{}) (err error) {
// TODO Investigate how to signal the stats handling party. // TODO Investigate how to signal the stats handling party.
// generate error stats if err != nil && err != io.EOF? // generate error stats if err != nil && err != io.EOF?
cs := a.cs
defer func() { defer func() {
if err != nil { // For non-client-streaming RPCs, we return nil instead of EOF on success
cs.finish(err) // because the generated code requires it. finish is not called; RecvMsg()
} // will call it with the stream's status independently.
if err == nil { if err == io.EOF && !cs.desc.ClientStreams {
return err = nil
} }
if err == io.EOF { if err != nil && err != io.EOF {
// Specialize the process for server streaming. SendMsg is only called // Call finish on the client stream for errors generated by this SendMsg
// once when creating the stream object. io.EOF needs to be skipped when // call, as these indicate problems created by this client. (Transport
// the rpc is early finished (before the stream object is created.). // errors are converted to an io.EOF error below; the real error will be
// TODO: It is probably better to move this into the generated code. // returned from RecvMsg eventually in that case, or be retried.)
if !cs.desc.ClientStreams && cs.desc.ServerStreams { cs.finish(err)
err = nil
}
return
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
} }
err = toRPCErr(err)
}() }()
var outPayload *stats.OutPayload // TODO: Check cs.sentLast and error if we already ended the stream.
if cs.statsHandler != nil { if EnableTracing {
outPayload = &stats.OutPayload{ a.mu.Lock()
Client: true, if a.trInfo.tr != nil {
a.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
} }
a.mu.Unlock()
} }
hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp) data, err := encode(cs.codec, m)
if err != nil { if err != nil {
return err return err
} }
if cs.c.maxSendMessageSize == nil { compData, err := compress(data, cs.cp, cs.comp)
return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") if err != nil {
return err
} }
if len(data) > *cs.c.maxSendMessageSize { hdr, payload := msgHeader(data, compData)
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > *cs.c.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.c.maxSendMessageSize)
} }
err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false})
if err == nil && outPayload != nil { if !cs.desc.ClientStreams {
outPayload.SentTime = time.Now() cs.sentLast = true
cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
} }
return err err = a.t.Write(a.s, hdr, payload, &transport.Options{Last: !cs.desc.ClientStreams})
if err == nil {
if a.statsHandler != nil {
a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, data, payload, time.Now()))
}
if channelz.IsOn() {
a.t.IncrMsgSent()
}
return nil
}
return io.EOF
} }
func (cs *clientStream) RecvMsg(m interface{}) (err error) { func (a *csAttempt) recvMsg(m interface{}) (err error) {
cs := a.cs
defer func() {
if err != nil || !cs.desc.ServerStreams {
// err != nil or non-server-streaming indicates end of stream.
cs.finish(err)
}
}()
var inPayload *stats.InPayload var inPayload *stats.InPayload
if cs.statsHandler != nil { if a.statsHandler != nil {
inPayload = &stats.InPayload{ inPayload = &stats.InPayload{
Client: true, Client: true,
} }
} }
if cs.c.maxReceiveMessageSize == nil { if !a.decompSet {
return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
if !cs.decompSet {
// Block until we receive headers containing received message encoding. // Block until we receive headers containing received message encoding.
if ct := cs.s.RecvCompress(); ct != "" && ct != encoding.Identity { if ct := a.s.RecvCompress(); ct != "" && ct != encoding.Identity {
if cs.dc == nil || cs.dc.Type() != ct { if a.dc == nil || a.dc.Type() != ct {
// No configured decompressor, or it does not match the incoming // No configured decompressor, or it does not match the incoming
// message encoding; attempt to find a registered compressor that does. // message encoding; attempt to find a registered compressor that does.
cs.dc = nil a.dc = nil
cs.decomp = encoding.GetCompressor(ct) a.decomp = encoding.GetCompressor(ct)
} }
} else { } else {
// No compression is used; disable our decompressor. // No compression is used; disable our decompressor.
cs.dc = nil a.dc = nil
} }
// Only initialize this state once per stream. // Only initialize this state once per stream.
cs.decompSet = true a.decompSet = true
} }
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp) err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, inPayload, a.decomp)
defer func() { if err != nil {
// err != nil indicates the termination of the stream.
if err != nil {
cs.finish(err)
}
}()
if err == nil {
if cs.tracing {
cs.mu.Lock()
if cs.trInfo.tr != nil {
cs.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
}
cs.mu.Unlock()
}
if inPayload != nil {
cs.statsHandler.HandleRPC(cs.statsCtx, inPayload)
}
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
return
}
// Special handling for client streaming rpc.
// This recv expects EOF or errors, so we don't collect inPayload.
if cs.c.maxReceiveMessageSize == nil {
return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err == io.EOF { if err == io.EOF {
if se := cs.s.Status().Err(); se != nil { if statusErr := a.s.Status().Err(); statusErr != nil {
return se return statusErr
} }
cs.finish(err) return io.EOF // indicates successful end of stream.
return nil
} }
return toRPCErr(err) return toRPCErr(err)
} }
if _, ok := err.(transport.ConnectionError); !ok { if EnableTracing {
cs.closeTransportStream(err) a.mu.Lock()
} if a.trInfo.tr != nil {
if err == io.EOF { a.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
if statusErr := cs.s.Status().Err(); statusErr != nil {
return statusErr
} }
// Returns io.EOF to indicate the end of the stream. a.mu.Unlock()
return
} }
return toRPCErr(err) if inPayload != nil {
} a.statsHandler.HandleRPC(a.ctx, inPayload)
}
func (cs *clientStream) CloseSend() (err error) { if channelz.IsOn() {
err = cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true}) a.t.IncrMsgRecv()
defer func() { }
if err != nil { if cs.desc.ServerStreams {
cs.finish(err) // Subsequent messages should be received by subsequent RecvMsg calls.
}
}()
if err == nil || err == io.EOF {
return nil return nil
} }
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err) // Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, nil, a.decomp)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
} }
err = toRPCErr(err) if err == io.EOF {
return return a.s.Status().Err() // non-server streaming Recv returns nil on success
}
return toRPCErr(err)
} }
func (cs *clientStream) closeTransportStream(err error) { func (a *csAttempt) closeSend() {
cs.mu.Lock() cs := a.cs
if cs.closed { if cs.sentLast {
cs.mu.Unlock()
return return
} }
cs.closed = true cs.sentLast = true
cs.mu.Unlock() cs.attempt.t.Write(cs.attempt.s, nil, nil, &transport.Options{Last: true})
cs.t.CloseStream(cs.s, err) // We ignore errors from Write. Any error it would return would also be
// returned by a subsequent RecvMsg call, and the user is supposed to always
// finish the stream by calling RecvMsg until it returns err != nil.
} }
func (cs *clientStream) finish(err error) { func (a *csAttempt) finish(err error) {
cs.mu.Lock() a.mu.Lock()
defer cs.mu.Unlock() a.t.CloseStream(a.s, err)
if cs.finished {
return if a.done != nil {
} a.done(balancer.DoneInfo{
cs.finished = true
defer func() {
if cs.cancel != nil {
cs.cancel()
}
}()
for _, o := range cs.opts {
o.after(cs.c)
}
if cs.done != nil {
cs.done(balancer.DoneInfo{
Err: err, Err: err,
BytesSent: true, BytesSent: true,
BytesReceived: cs.s.BytesReceived(), BytesReceived: a.s.BytesReceived(),
}) })
cs.done = nil
} }
if cs.statsHandler != nil { if a.statsHandler != nil {
end := &stats.End{ end := &stats.End{
Client: true, Client: true,
EndTime: time.Now(), BeginTime: a.beginTime,
} EndTime: time.Now(),
if err != io.EOF { Error: err,
// end.Error is nil if the RPC finished successfully.
end.Error = toRPCErr(err)
} }
cs.statsHandler.HandleRPC(cs.statsCtx, end) a.statsHandler.HandleRPC(a.ctx, end)
} }
if !cs.tracing { if a.trInfo.tr != nil {
return if err == nil {
} a.trInfo.tr.LazyPrintf("RPC: [OK]")
if cs.trInfo.tr != nil {
if err == nil || err == io.EOF {
cs.trInfo.tr.LazyPrintf("RPC: [OK]")
} else { } else {
cs.trInfo.tr.LazyPrintf("RPC: [%v]", err) a.trInfo.tr.LazyPrintf("RPC: [%v]", err)
cs.trInfo.tr.SetError() a.trInfo.tr.SetError()
} }
cs.trInfo.tr.Finish() a.trInfo.tr.Finish()
cs.trInfo.tr = nil a.trInfo.tr = nil
} }
a.mu.Unlock()
} }
// ServerStream defines the interface a server stream has to satisfy. // ServerStream defines the interface a server stream has to satisfy.
...@@ -590,10 +645,11 @@ type ServerStream interface { ...@@ -590,10 +645,11 @@ type ServerStream interface {
// serverStream implements a server side Stream. // serverStream implements a server side Stream.
type serverStream struct { type serverStream struct {
ctx context.Context
t transport.ServerTransport t transport.ServerTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
codec Codec codec baseCodec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
...@@ -610,7 +666,7 @@ type serverStream struct { ...@@ -610,7 +666,7 @@ type serverStream struct {
} }
func (ss *serverStream) Context() context.Context { func (ss *serverStream) Context() context.Context {
return ss.s.Context() return ss.ctx
} }
func (ss *serverStream) SetHeader(md metadata.MD) error { func (ss *serverStream) SetHeader(md metadata.MD) error {
...@@ -629,7 +685,6 @@ func (ss *serverStream) SetTrailer(md metadata.MD) { ...@@ -629,7 +685,6 @@ func (ss *serverStream) SetTrailer(md metadata.MD) {
return return
} }
ss.s.SetTrailer(md) ss.s.SetTrailer(md)
return
} }
func (ss *serverStream) SendMsg(m interface{}) (err error) { func (ss *serverStream) SendMsg(m interface{}) (err error) {
...@@ -650,24 +705,28 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { ...@@ -650,24 +705,28 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
st, _ := status.FromError(toRPCErr(err)) st, _ := status.FromError(toRPCErr(err))
ss.t.WriteStatus(ss.s, st) ss.t.WriteStatus(ss.s, st)
} }
if channelz.IsOn() && err == nil {
ss.t.IncrMsgSent()
}
}() }()
var outPayload *stats.OutPayload data, err := encode(ss.codec, m)
if ss.statsHandler != nil { if err != nil {
outPayload = &stats.OutPayload{} return err
} }
hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp) compData, err := compress(data, ss.cp, ss.comp)
if err != nil { if err != nil {
return err return err
} }
if len(data) > ss.maxSendMessageSize { hdr, payload := msgHeader(data, compData)
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > ss.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize)
} }
if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil { if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
if outPayload != nil { if ss.statsHandler != nil {
outPayload.SentTime = time.Now() ss.statsHandler.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now()))
ss.statsHandler.HandleRPC(ss.s.Context(), outPayload)
} }
return nil return nil
} }
...@@ -690,6 +749,9 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { ...@@ -690,6 +749,9 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
st, _ := status.FromError(toRPCErr(err)) st, _ := status.FromError(toRPCErr(err))
ss.t.WriteStatus(ss.s, st) ss.t.WriteStatus(ss.s, st)
} }
if channelz.IsOn() && err == nil {
ss.t.IncrMsgRecv()
}
}() }()
var inPayload *stats.InPayload var inPayload *stats.InPayload
if ss.statsHandler != nil { if ss.statsHandler != nil {
...@@ -713,9 +775,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { ...@@ -713,9 +775,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
// MethodFromServerStream returns the method string for the input stream. // MethodFromServerStream returns the method string for the input stream.
// The returned string is in the format of "/service/method". // The returned string is in the format of "/service/method".
func MethodFromServerStream(stream ServerStream) (string, bool) { func MethodFromServerStream(stream ServerStream) (string, bool) {
s, ok := transport.StreamFromContext(stream.Context()) return Method(stream.Context())
if !ok {
return "", ok
}
return s.Method(), ok
} }
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"bytes"
"fmt"
"runtime"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
e.SetMaxDynamicTableSizeLimit(v)
}
type itemNode struct {
it interface{}
next *itemNode
}
type itemList struct {
head *itemNode
tail *itemNode
}
func (il *itemList) enqueue(i interface{}) {
n := &itemNode{it: i}
if il.tail == nil {
il.head, il.tail = n, n
return
}
il.tail.next = n
il.tail = n
}
// peek returns the first item in the list without removing it from the
// list.
func (il *itemList) peek() interface{} {
return il.head.it
}
func (il *itemList) dequeue() interface{} {
if il.head == nil {
return nil
}
i := il.head.it
il.head = il.head.next
if il.head == nil {
il.tail = nil
}
return i
}
func (il *itemList) dequeueAll() *itemNode {
h := il.head
il.head, il.tail = nil, nil
return h
}
func (il *itemList) isEmpty() bool {
return il.head == nil
}
// The following defines various control items which could flow through
// the control buffer of transport. They represent different aspects of
// control tasks, e.g., flow control, settings, streaming resetting, etc.
// registerStream is used to register an incoming stream with loopy writer.
type registerStream struct {
streamID uint32
wq *writeQuota
}
// headerFrame is also used to register stream on the client-side.
type headerFrame struct {
streamID uint32
hf []hpack.HeaderField
endStream bool // Valid on server side.
initStream func(uint32) (bool, error) // Used only on the client side.
onWrite func()
wq *writeQuota // write quota for the stream created.
cleanup *cleanupStream // Valid on the server side.
onOrphaned func(error) // Valid on client-side
}
type cleanupStream struct {
streamID uint32
idPtr *uint32
rst bool
rstCode http2.ErrCode
onWrite func()
}
type dataFrame struct {
streamID uint32
endStream bool
h []byte
d []byte
// onEachWrite is called every time
// a part of d is written out.
onEachWrite func()
}
type incomingWindowUpdate struct {
streamID uint32
increment uint32
}
type outgoingWindowUpdate struct {
streamID uint32
increment uint32
}
type incomingSettings struct {
ss []http2.Setting
}
type outgoingSettings struct {
ss []http2.Setting
}
type settingsAck struct {
}
type incomingGoAway struct {
}
type goAway struct {
code http2.ErrCode
debugData []byte
headsUp bool
closeConn bool
}
type ping struct {
ack bool
data [8]byte
}
type outFlowControlSizeRequest struct {
resp chan uint32
}
type outStreamState int
const (
active outStreamState = iota
empty
waitingOnStreamQuota
)
type outStream struct {
id uint32
state outStreamState
itl *itemList
bytesOutStanding int
wq *writeQuota
next *outStream
prev *outStream
}
func (s *outStream) deleteSelf() {
if s.prev != nil {
s.prev.next = s.next
}
if s.next != nil {
s.next.prev = s.prev
}
s.next, s.prev = nil, nil
}
type outStreamList struct {
// Following are sentinel objects that mark the
// beginning and end of the list. They do not
// contain any item lists. All valid objects are
// inserted in between them.
// This is needed so that an outStream object can
// deleteSelf() in O(1) time without knowing which
// list it belongs to.
head *outStream
tail *outStream
}
func newOutStreamList() *outStreamList {
head, tail := new(outStream), new(outStream)
head.next = tail
tail.prev = head
return &outStreamList{
head: head,
tail: tail,
}
}
func (l *outStreamList) enqueue(s *outStream) {
e := l.tail.prev
e.next = s
s.prev = e
s.next = l.tail
l.tail.prev = s
}
// remove from the beginning of the list.
func (l *outStreamList) dequeue() *outStream {
b := l.head.next
if b == l.tail {
return nil
}
b.deleteSelf()
return b
}
type controlBuffer struct {
ch chan struct{}
done <-chan struct{}
mu sync.Mutex
consumerWaiting bool
list *itemList
err error
}
func newControlBuffer(done <-chan struct{}) *controlBuffer {
return &controlBuffer{
ch: make(chan struct{}, 1),
list: &itemList{},
done: done,
}
}
func (c *controlBuffer) put(it interface{}) error {
_, err := c.executeAndPut(nil, it)
return err
}
func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{}) (bool, error) {
var wakeUp bool
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return false, c.err
}
if f != nil {
if !f(it) { // f wasn't successful
c.mu.Unlock()
return false, nil
}
}
if c.consumerWaiting {
wakeUp = true
c.consumerWaiting = false
}
c.list.enqueue(it)
c.mu.Unlock()
if wakeUp {
select {
case c.ch <- struct{}{}:
default:
}
}
return true, nil
}
func (c *controlBuffer) get(block bool) (interface{}, error) {
for {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return nil, c.err
}
if !c.list.isEmpty() {
h := c.list.dequeue()
c.mu.Unlock()
return h, nil
}
if !block {
c.mu.Unlock()
return nil, nil
}
c.consumerWaiting = true
c.mu.Unlock()
select {
case <-c.ch:
case <-c.done:
c.finish()
return nil, ErrConnClosing
}
}
}
func (c *controlBuffer) finish() {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return
}
c.err = ErrConnClosing
// There may be headers for streams in the control buffer.
// These streams need to be cleaned out since the transport
// is still not aware of these yet.
for head := c.list.dequeueAll(); head != nil; head = head.next {
hdr, ok := head.it.(*headerFrame)
if !ok {
continue
}
if hdr.onOrphaned != nil { // It will be nil on the server-side.
hdr.onOrphaned(ErrConnClosing)
}
}
c.mu.Unlock()
}
type side int
const (
clientSide side = iota
serverSide
)
type loopyWriter struct {
side side
cbuf *controlBuffer
sendQuota uint32
oiws uint32 // outbound initial window size.
estdStreams map[uint32]*outStream // Established streams.
activeStreams *outStreamList // Streams that are sending data.
framer *framer
hBuf *bytes.Buffer // The buffer for HPACK encoding.
hEnc *hpack.Encoder // HPACK encoder.
bdpEst *bdpEstimator
draining bool
// Side-specific handlers
ssGoAwayHandler func(*goAway) (bool, error)
}
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator) *loopyWriter {
var buf bytes.Buffer
l := &loopyWriter{
side: s,
cbuf: cbuf,
sendQuota: defaultWindowSize,
oiws: defaultWindowSize,
estdStreams: make(map[uint32]*outStream),
activeStreams: newOutStreamList(),
framer: fr,
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
bdpEst: bdpEst,
}
return l
}
const minBatchSize = 1000
// run should be run in a separate goroutine.
func (l *loopyWriter) run() (err error) {
defer func() {
if err == ErrConnClosing {
// Don't log ErrConnClosing as error since it happens
// 1. When the connection is closed by some other known issue.
// 2. User closed the connection.
// 3. A graceful close of connection.
infof("transport: loopyWriter.run returning. %v", err)
err = nil
}
}()
for {
it, err := l.cbuf.get(true)
if err != nil {
return err
}
if err = l.handle(it); err != nil {
return err
}
if _, err = l.processData(); err != nil {
return err
}
gosched := true
hasdata:
for {
it, err := l.cbuf.get(false)
if err != nil {
return err
}
if it != nil {
if err = l.handle(it); err != nil {
return err
}
if _, err = l.processData(); err != nil {
return err
}
continue hasdata
}
isEmpty, err := l.processData()
if err != nil {
return err
}
if !isEmpty {
continue hasdata
}
if gosched {
gosched = false
if l.framer.writer.offset < minBatchSize {
runtime.Gosched()
continue hasdata
}
}
l.framer.writer.Flush()
break hasdata
}
}
}
func (l *loopyWriter) outgoingWindowUpdateHandler(w *outgoingWindowUpdate) error {
return l.framer.fr.WriteWindowUpdate(w.streamID, w.increment)
}
func (l *loopyWriter) incomingWindowUpdateHandler(w *incomingWindowUpdate) error {
// Otherwise update the quota.
if w.streamID == 0 {
l.sendQuota += w.increment
return nil
}
// Find the stream and update it.
if str, ok := l.estdStreams[w.streamID]; ok {
str.bytesOutStanding -= int(w.increment)
if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota > 0 && str.state == waitingOnStreamQuota {
str.state = active
l.activeStreams.enqueue(str)
return nil
}
}
return nil
}
func (l *loopyWriter) outgoingSettingsHandler(s *outgoingSettings) error {
return l.framer.fr.WriteSettings(s.ss...)
}
func (l *loopyWriter) incomingSettingsHandler(s *incomingSettings) error {
if err := l.applySettings(s.ss); err != nil {
return err
}
return l.framer.fr.WriteSettingsAck()
}
func (l *loopyWriter) registerStreamHandler(h *registerStream) error {
str := &outStream{
id: h.streamID,
state: empty,
itl: &itemList{},
wq: h.wq,
}
l.estdStreams[h.streamID] = str
return nil
}
func (l *loopyWriter) headerHandler(h *headerFrame) error {
if l.side == serverSide {
str, ok := l.estdStreams[h.streamID]
if !ok {
warningf("transport: loopy doesn't recognize the stream: %d", h.streamID)
return nil
}
// Case 1.A: Server is responding back with headers.
if !h.endStream {
return l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite)
}
// else: Case 1.B: Server wants to close stream.
if str.state != empty { // either active or waiting on stream quota.
// add it str's list of items.
str.itl.enqueue(h)
return nil
}
if err := l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite); err != nil {
return err
}
return l.cleanupStreamHandler(h.cleanup)
}
// Case 2: Client wants to originate stream.
str := &outStream{
id: h.streamID,
state: empty,
itl: &itemList{},
wq: h.wq,
}
str.itl.enqueue(h)
return l.originateStream(str)
}
func (l *loopyWriter) originateStream(str *outStream) error {
hdr := str.itl.dequeue().(*headerFrame)
sendPing, err := hdr.initStream(str.id)
if err != nil {
if err == ErrConnClosing {
return err
}
// Other errors(errStreamDrain) need not close transport.
return nil
}
if err = l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil {
return err
}
l.estdStreams[str.id] = str
if sendPing {
return l.pingHandler(&ping{data: [8]byte{}})
}
return nil
}
func (l *loopyWriter) writeHeader(streamID uint32, endStream bool, hf []hpack.HeaderField, onWrite func()) error {
if onWrite != nil {
onWrite()
}
l.hBuf.Reset()
for _, f := range hf {
if err := l.hEnc.WriteField(f); err != nil {
warningf("transport: loopyWriter.writeHeader encountered error while encoding headers:", err)
}
}
var (
err error
endHeaders, first bool
)
first = true
for !endHeaders {
size := l.hBuf.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
if first {
first = false
err = l.framer.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: streamID,
BlockFragment: l.hBuf.Next(size),
EndStream: endStream,
EndHeaders: endHeaders,
})
} else {
err = l.framer.fr.WriteContinuation(
streamID,
endHeaders,
l.hBuf.Next(size),
)
}
if err != nil {
return err
}
}
return nil
}
func (l *loopyWriter) preprocessData(df *dataFrame) error {
str, ok := l.estdStreams[df.streamID]
if !ok {
return nil
}
// If we got data for a stream it means that
// stream was originated and the headers were sent out.
str.itl.enqueue(df)
if str.state == empty {
str.state = active
l.activeStreams.enqueue(str)
}
return nil
}
func (l *loopyWriter) pingHandler(p *ping) error {
if !p.ack {
l.bdpEst.timesnap(p.data)
}
return l.framer.fr.WritePing(p.ack, p.data)
}
func (l *loopyWriter) outFlowControlSizeRequestHandler(o *outFlowControlSizeRequest) error {
o.resp <- l.sendQuota
return nil
}
func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
c.onWrite()
if str, ok := l.estdStreams[c.streamID]; ok {
// On the server side it could be a trailers-only response or
// a RST_STREAM before stream initialization thus the stream might
// not be established yet.
delete(l.estdStreams, c.streamID)
str.deleteSelf()
}
if c.rst { // If RST_STREAM needs to be sent.
if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil {
return err
}
}
if l.side == clientSide && l.draining && len(l.estdStreams) == 0 {
return ErrConnClosing
}
return nil
}
func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error {
if l.side == clientSide {
l.draining = true
if len(l.estdStreams) == 0 {
return ErrConnClosing
}
}
return nil
}
func (l *loopyWriter) goAwayHandler(g *goAway) error {
// Handling of outgoing GoAway is very specific to side.
if l.ssGoAwayHandler != nil {
draining, err := l.ssGoAwayHandler(g)
if err != nil {
return err
}
l.draining = draining
}
return nil
}
func (l *loopyWriter) handle(i interface{}) error {
switch i := i.(type) {
case *incomingWindowUpdate:
return l.incomingWindowUpdateHandler(i)
case *outgoingWindowUpdate:
return l.outgoingWindowUpdateHandler(i)
case *incomingSettings:
return l.incomingSettingsHandler(i)
case *outgoingSettings:
return l.outgoingSettingsHandler(i)
case *headerFrame:
return l.headerHandler(i)
case *registerStream:
return l.registerStreamHandler(i)
case *cleanupStream:
return l.cleanupStreamHandler(i)
case *incomingGoAway:
return l.incomingGoAwayHandler(i)
case *dataFrame:
return l.preprocessData(i)
case *ping:
return l.pingHandler(i)
case *goAway:
return l.goAwayHandler(i)
case *outFlowControlSizeRequest:
return l.outFlowControlSizeRequestHandler(i)
default:
return fmt.Errorf("transport: unknown control message type %T", i)
}
}
func (l *loopyWriter) applySettings(ss []http2.Setting) error {
for _, s := range ss {
switch s.ID {
case http2.SettingInitialWindowSize:
o := l.oiws
l.oiws = s.Val
if o < l.oiws {
// If the new limit is greater make all depleted streams active.
for _, stream := range l.estdStreams {
if stream.state == waitingOnStreamQuota {
stream.state = active
l.activeStreams.enqueue(stream)
}
}
}
case http2.SettingHeaderTableSize:
updateHeaderTblSize(l.hEnc, s.Val)
}
}
return nil
}
func (l *loopyWriter) processData() (bool, error) {
if l.sendQuota == 0 {
return true, nil
}
str := l.activeStreams.dequeue()
if str == nil {
return true, nil
}
dataItem := str.itl.peek().(*dataFrame)
if len(dataItem.h) == 0 && len(dataItem.d) == 0 {
// Client sends out empty data frame with endStream = true
if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil {
return false, err
}
str.itl.dequeue()
if str.itl.isEmpty() {
str.state = empty
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers.
if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil {
return false, err
}
if err := l.cleanupStreamHandler(trailer.cleanup); err != nil {
return false, nil
}
} else {
l.activeStreams.enqueue(str)
}
return false, nil
}
var (
idx int
buf []byte
)
if len(dataItem.h) != 0 { // data header has not been written out yet.
buf = dataItem.h
} else {
idx = 1
buf = dataItem.d
}
size := http2MaxFrameLen
if len(buf) < size {
size = len(buf)
}
if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 {
str.state = waitingOnStreamQuota
return false, nil
} else if strQuota < size {
size = strQuota
}
if l.sendQuota < uint32(size) {
size = int(l.sendQuota)
}
// Now that outgoing flow controls are checked we can replenish str's write quota
str.wq.replenish(size)
var endStream bool
// This last data message on this stream and all
// of it can be written in this go.
if dataItem.endStream && size == len(buf) {
// buf contains either data or it contains header but data is empty.
if idx == 1 || len(dataItem.d) == 0 {
endStream = true
}
}
if dataItem.onEachWrite != nil {
dataItem.onEachWrite()
}
if err := l.framer.fr.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil {
return false, err
}
buf = buf[size:]
str.bytesOutStanding += size
l.sendQuota -= uint32(size)
if idx == 0 {
dataItem.h = buf
} else {
dataItem.d = buf
}
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out.
str.itl.dequeue()
}
if str.itl.isEmpty() {
str.state = empty
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // The next item is trailers.
if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil {
return false, err
}
if err := l.cleanupStreamHandler(trailer.cleanup); err != nil {
return false, err
}
} else if int(l.oiws)-str.bytesOutStanding <= 0 { // Ran out of stream quota.
str.state = waitingOnStreamQuota
} else { // Otherwise add it back to the list of active streams.
l.activeStreams.enqueue(str)
}
return false, nil
}
...@@ -20,13 +20,10 @@ package transport ...@@ -20,13 +20,10 @@ package transport
import ( import (
"fmt" "fmt"
"io"
"math" "math"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
) )
const ( const (
...@@ -36,202 +33,115 @@ const ( ...@@ -36,202 +33,115 @@ const (
initialWindowSize = defaultWindowSize // for an RPC initialWindowSize = defaultWindowSize // for an RPC
infinity = time.Duration(math.MaxInt64) infinity = time.Duration(math.MaxInt64)
defaultClientKeepaliveTime = infinity defaultClientKeepaliveTime = infinity
defaultClientKeepaliveTimeout = time.Duration(20 * time.Second) defaultClientKeepaliveTimeout = 20 * time.Second
defaultMaxStreamsClient = 100 defaultMaxStreamsClient = 100
defaultMaxConnectionIdle = infinity defaultMaxConnectionIdle = infinity
defaultMaxConnectionAge = infinity defaultMaxConnectionAge = infinity
defaultMaxConnectionAgeGrace = infinity defaultMaxConnectionAgeGrace = infinity
defaultServerKeepaliveTime = time.Duration(2 * time.Hour) defaultServerKeepaliveTime = 2 * time.Hour
defaultServerKeepaliveTimeout = time.Duration(20 * time.Second) defaultServerKeepaliveTimeout = 20 * time.Second
defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute) defaultKeepalivePolicyMinTime = 5 * time.Minute
// max window limit set by HTTP2 Specs. // max window limit set by HTTP2 Specs.
maxWindowSize = math.MaxInt32 maxWindowSize = math.MaxInt32
// defaultLocalSendQuota sets is default value for number of data // defaultWriteQuota is the default value for number of data
// bytes that each stream can schedule before some of it being // bytes that each stream can schedule before some of it being
// flushed out. // flushed out.
defaultLocalSendQuota = 128 * 1024 defaultWriteQuota = 64 * 1024
) )
// The following defines various control items which could flow through // writeQuota is a soft limit on the amount of data a stream can
// the control buffer of transport. They represent different aspects of // schedule before some of it is written out.
// control tasks, e.g., flow control, settings, streaming resetting, etc. type writeQuota struct {
quota int32
type headerFrame struct { // get waits on read from when quota goes less than or equal to zero.
streamID uint32 // replenish writes on it when quota goes positive again.
hf []hpack.HeaderField ch chan struct{}
endStream bool // done is triggered in error case.
} done <-chan struct{}
// replenish is called by loopyWriter to give quota back to.
func (*headerFrame) item() {} // It is implemented as a field so that it can be updated
// by tests.
type continuationFrame struct { replenish func(n int)
streamID uint32 }
endHeaders bool
headerBlockFragment []byte func newWriteQuota(sz int32, done <-chan struct{}) *writeQuota {
} w := &writeQuota{
quota: sz,
type dataFrame struct { ch: make(chan struct{}, 1),
streamID uint32 done: done,
endStream bool
d []byte
f func()
}
func (*dataFrame) item() {}
func (*continuationFrame) item() {}
type windowUpdate struct {
streamID uint32
increment uint32
}
func (*windowUpdate) item() {}
type settings struct {
ss []http2.Setting
}
func (*settings) item() {}
type settingsAck struct {
}
func (*settingsAck) item() {}
type resetStream struct {
streamID uint32
code http2.ErrCode
}
func (*resetStream) item() {}
type goAway struct {
code http2.ErrCode
debugData []byte
headsUp bool
closeConn bool
}
func (*goAway) item() {}
type flushIO struct {
closeTr bool
}
func (*flushIO) item() {}
type ping struct {
ack bool
data [8]byte
}
func (*ping) item() {}
// quotaPool is a pool which accumulates the quota and sends it to acquire()
// when it is available.
type quotaPool struct {
mu sync.Mutex
c chan struct{}
version uint32
quota int
}
// newQuotaPool creates a quotaPool which has quota q available to consume.
func newQuotaPool(q int) *quotaPool {
qb := &quotaPool{
quota: q,
c: make(chan struct{}, 1),
} }
return qb w.replenish = w.realReplenish
return w
} }
// add cancels the pending quota sent on acquired, incremented by v and sends func (w *writeQuota) get(sz int32) error {
// it back on acquire. for {
func (qb *quotaPool) add(v int) { if atomic.LoadInt32(&w.quota) > 0 {
qb.mu.Lock() atomic.AddInt32(&w.quota, -sz)
defer qb.mu.Unlock() return nil
qb.lockedAdd(v) }
select {
case <-w.ch:
continue
case <-w.done:
return errStreamDone
}
}
} }
func (qb *quotaPool) lockedAdd(v int) { func (w *writeQuota) realReplenish(n int) {
var wakeUp bool sz := int32(n)
if qb.quota <= 0 { a := atomic.AddInt32(&w.quota, sz)
wakeUp = true // Wake up potential waiters. b := a - sz
} if b <= 0 && a > 0 {
qb.quota += v
if wakeUp && qb.quota > 0 {
select { select {
case qb.c <- struct{}{}: case w.ch <- struct{}{}:
default: default:
} }
} }
} }
func (qb *quotaPool) addAndUpdate(v int) { type trInFlow struct {
qb.mu.Lock() limit uint32
qb.lockedAdd(v) unacked uint32
qb.version++ effectiveWindowSize uint32
qb.mu.Unlock()
} }
func (qb *quotaPool) get(v int, wc waiters) (int, uint32, error) { func (f *trInFlow) newLimit(n uint32) uint32 {
qb.mu.Lock() d := n - f.limit
if qb.quota > 0 { f.limit = n
if v > qb.quota { f.updateEffectiveWindowSize()
v = qb.quota return d
} }
qb.quota -= v
ver := qb.version
qb.mu.Unlock()
return v, ver, nil
}
qb.mu.Unlock()
for {
select {
case <-wc.ctx.Done():
return 0, 0, ContextErr(wc.ctx.Err())
case <-wc.tctx.Done():
return 0, 0, ErrConnClosing
case <-wc.done:
return 0, 0, io.EOF
case <-wc.goAway:
return 0, 0, errStreamDrain
case <-qb.c:
qb.mu.Lock()
if qb.quota > 0 {
if v > qb.quota {
v = qb.quota
}
qb.quota -= v
ver := qb.version
if qb.quota > 0 {
select {
case qb.c <- struct{}{}:
default:
}
}
qb.mu.Unlock()
return v, ver, nil
} func (f *trInFlow) onData(n uint32) uint32 {
qb.mu.Unlock() f.unacked += n
} if f.unacked >= f.limit/4 {
w := f.unacked
f.unacked = 0
f.updateEffectiveWindowSize()
return w
} }
f.updateEffectiveWindowSize()
return 0
} }
func (qb *quotaPool) compareAndExecute(version uint32, success, failure func()) bool { func (f *trInFlow) reset() uint32 {
qb.mu.Lock() w := f.unacked
if version == qb.version { f.unacked = 0
success() f.updateEffectiveWindowSize()
qb.mu.Unlock() return w
return true
}
failure()
qb.mu.Unlock()
return false
} }
func (f *trInFlow) updateEffectiveWindowSize() {
atomic.StoreUint32(&f.effectiveWindowSize, f.limit-f.unacked)
}
func (f *trInFlow) getSize() uint32 {
return atomic.LoadUint32(&f.effectiveWindowSize)
}
// TODO(mmukhi): Simplify this code.
// inFlow deals with inbound flow control // inFlow deals with inbound flow control
type inFlow struct { type inFlow struct {
mu sync.Mutex mu sync.Mutex
...@@ -252,9 +162,9 @@ type inFlow struct { ...@@ -252,9 +162,9 @@ type inFlow struct {
// It assumes that n is always greater than the old limit. // It assumes that n is always greater than the old limit.
func (f *inFlow) newLimit(n uint32) uint32 { func (f *inFlow) newLimit(n uint32) uint32 {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock()
d := n - f.limit d := n - f.limit
f.limit = n f.limit = n
f.mu.Unlock()
return d return d
} }
...@@ -263,7 +173,6 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 { ...@@ -263,7 +173,6 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 {
n = uint32(math.MaxInt32) n = uint32(math.MaxInt32)
} }
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock()
// estSenderQuota is the receiver's view of the maximum number of bytes the sender // estSenderQuota is the receiver's view of the maximum number of bytes the sender
// can send without a window update. // can send without a window update.
estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate)) estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
...@@ -275,7 +184,7 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 { ...@@ -275,7 +184,7 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 {
// for this message. Therefore we must send an update over the limit since there's an active read // for this message. Therefore we must send an update over the limit since there's an active read
// request from the application. // request from the application.
if estUntransmittedData > estSenderQuota { if estUntransmittedData > estSenderQuota {
// Sender's window shouldn't go more than 2^31 - 1 as speecified in the HTTP spec. // Sender's window shouldn't go more than 2^31 - 1 as specified in the HTTP spec.
if f.limit+n > maxWindowSize { if f.limit+n > maxWindowSize {
f.delta = maxWindowSize - f.limit f.delta = maxWindowSize - f.limit
} else { } else {
...@@ -284,19 +193,24 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 { ...@@ -284,19 +193,24 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 {
// is padded; We will fallback on the current available window(at least a 1/4th of the limit). // is padded; We will fallback on the current available window(at least a 1/4th of the limit).
f.delta = n f.delta = n
} }
f.mu.Unlock()
return f.delta return f.delta
} }
f.mu.Unlock()
return 0 return 0
} }
// onData is invoked when some data frame is received. It updates pendingData. // onData is invoked when some data frame is received. It updates pendingData.
func (f *inFlow) onData(n uint32) error { func (f *inFlow) onData(n uint32) error {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock()
f.pendingData += n f.pendingData += n
if f.pendingData+f.pendingUpdate > f.limit+f.delta { if f.pendingData+f.pendingUpdate > f.limit+f.delta {
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate, f.limit) limit := f.limit
rcvd := f.pendingData + f.pendingUpdate
f.mu.Unlock()
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, limit)
} }
f.mu.Unlock()
return nil return nil
} }
...@@ -304,8 +218,8 @@ func (f *inFlow) onData(n uint32) error { ...@@ -304,8 +218,8 @@ func (f *inFlow) onData(n uint32) error {
// to be sent to the peer. // to be sent to the peer.
func (f *inFlow) onRead(n uint32) uint32 { func (f *inFlow) onRead(n uint32) uint32 {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock()
if f.pendingData == 0 { if f.pendingData == 0 {
f.mu.Unlock()
return 0 return 0
} }
f.pendingData -= n f.pendingData -= n
...@@ -320,15 +234,9 @@ func (f *inFlow) onRead(n uint32) uint32 { ...@@ -320,15 +234,9 @@ func (f *inFlow) onRead(n uint32) uint32 {
if f.pendingUpdate >= f.limit/4 { if f.pendingUpdate >= f.limit/4 {
wu := f.pendingUpdate wu := f.pendingUpdate
f.pendingUpdate = 0 f.pendingUpdate = 0
f.mu.Unlock()
return wu return wu
} }
f.mu.Unlock()
return 0 return 0
} }
func (f *inFlow) resetPendingUpdate() uint32 {
f.mu.Lock()
defer f.mu.Unlock()
n := f.pendingUpdate
f.pendingUpdate = 0
return n
}
...@@ -40,20 +40,24 @@ import ( ...@@ -40,20 +40,24 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
// NewServerHandlerTransport returns a ServerTransport handling gRPC // NewServerHandlerTransport returns a ServerTransport handling gRPC
// from inside an http.Handler. It requires that the http Server // from inside an http.Handler. It requires that the http Server
// supports HTTP/2. // supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) { func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) {
if r.ProtoMajor != 2 { if r.ProtoMajor != 2 {
return nil, errors.New("gRPC requires HTTP/2") return nil, errors.New("gRPC requires HTTP/2")
} }
if r.Method != "POST" { if r.Method != "POST" {
return nil, errors.New("invalid gRPC request method") return nil, errors.New("invalid gRPC request method")
} }
if !validContentType(r.Header.Get("Content-Type")) { contentType := r.Header.Get("Content-Type")
// TODO: do we assume contentType is lowercase? we did before
contentSubtype, validContentType := contentSubtype(contentType)
if !validContentType {
return nil, errors.New("invalid gRPC request content-type") return nil, errors.New("invalid gRPC request content-type")
} }
if _, ok := w.(http.Flusher); !ok { if _, ok := w.(http.Flusher); !ok {
...@@ -64,10 +68,13 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr ...@@ -64,10 +68,13 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
} }
st := &serverHandlerTransport{ st := &serverHandlerTransport{
rw: w, rw: w,
req: r, req: r,
closedCh: make(chan struct{}), closedCh: make(chan struct{}),
writes: make(chan func()), writes: make(chan func()),
contentType: contentType,
contentSubtype: contentSubtype,
stats: stats,
} }
if v := r.Header.Get("grpc-timeout"); v != "" { if v := r.Header.Get("grpc-timeout"); v != "" {
...@@ -79,19 +86,19 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr ...@@ -79,19 +86,19 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
st.timeout = to st.timeout = to
} }
var metakv []string metakv := []string{"content-type", contentType}
if r.Host != "" { if r.Host != "" {
metakv = append(metakv, ":authority", r.Host) metakv = append(metakv, ":authority", r.Host)
} }
for k, vv := range r.Header { for k, vv := range r.Header {
k = strings.ToLower(k) k = strings.ToLower(k)
if isReservedHeader(k) && !isWhitelistedPseudoHeader(k) { if isReservedHeader(k) && !isWhitelistedHeader(k) {
continue continue
} }
for _, v := range vv { for _, v := range vv {
v, err := decodeMetadataHeader(k, v) v, err := decodeMetadataHeader(k, v)
if err != nil { if err != nil {
return nil, streamErrorf(codes.InvalidArgument, "malformed binary metadata: %v", err) return nil, streamErrorf(codes.Internal, "malformed binary metadata: %v", err)
} }
metakv = append(metakv, k, v) metakv = append(metakv, k, v)
} }
...@@ -126,6 +133,14 @@ type serverHandlerTransport struct { ...@@ -126,6 +133,14 @@ type serverHandlerTransport struct {
// block concurrent WriteStatus calls // block concurrent WriteStatus calls
// e.g. grpc/(*serverStream).SendMsg/RecvMsg // e.g. grpc/(*serverStream).SendMsg/RecvMsg
writeStatusMu sync.Mutex writeStatusMu sync.Mutex
// we just mirror the request content-type
contentType string
// we store both contentType and contentSubtype so we don't keep recreating them
// TODO make sure this is consistent across handler_server and http2_server
contentSubtype string
stats stats.Handler
} }
func (ht *serverHandlerTransport) Close() error { func (ht *serverHandlerTransport) Close() error {
...@@ -219,6 +234,9 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro ...@@ -219,6 +234,9 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
}) })
if err == nil { // transport has not been closed if err == nil { // transport has not been closed
if ht.stats != nil {
ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
}
ht.Close() ht.Close()
close(ht.writes) close(ht.writes)
} }
...@@ -235,7 +253,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { ...@@ -235,7 +253,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
h := ht.rw.Header() h := ht.rw.Header()
h["Date"] = nil // suppress Date to make tests happy; TODO: restore h["Date"] = nil // suppress Date to make tests happy; TODO: restore
h.Set("Content-Type", "application/grpc") h.Set("Content-Type", ht.contentType)
// Predeclare trailers we'll set later in WriteStatus (after the body). // Predeclare trailers we'll set later in WriteStatus (after the body).
// This is a SHOULD in the HTTP RFC, and the way you add (known) // This is a SHOULD in the HTTP RFC, and the way you add (known)
...@@ -263,7 +281,7 @@ func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts ...@@ -263,7 +281,7 @@ func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts
} }
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return ht.do(func() { err := ht.do(func() {
ht.writeCommonHeaders(s) ht.writeCommonHeaders(s)
h := ht.rw.Header() h := ht.rw.Header()
for k, vv := range md { for k, vv := range md {
...@@ -279,6 +297,13 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { ...@@ -279,6 +297,13 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
ht.rw.WriteHeader(200) ht.rw.WriteHeader(200)
ht.rw.(http.Flusher).Flush() ht.rw.(http.Flusher).Flush()
}) })
if err == nil {
if ht.stats != nil {
ht.stats.HandleRPC(s.Context(), &stats.OutHeader{})
}
}
return err
} }
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) { func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
...@@ -313,13 +338,14 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace ...@@ -313,13 +338,14 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
req := ht.req req := ht.req
s := &Stream{ s := &Stream{
id: 0, // irrelevant id: 0, // irrelevant
requestRead: func(int) {}, requestRead: func(int) {},
cancel: cancel, cancel: cancel,
buf: newRecvBuffer(), buf: newRecvBuffer(),
st: ht, st: ht,
method: req.URL.Path, method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"), recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
} }
pr := &peer.Peer{ pr := &peer.Peer{
Addr: ht.RemoteAddr(), Addr: ht.RemoteAddr(),
...@@ -328,10 +354,18 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace ...@@ -328,10 +354,18 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS} pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
} }
ctx = metadata.NewIncomingContext(ctx, ht.headerMD) ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
ctx = peer.NewContext(ctx, pr) s.ctx = peer.NewContext(ctx, pr)
s.ctx = newContextWithStream(ctx, s) if ht.stats != nil {
s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: ht.RemoteAddr(),
Compression: s.recvCompress,
}
ht.stats.HandleRPC(s.ctx, inHeader)
}
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, recv: s.buf}, reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
windowHandler: func(int) {}, windowHandler: func(int) {},
} }
...@@ -386,6 +420,10 @@ func (ht *serverHandlerTransport) runStream() { ...@@ -386,6 +420,10 @@ func (ht *serverHandlerTransport) runStream() {
} }
} }
func (ht *serverHandlerTransport) IncrMsgSent() {}
func (ht *serverHandlerTransport) IncrMsgRecv() {}
func (ht *serverHandlerTransport) Drain() { func (ht *serverHandlerTransport) Drain() {
panic("Drain() is not implemented") panic("Drain() is not implemented")
} }
......
...@@ -19,8 +19,6 @@ ...@@ -19,8 +19,6 @@
package transport package transport
import ( import (
"bytes"
"fmt"
"io" "io"
"math" "math"
"net" "net"
...@@ -32,8 +30,10 @@ import ( ...@@ -32,8 +30,10 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
...@@ -45,14 +45,17 @@ import ( ...@@ -45,14 +45,17 @@ import (
type http2Client struct { type http2Client struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ctxDone <-chan struct{} // Cache the ctx.Done() chan.
userAgent string userAgent string
md interface{} md interface{}
conn net.Conn // underlying communication channel conn net.Conn // underlying communication channel
loopy *loopyWriter
remoteAddr net.Addr remoteAddr net.Addr
localAddr net.Addr localAddr net.Addr
authInfo credentials.AuthInfo // auth info about the connection authInfo credentials.AuthInfo // auth info about the connection
nextID uint32 // the next stream ID to be used
readerDone chan struct{} // sync point to enable testing.
writerDone chan struct{} // sync point to enable testing.
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport. // that the server sent GoAway on this transport.
goAway chan struct{} goAway chan struct{}
...@@ -60,21 +63,10 @@ type http2Client struct { ...@@ -60,21 +63,10 @@ type http2Client struct {
awakenKeepalive chan struct{} awakenKeepalive chan struct{}
framer *framer framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding
hEnc *hpack.Encoder // HPACK encoder
// controlBuf delivers all the control related tasks (e.g., window // controlBuf delivers all the control related tasks (e.g., window
// updates, reset streams, and various settings) to the controller. // updates, reset streams, and various settings) to the controller.
controlBuf *controlBuffer controlBuf *controlBuffer
fc *inFlow fc *trInFlow
// sendQuotaPool provides flow control to outbound message.
sendQuotaPool *quotaPool
// localSendQuota limits the amount of data that can be scheduled
// for writing before it is actually written out.
localSendQuota *quotaPool
// streamsQuota limits the max number of concurrent streams.
streamsQuota *quotaPool
// The scheme used: https if TLS is on, http otherwise. // The scheme used: https if TLS is on, http otherwise.
scheme string scheme string
...@@ -84,33 +76,50 @@ type http2Client struct { ...@@ -84,33 +76,50 @@ type http2Client struct {
// Boolean to keep track of reading activity on transport. // Boolean to keep track of reading activity on transport.
// 1 is true and 0 is false. // 1 is true and 0 is false.
activity uint32 // Accessed atomically. activity uint32 // Accessed atomically.
kp keepalive.ClientParameters kp keepalive.ClientParameters
keepaliveEnabled bool
statsHandler stats.Handler statsHandler stats.Handler
initialWindowSize int32 initialWindowSize int32
bdpEst *bdpEstimator bdpEst *bdpEstimator
outQuotaVersion uint32
// onSuccess is a callback that client transport calls upon // onSuccess is a callback that client transport calls upon
// receiving server preface to signal that a succefull HTTP2 // receiving server preface to signal that a succefull HTTP2
// connection was established. // connection was established.
onSuccess func() onSuccess func()
mu sync.Mutex // guard the following variables maxConcurrentStreams uint32
state transportState // the state of underlying connection streamQuota int64
streamsQuotaAvailable chan struct{}
waitingStreams uint32
nextID uint32
mu sync.Mutex // guard the following variables
state transportState
activeStreams map[uint32]*Stream activeStreams map[uint32]*Stream
// The max number of concurrent streams
maxStreams int
// the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32 prevGoAwayID uint32
// goAwayReason records the http2.ErrCode and debug data received with the // goAwayReason records the http2.ErrCode and debug data received with the
// GoAway frame. // GoAway frame.
goAwayReason GoAwayReason goAwayReason GoAwayReason
// Fields below are for channelz metric collection.
channelzID int64 // channelz unique identification number
czmu sync.RWMutex
kpCount int64
// The number of streams that have started, including already finished ones.
streamsStarted int64
// The number of streams that have ended successfully by receiving EoS bit set
// frame from server.
streamsSucceeded int64
streamsFailed int64
lastStreamCreated time.Time
msgSent int64
msgRecv int64
lastMsgSent time.Time
lastMsgRecv time.Time
} }
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) {
...@@ -121,18 +130,6 @@ func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error ...@@ -121,18 +130,6 @@ func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error
} }
func isTemporary(err error) bool { func isTemporary(err error) bool {
switch err {
case io.EOF:
// Connection closures may be resolved upon retry, and are thus
// treated as temporary.
return true
case context.DeadlineExceeded:
// In Go 1.7, context.DeadlineExceeded implements Timeout(), and this
// special case is not needed. Until then, we need to keep this
// clause.
return true
}
switch err := err.(type) { switch err := err.(type) {
case interface { case interface {
Temporary() bool Temporary() bool
...@@ -145,7 +142,7 @@ func isTemporary(err error) bool { ...@@ -145,7 +142,7 @@ func isTemporary(err error) bool {
// temporary. // temporary.
return err.Timeout() return err.Timeout()
} }
return false return true
} }
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
...@@ -181,10 +178,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne ...@@ -181,10 +178,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
scheme = "https" scheme = "https"
conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Authority, conn) conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Authority, conn)
if err != nil { if err != nil {
// Credentials handshake errors are typically considered permanent return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
// to avoid retrying on e.g. bad certificates.
temp := isTemporary(err)
return nil, connectionErrorf(temp, err, "transport: authentication handshake failed: %v", err)
} }
isSecure = true isSecure = true
} }
...@@ -202,7 +196,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne ...@@ -202,7 +196,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
icwz = opts.InitialConnWindowSize icwz = opts.InitialConnWindowSize
dynamicWindow = false dynamicWindow = false
} }
var buf bytes.Buffer
writeBufSize := defaultWriteBufSize writeBufSize := defaultWriteBufSize
if opts.WriteBufferSize > 0 { if opts.WriteBufferSize > 0 {
writeBufSize = opts.WriteBufferSize writeBufSize = opts.WriteBufferSize
...@@ -212,38 +205,35 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne ...@@ -212,38 +205,35 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
readBufSize = opts.ReadBufferSize readBufSize = opts.ReadBufferSize
} }
t := &http2Client{ t := &http2Client{
ctx: ctx, ctx: ctx,
cancel: cancel, ctxDone: ctx.Done(), // Cache Done chan.
userAgent: opts.UserAgent, cancel: cancel,
md: addr.Metadata, userAgent: opts.UserAgent,
conn: conn, md: addr.Metadata,
remoteAddr: conn.RemoteAddr(), conn: conn,
localAddr: conn.LocalAddr(), remoteAddr: conn.RemoteAddr(),
authInfo: authInfo, localAddr: conn.LocalAddr(),
// The client initiated stream id is odd starting from 1. authInfo: authInfo,
nextID: 1, readerDone: make(chan struct{}),
goAway: make(chan struct{}), writerDone: make(chan struct{}),
awakenKeepalive: make(chan struct{}, 1), goAway: make(chan struct{}),
hBuf: &buf, awakenKeepalive: make(chan struct{}, 1),
hEnc: hpack.NewEncoder(&buf), framer: newFramer(conn, writeBufSize, readBufSize),
framer: newFramer(conn, writeBufSize, readBufSize), fc: &trInFlow{limit: uint32(icwz)},
controlBuf: newControlBuffer(), scheme: scheme,
fc: &inFlow{limit: uint32(icwz)}, activeStreams: make(map[uint32]*Stream),
sendQuotaPool: newQuotaPool(defaultWindowSize), isSecure: isSecure,
localSendQuota: newQuotaPool(defaultLocalSendQuota), creds: opts.PerRPCCredentials,
scheme: scheme, kp: kp,
state: reachable, statsHandler: opts.StatsHandler,
activeStreams: make(map[uint32]*Stream), initialWindowSize: initialWindowSize,
isSecure: isSecure, onSuccess: onSuccess,
creds: opts.PerRPCCredentials, nextID: 1,
maxStreams: defaultMaxStreamsClient, maxConcurrentStreams: defaultMaxStreamsClient,
streamsQuota: newQuotaPool(defaultMaxStreamsClient), streamQuota: defaultMaxStreamsClient,
streamSendQuota: defaultWindowSize, streamsQuotaAvailable: make(chan struct{}, 1),
kp: kp, }
statsHandler: opts.StatsHandler, t.controlBuf = newControlBuffer(t.ctxDone)
initialWindowSize: initialWindowSize,
onSuccess: onSuccess,
}
if opts.InitialWindowSize >= defaultWindowSize { if opts.InitialWindowSize >= defaultWindowSize {
t.initialWindowSize = opts.InitialWindowSize t.initialWindowSize = opts.InitialWindowSize
dynamicWindow = false dynamicWindow = false
...@@ -267,6 +257,13 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne ...@@ -267,6 +257,13 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
} }
t.statsHandler.HandleConn(t.ctx, connBegin) t.statsHandler.HandleConn(t.ctx, connBegin)
} }
if channelz.IsOn() {
t.channelzID = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, "")
}
if t.kp.Time != infinity {
t.keepaliveEnabled = true
go t.keepalive()
}
// Start the reader goroutine for incoming message. Each transport has // Start the reader goroutine for incoming message. Each transport has
// a dedicated goroutine which reads HTTP2 frame from network. Then it // a dedicated goroutine which reads HTTP2 frame from network. Then it
// dispatches the frame to the corresponding stream entity. // dispatches the frame to the corresponding stream entity.
...@@ -302,29 +299,32 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne ...@@ -302,29 +299,32 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
} }
t.framer.writer.Flush() t.framer.writer.Flush()
go func() { go func() {
loopyWriter(t.ctx, t.controlBuf, t.itemHandler) t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst)
t.conn.Close() err := t.loopy.run()
if err != nil {
errorf("transport: loopyWriter.run returning. Err: %v", err)
}
// If it's a connection error, let reader goroutine handle it
// since there might be data in the buffers.
if _, ok := err.(net.Error); !ok {
t.conn.Close()
}
close(t.writerDone)
}() }()
if t.kp.Time != infinity {
go t.keepalive()
}
return t, nil return t, nil
} }
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id. // TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{ s := &Stream{
id: t.nextID, done: make(chan struct{}),
done: make(chan struct{}), method: callHdr.Method,
goAway: make(chan struct{}), sendCompress: callHdr.SendCompress,
method: callHdr.Method, buf: newRecvBuffer(),
sendCompress: callHdr.SendCompress, headerChan: make(chan struct{}),
buf: newRecvBuffer(), contentSubtype: callHdr.ContentSubtype,
fc: &inFlow{limit: uint32(t.initialWindowSize)}, }
sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), s.wq = newWriteQuota(defaultWriteQuota, s.done)
headerChan: make(chan struct{}),
}
t.nextID += 2
s.requestRead = func(n int) { s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n)) t.adjustWindow(s, uint32(n))
} }
...@@ -334,26 +334,18 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { ...@@ -334,26 +334,18 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
s.ctx = ctx s.ctx = ctx
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ reader: &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
goAway: s.goAway, ctxDone: s.ctx.Done(),
recv: s.buf, recv: s.buf,
}, },
windowHandler: func(n int) { windowHandler: func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
}, },
} }
s.waiters = waiters{
ctx: s.ctx,
tctx: t.ctx,
done: s.done,
goAway: s.goAway,
}
return s return s
} }
// NewStream creates a stream and registers it into the transport as "active" func (t *http2Client) getPeer() *peer.Peer {
// streams.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
pr := &peer.Peer{ pr := &peer.Peer{
Addr: t.remoteAddr, Addr: t.remoteAddr,
} }
...@@ -361,67 +353,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea ...@@ -361,67 +353,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if t.authInfo != nil { if t.authInfo != nil {
pr.AuthInfo = t.authInfo pr.AuthInfo = t.authInfo
} }
ctx = peer.NewContext(ctx, pr) return pr
var ( }
authData = make(map[string]string)
audience string func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) {
) aud := t.createAudience(callHdr)
// Create an audience string only if needed. authData, err := t.getTrAuthData(ctx, aud)
if len(t.creds) > 0 || callHdr.Creds != nil { if err != nil {
// Construct URI required to get auth request metadata. return nil, err
// Omit port if it is the default one.
host := strings.TrimSuffix(callHdr.Host, ":443")
pos := strings.LastIndex(callHdr.Method, "/")
if pos == -1 {
pos = len(callHdr.Method)
}
audience = "https://" + host + callHdr.Method[:pos]
}
for _, c := range t.creds {
data, err := c.GetRequestMetadata(ctx, audience)
if err != nil {
return nil, streamErrorf(codes.Internal, "transport: %v", err)
}
for k, v := range data {
// Capital header names are illegal in HTTP/2.
k = strings.ToLower(k)
authData[k] = v
}
}
callAuthData := map[string]string{}
// Check if credentials.PerRPCCredentials were provided via call options.
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
if callCreds := callHdr.Creds; callCreds != nil {
if !t.isSecure && callCreds.RequireTransportSecurity() {
return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
}
data, err := callCreds.GetRequestMetadata(ctx, audience)
if err != nil {
return nil, streamErrorf(codes.Internal, "transport: %v", err)
}
for k, v := range data {
// Capital header names are illegal in HTTP/2
k = strings.ToLower(k)
callAuthData[k] = v
}
}
t.mu.Lock()
if t.activeStreams == nil {
t.mu.Unlock()
return nil, ErrConnClosing
}
if t.state == draining {
t.mu.Unlock()
return nil, errStreamDrain
}
if t.state != reachable {
t.mu.Unlock()
return nil, ErrConnClosing
} }
t.mu.Unlock() callAuthData, err := t.getCallAuthData(ctx, aud, callHdr)
// Get a quota of 1 from streamsQuota. if err != nil {
if _, _, err := t.streamsQuota.get(1, waiters{ctx: ctx, tctx: t.ctx}); err != nil {
return nil, err return nil, err
} }
// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
...@@ -434,7 +376,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea ...@@ -434,7 +376,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme}) headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method}) headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method})
headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(callHdr.ContentSubtype)})
headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"}) headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"})
...@@ -459,7 +401,22 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea ...@@ -459,7 +401,22 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if b := stats.OutgoingTrace(ctx); b != nil { if b := stats.OutgoingTrace(ctx); b != nil {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)})
} }
if md, ok := metadata.FromOutgoingContext(ctx); ok {
if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok {
var k string
for _, vv := range added {
for i, v := range vv {
if i%2 == 0 {
k = v
continue
}
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) {
continue
}
headerFields = append(headerFields, hpack.HeaderField{Name: strings.ToLower(k), Value: encodeMetadataHeader(k, v)})
}
}
for k, vv := range md { for k, vv := range md {
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) { if isReservedHeader(k) {
...@@ -480,38 +437,178 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea ...@@ -480,38 +437,178 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
} }
} }
t.mu.Lock() return headerFields, nil
if t.state == draining { }
t.mu.Unlock()
t.streamsQuota.add(1) func (t *http2Client) createAudience(callHdr *CallHdr) string {
return nil, errStreamDrain // Create an audience string only if needed.
if len(t.creds) == 0 && callHdr.Creds == nil {
return ""
} }
if t.state != reachable { // Construct URI required to get auth request metadata.
t.mu.Unlock() // Omit port if it is the default one.
return nil, ErrConnClosing host := strings.TrimSuffix(callHdr.Host, ":443")
pos := strings.LastIndex(callHdr.Method, "/")
if pos == -1 {
pos = len(callHdr.Method)
}
return "https://" + host + callHdr.Method[:pos]
}
func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[string]string, error) {
authData := map[string]string{}
for _, c := range t.creds {
data, err := c.GetRequestMetadata(ctx, audience)
if err != nil {
if _, ok := status.FromError(err); ok {
return nil, err
}
return nil, streamErrorf(codes.Unauthenticated, "transport: %v", err)
}
for k, v := range data {
// Capital header names are illegal in HTTP/2.
k = strings.ToLower(k)
authData[k] = v
}
}
return authData, nil
}
func (t *http2Client) getCallAuthData(ctx context.Context, audience string, callHdr *CallHdr) (map[string]string, error) {
callAuthData := map[string]string{}
// Check if credentials.PerRPCCredentials were provided via call options.
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
if callCreds := callHdr.Creds; callCreds != nil {
if !t.isSecure && callCreds.RequireTransportSecurity() {
return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
}
data, err := callCreds.GetRequestMetadata(ctx, audience)
if err != nil {
return nil, streamErrorf(codes.Internal, "transport: %v", err)
}
for k, v := range data {
// Capital header names are illegal in HTTP/2
k = strings.ToLower(k)
callAuthData[k] = v
}
}
return callAuthData, nil
}
// NewStream creates a stream and registers it into the transport as "active"
// streams.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
ctx = peer.NewContext(ctx, t.getPeer())
headerFields, err := t.createHeaderFields(ctx, callHdr)
if err != nil {
return nil, err
} }
s := t.newStream(ctx, callHdr) s := t.newStream(ctx, callHdr)
t.activeStreams[s.id] = s cleanup := func(err error) {
// If the number of active streams change from 0 to 1, then check if keepalive if s.swapState(streamDone) == streamDone {
// has gone dormant. If so, wake it up. // If it was already done, return.
if len(t.activeStreams) == 1 { return
select {
case t.awakenKeepalive <- struct{}{}:
t.controlBuf.put(&ping{data: [8]byte{}})
// Fill the awakenKeepalive channel again as this channel must be
// kept non-writable except at the point that the keepalive()
// goroutine is waiting either to be awaken or shutdown.
t.awakenKeepalive <- struct{}{}
default:
} }
// The stream was unprocessed by the server.
atomic.StoreUint32(&s.unprocessed, 1)
s.write(recvMsg{err: err})
close(s.done)
// If headerChan isn't closed, then close it.
if atomic.SwapUint32(&s.headerDone, 1) == 0 {
close(s.headerChan)
}
} }
t.controlBuf.put(&headerFrame{ hdr := &headerFrame{
streamID: s.id,
hf: headerFields, hf: headerFields,
endStream: false, endStream: false,
}) initStream: func(id uint32) (bool, error) {
t.mu.Unlock() t.mu.Lock()
if state := t.state; state != reachable {
t.mu.Unlock()
// Do a quick cleanup.
err := error(errStreamDrain)
if state == closing {
err = ErrConnClosing
}
cleanup(err)
return false, err
}
t.activeStreams[id] = s
if channelz.IsOn() {
t.czmu.Lock()
t.streamsStarted++
t.lastStreamCreated = time.Now()
t.czmu.Unlock()
}
var sendPing bool
// If the number of active streams change from 0 to 1, then check if keepalive
// has gone dormant. If so, wake it up.
if len(t.activeStreams) == 1 && t.keepaliveEnabled {
select {
case t.awakenKeepalive <- struct{}{}:
sendPing = true
// Fill the awakenKeepalive channel again as this channel must be
// kept non-writable except at the point that the keepalive()
// goroutine is waiting either to be awaken or shutdown.
t.awakenKeepalive <- struct{}{}
default:
}
}
t.mu.Unlock()
return sendPing, nil
},
onOrphaned: cleanup,
wq: s.wq,
}
firstTry := true
var ch chan struct{}
checkForStreamQuota := func(it interface{}) bool {
if t.streamQuota <= 0 { // Can go negative if server decreases it.
if firstTry {
t.waitingStreams++
}
ch = t.streamsQuotaAvailable
return false
}
if !firstTry {
t.waitingStreams--
}
t.streamQuota--
h := it.(*headerFrame)
h.streamID = t.nextID
t.nextID += 2
s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
if t.streamQuota > 0 && t.waitingStreams > 0 {
select {
case t.streamsQuotaAvailable <- struct{}{}:
default:
}
}
return true
}
for {
success, err := t.controlBuf.executeAndPut(checkForStreamQuota, hdr)
if err != nil {
return nil, err
}
if success {
break
}
firstTry = false
select {
case <-ch:
case <-s.ctx.Done():
return nil, ContextErr(s.ctx.Err())
case <-t.goAway:
return nil, errStreamDrain
case <-t.ctx.Done():
return nil, ErrConnClosing
}
}
if t.statsHandler != nil { if t.statsHandler != nil {
outHeader := &stats.OutHeader{ outHeader := &stats.OutHeader{
Client: true, Client: true,
...@@ -528,58 +625,72 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea ...@@ -528,58 +625,72 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
// CloseStream clears the footprint of a stream when the stream is not needed any more. // CloseStream clears the footprint of a stream when the stream is not needed any more.
// This must not be executed in reader's goroutine. // This must not be executed in reader's goroutine.
func (t *http2Client) CloseStream(s *Stream, err error) { func (t *http2Client) CloseStream(s *Stream, err error) {
t.mu.Lock() var (
if t.activeStreams == nil { rst bool
t.mu.Unlock() rstCode http2.ErrCode
return )
}
if err != nil { if err != nil {
// notify in-flight streams, before the deletion rst = true
s.write(recvMsg{err: err}) rstCode = http2.ErrCodeCancel
} }
delete(t.activeStreams, s.id) t.closeStream(s, err, rst, rstCode, nil, nil, false)
if t.state == draining && len(t.activeStreams) == 0 { }
// The transport is draining and s is the last live stream on t.
t.mu.Unlock() func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) {
t.Close() // Set stream status to done.
if s.swapState(streamDone) == streamDone {
// If it was already done, return.
return return
} }
t.mu.Unlock() // status and trailers can be updated here without any synchronization because the stream goroutine will
// rstStream is true in case the stream is being closed at the client-side // only read it after it sees an io.EOF error from read or write and we'll write those errors
// and the server needs to be intimated about it by sending a RST_STREAM // only after updating this.
// frame. s.status = st
// To make sure this frame is written to the wire before the headers of the if len(mdata) > 0 {
// next stream waiting for streamsQuota, we add to streamsQuota pool only s.trailer = mdata
// after having acquired the writableChan to send RST_STREAM out (look at }
// the controller() routine). if err != nil {
var rstStream bool // This will unblock reads eventually.
var rstError http2.ErrCode s.write(recvMsg{err: err})
defer func() {
// In case, the client doesn't have to send RST_STREAM to server
// we can safely add back to streamsQuota pool now.
if !rstStream {
t.streamsQuota.add(1)
return
}
t.controlBuf.put(&resetStream{s.id, rstError})
}()
s.mu.Lock()
rstStream = s.rstStream
rstError = s.rstError
if s.state == streamDone {
s.mu.Unlock()
return
} }
if !s.headerDone { // This will unblock write.
close(s.done)
// If headerChan isn't closed, then close it.
if atomic.SwapUint32(&s.headerDone, 1) == 0 {
close(s.headerChan) close(s.headerChan)
s.headerDone = true
} }
s.state = streamDone cleanup := &cleanupStream{
s.mu.Unlock() streamID: s.id,
if _, ok := err.(StreamError); ok { onWrite: func() {
rstStream = true t.mu.Lock()
rstError = http2.ErrCodeCancel if t.activeStreams != nil {
delete(t.activeStreams, s.id)
}
t.mu.Unlock()
if channelz.IsOn() {
t.czmu.Lock()
if eosReceived {
t.streamsSucceeded++
} else {
t.streamsFailed++
}
t.czmu.Unlock()
}
},
rst: rst,
rstCode: rstCode,
}
addBackStreamQuota := func(interface{}) bool {
t.streamQuota++
if t.streamQuota > 0 && t.waitingStreams > 0 {
select {
case t.streamsQuotaAvailable <- struct{}{}:
default:
}
}
return true
} }
t.controlBuf.executeAndPut(addBackStreamQuota, cleanup)
} }
// Close kicks off the shutdown process of the transport. This should be called // Close kicks off the shutdown process of the transport. This should be called
...@@ -587,27 +698,24 @@ func (t *http2Client) CloseStream(s *Stream, err error) { ...@@ -587,27 +698,24 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more. // accessed any more.
func (t *http2Client) Close() error { func (t *http2Client) Close() error {
t.mu.Lock() t.mu.Lock()
// Make sure we only Close once.
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return nil return nil
} }
t.state = closing t.state = closing
t.mu.Unlock()
t.cancel()
err := t.conn.Close()
t.mu.Lock()
streams := t.activeStreams streams := t.activeStreams
t.activeStreams = nil t.activeStreams = nil
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.finish()
t.cancel()
err := t.conn.Close()
if channelz.IsOn() {
channelz.RemoveEntry(t.channelzID)
}
// Notify all active streams. // Notify all active streams.
for _, s := range streams { for _, s := range streams {
s.mu.Lock() t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, nil, nil, false)
if !s.headerDone {
close(s.headerChan)
s.headerDone = true
}
s.mu.Unlock()
s.write(recvMsg{err: ErrConnClosing})
} }
if t.statsHandler != nil { if t.statsHandler != nil {
connEnd := &stats.ConnEnd{ connEnd := &stats.ConnEnd{
...@@ -625,8 +733,8 @@ func (t *http2Client) Close() error { ...@@ -625,8 +733,8 @@ func (t *http2Client) Close() error {
// closing. // closing.
func (t *http2Client) GracefulClose() error { func (t *http2Client) GracefulClose() error {
t.mu.Lock() t.mu.Lock()
switch t.state { // Make sure we move to draining only from active.
case closing, draining: if t.state == draining || t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return nil return nil
} }
...@@ -636,112 +744,41 @@ func (t *http2Client) GracefulClose() error { ...@@ -636,112 +744,41 @@ func (t *http2Client) GracefulClose() error {
if active == 0 { if active == 0 {
return t.Close() return t.Close()
} }
t.controlBuf.put(&incomingGoAway{})
return nil return nil
} }
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller // Write formats the data into HTTP2 data frame(s) and sends it out. The caller
// should proceed only if Write returns nil. // should proceed only if Write returns nil.
func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
select { if opts.Last {
case <-s.ctx.Done(): // If it's the last message, update stream state.
return ContextErr(s.ctx.Err()) if !s.compareAndSwapState(streamActive, streamWriteDone) {
case <-t.ctx.Done(): return errStreamDone
return ErrConnClosing
default:
}
if hdr == nil && data == nil && opts.Last {
// stream.CloseSend uses this to send an empty frame with endStream=True
t.controlBuf.put(&dataFrame{streamID: s.id, endStream: true, f: func() {}})
return nil
}
// Add data to header frame so that we can equally distribute data across frames.
emptyLen := http2MaxFrameLen - len(hdr)
if emptyLen > len(data) {
emptyLen = len(data)
}
hdr = append(hdr, data[:emptyLen]...)
data = data[emptyLen:]
var (
streamQuota int
streamQuotaVer uint32
err error
)
for idx, r := range [][]byte{hdr, data} {
for len(r) > 0 {
size := http2MaxFrameLen
if size > len(r) {
size = len(r)
}
if streamQuota == 0 { // Used up all the locally cached stream quota.
// Get all the stream quota there is.
streamQuota, streamQuotaVer, err = s.sendQuotaPool.get(math.MaxInt32, s.waiters)
if err != nil {
return err
}
}
if size > streamQuota {
size = streamQuota
}
// Get size worth quota from transport.
tq, _, err := t.sendQuotaPool.get(size, s.waiters)
if err != nil {
return err
}
if tq < size {
size = tq
}
ltq, _, err := t.localSendQuota.get(size, s.waiters)
if err != nil {
return err
}
// even if ltq is smaller than size we don't adjust size since
// ltq is only a soft limit.
streamQuota -= size
p := r[:size]
var endStream bool
// See if this is the last frame to be written.
if opts.Last {
if len(r)-size == 0 { // No more data in r after this iteration.
if idx == 0 { // We're writing data header.
if len(data) == 0 { // There's no data to follow.
endStream = true
}
} else { // We're writing data.
endStream = true
}
}
}
success := func() {
ltq := ltq
t.controlBuf.put(&dataFrame{streamID: s.id, endStream: endStream, d: p, f: func() { t.localSendQuota.add(ltq) }})
r = r[size:]
}
failure := func() { // The stream quota version must have changed.
// Our streamQuota cache is invalidated now, so give it back.
s.sendQuotaPool.lockedAdd(streamQuota + size)
}
if !s.sendQuotaPool.compareAndExecute(streamQuotaVer, success, failure) {
// Couldn't send this chunk out.
t.sendQuotaPool.add(size)
t.localSendQuota.add(ltq)
streamQuota = 0
}
} }
} else if s.getState() != streamActive {
return errStreamDone
} }
if streamQuota > 0 { // Add the left over quota back to stream. df := &dataFrame{
s.sendQuotaPool.add(streamQuota) streamID: s.id,
} endStream: opts.Last,
if !opts.Last { }
return nil if hdr != nil || data != nil { // If it's not an empty data frame.
} // Add some data to grpc message header so that we can equally
s.mu.Lock() // distribute bytes across frames.
if s.state != streamDone { emptyLen := http2MaxFrameLen - len(hdr)
s.state = streamWriteDone if emptyLen > len(data) {
emptyLen = len(data)
}
hdr = append(hdr, data[:emptyLen]...)
data = data[emptyLen:]
df.h, df.d = hdr, data
// TODO(mmukhi): The above logic in this if can be moved to loopyWriter's data handler.
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
return err
}
} }
s.mu.Unlock() return t.controlBuf.put(df)
return nil
} }
func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) { func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
...@@ -755,34 +792,17 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) { ...@@ -755,34 +792,17 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
// of stream if the application is requesting data larger in size than // of stream if the application is requesting data larger in size than
// the window. // the window.
func (t *http2Client) adjustWindow(s *Stream, n uint32) { func (t *http2Client) adjustWindow(s *Stream, n uint32) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == streamDone {
return
}
if w := s.fc.maybeAdjust(n); w > 0 { if w := s.fc.maybeAdjust(n); w > 0 {
// Piggyback connection's window update along. t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
if cw := t.fc.resetPendingUpdate(); cw > 0 {
t.controlBuf.put(&windowUpdate{0, cw})
}
t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
// updateWindow adjusts the inbound quota for the stream and the transport. // updateWindow adjusts the inbound quota for the stream.
// Window updates will deliver to the controller for sending when // Window updates will be sent out when the cumulative quota
// the cumulative quota exceeds the corresponding threshold. // exceeds the corresponding threshold.
func (t *http2Client) updateWindow(s *Stream, n uint32) { func (t *http2Client) updateWindow(s *Stream, n uint32) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == streamDone {
return
}
if w := s.fc.onRead(n); w > 0 { if w := s.fc.onRead(n); w > 0 {
if cw := t.fc.resetPendingUpdate(); cw > 0 { t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
t.controlBuf.put(&windowUpdate{0, cw})
}
t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
...@@ -794,14 +814,17 @@ func (t *http2Client) updateFlowControl(n uint32) { ...@@ -794,14 +814,17 @@ func (t *http2Client) updateFlowControl(n uint32) {
for _, s := range t.activeStreams { for _, s := range t.activeStreams {
s.fc.newLimit(n) s.fc.newLimit(n)
} }
t.initialWindowSize = int32(n)
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n)}) updateIWS := func(interface{}) bool {
t.controlBuf.put(&settings{ t.initialWindowSize = int32(n)
return true
}
t.controlBuf.executeAndPut(updateIWS, &outgoingWindowUpdate{streamID: 0, increment: t.fc.newLimit(n)})
t.controlBuf.put(&outgoingSettings{
ss: []http2.Setting{ ss: []http2.Setting{
{ {
ID: http2.SettingInitialWindowSize, ID: http2.SettingInitialWindowSize,
Val: uint32(n), Val: n,
}, },
}, },
}) })
...@@ -811,7 +834,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { ...@@ -811,7 +834,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
size := f.Header().Length size := f.Header().Length
var sendBDPPing bool var sendBDPPing bool
if t.bdpEst != nil { if t.bdpEst != nil {
sendBDPPing = t.bdpEst.add(uint32(size)) sendBDPPing = t.bdpEst.add(size)
} }
// Decouple connection's flow control from application's read. // Decouple connection's flow control from application's read.
// An update on connection's flow control should not depend on // An update on connection's flow control should not depend on
...@@ -822,21 +845,24 @@ func (t *http2Client) handleData(f *http2.DataFrame) { ...@@ -822,21 +845,24 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// active(fast) streams from starving in presence of slow or // active(fast) streams from starving in presence of slow or
// inactive streams. // inactive streams.
// //
// Furthermore, if a bdpPing is being sent out we can piggyback if w := t.fc.onData(size); w > 0 {
// connection's window update for the bytes we just received. t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
}
if sendBDPPing { if sendBDPPing {
if size != 0 { // Could've been an empty data frame. // Avoid excessive ping detection (e.g. in an L7 proxy)
t.controlBuf.put(&windowUpdate{0, uint32(size)}) // by sending a window update prior to the BDP ping.
if w := t.fc.reset(); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
} }
t.controlBuf.put(bdpPing) t.controlBuf.put(bdpPing)
} else {
if err := t.fc.onData(uint32(size)); err != nil {
t.Close()
return
}
if w := t.fc.onRead(uint32(size)); w > 0 {
t.controlBuf.put(&windowUpdate{0, w})
}
} }
// Select the right stream to dispatch. // Select the right stream to dispatch.
s, ok := t.getStream(f) s, ok := t.getStream(f)
...@@ -844,25 +870,15 @@ func (t *http2Client) handleData(f *http2.DataFrame) { ...@@ -844,25 +870,15 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
return return
} }
if size > 0 { if size > 0 {
s.mu.Lock() if err := s.fc.onData(size); err != nil {
if s.state == streamDone { t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false)
s.mu.Unlock()
return
}
if err := s.fc.onData(uint32(size)); err != nil {
s.rstStream = true
s.rstError = http2.ErrCodeFlowControl
s.finish(status.New(codes.Internal, err.Error()))
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
return return
} }
if f.Header().Flags.Has(http2.FlagDataPadded) { if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 {
t.controlBuf.put(&windowUpdate{s.id, w}) t.controlBuf.put(&outgoingWindowUpdate{s.id, w})
} }
} }
s.mu.Unlock()
// TODO(bradfitz, zhaoq): A copy is required here because there is no // TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
...@@ -875,14 +891,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { ...@@ -875,14 +891,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// The server has closed the stream without sending trailers. Record that // The server has closed the stream without sending trailers. Record that
// the read direction is closed, and set the status appropriately. // the read direction is closed, and set the status appropriately.
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
s.mu.Lock() t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
if s.state == streamDone {
s.mu.Unlock()
return
}
s.finish(status.New(codes.Internal, "server closed the stream without sending trailers"))
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
} }
} }
...@@ -891,73 +900,55 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { ...@@ -891,73 +900,55 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
if !ok { if !ok {
return return
} }
s.mu.Lock() if f.ErrCode == http2.ErrCodeRefusedStream {
if s.state == streamDone {
s.mu.Unlock()
return
}
if !s.headerDone {
close(s.headerChan)
s.headerDone = true
}
code := http2.ErrCode(f.ErrCode)
if code == http2.ErrCodeRefusedStream {
// The stream was unprocessed by the server. // The stream was unprocessed by the server.
s.unprocessed = true atomic.StoreUint32(&s.unprocessed, 1)
} }
statusCode, ok := http2ErrConvTab[code] statusCode, ok := http2ErrConvTab[f.ErrCode]
if !ok { if !ok {
warningf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error %v", f.ErrCode) warningf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error %v", f.ErrCode)
statusCode = codes.Unknown statusCode = codes.Unknown
} }
s.finish(status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode)) t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
} }
func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) { func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) {
if f.IsAck() { if f.IsAck() {
return return
} }
var rs []http2.Setting var maxStreams *uint32
var ps []http2.Setting var ss []http2.Setting
isMaxConcurrentStreamsMissing := true
f.ForeachSetting(func(s http2.Setting) error { f.ForeachSetting(func(s http2.Setting) error {
if s.ID == http2.SettingMaxConcurrentStreams { if s.ID == http2.SettingMaxConcurrentStreams {
isMaxConcurrentStreamsMissing = false maxStreams = new(uint32)
} *maxStreams = s.Val
if t.isRestrictive(s) { return nil
rs = append(rs, s)
} else {
ps = append(ps, s)
} }
ss = append(ss, s)
return nil return nil
}) })
if isFirst && isMaxConcurrentStreamsMissing { if isFirst && maxStreams == nil {
// This means server is imposing no limits on maxStreams = new(uint32)
// maximum number of concurrent streams initiated by client. *maxStreams = math.MaxUint32
// So we must remove our self-imposed limit.
ps = append(ps, http2.Setting{
ID: http2.SettingMaxConcurrentStreams,
Val: math.MaxUint32,
})
} }
t.applySettings(rs) sf := &incomingSettings{
t.controlBuf.put(&settingsAck{}) ss: ss,
t.applySettings(ps) }
} if maxStreams == nil {
t.controlBuf.put(sf)
func (t *http2Client) isRestrictive(s http2.Setting) bool { return
switch s.ID { }
case http2.SettingMaxConcurrentStreams: updateStreamQuota := func(interface{}) bool {
return int(s.Val) < t.maxStreams delta := int64(*maxStreams) - int64(t.maxConcurrentStreams)
case http2.SettingInitialWindowSize: t.maxConcurrentStreams = *maxStreams
// Note: we don't acquire a lock here to read streamSendQuota t.streamQuota += delta
// because the same goroutine updates it later. if delta > 0 && t.waitingStreams > 0 {
return s.Val < t.streamSendQuota close(t.streamsQuotaAvailable) // wake all of them up.
} t.streamsQuotaAvailable = make(chan struct{}, 1)
return false }
return true
}
t.controlBuf.executeAndPut(updateStreamQuota, sf)
} }
func (t *http2Client) handlePing(f *http2.PingFrame) { func (t *http2Client) handlePing(f *http2.PingFrame) {
...@@ -975,7 +966,7 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { ...@@ -975,7 +966,7 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
t.mu.Lock() t.mu.Lock()
if t.state != reachable && t.state != draining { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return return
} }
...@@ -1010,6 +1001,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { ...@@ -1010,6 +1001,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
t.setGoAwayReason(f) t.setGoAwayReason(f)
close(t.goAway) close(t.goAway)
t.state = draining t.state = draining
t.controlBuf.put(&incomingGoAway{})
} }
// All streams with IDs greater than the GoAwayId // All streams with IDs greater than the GoAwayId
// and smaller than the previous GoAway ID should be killed. // and smaller than the previous GoAway ID should be killed.
...@@ -1020,11 +1012,8 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { ...@@ -1020,11 +1012,8 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
for streamID, stream := range t.activeStreams { for streamID, stream := range t.activeStreams {
if streamID > id && streamID <= upperLimit { if streamID > id && streamID <= upperLimit {
// The stream was unprocessed by the server. // The stream was unprocessed by the server.
stream.mu.Lock() atomic.StoreUint32(&stream.unprocessed, 1)
stream.unprocessed = true t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false)
stream.finish(statusGoAway)
stream.mu.Unlock()
close(stream.goAway)
} }
} }
t.prevGoAwayID = id t.prevGoAwayID = id
...@@ -1056,15 +1045,10 @@ func (t *http2Client) GetGoAwayReason() GoAwayReason { ...@@ -1056,15 +1045,10 @@ func (t *http2Client) GetGoAwayReason() GoAwayReason {
} }
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
id := f.Header().StreamID t.controlBuf.put(&incomingWindowUpdate{
incr := f.Increment streamID: f.Header().StreamID,
if id == 0 { increment: f.Increment,
t.sendQuotaPool.add(int(incr)) })
return
}
if s, ok := t.getStream(f); ok {
s.sendQuotaPool.add(int(incr))
}
} }
// operateHeaders takes action on the decoded headers. // operateHeaders takes action on the decoded headers.
...@@ -1073,18 +1057,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { ...@@ -1073,18 +1057,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if !ok { if !ok {
return return
} }
s.mu.Lock() atomic.StoreUint32(&s.bytesReceived, 1)
s.bytesReceived = true
s.mu.Unlock()
var state decodeState var state decodeState
if err := state.decodeResponseHeader(frame); err != nil { if err := state.decodeResponseHeader(frame); err != nil {
s.mu.Lock() t.closeStream(s, err, true, http2.ErrCodeProtocol, nil, nil, false)
if !s.headerDone {
close(s.headerChan)
s.headerDone = true
}
s.mu.Unlock()
s.write(recvMsg{err: err})
// Something wrong. Stops reading even when there is remaining. // Something wrong. Stops reading even when there is remaining.
return return
} }
...@@ -1108,39 +1084,25 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { ...@@ -1108,39 +1084,25 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
} }
} }
}() }()
// If headers haven't been received yet.
s.mu.Lock() if atomic.SwapUint32(&s.headerDone, 1) == 0 {
if !endStream { if !endStream {
s.recvCompress = state.encoding // Headers frame is not actually a trailers-only frame.
} isHeader = true
if !s.headerDone { // These values can be set without any synchronization because
if !endStream && len(state.mdata) > 0 { // stream goroutine will read it only after seeing a closed
s.header = state.mdata // headerChan which we'll close after setting this.
s.recvCompress = state.encoding
if len(state.mdata) > 0 {
s.header = state.mdata
}
} }
close(s.headerChan) close(s.headerChan)
s.headerDone = true
isHeader = true
} }
if !endStream || s.state == streamDone { if !endStream {
s.mu.Unlock()
return return
} }
if len(state.mdata) > 0 { t.closeStream(s, io.EOF, false, http2.ErrCodeNo, state.status(), state.mdata, true)
s.trailer = state.mdata
}
s.finish(state.status())
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
func handleMalformedHTTP2(s *Stream, err error) {
s.mu.Lock()
if !s.headerDone {
close(s.headerChan)
s.headerDone = true
}
s.mu.Unlock()
s.write(recvMsg{err: err})
} }
// reader runs as a separate goroutine in charge of reading data from network // reader runs as a separate goroutine in charge of reading data from network
...@@ -1150,13 +1112,16 @@ func handleMalformedHTTP2(s *Stream, err error) { ...@@ -1150,13 +1112,16 @@ func handleMalformedHTTP2(s *Stream, err error) {
// optimal. // optimal.
// TODO(zhaoq): Check the validity of the incoming frame sequence. // TODO(zhaoq): Check the validity of the incoming frame sequence.
func (t *http2Client) reader() { func (t *http2Client) reader() {
defer close(t.readerDone)
// Check the validity of server preface. // Check the validity of server preface.
frame, err := t.framer.fr.ReadFrame() frame, err := t.framer.fr.ReadFrame()
if err != nil { if err != nil {
t.Close() t.Close()
return return
} }
atomic.CompareAndSwapUint32(&t.activity, 0, 1) if t.keepaliveEnabled {
atomic.CompareAndSwapUint32(&t.activity, 0, 1)
}
sf, ok := frame.(*http2.SettingsFrame) sf, ok := frame.(*http2.SettingsFrame)
if !ok { if !ok {
t.Close() t.Close()
...@@ -1168,7 +1133,9 @@ func (t *http2Client) reader() { ...@@ -1168,7 +1133,9 @@ func (t *http2Client) reader() {
// loop to keep reading incoming messages on this transport. // loop to keep reading incoming messages on this transport.
for { for {
frame, err := t.framer.fr.ReadFrame() frame, err := t.framer.fr.ReadFrame()
atomic.CompareAndSwapUint32(&t.activity, 0, 1) if t.keepaliveEnabled {
atomic.CompareAndSwapUint32(&t.activity, 0, 1)
}
if err != nil { if err != nil {
// Abort an active stream if the http2.Framer returns a // Abort an active stream if the http2.Framer returns a
// http2.StreamError. This can happen only if the server's response // http2.StreamError. This can happen only if the server's response
...@@ -1179,7 +1146,7 @@ func (t *http2Client) reader() { ...@@ -1179,7 +1146,7 @@ func (t *http2Client) reader() {
t.mu.Unlock() t.mu.Unlock()
if s != nil { if s != nil {
// use error detail to provide better err message // use error detail to provide better err message
handleMalformedHTTP2(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.fr.ErrorDetail())) t.closeStream(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.fr.ErrorDetail()), true, http2.ErrCodeProtocol, nil, nil, false)
} }
continue continue
} else { } else {
...@@ -1209,109 +1176,6 @@ func (t *http2Client) reader() { ...@@ -1209,109 +1176,6 @@ func (t *http2Client) reader() {
} }
} }
func (t *http2Client) applySettings(ss []http2.Setting) {
for _, s := range ss {
switch s.ID {
case http2.SettingMaxConcurrentStreams:
// TODO(zhaoq): This is a hack to avoid significant refactoring of the
// code to deal with the unrealistic int32 overflow. Probably will try
// to find a better way to handle this later.
if s.Val > math.MaxInt32 {
s.Val = math.MaxInt32
}
ms := t.maxStreams
t.maxStreams = int(s.Val)
t.streamsQuota.add(int(s.Val) - ms)
case http2.SettingInitialWindowSize:
t.mu.Lock()
for _, stream := range t.activeStreams {
// Adjust the sending quota for each stream.
stream.sendQuotaPool.addAndUpdate(int(s.Val) - int(t.streamSendQuota))
}
t.streamSendQuota = s.Val
t.mu.Unlock()
}
}
}
// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
// is duplicated between the client and the server.
// The transport layer needs to be refactored to take care of this.
func (t *http2Client) itemHandler(i item) (err error) {
defer func() {
if err != nil {
errorf(" error in itemHandler: %v", err)
}
}()
switch i := i.(type) {
case *dataFrame:
if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil {
return err
}
i.f()
return nil
case *headerFrame:
t.hBuf.Reset()
for _, f := range i.hf {
t.hEnc.WriteField(f)
}
endHeaders := false
first := true
for !endHeaders {
size := t.hBuf.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
if first {
first = false
err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: i.streamID,
BlockFragment: t.hBuf.Next(size),
EndStream: i.endStream,
EndHeaders: endHeaders,
})
} else {
err = t.framer.fr.WriteContinuation(
i.streamID,
endHeaders,
t.hBuf.Next(size),
)
}
if err != nil {
return err
}
}
return nil
case *windowUpdate:
return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
case *settings:
return t.framer.fr.WriteSettings(i.ss...)
case *settingsAck:
return t.framer.fr.WriteSettingsAck()
case *resetStream:
// If the server needs to be to intimated about stream closing,
// then we need to make sure the RST_STREAM frame is written to
// the wire before the headers of the next stream waiting on
// streamQuota. We ensure this by adding to the streamsQuota pool
// only after having acquired the writableChan to send RST_STREAM.
err := t.framer.fr.WriteRSTStream(i.streamID, i.code)
t.streamsQuota.add(1)
return err
case *flushIO:
return t.framer.writer.Flush()
case *ping:
if !i.ack {
t.bdpEst.timesnap(i.data)
}
return t.framer.fr.WritePing(i.ack, i.data)
default:
errorf("transport: http2Client.controller got unexpected item type %v", i)
return fmt.Errorf("transport: http2Client.controller got unexpected item type %v", i)
}
}
// keepalive running in a separate goroutune makes sure the connection is alive by sending pings. // keepalive running in a separate goroutune makes sure the connection is alive by sending pings.
func (t *http2Client) keepalive() { func (t *http2Client) keepalive() {
p := &ping{data: [8]byte{}} p := &ping{data: [8]byte{}}
...@@ -1338,6 +1202,11 @@ func (t *http2Client) keepalive() { ...@@ -1338,6 +1202,11 @@ func (t *http2Client) keepalive() {
} }
} else { } else {
t.mu.Unlock() t.mu.Unlock()
if channelz.IsOn() {
t.czmu.Lock()
t.kpCount++
t.czmu.Unlock()
}
// Send ping. // Send ping.
t.controlBuf.put(p) t.controlBuf.put(p)
} }
...@@ -1374,3 +1243,56 @@ func (t *http2Client) Error() <-chan struct{} { ...@@ -1374,3 +1243,56 @@ func (t *http2Client) Error() <-chan struct{} {
func (t *http2Client) GoAway() <-chan struct{} { func (t *http2Client) GoAway() <-chan struct{} {
return t.goAway return t.goAway
} }
func (t *http2Client) ChannelzMetric() *channelz.SocketInternalMetric {
t.czmu.RLock()
s := channelz.SocketInternalMetric{
StreamsStarted: t.streamsStarted,
StreamsSucceeded: t.streamsSucceeded,
StreamsFailed: t.streamsFailed,
MessagesSent: t.msgSent,
MessagesReceived: t.msgRecv,
KeepAlivesSent: t.kpCount,
LastLocalStreamCreatedTimestamp: t.lastStreamCreated,
LastMessageSentTimestamp: t.lastMsgSent,
LastMessageReceivedTimestamp: t.lastMsgRecv,
LocalFlowControlWindow: int64(t.fc.getSize()),
//socket options
LocalAddr: t.localAddr,
RemoteAddr: t.remoteAddr,
// Security
// RemoteName :
}
t.czmu.RUnlock()
s.RemoteFlowControlWindow = t.getOutFlowWindow()
return &s
}
func (t *http2Client) IncrMsgSent() {
t.czmu.Lock()
t.msgSent++
t.lastMsgSent = time.Now()
t.czmu.Unlock()
}
func (t *http2Client) IncrMsgRecv() {
t.czmu.Lock()
t.msgRecv++
t.lastMsgRecv = time.Now()
t.czmu.Unlock()
}
func (t *http2Client) getOutFlowWindow() int64 {
resp := make(chan uint32, 1)
timer := time.NewTimer(time.Second)
defer timer.Stop()
t.controlBuf.put(&outFlowControlSizeRequest{resp})
select {
case sz := <-resp:
return int64(sz)
case <-t.ctxDone:
return -1
case <-timer.C:
return -2
}
}
...@@ -24,7 +24,6 @@ import ( ...@@ -24,7 +24,6 @@ import (
"fmt" "fmt"
"io" "io"
"math" "math"
"math/rand"
"net" "net"
"strconv" "strconv"
"sync" "sync"
...@@ -35,8 +34,12 @@ import ( ...@@ -35,8 +34,12 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
...@@ -52,28 +55,25 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe ...@@ -52,28 +55,25 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe
// http2Server implements the ServerTransport interface with HTTP2. // http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct { type http2Server struct {
ctx context.Context ctx context.Context
ctxDone <-chan struct{} // Cache the context.Done() chan
cancel context.CancelFunc cancel context.CancelFunc
conn net.Conn conn net.Conn
loopy *loopyWriter
readerDone chan struct{} // sync point to enable testing.
writerDone chan struct{} // sync point to enable testing.
remoteAddr net.Addr remoteAddr net.Addr
localAddr net.Addr localAddr net.Addr
maxStreamID uint32 // max stream ID ever seen maxStreamID uint32 // max stream ID ever seen
authInfo credentials.AuthInfo // auth info about the connection authInfo credentials.AuthInfo // auth info about the connection
inTapHandle tap.ServerInHandle inTapHandle tap.ServerInHandle
framer *framer framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding
hEnc *hpack.Encoder // HPACK encoder
// The max number of concurrent streams. // The max number of concurrent streams.
maxStreams uint32 maxStreams uint32
// controlBuf delivers all the control related tasks (e.g., window // controlBuf delivers all the control related tasks (e.g., window
// updates, reset streams, and various settings) to the controller. // updates, reset streams, and various settings) to the controller.
controlBuf *controlBuffer controlBuf *controlBuffer
fc *inFlow fc *trInFlow
// sendQuotaPool provides flow control to outbound message. stats stats.Handler
sendQuotaPool *quotaPool
// localSendQuota limits the amount of data that can be scheduled
// for writing before it is actually written out.
localSendQuota *quotaPool
stats stats.Handler
// Flag to keep track of reading activity on transport. // Flag to keep track of reading activity on transport.
// 1 is true and 0 is false. // 1 is true and 0 is false.
activity uint32 // Accessed atomically. activity uint32 // Accessed atomically.
...@@ -104,13 +104,27 @@ type http2Server struct { ...@@ -104,13 +104,27 @@ type http2Server struct {
drainChan chan struct{} drainChan chan struct{}
state transportState state transportState
activeStreams map[uint32]*Stream activeStreams map[uint32]*Stream
// the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32
// idle is the time instant when the connection went idle. // idle is the time instant when the connection went idle.
// This is either the beginning of the connection or when the number of // This is either the beginning of the connection or when the number of
// RPCs go down to 0. // RPCs go down to 0.
// When the connection is busy, this value is set to 0. // When the connection is busy, this value is set to 0.
idle time.Time idle time.Time
// Fields below are for channelz metric collection.
channelzID int64 // channelz unique identification number
czmu sync.RWMutex
kpCount int64
// The number of streams that have started, including already finished ones.
streamsStarted int64
// The number of streams that have ended successfully by sending frame with
// EoS bit set.
streamsSucceeded int64
streamsFailed int64
lastStreamCreated time.Time
msgSent int64
msgRecv int64
lastMsgSent time.Time
lastMsgRecv time.Time
} }
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
...@@ -185,33 +199,30 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err ...@@ -185,33 +199,30 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
if kep.MinTime == 0 { if kep.MinTime == 0 {
kep.MinTime = defaultKeepalivePolicyMinTime kep.MinTime = defaultKeepalivePolicyMinTime
} }
var buf bytes.Buffer
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
t := &http2Server{ t := &http2Server{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
ctxDone: ctx.Done(),
conn: conn, conn: conn,
remoteAddr: conn.RemoteAddr(), remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(), localAddr: conn.LocalAddr(),
authInfo: config.AuthInfo, authInfo: config.AuthInfo,
framer: framer, framer: framer,
hBuf: &buf, readerDone: make(chan struct{}),
hEnc: hpack.NewEncoder(&buf), writerDone: make(chan struct{}),
maxStreams: maxStreams, maxStreams: maxStreams,
inTapHandle: config.InTapHandle, inTapHandle: config.InTapHandle,
controlBuf: newControlBuffer(), fc: &trInFlow{limit: uint32(icwz)},
fc: &inFlow{limit: uint32(icwz)},
sendQuotaPool: newQuotaPool(defaultWindowSize),
localSendQuota: newQuotaPool(defaultLocalSendQuota),
state: reachable, state: reachable,
activeStreams: make(map[uint32]*Stream), activeStreams: make(map[uint32]*Stream),
streamSendQuota: defaultWindowSize,
stats: config.StatsHandler, stats: config.StatsHandler,
kp: kp, kp: kp,
idle: time.Now(), idle: time.Now(),
kep: kep, kep: kep,
initialWindowSize: iwz, initialWindowSize: iwz,
} }
t.controlBuf = newControlBuffer(t.ctxDone)
if dynamicWindow { if dynamicWindow {
t.bdpEst = &bdpEstimator{ t.bdpEst = &bdpEstimator{
bdp: initialWindowSize, bdp: initialWindowSize,
...@@ -226,6 +237,9 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err ...@@ -226,6 +237,9 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
connBegin := &stats.ConnBegin{} connBegin := &stats.ConnBegin{}
t.stats.HandleConn(t.ctx, connBegin) t.stats.HandleConn(t.ctx, connBegin)
} }
if channelz.IsOn() {
t.channelzID = channelz.RegisterNormalSocket(t, config.ChannelzParentID, "")
}
t.framer.writer.Flush() t.framer.writer.Flush()
defer func() { defer func() {
...@@ -258,8 +272,13 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err ...@@ -258,8 +272,13 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
t.handleSettings(sf) t.handleSettings(sf)
go func() { go func() {
loopyWriter(t.ctx, t.controlBuf, t.itemHandler) t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst)
t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler
if err := t.loopy.run(); err != nil {
errorf("transport: loopyWriter.run returning. Err: %v", err)
}
t.conn.Close() t.conn.Close()
close(t.writerDone)
}() }()
go t.keepalive() go t.keepalive()
return t, nil return t, nil
...@@ -268,12 +287,16 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err ...@@ -268,12 +287,16 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
// operateHeader takes action on the decoded headers. // operateHeader takes action on the decoded headers.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) { func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) {
streamID := frame.Header().StreamID streamID := frame.Header().StreamID
var state decodeState var state decodeState
for _, hf := range frame.Fields { for _, hf := range frame.Fields {
if err := state.processHeaderField(hf); err != nil { if err := state.processHeaderField(hf); err != nil {
if se, ok := err.(StreamError); ok { if se, ok := err.(StreamError); ok {
t.controlBuf.put(&resetStream{streamID, statusCodeConvTab[se.Code]}) t.controlBuf.put(&cleanupStream{
streamID: streamID,
rst: true,
rstCode: statusCodeConvTab[se.Code],
onWrite: func() {},
})
} }
return return
} }
...@@ -281,14 +304,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( ...@@ -281,14 +304,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
buf := newRecvBuffer() buf := newRecvBuffer()
s := &Stream{ s := &Stream{
id: streamID, id: streamID,
st: t, st: t,
buf: buf, buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)}, fc: &inFlow{limit: uint32(t.initialWindowSize)},
recvCompress: state.encoding, recvCompress: state.encoding,
method: state.method, method: state.method,
contentSubtype: state.contentSubtype,
} }
if frame.StreamEnded() { if frame.StreamEnded() {
// s is just created by the caller. No lock needed. // s is just created by the caller. No lock needed.
s.state = streamReadDone s.state = streamReadDone
...@@ -306,10 +329,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( ...@@ -306,10 +329,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
pr.AuthInfo = t.authInfo pr.AuthInfo = t.authInfo
} }
s.ctx = peer.NewContext(s.ctx, pr) s.ctx = peer.NewContext(s.ctx, pr)
// Cache the current stream to the context so that the server application
// can find out. Required when the server wants to send some metadata
// back to the client (unary call only).
s.ctx = newContextWithStream(s.ctx, s)
// Attach the received metadata to the context. // Attach the received metadata to the context.
if len(state.mdata) > 0 { if len(state.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata) s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
...@@ -328,7 +347,12 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( ...@@ -328,7 +347,12 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.ctx, err = t.inTapHandle(s.ctx, info) s.ctx, err = t.inTapHandle(s.ctx, info)
if err != nil { if err != nil {
warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) t.controlBuf.put(&cleanupStream{
streamID: s.id,
rst: true,
rstCode: http2.ErrCodeRefusedStream,
onWrite: func() {},
})
return return
} }
} }
...@@ -339,7 +363,12 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( ...@@ -339,7 +363,12 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
} }
if uint32(len(t.activeStreams)) >= t.maxStreams { if uint32(len(t.activeStreams)) >= t.maxStreams {
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.put(&resetStream{streamID, http2.ErrCodeRefusedStream}) t.controlBuf.put(&cleanupStream{
streamID: streamID,
rst: true,
rstCode: http2.ErrCodeRefusedStream,
onWrite: func() {},
})
return return
} }
if streamID%2 != 1 || streamID <= t.maxStreamID { if streamID%2 != 1 || streamID <= t.maxStreamID {
...@@ -349,12 +378,17 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( ...@@ -349,12 +378,17 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
return true return true
} }
t.maxStreamID = streamID t.maxStreamID = streamID
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[streamID] = s t.activeStreams[streamID] = s
if len(t.activeStreams) == 1 { if len(t.activeStreams) == 1 {
t.idle = time.Time{} t.idle = time.Time{}
} }
t.mu.Unlock() t.mu.Unlock()
if channelz.IsOn() {
t.czmu.Lock()
t.streamsStarted++
t.lastStreamCreated = time.Now()
t.czmu.Unlock()
}
s.requestRead = func(n int) { s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n)) t.adjustWindow(s, uint32(n))
} }
...@@ -370,19 +404,23 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( ...@@ -370,19 +404,23 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
} }
t.stats.HandleRPC(s.ctx, inHeader) t.stats.HandleRPC(s.ctx, inHeader)
} }
s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ reader: &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
recv: s.buf, ctxDone: s.ctxDone,
recv: s.buf,
}, },
windowHandler: func(n int) { windowHandler: func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
}, },
} }
s.waiters = waiters{ // Register the stream with loopy.
ctx: s.ctx, t.controlBuf.put(&registerStream{
tctx: t.ctx, streamID: s.id,
} wq: s.wq,
})
handle(s) handle(s)
return return
} }
...@@ -391,18 +429,26 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( ...@@ -391,18 +429,26 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
// typically run in a separate goroutine. // typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context. // traceCtx attaches trace to ctx and returns the new context.
func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) { func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
defer close(t.readerDone)
for { for {
frame, err := t.framer.fr.ReadFrame() frame, err := t.framer.fr.ReadFrame()
atomic.StoreUint32(&t.activity, 1) atomic.StoreUint32(&t.activity, 1)
if err != nil { if err != nil {
if se, ok := err.(http2.StreamError); ok { if se, ok := err.(http2.StreamError); ok {
warningf("transport: http2Server.HandleStreams encountered http2.StreamError: %v", se)
t.mu.Lock() t.mu.Lock()
s := t.activeStreams[se.StreamID] s := t.activeStreams[se.StreamID]
t.mu.Unlock() t.mu.Unlock()
if s != nil { if s != nil {
t.closeStream(s) t.closeStream(s, true, se.Code, nil, false)
} else {
t.controlBuf.put(&cleanupStream{
streamID: se.StreamID,
rst: true,
rstCode: se.Code,
onWrite: func() {},
})
} }
t.controlBuf.put(&resetStream{se.StreamID, se.Code})
continue continue
} }
if err == io.EOF || err == io.ErrUnexpectedEOF { if err == io.EOF || err == io.ErrUnexpectedEOF {
...@@ -456,33 +502,20 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) { ...@@ -456,33 +502,20 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
// of stream if the application is requesting data larger in size than // of stream if the application is requesting data larger in size than
// the window. // the window.
func (t *http2Server) adjustWindow(s *Stream, n uint32) { func (t *http2Server) adjustWindow(s *Stream, n uint32) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == streamDone {
return
}
if w := s.fc.maybeAdjust(n); w > 0 { if w := s.fc.maybeAdjust(n); w > 0 {
if cw := t.fc.resetPendingUpdate(); cw > 0 { t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
t.controlBuf.put(&windowUpdate{0, cw})
}
t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
// updateWindow adjusts the inbound quota for the stream and the transport. // updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when // Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold. // the cumulative quota exceeds the corresponding threshold.
func (t *http2Server) updateWindow(s *Stream, n uint32) { func (t *http2Server) updateWindow(s *Stream, n uint32) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == streamDone {
return
}
if w := s.fc.onRead(n); w > 0 { if w := s.fc.onRead(n); w > 0 {
if cw := t.fc.resetPendingUpdate(); cw > 0 { t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id,
t.controlBuf.put(&windowUpdate{0, cw}) increment: w,
} })
t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
...@@ -496,12 +529,15 @@ func (t *http2Server) updateFlowControl(n uint32) { ...@@ -496,12 +529,15 @@ func (t *http2Server) updateFlowControl(n uint32) {
} }
t.initialWindowSize = int32(n) t.initialWindowSize = int32(n)
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n)}) t.controlBuf.put(&outgoingWindowUpdate{
t.controlBuf.put(&settings{ streamID: 0,
increment: t.fc.newLimit(n),
})
t.controlBuf.put(&outgoingSettings{
ss: []http2.Setting{ ss: []http2.Setting{
{ {
ID: http2.SettingInitialWindowSize, ID: http2.SettingInitialWindowSize,
Val: uint32(n), Val: n,
}, },
}, },
}) })
...@@ -512,7 +548,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { ...@@ -512,7 +548,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
size := f.Header().Length size := f.Header().Length
var sendBDPPing bool var sendBDPPing bool
if t.bdpEst != nil { if t.bdpEst != nil {
sendBDPPing = t.bdpEst.add(uint32(size)) sendBDPPing = t.bdpEst.add(size)
} }
// Decouple connection's flow control from application's read. // Decouple connection's flow control from application's read.
// An update on connection's flow control should not depend on // An update on connection's flow control should not depend on
...@@ -522,23 +558,22 @@ func (t *http2Server) handleData(f *http2.DataFrame) { ...@@ -522,23 +558,22 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// Decoupling the connection flow control will prevent other // Decoupling the connection flow control will prevent other
// active(fast) streams from starving in presence of slow or // active(fast) streams from starving in presence of slow or
// inactive streams. // inactive streams.
// if w := t.fc.onData(size); w > 0 {
// Furthermore, if a bdpPing is being sent out we can piggyback t.controlBuf.put(&outgoingWindowUpdate{
// connection's window update for the bytes we just received. streamID: 0,
increment: w,
})
}
if sendBDPPing { if sendBDPPing {
if size != 0 { // Could be an empty frame. // Avoid excessive ping detection (e.g. in an L7 proxy)
t.controlBuf.put(&windowUpdate{0, uint32(size)}) // by sending a window update prior to the BDP ping.
if w := t.fc.reset(); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
} }
t.controlBuf.put(bdpPing) t.controlBuf.put(bdpPing)
} else {
if err := t.fc.onData(uint32(size)); err != nil {
errorf("transport: http2Server %v", err)
t.Close()
return
}
if w := t.fc.onRead(uint32(size)); w > 0 {
t.controlBuf.put(&windowUpdate{0, w})
}
} }
// Select the right stream to dispatch. // Select the right stream to dispatch.
s, ok := t.getStream(f) s, ok := t.getStream(f)
...@@ -546,23 +581,15 @@ func (t *http2Server) handleData(f *http2.DataFrame) { ...@@ -546,23 +581,15 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
return return
} }
if size > 0 { if size > 0 {
s.mu.Lock() if err := s.fc.onData(size); err != nil {
if s.state == streamDone { t.closeStream(s, true, http2.ErrCodeFlowControl, nil, false)
s.mu.Unlock()
return
}
if err := s.fc.onData(uint32(size)); err != nil {
s.mu.Unlock()
t.closeStream(s)
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
return return
} }
if f.Header().Flags.Has(http2.FlagDataPadded) { if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 {
t.controlBuf.put(&windowUpdate{s.id, w}) t.controlBuf.put(&outgoingWindowUpdate{s.id, w})
} }
} }
s.mu.Unlock()
// TODO(bradfitz, zhaoq): A copy is required here because there is no // TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
...@@ -574,11 +601,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { ...@@ -574,11 +601,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
} }
if f.Header().Flags.Has(http2.FlagDataEndStream) { if f.Header().Flags.Has(http2.FlagDataEndStream) {
// Received the end of stream from the client. // Received the end of stream from the client.
s.mu.Lock() s.compareAndSwapState(streamActive, streamReadDone)
if s.state != streamDone {
s.state = streamReadDone
}
s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
} }
...@@ -588,50 +611,21 @@ func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { ...@@ -588,50 +611,21 @@ func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) {
if !ok { if !ok {
return return
} }
t.closeStream(s) t.closeStream(s, false, 0, nil, false)
} }
func (t *http2Server) handleSettings(f *http2.SettingsFrame) { func (t *http2Server) handleSettings(f *http2.SettingsFrame) {
if f.IsAck() { if f.IsAck() {
return return
} }
var rs []http2.Setting var ss []http2.Setting
var ps []http2.Setting
f.ForeachSetting(func(s http2.Setting) error { f.ForeachSetting(func(s http2.Setting) error {
if t.isRestrictive(s) { ss = append(ss, s)
rs = append(rs, s)
} else {
ps = append(ps, s)
}
return nil return nil
}) })
t.applySettings(rs) t.controlBuf.put(&incomingSettings{
t.controlBuf.put(&settingsAck{}) ss: ss,
t.applySettings(ps) })
}
func (t *http2Server) isRestrictive(s http2.Setting) bool {
switch s.ID {
case http2.SettingInitialWindowSize:
// Note: we don't acquire a lock here to read streamSendQuota
// because the same goroutine updates it later.
return s.Val < t.streamSendQuota
}
return false
}
func (t *http2Server) applySettings(ss []http2.Setting) {
for _, s := range ss {
if s.ID == http2.SettingInitialWindowSize {
t.mu.Lock()
for _, stream := range t.activeStreams {
stream.sendQuotaPool.addAndUpdate(int(s.Val) - int(t.streamSendQuota))
}
t.streamSendQuota = s.Val
t.mu.Unlock()
}
}
} }
const ( const (
...@@ -690,33 +684,31 @@ func (t *http2Server) handlePing(f *http2.PingFrame) { ...@@ -690,33 +684,31 @@ func (t *http2Server) handlePing(f *http2.PingFrame) {
} }
func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) {
id := f.Header().StreamID t.controlBuf.put(&incomingWindowUpdate{
incr := f.Increment streamID: f.Header().StreamID,
if id == 0 { increment: f.Increment,
t.sendQuotaPool.add(int(incr)) })
return }
}
if s, ok := t.getStream(f); ok { func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD) []hpack.HeaderField {
s.sendQuotaPool.add(int(incr)) for k, vv := range md {
if isReservedHeader(k) {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
} }
return headerFields
} }
// WriteHeader sends the header metedata md back to the client. // WriteHeader sends the header metedata md back to the client.
func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
select { if s.updateHeaderSent() || s.getState() == streamDone {
case <-s.ctx.Done():
return ContextErr(s.ctx.Err())
case <-t.ctx.Done():
return ErrConnClosing
default:
}
s.mu.Lock()
if s.headerOk || s.state == streamDone {
s.mu.Unlock()
return ErrIllegalHeaderWrite return ErrIllegalHeaderWrite
} }
s.headerOk = true s.hdrMu.Lock()
if md.Len() > 0 { if md.Len() > 0 {
if s.header.Len() > 0 { if s.header.Len() > 0 {
s.header = metadata.Join(s.header, md) s.header = metadata.Join(s.header, md)
...@@ -724,37 +716,35 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { ...@@ -724,37 +716,35 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
s.header = md s.header = md
} }
} }
md = s.header t.writeHeaderLocked(s)
s.mu.Unlock() s.hdrMu.Unlock()
return nil
}
func (t *http2Server) writeHeaderLocked(s *Stream) {
// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
// first and create a slice of that exact size. // first and create a slice of that exact size.
headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
if s.sendCompress != "" { if s.sendCompress != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
} }
for k, vv := range md { headerFields = appendHeaderFieldsFromMD(headerFields, s.header)
if isReservedHeader(k) {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
t.controlBuf.put(&headerFrame{ t.controlBuf.put(&headerFrame{
streamID: s.id, streamID: s.id,
hf: headerFields, hf: headerFields,
endStream: false, endStream: false,
onWrite: func() {
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
}) })
if t.stats != nil { if t.stats != nil {
outHeader := &stats.OutHeader{ // Note: WireLength is not set in outHeader.
//WireLength: // TODO(mmukhi): Revisit this later, if needed. // TODO(mmukhi): Revisit this later, if needed.
} outHeader := &stats.OutHeader{}
t.stats.HandleRPC(s.Context(), outHeader) t.stats.HandleRPC(s.Context(), outHeader)
} }
return nil
} }
// WriteStatus sends stream status to the client and terminates the stream. // WriteStatus sends stream status to the client and terminates the stream.
...@@ -762,37 +752,20 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { ...@@ -762,37 +752,20 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted. // OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
select { if s.getState() == streamDone {
case <-t.ctx.Done():
return ErrConnClosing
default:
}
var headersSent, hasHeader bool
s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()
return nil return nil
} }
if s.headerOk { s.hdrMu.Lock()
headersSent = true
}
if s.header.Len() > 0 {
hasHeader = true
}
s.mu.Unlock()
if !headersSent && hasHeader {
t.WriteHeader(s, nil)
headersSent = true
}
// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
// first and create a slice of that exact size. // first and create a slice of that exact size.
headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else. headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else.
if !headersSent { if !s.updateHeaderSent() { // No headers have been sent.
headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) if len(s.header) > 0 { // Send a separate header frame.
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) t.writeHeaderLocked(s)
} else { // Send a trailer only response.
headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
}
} }
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
...@@ -801,129 +774,75 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { ...@@ -801,129 +774,75 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
stBytes, err := proto.Marshal(p) stBytes, err := proto.Marshal(p)
if err != nil { if err != nil {
// TODO: return error instead, when callers are able to handle it. // TODO: return error instead, when callers are able to handle it.
panic(err) grpclog.Errorf("transport: failed to marshal rpc status: %v, error: %v", p, err)
} else {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
} }
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
} }
// Attach the trailer metadata. // Attach the trailer metadata.
for k, vv := range s.trailer { headerFields = appendHeaderFieldsFromMD(headerFields, s.trailer)
// Clients don't tolerate reading restricted headers after some non restricted ones were sent. trailingHeader := &headerFrame{
if isReservedHeader(k) {
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
t.controlBuf.put(&headerFrame{
streamID: s.id, streamID: s.id,
hf: headerFields, hf: headerFields,
endStream: true, endStream: true,
}) onWrite: func() {
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
}
s.hdrMu.Unlock()
t.closeStream(s, false, 0, trailingHeader, true)
if t.stats != nil { if t.stats != nil {
t.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) t.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
} }
t.closeStream(s)
return nil return nil
} }
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error // Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error). // is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
select { if !s.isHeaderSent() { // Headers haven't been written yet.
case <-s.ctx.Done(): if err := t.WriteHeader(s, nil); err != nil {
return ContextErr(s.ctx.Err()) // TODO(mmukhi, dfawley): Make sure this is the right code to return.
case <-t.ctx.Done(): return streamErrorf(codes.Internal, "transport: %v", err)
return ErrConnClosing }
default: } else {
} // Writing headers checks for this condition.
if s.getState() == streamDone {
var writeHeaderFrame bool // TODO(mmukhi, dfawley): Should the server write also return io.EOF?
s.mu.Lock() s.cancel()
if s.state == streamDone { select {
s.mu.Unlock() case <-t.ctx.Done():
return streamErrorf(codes.Unknown, "the stream has been done") return ErrConnClosing
} default:
if !s.headerOk { }
writeHeaderFrame = true return ContextErr(s.ctx.Err())
} }
s.mu.Unlock()
if writeHeaderFrame {
t.WriteHeader(s, nil)
} }
// Add data to header frame so that we can equally distribute data across frames. // Add some data to header frame so that we can equally distribute bytes across frames.
emptyLen := http2MaxFrameLen - len(hdr) emptyLen := http2MaxFrameLen - len(hdr)
if emptyLen > len(data) { if emptyLen > len(data) {
emptyLen = len(data) emptyLen = len(data)
} }
hdr = append(hdr, data[:emptyLen]...) hdr = append(hdr, data[:emptyLen]...)
data = data[emptyLen:] data = data[emptyLen:]
var ( df := &dataFrame{
streamQuota int streamID: s.id,
streamQuotaVer uint32 h: hdr,
err error d: data,
) onEachWrite: func() {
for _, r := range [][]byte{hdr, data} {
for len(r) > 0 {
size := http2MaxFrameLen
if size > len(r) {
size = len(r)
}
if streamQuota == 0 { // Used up all the locally cached stream quota.
// Get all the stream quota there is.
streamQuota, streamQuotaVer, err = s.sendQuotaPool.get(math.MaxInt32, s.waiters)
if err != nil {
return err
}
}
if size > streamQuota {
size = streamQuota
}
// Get size worth quota from transport.
tq, _, err := t.sendQuotaPool.get(size, s.waiters)
if err != nil {
return err
}
if tq < size {
size = tq
}
ltq, _, err := t.localSendQuota.get(size, s.waiters)
if err != nil {
return err
}
// even if ltq is smaller than size we don't adjust size since,
// ltq is only a soft limit.
streamQuota -= size
p := r[:size]
// Reset ping strikes when sending data since this might cause
// the peer to send ping.
atomic.StoreUint32(&t.resetPingStrikes, 1) atomic.StoreUint32(&t.resetPingStrikes, 1)
success := func() { },
ltq := ltq
t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() {
t.localSendQuota.add(ltq)
}})
r = r[size:]
}
failure := func() { // The stream quota version must have changed.
// Our streamQuota cache is invalidated now, so give it back.
s.sendQuotaPool.lockedAdd(streamQuota + size)
}
if !s.sendQuotaPool.compareAndExecute(streamQuotaVer, success, failure) {
// Couldn't send this chunk out.
t.sendQuotaPool.add(size)
t.localSendQuota.add(ltq)
streamQuota = 0
}
}
} }
if streamQuota > 0 { if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
// ADd the left over quota back to stream. select {
s.sendQuotaPool.add(streamQuota) case <-t.ctx.Done():
return ErrConnClosing
default:
}
return ContextErr(s.ctx.Err())
} }
return nil return t.controlBuf.put(df)
} }
// keepalive running in a separate goroutine does the following: // keepalive running in a separate goroutine does the following:
...@@ -968,7 +887,7 @@ func (t *http2Server) keepalive() { ...@@ -968,7 +887,7 @@ func (t *http2Server) keepalive() {
// The connection has been idle for a duration of keepalive.MaxConnectionIdle or more. // The connection has been idle for a duration of keepalive.MaxConnectionIdle or more.
// Gracefully close the connection. // Gracefully close the connection.
t.drain(http2.ErrCodeNo, []byte{}) t.drain(http2.ErrCodeNo, []byte{})
// Reseting the timer so that the clean-up doesn't deadlock. // Resetting the timer so that the clean-up doesn't deadlock.
maxIdle.Reset(infinity) maxIdle.Reset(infinity)
return return
} }
...@@ -980,7 +899,7 @@ func (t *http2Server) keepalive() { ...@@ -980,7 +899,7 @@ func (t *http2Server) keepalive() {
case <-maxAge.C: case <-maxAge.C:
// Close the connection after grace period. // Close the connection after grace period.
t.Close() t.Close()
// Reseting the timer so that the clean-up doesn't deadlock. // Resetting the timer so that the clean-up doesn't deadlock.
maxAge.Reset(infinity) maxAge.Reset(infinity)
case <-t.ctx.Done(): case <-t.ctx.Done():
} }
...@@ -993,11 +912,16 @@ func (t *http2Server) keepalive() { ...@@ -993,11 +912,16 @@ func (t *http2Server) keepalive() {
} }
if pingSent { if pingSent {
t.Close() t.Close()
// Reseting the timer so that the clean-up doesn't deadlock. // Resetting the timer so that the clean-up doesn't deadlock.
keepalive.Reset(infinity) keepalive.Reset(infinity)
return return
} }
pingSent = true pingSent = true
if channelz.IsOn() {
t.czmu.Lock()
t.kpCount++
t.czmu.Unlock()
}
t.controlBuf.put(p) t.controlBuf.put(p)
keepalive.Reset(t.kp.Timeout) keepalive.Reset(t.kp.Timeout)
case <-t.ctx.Done(): case <-t.ctx.Done():
...@@ -1006,133 +930,6 @@ func (t *http2Server) keepalive() { ...@@ -1006,133 +930,6 @@ func (t *http2Server) keepalive() {
} }
} }
var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}}
// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
// is duplicated between the client and the server.
// The transport layer needs to be refactored to take care of this.
func (t *http2Server) itemHandler(i item) error {
switch i := i.(type) {
case *dataFrame:
if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil {
return err
}
i.f()
return nil
case *headerFrame:
t.hBuf.Reset()
for _, f := range i.hf {
t.hEnc.WriteField(f)
}
first := true
endHeaders := false
for !endHeaders {
size := t.hBuf.Len()
if size > http2MaxFrameLen {
size = http2MaxFrameLen
} else {
endHeaders = true
}
var err error
if first {
first = false
err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: i.streamID,
BlockFragment: t.hBuf.Next(size),
EndStream: i.endStream,
EndHeaders: endHeaders,
})
} else {
err = t.framer.fr.WriteContinuation(
i.streamID,
endHeaders,
t.hBuf.Next(size),
)
}
if err != nil {
return err
}
}
atomic.StoreUint32(&t.resetPingStrikes, 1)
return nil
case *windowUpdate:
return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
case *settings:
return t.framer.fr.WriteSettings(i.ss...)
case *settingsAck:
return t.framer.fr.WriteSettingsAck()
case *resetStream:
return t.framer.fr.WriteRSTStream(i.streamID, i.code)
case *goAway:
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
// The transport is closing.
return fmt.Errorf("transport: Connection closing")
}
sid := t.maxStreamID
if !i.headsUp {
// Stop accepting more streams now.
t.state = draining
if len(t.activeStreams) == 0 {
i.closeConn = true
}
t.mu.Unlock()
if err := t.framer.fr.WriteGoAway(sid, i.code, i.debugData); err != nil {
return err
}
if i.closeConn {
// Abruptly close the connection following the GoAway (via
// loopywriter). But flush out what's inside the buffer first.
t.controlBuf.put(&flushIO{closeTr: true})
}
return nil
}
t.mu.Unlock()
// For a graceful close, send out a GoAway with stream ID of MaxUInt32,
// Follow that with a ping and wait for the ack to come back or a timer
// to expire. During this time accept new streams since they might have
// originated before the GoAway reaches the client.
// After getting the ack or timer expiration send out another GoAway this
// time with an ID of the max stream server intends to process.
if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil {
return err
}
if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil {
return err
}
go func() {
timer := time.NewTimer(time.Minute)
defer timer.Stop()
select {
case <-t.drainChan:
case <-timer.C:
case <-t.ctx.Done():
return
}
t.controlBuf.put(&goAway{code: i.code, debugData: i.debugData})
}()
return nil
case *flushIO:
if err := t.framer.writer.Flush(); err != nil {
return err
}
if i.closeTr {
return ErrConnClosing
}
return nil
case *ping:
if !i.ack {
t.bdpEst.timesnap(i.data)
}
return t.framer.fr.WritePing(i.ack, i.data)
default:
err := status.Errorf(codes.Internal, "transport: http2Server.controller got unexpected item type %t", i)
errorf("%v", err)
return err
}
}
// Close starts shutting down the http2Server transport. // Close starts shutting down the http2Server transport.
// TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This
// could cause some resource issue. Revisit this later. // could cause some resource issue. Revisit this later.
...@@ -1146,8 +943,12 @@ func (t *http2Server) Close() error { ...@@ -1146,8 +943,12 @@ func (t *http2Server) Close() error {
streams := t.activeStreams streams := t.activeStreams
t.activeStreams = nil t.activeStreams = nil
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.finish()
t.cancel() t.cancel()
err := t.conn.Close() err := t.conn.Close()
if channelz.IsOn() {
channelz.RemoveEntry(t.channelzID)
}
// Cancel all active streams. // Cancel all active streams.
for _, s := range streams { for _, s := range streams {
s.cancel() s.cancel()
...@@ -1161,27 +962,45 @@ func (t *http2Server) Close() error { ...@@ -1161,27 +962,45 @@ func (t *http2Server) Close() error {
// closeStream clears the footprint of a stream when the stream is not needed // closeStream clears the footprint of a stream when the stream is not needed
// any more. // any more.
func (t *http2Server) closeStream(s *Stream) { func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) {
t.mu.Lock() if s.swapState(streamDone) == streamDone {
delete(t.activeStreams, s.id) // If the stream was already done, return.
if len(t.activeStreams) == 0 { return
t.idle = time.Now()
}
if t.state == draining && len(t.activeStreams) == 0 {
defer t.controlBuf.put(&flushIO{closeTr: true})
} }
t.mu.Unlock()
// In case stream sending and receiving are invoked in separate // In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be // goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines. // called to interrupt the potential blocking on other goroutines.
s.cancel() s.cancel()
s.mu.Lock() cleanup := &cleanupStream{
if s.state == streamDone { streamID: s.id,
s.mu.Unlock() rst: rst,
return rstCode: rstCode,
onWrite: func() {
t.mu.Lock()
if t.activeStreams != nil {
delete(t.activeStreams, s.id)
if len(t.activeStreams) == 0 {
t.idle = time.Now()
}
}
t.mu.Unlock()
if channelz.IsOn() {
t.czmu.Lock()
if eosReceived {
t.streamsSucceeded++
} else {
t.streamsFailed++
}
t.czmu.Unlock()
}
},
}
if hdr != nil {
hdr.cleanup = cleanup
t.controlBuf.put(hdr)
} else {
t.controlBuf.put(cleanup)
} }
s.state = streamDone
s.mu.Unlock()
} }
func (t *http2Server) RemoteAddr() net.Addr { func (t *http2Server) RemoteAddr() net.Addr {
...@@ -1202,7 +1021,115 @@ func (t *http2Server) drain(code http2.ErrCode, debugData []byte) { ...@@ -1202,7 +1021,115 @@ func (t *http2Server) drain(code http2.ErrCode, debugData []byte) {
t.controlBuf.put(&goAway{code: code, debugData: debugData, headsUp: true}) t.controlBuf.put(&goAway{code: code, debugData: debugData, headsUp: true})
} }
var rgen = rand.New(rand.NewSource(time.Now().UnixNano())) var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}}
// Handles outgoing GoAway and returns true if loopy needs to put itself
// in draining mode.
func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
t.mu.Lock()
if t.state == closing { // TODO(mmukhi): This seems unnecessary.
t.mu.Unlock()
// The transport is closing.
return false, ErrConnClosing
}
sid := t.maxStreamID
if !g.headsUp {
// Stop accepting more streams now.
t.state = draining
if len(t.activeStreams) == 0 {
g.closeConn = true
}
t.mu.Unlock()
if err := t.framer.fr.WriteGoAway(sid, g.code, g.debugData); err != nil {
return false, err
}
if g.closeConn {
// Abruptly close the connection following the GoAway (via
// loopywriter). But flush out what's inside the buffer first.
t.framer.writer.Flush()
return false, fmt.Errorf("transport: Connection closing")
}
return true, nil
}
t.mu.Unlock()
// For a graceful close, send out a GoAway with stream ID of MaxUInt32,
// Follow that with a ping and wait for the ack to come back or a timer
// to expire. During this time accept new streams since they might have
// originated before the GoAway reaches the client.
// After getting the ack or timer expiration send out another GoAway this
// time with an ID of the max stream server intends to process.
if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil {
return false, err
}
if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil {
return false, err
}
go func() {
timer := time.NewTimer(time.Minute)
defer timer.Stop()
select {
case <-t.drainChan:
case <-timer.C:
case <-t.ctx.Done():
return
}
t.controlBuf.put(&goAway{code: g.code, debugData: g.debugData})
}()
return false, nil
}
func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric {
t.czmu.RLock()
s := channelz.SocketInternalMetric{
StreamsStarted: t.streamsStarted,
StreamsSucceeded: t.streamsSucceeded,
StreamsFailed: t.streamsFailed,
MessagesSent: t.msgSent,
MessagesReceived: t.msgRecv,
KeepAlivesSent: t.kpCount,
LastRemoteStreamCreatedTimestamp: t.lastStreamCreated,
LastMessageSentTimestamp: t.lastMsgSent,
LastMessageReceivedTimestamp: t.lastMsgRecv,
LocalFlowControlWindow: int64(t.fc.getSize()),
//socket options
LocalAddr: t.localAddr,
RemoteAddr: t.remoteAddr,
// Security
// RemoteName :
}
t.czmu.RUnlock()
s.RemoteFlowControlWindow = t.getOutFlowWindow()
return &s
}
func (t *http2Server) IncrMsgSent() {
t.czmu.Lock()
t.msgSent++
t.lastMsgSent = time.Now()
t.czmu.Unlock()
}
func (t *http2Server) IncrMsgRecv() {
t.czmu.Lock()
t.msgRecv++
t.lastMsgRecv = time.Now()
t.czmu.Unlock()
}
func (t *http2Server) getOutFlowWindow() int64 {
resp := make(chan uint32)
timer := time.NewTimer(time.Second)
defer timer.Stop()
t.controlBuf.put(&outFlowControlSizeRequest{resp})
select {
case sz := <-resp:
return int64(sz)
case <-t.ctxDone:
return -1
case <-timer.C:
return -2
}
}
func getJitter(v time.Duration) time.Duration { func getJitter(v time.Duration) time.Duration {
if v == infinity { if v == infinity {
...@@ -1210,6 +1137,6 @@ func getJitter(v time.Duration) time.Duration { ...@@ -1210,6 +1137,6 @@ func getJitter(v time.Duration) time.Duration {
} }
// Generate a jitter between +/- 10% of the value. // Generate a jitter between +/- 10% of the value.
r := int64(v / 10) r := int64(v / 10)
j := rgen.Int63n(2*r) - r j := grpcrand.Int63n(2*r) - r
return time.Duration(j) return time.Duration(j)
} }
...@@ -23,12 +23,12 @@ import ( ...@@ -23,12 +23,12 @@ import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"unicode/utf8"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"golang.org/x/net/http2" "golang.org/x/net/http2"
...@@ -46,6 +46,12 @@ const ( ...@@ -46,6 +46,12 @@ const (
// http2IOBufSize specifies the buffer size for sending frames. // http2IOBufSize specifies the buffer size for sending frames.
defaultWriteBufSize = 32 * 1024 defaultWriteBufSize = 32 * 1024
defaultReadBufSize = 32 * 1024 defaultReadBufSize = 32 * 1024
// baseContentType is the base content-type for gRPC. This is a valid
// content-type on it's own, but can also include a content-subtype such as
// "proto" as a suffix after "+" or ";". See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
// for more details.
baseContentType = "application/grpc"
) )
var ( var (
...@@ -64,7 +70,7 @@ var ( ...@@ -64,7 +70,7 @@ var (
http2.ErrCodeConnect: codes.Internal, http2.ErrCodeConnect: codes.Internal,
http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
http2.ErrCodeInadequateSecurity: codes.PermissionDenied, http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
http2.ErrCodeHTTP11Required: codes.FailedPrecondition, http2.ErrCodeHTTP11Required: codes.Internal,
} }
statusCodeConvTab = map[codes.Code]http2.ErrCode{ statusCodeConvTab = map[codes.Code]http2.ErrCode{
codes.Internal: http2.ErrCodeInternal, codes.Internal: http2.ErrCodeInternal,
...@@ -111,9 +117,10 @@ type decodeState struct { ...@@ -111,9 +117,10 @@ type decodeState struct {
timeout time.Duration timeout time.Duration
method string method string
// key-value metadata map from the peer. // key-value metadata map from the peer.
mdata map[string][]string mdata map[string][]string
statsTags []byte statsTags []byte
statsTrace []byte statsTrace []byte
contentSubtype string
} }
// isReservedHeader checks whether hdr belongs to HTTP2 headers // isReservedHeader checks whether hdr belongs to HTTP2 headers
...@@ -125,6 +132,7 @@ func isReservedHeader(hdr string) bool { ...@@ -125,6 +132,7 @@ func isReservedHeader(hdr string) bool {
} }
switch hdr { switch hdr {
case "content-type", case "content-type",
"user-agent",
"grpc-message-type", "grpc-message-type",
"grpc-encoding", "grpc-encoding",
"grpc-message", "grpc-message",
...@@ -138,28 +146,55 @@ func isReservedHeader(hdr string) bool { ...@@ -138,28 +146,55 @@ func isReservedHeader(hdr string) bool {
} }
} }
// isWhitelistedPseudoHeader checks whether hdr belongs to HTTP2 pseudoheaders // isWhitelistedHeader checks whether hdr should be propagated
// that should be propagated into metadata visible to users. // into metadata visible to users.
func isWhitelistedPseudoHeader(hdr string) bool { func isWhitelistedHeader(hdr string) bool {
switch hdr { switch hdr {
case ":authority": case ":authority", "user-agent":
return true return true
default: default:
return false return false
} }
} }
func validContentType(t string) bool { // contentSubtype returns the content-subtype for the given content-type. The
e := "application/grpc" // given content-type must be a valid content-type that starts with
if !strings.HasPrefix(t, e) { // "application/grpc". A content-subtype will follow "application/grpc" after a
return false // "+" or ";". See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// If contentType is not a valid content-type for gRPC, the boolean
// will be false, otherwise true. If content-type == "application/grpc",
// "application/grpc+", or "application/grpc;", the boolean will be true,
// but no content-subtype will be returned.
//
// contentType is assumed to be lowercase already.
func contentSubtype(contentType string) (string, bool) {
if contentType == baseContentType {
return "", true
}
if !strings.HasPrefix(contentType, baseContentType) {
return "", false
}
// guaranteed since != baseContentType and has baseContentType prefix
switch contentType[len(baseContentType)] {
case '+', ';':
// this will return true for "application/grpc+" or "application/grpc;"
// which the previous validContentType function tested to be valid, so we
// just say that no content-subtype is specified in this case
return contentType[len(baseContentType)+1:], true
default:
return "", false
} }
// Support variations on the content-type }
// (e.g. "application/grpc+blah", "application/grpc;blah").
if len(t) > len(e) && t[len(e)] != '+' && t[len(e)] != ';' { // contentSubtype is assumed to be lowercase
return false func contentType(contentSubtype string) string {
if contentSubtype == "" {
return baseContentType
} }
return true return baseContentType + "+" + contentSubtype
} }
func (d *decodeState) status() *status.Status { func (d *decodeState) status() *status.Status {
...@@ -228,9 +263,9 @@ func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error ...@@ -228,9 +263,9 @@ func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error
// gRPC status doesn't exist and http status is OK. // gRPC status doesn't exist and http status is OK.
// Set rawStatusCode to be unknown and return nil error. // Set rawStatusCode to be unknown and return nil error.
// So that, if the stream has ended this Unknown status // So that, if the stream has ended this Unknown status
// will be propogated to the user. // will be propagated to the user.
// Otherwise, it will be ignored. In which case, status from // Otherwise, it will be ignored. In which case, status from
// a later trailer, that has StreamEnded flag set, is propogated. // a later trailer, that has StreamEnded flag set, is propagated.
code := int(codes.Unknown) code := int(codes.Unknown)
d.rawStatusCode = &code d.rawStatusCode = &code
return nil return nil
...@@ -247,9 +282,16 @@ func (d *decodeState) addMetadata(k, v string) { ...@@ -247,9 +282,16 @@ func (d *decodeState) addMetadata(k, v string) {
func (d *decodeState) processHeaderField(f hpack.HeaderField) error { func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
switch f.Name { switch f.Name {
case "content-type": case "content-type":
if !validContentType(f.Value) { contentSubtype, validContentType := contentSubtype(f.Value)
return streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value) if !validContentType {
return streamErrorf(codes.Internal, "transport: received the unexpected content-type %q", f.Value)
} }
d.contentSubtype = contentSubtype
// TODO: do we want to propagate the whole content-type in the metadata,
// or come up with a way to just propagate the content-subtype if it was set?
// ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"}
// in the metadata?
d.addMetadata(f.Name, f.Value)
case "grpc-encoding": case "grpc-encoding":
d.encoding = f.Value d.encoding = f.Value
case "grpc-status": case "grpc-status":
...@@ -299,7 +341,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error { ...@@ -299,7 +341,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
d.statsTrace = v d.statsTrace = v
d.addMetadata(f.Name, string(v)) d.addMetadata(f.Name, string(v))
default: default:
if isReservedHeader(f.Name) && !isWhitelistedPseudoHeader(f.Name) { if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) {
break break
} }
v, err := decodeMetadataHeader(f.Name, f.Value) v, err := decodeMetadataHeader(f.Name, f.Value)
...@@ -307,7 +349,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error { ...@@ -307,7 +349,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err) errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err)
return nil return nil
} }
d.addMetadata(f.Name, string(v)) d.addMetadata(f.Name, v)
} }
return nil return nil
} }
...@@ -396,16 +438,17 @@ func decodeTimeout(s string) (time.Duration, error) { ...@@ -396,16 +438,17 @@ func decodeTimeout(s string) (time.Duration, error) {
const ( const (
spaceByte = ' ' spaceByte = ' '
tildaByte = '~' tildeByte = '~'
percentByte = '%' percentByte = '%'
) )
// encodeGrpcMessage is used to encode status code in header field // encodeGrpcMessage is used to encode status code in header field
// "grpc-message". // "grpc-message". It does percent encoding and also replaces invalid utf-8
// It checks to see if each individual byte in msg is an // characters with Unicode replacement character.
// allowable byte, and then either percent encoding or passing it through. //
// When percent encoding, the byte is converted into hexadecimal notation // It checks to see if each individual byte in msg is an allowable byte, and
// with a '%' prepended. // then either percent encoding or passing it through. When percent encoding,
// the byte is converted into hexadecimal notation with a '%' prepended.
func encodeGrpcMessage(msg string) string { func encodeGrpcMessage(msg string) string {
if msg == "" { if msg == "" {
return "" return ""
...@@ -413,7 +456,7 @@ func encodeGrpcMessage(msg string) string { ...@@ -413,7 +456,7 @@ func encodeGrpcMessage(msg string) string {
lenMsg := len(msg) lenMsg := len(msg)
for i := 0; i < lenMsg; i++ { for i := 0; i < lenMsg; i++ {
c := msg[i] c := msg[i]
if !(c >= spaceByte && c < tildaByte && c != percentByte) { if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg) return encodeGrpcMessageUnchecked(msg)
} }
} }
...@@ -422,14 +465,26 @@ func encodeGrpcMessage(msg string) string { ...@@ -422,14 +465,26 @@ func encodeGrpcMessage(msg string) string {
func encodeGrpcMessageUnchecked(msg string) string { func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer var buf bytes.Buffer
lenMsg := len(msg) for len(msg) > 0 {
for i := 0; i < lenMsg; i++ { r, size := utf8.DecodeRuneInString(msg)
c := msg[i] for _, b := range []byte(string(r)) {
if c >= spaceByte && c < tildaByte && c != percentByte { if size > 1 {
buf.WriteByte(c) // If size > 1, r is not ascii. Always do percent encoding.
} else { buf.WriteString(fmt.Sprintf("%%%02X", b))
buf.WriteString(fmt.Sprintf("%%%02X", c)) continue
}
// The for loop is necessary even if size == 1. r could be
// utf8.RuneError.
//
// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
if b >= spaceByte && b <= tildeByte && b != percentByte {
buf.WriteByte(b)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", b))
}
} }
msg = msg[size:]
} }
return buf.String() return buf.String()
} }
...@@ -468,19 +523,67 @@ func decodeGrpcMessageUnchecked(msg string) string { ...@@ -468,19 +523,67 @@ func decodeGrpcMessageUnchecked(msg string) string {
return buf.String() return buf.String()
} }
type bufWriter struct {
buf []byte
offset int
batchSize int
conn net.Conn
err error
onFlush func()
}
func newBufWriter(conn net.Conn, batchSize int) *bufWriter {
return &bufWriter{
buf: make([]byte, batchSize*2),
batchSize: batchSize,
conn: conn,
}
}
func (w *bufWriter) Write(b []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
for len(b) > 0 {
nn := copy(w.buf[w.offset:], b)
b = b[nn:]
w.offset += nn
n += nn
if w.offset >= w.batchSize {
err = w.Flush()
}
}
return n, err
}
func (w *bufWriter) Flush() error {
if w.err != nil {
return w.err
}
if w.offset == 0 {
return nil
}
if w.onFlush != nil {
w.onFlush()
}
_, w.err = w.conn.Write(w.buf[:w.offset])
w.offset = 0
return w.err
}
type framer struct { type framer struct {
numWriters int32 writer *bufWriter
reader io.Reader fr *http2.Framer
writer *bufio.Writer
fr *http2.Framer
} }
func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer { func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer {
r := bufio.NewReaderSize(conn, readBufferSize)
w := newBufWriter(conn, writeBufferSize)
f := &framer{ f := &framer{
reader: bufio.NewReaderSize(conn, readBufferSize), writer: w,
writer: bufio.NewWriterSize(conn, writeBufferSize), fr: http2.NewFramer(w, r),
} }
f.fr = http2.NewFramer(f.writer, f.reader)
// Opt-in to Frame reuse API on framer to reduce garbage. // Opt-in to Frame reuse API on framer to reduce garbage.
// Frames aren't safe to read from after a subsequent call to ReadFrame. // Frames aren't safe to read from after a subsequent call to ReadFrame.
f.fr.SetReuseFrames() f.fr.SetReuseFrames()
......
...@@ -19,16 +19,17 @@ ...@@ -19,16 +19,17 @@
// Package transport defines and implements message oriented communication // Package transport defines and implements message oriented communication
// channel to complete various transactions (e.g., an RPC). It is meant for // channel to complete various transactions (e.g., an RPC). It is meant for
// grpc-internal usage and is not intended to be imported directly by users. // grpc-internal usage and is not intended to be imported directly by users.
package transport // import "google.golang.org/grpc/transport" package transport // externally used as import "google.golang.org/grpc/transport"
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"sync" "sync"
"sync/atomic"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/http2"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
...@@ -57,6 +58,7 @@ type recvBuffer struct { ...@@ -57,6 +58,7 @@ type recvBuffer struct {
c chan recvMsg c chan recvMsg
mu sync.Mutex mu sync.Mutex
backlog []recvMsg backlog []recvMsg
err error
} }
func newRecvBuffer() *recvBuffer { func newRecvBuffer() *recvBuffer {
...@@ -68,6 +70,13 @@ func newRecvBuffer() *recvBuffer { ...@@ -68,6 +70,13 @@ func newRecvBuffer() *recvBuffer {
func (b *recvBuffer) put(r recvMsg) { func (b *recvBuffer) put(r recvMsg) {
b.mu.Lock() b.mu.Lock()
if b.err != nil {
b.mu.Unlock()
// An error had occurred earlier, don't accept more
// data or errors.
return
}
b.err = r.err
if len(b.backlog) == 0 { if len(b.backlog) == 0 {
select { select {
case b.c <- r: case b.c <- r:
...@@ -101,14 +110,15 @@ func (b *recvBuffer) get() <-chan recvMsg { ...@@ -101,14 +110,15 @@ func (b *recvBuffer) get() <-chan recvMsg {
return b.c return b.c
} }
//
// recvBufferReader implements io.Reader interface to read the data from // recvBufferReader implements io.Reader interface to read the data from
// recvBuffer. // recvBuffer.
type recvBufferReader struct { type recvBufferReader struct {
ctx context.Context ctx context.Context
goAway chan struct{} ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer recv *recvBuffer
last []byte // Stores the remaining data in the previous calls. last []byte // Stores the remaining data in the previous calls.
err error err error
} }
// Read reads the next len(p) bytes from last. If last is drained, it tries to // Read reads the next len(p) bytes from last. If last is drained, it tries to
...@@ -130,10 +140,8 @@ func (r *recvBufferReader) read(p []byte) (n int, err error) { ...@@ -130,10 +140,8 @@ func (r *recvBufferReader) read(p []byte) (n int, err error) {
return copied, nil return copied, nil
} }
select { select {
case <-r.ctx.Done(): case <-r.ctxDone:
return 0, ContextErr(r.ctx.Err()) return 0, ContextErr(r.ctx.Err())
case <-r.goAway:
return 0, errStreamDrain
case m := <-r.recv.get(): case m := <-r.recv.get():
r.recv.load() r.recv.load()
if m.err != nil { if m.err != nil {
...@@ -145,61 +153,7 @@ func (r *recvBufferReader) read(p []byte) (n int, err error) { ...@@ -145,61 +153,7 @@ func (r *recvBufferReader) read(p []byte) (n int, err error) {
} }
} }
// All items in an out of a controlBuffer should be the same type. type streamState uint32
type item interface {
item()
}
// controlBuffer is an unbounded channel of item.
type controlBuffer struct {
c chan item
mu sync.Mutex
backlog []item
}
func newControlBuffer() *controlBuffer {
b := &controlBuffer{
c: make(chan item, 1),
}
return b
}
func (b *controlBuffer) put(r item) {
b.mu.Lock()
if len(b.backlog) == 0 {
select {
case b.c <- r:
b.mu.Unlock()
return
default:
}
}
b.backlog = append(b.backlog, r)
b.mu.Unlock()
}
func (b *controlBuffer) load() {
b.mu.Lock()
if len(b.backlog) > 0 {
select {
case b.c <- b.backlog[0]:
b.backlog[0] = nil
b.backlog = b.backlog[1:]
default:
}
}
b.mu.Unlock()
}
// get returns the channel that receives an item in the buffer.
//
// Upon receipt of an item, the caller should call load to send another
// item onto the channel if there is any.
func (b *controlBuffer) get() <-chan item {
return b.c
}
type streamState uint8
const ( const (
streamActive streamState = iota streamActive streamState = iota
...@@ -214,8 +168,8 @@ type Stream struct { ...@@ -214,8 +168,8 @@ type Stream struct {
st ServerTransport // nil for client side Stream st ServerTransport // nil for client side Stream
ctx context.Context // the associated context of the stream ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed when the final status arrives done chan struct{} // closed at the end of stream to unblock writers. On the client side.
goAway chan struct{} // closed when a GOAWAY control message is received ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance)
method string // the associated RPC method of the stream method string // the associated RPC method of the stream
recvCompress string recvCompress string
sendCompress string sendCompress string
...@@ -223,29 +177,58 @@ type Stream struct { ...@@ -223,29 +177,58 @@ type Stream struct {
trReader io.Reader trReader io.Reader
fc *inFlow fc *inFlow
recvQuota uint32 recvQuota uint32
waiters waiters wq *writeQuota
// Callback to state application's intentions to read data. This // Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed. // is used to adjust flow control, if needed.
requestRead func(int) requestRead func(int)
sendQuotaPool *quotaPool headerChan chan struct{} // closed to indicate the end of header metadata.
headerChan chan struct{} // closed to indicate the end of header metadata. headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
headerDone bool // set when headerChan is closed. Used to avoid closing headerChan multiple times.
header metadata.MD // the received header metadata. // hdrMu protects header and trailer metadata on the server-side.
trailer metadata.MD // the key-value map of trailer metadata. hdrMu sync.Mutex
header metadata.MD // the received header metadata.
trailer metadata.MD // the key-value map of trailer metadata.
// On the server-side, headerSent is atomically set to 1 when the headers are sent out.
headerSent uint32
mu sync.RWMutex // guard the following state streamState
headerOk bool // becomes true from the first header is about to send
state streamState
status *status.Status // the status error received from the server // On client-side it is the status error received from the server.
// On server-side it is unused.
status *status.Status
rstStream bool // indicates whether a RST_STREAM frame needs to be sent bytesReceived uint32 // indicates whether any bytes have been received on this stream
rstError http2.ErrCode // the error that needs to be sent along with the RST_STREAM frame unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
contentSubtype string
}
bytesReceived bool // indicates whether any bytes have been received on this stream // isHeaderSent is only valid on the server-side.
unprocessed bool // set if the server sends a refused stream or GOAWAY including this stream func (s *Stream) isHeaderSent() bool {
return atomic.LoadUint32(&s.headerSent) == 1
}
// updateHeaderSent updates headerSent and returns true
// if it was alreay set. It is valid only on server-side.
func (s *Stream) updateHeaderSent() bool {
return atomic.SwapUint32(&s.headerSent, 1) == 1
}
func (s *Stream) swapState(st streamState) streamState {
return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st)))
}
func (s *Stream) compareAndSwapState(oldState, newState streamState) bool {
return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState))
}
func (s *Stream) getState() streamState {
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
} }
func (s *Stream) waitOnHeader() error { func (s *Stream) waitOnHeader() error {
...@@ -254,12 +237,9 @@ func (s *Stream) waitOnHeader() error { ...@@ -254,12 +237,9 @@ func (s *Stream) waitOnHeader() error {
// only after having received headers. // only after having received headers.
return nil return nil
} }
wc := s.waiters
select { select {
case <-wc.ctx.Done(): case <-s.ctx.Done():
return ContextErr(wc.ctx.Err()) return ContextErr(s.ctx.Err())
case <-wc.goAway:
return errStreamDrain
case <-s.headerChan: case <-s.headerChan:
return nil return nil
} }
...@@ -285,12 +265,6 @@ func (s *Stream) Done() <-chan struct{} { ...@@ -285,12 +265,6 @@ func (s *Stream) Done() <-chan struct{} {
return s.done return s.done
} }
// GoAway returns a channel which is closed when the server sent GoAways signal
// before this stream was initiated.
func (s *Stream) GoAway() <-chan struct{} {
return s.goAway
}
// Header acquires the key-value pairs of header metadata once it // Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no // is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is canceled/expired. // header metadata or iii) the stream is canceled/expired.
...@@ -299,6 +273,9 @@ func (s *Stream) Header() (metadata.MD, error) { ...@@ -299,6 +273,9 @@ func (s *Stream) Header() (metadata.MD, error) {
// Even if the stream is closed, header is returned if available. // Even if the stream is closed, header is returned if available.
select { select {
case <-s.headerChan: case <-s.headerChan:
if s.header == nil {
return nil, nil
}
return s.header.Copy(), nil return s.header.Copy(), nil
default: default:
} }
...@@ -308,10 +285,10 @@ func (s *Stream) Header() (metadata.MD, error) { ...@@ -308,10 +285,10 @@ func (s *Stream) Header() (metadata.MD, error) {
// Trailer returns the cached trailer metedata. Note that if it is not called // Trailer returns the cached trailer metedata. Note that if it is not called
// after the entire stream is done, it could return an empty MD. Client // after the entire stream is done, it could return an empty MD. Client
// side only. // side only.
// It can be safely read only after stream has ended that is either read
// or write have returned io.EOF.
func (s *Stream) Trailer() metadata.MD { func (s *Stream) Trailer() metadata.MD {
s.mu.RLock()
c := s.trailer.Copy() c := s.trailer.Copy()
s.mu.RUnlock()
return c return c
} }
...@@ -321,6 +298,15 @@ func (s *Stream) ServerTransport() ServerTransport { ...@@ -321,6 +298,15 @@ func (s *Stream) ServerTransport() ServerTransport {
return s.st return s.st
} }
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *Stream) ContentSubtype() string {
return s.contentSubtype
}
// Context returns the context of the stream. // Context returns the context of the stream.
func (s *Stream) Context() context.Context { func (s *Stream) Context() context.Context {
return s.ctx return s.ctx
...@@ -332,36 +318,49 @@ func (s *Stream) Method() string { ...@@ -332,36 +318,49 @@ func (s *Stream) Method() string {
} }
// Status returns the status received from the server. // Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, read or write has returned io.EOF.
func (s *Stream) Status() *status.Status { func (s *Stream) Status() *status.Status {
return s.status return s.status
} }
// SetHeader sets the header metadata. This can be called multiple times. // SetHeader sets the header metadata. This can be called multiple times.
// Server side only. // Server side only.
// This should not be called in parallel to other data writes.
func (s *Stream) SetHeader(md metadata.MD) error { func (s *Stream) SetHeader(md metadata.MD) error {
s.mu.Lock()
if s.headerOk || s.state == streamDone {
s.mu.Unlock()
return ErrIllegalHeaderWrite
}
if md.Len() == 0 { if md.Len() == 0 {
s.mu.Unlock()
return nil return nil
} }
if s.isHeaderSent() || s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.header = metadata.Join(s.header, md) s.header = metadata.Join(s.header, md)
s.mu.Unlock() s.hdrMu.Unlock()
return nil return nil
} }
// SendHeader sends the given header metadata. The given metadata is
// combined with any metadata set by previous calls to SetHeader and
// then written to the transport stream.
func (s *Stream) SendHeader(md metadata.MD) error {
t := s.ServerTransport()
return t.WriteHeader(s, md)
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status // SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times. Server side only. // by the server. This can be called multiple times. Server side only.
// This should not be called parallel to other data writes.
func (s *Stream) SetTrailer(md metadata.MD) error { func (s *Stream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 { if md.Len() == 0 {
return nil return nil
} }
s.mu.Lock() if s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.trailer = metadata.Join(s.trailer, md) s.trailer = metadata.Join(s.trailer, md)
s.mu.Unlock() s.hdrMu.Unlock()
return nil return nil
} }
...@@ -401,29 +400,15 @@ func (t *transportReader) Read(p []byte) (n int, err error) { ...@@ -401,29 +400,15 @@ func (t *transportReader) Read(p []byte) (n int, err error) {
return return
} }
// finish sets the stream's state and status, and closes the done channel.
// s.mu must be held by the caller. st must always be non-nil.
func (s *Stream) finish(st *status.Status) {
s.status = st
s.state = streamDone
close(s.done)
}
// BytesReceived indicates whether any bytes have been received on this stream. // BytesReceived indicates whether any bytes have been received on this stream.
func (s *Stream) BytesReceived() bool { func (s *Stream) BytesReceived() bool {
s.mu.Lock() return atomic.LoadUint32(&s.bytesReceived) == 1
br := s.bytesReceived
s.mu.Unlock()
return br
} }
// Unprocessed indicates whether the server did not process this stream -- // Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID. // i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *Stream) Unprocessed() bool { func (s *Stream) Unprocessed() bool {
s.mu.Lock() return atomic.LoadUint32(&s.unprocessed) == 1
br := s.unprocessed
s.mu.Unlock()
return br
} }
// GoString is implemented by Stream so context.String() won't // GoString is implemented by Stream so context.String() won't
...@@ -432,21 +417,6 @@ func (s *Stream) GoString() string { ...@@ -432,21 +417,6 @@ func (s *Stream) GoString() string {
return fmt.Sprintf("<stream: %p, %v>", s, s.method) return fmt.Sprintf("<stream: %p, %v>", s, s.method)
} }
// The key to save transport.Stream in the context.
type streamKey struct{}
// newContextWithStream creates a new context from ctx and attaches stream
// to it.
func newContextWithStream(ctx context.Context, stream *Stream) context.Context {
return context.WithValue(ctx, streamKey{}, stream)
}
// StreamFromContext returns the stream saved in ctx.
func StreamFromContext(ctx context.Context) (s *Stream, ok bool) {
s, ok = ctx.Value(streamKey{}).(*Stream)
return
}
// state of transport // state of transport
type transportState int type transportState int
...@@ -468,6 +438,7 @@ type ServerConfig struct { ...@@ -468,6 +438,7 @@ type ServerConfig struct {
InitialConnWindowSize int32 InitialConnWindowSize int32
WriteBufferSize int WriteBufferSize int
ReadBufferSize int ReadBufferSize int
ChannelzParentID int64
} }
// NewServerTransport creates a ServerTransport with conn or non-nil error // NewServerTransport creates a ServerTransport with conn or non-nil error
...@@ -503,6 +474,8 @@ type ConnectOptions struct { ...@@ -503,6 +474,8 @@ type ConnectOptions struct {
WriteBufferSize int WriteBufferSize int
// ReadBufferSize sets the size of read buffer, which in turn determines how much data can be read at most for one read syscall. // ReadBufferSize sets the size of read buffer, which in turn determines how much data can be read at most for one read syscall.
ReadBufferSize int ReadBufferSize int
// ChannelzParentID sets the addrConn id which initiate the creation of this client transport.
ChannelzParentID int64
} }
// TargetInfo contains the information of the target such as network address and metadata. // TargetInfo contains the information of the target such as network address and metadata.
...@@ -553,6 +526,14 @@ type CallHdr struct { ...@@ -553,6 +526,14 @@ type CallHdr struct {
// for performance purposes. // for performance purposes.
// If it's false, new stream will never be flushed. // If it's false, new stream will never be flushed.
Flush bool Flush bool
// ContentSubtype specifies the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". The value of ContentSubtype must be all
// lowercase, otherwise the behavior is undefined. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
// for more details.
ContentSubtype string
} }
// ClientTransport is the common interface for all gRPC client-side transport // ClientTransport is the common interface for all gRPC client-side transport
...@@ -594,6 +575,12 @@ type ClientTransport interface { ...@@ -594,6 +575,12 @@ type ClientTransport interface {
// GetGoAwayReason returns the reason why GoAway frame was received. // GetGoAwayReason returns the reason why GoAway frame was received.
GetGoAwayReason() GoAwayReason GetGoAwayReason() GoAwayReason
// IncrMsgSent increments the number of message sent through this transport.
IncrMsgSent()
// IncrMsgRecv increments the number of message received through this transport.
IncrMsgRecv()
} }
// ServerTransport is the common interface for all gRPC server-side transport // ServerTransport is the common interface for all gRPC server-side transport
...@@ -627,6 +614,12 @@ type ServerTransport interface { ...@@ -627,6 +614,12 @@ type ServerTransport interface {
// Drain notifies the client this ServerTransport stops accepting new RPCs. // Drain notifies the client this ServerTransport stops accepting new RPCs.
Drain() Drain()
// IncrMsgSent increments the number of message sent through this transport.
IncrMsgSent()
// IncrMsgRecv increments the number of message received through this transport.
IncrMsgRecv()
} }
// streamErrorf creates an StreamError with the specified error code and description. // streamErrorf creates an StreamError with the specified error code and description.
...@@ -676,13 +669,16 @@ func (e ConnectionError) Origin() error { ...@@ -676,13 +669,16 @@ func (e ConnectionError) Origin() error {
var ( var (
// ErrConnClosing indicates that the transport is closing. // ErrConnClosing indicates that the transport is closing.
ErrConnClosing = connectionErrorf(true, nil, "transport is closing") ErrConnClosing = connectionErrorf(true, nil, "transport is closing")
// errStreamDrain indicates that the stream is rejected by the server because // errStreamDrain indicates that the stream is rejected because the
// the server stops accepting new RPCs. // connection is draining. This could be caused by goaway or balancer
// TODO: delete this error; it is no longer necessary. // removing the address.
errStreamDrain = streamErrorf(codes.Unavailable, "the server stops accepting new RPCs") errStreamDrain = streamErrorf(codes.Unavailable, "the connection is draining")
// errStreamDone is returned from write at the client side to indiacte application
// layer of an error.
errStreamDone = errors.New("the stream is done")
// StatusGoAway indicates that the server sent a GOAWAY that included this // StatusGoAway indicates that the server sent a GOAWAY that included this
// stream's ID in unprocessed RPCs. // stream's ID in unprocessed RPCs.
statusGoAway = status.New(codes.Unavailable, "the server stopped accepting new RPCs") statusGoAway = status.New(codes.Unavailable, "the stream is rejected because server is draining the connection")
) )
// TODO: See if we can replace StreamError with status package errors. // TODO: See if we can replace StreamError with status package errors.
...@@ -697,15 +693,6 @@ func (e StreamError) Error() string { ...@@ -697,15 +693,6 @@ func (e StreamError) Error() string {
return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc) return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc)
} }
// waiters are passed to quotaPool get methods to
// wait on in addition to waiting on quota.
type waiters struct {
ctx context.Context
tctx context.Context
done chan struct{}
goAway chan struct{}
}
// GoAwayReason contains the reason for the GoAway frame received. // GoAwayReason contains the reason for the GoAway frame received.
type GoAwayReason uint8 type GoAwayReason uint8
...@@ -719,39 +706,3 @@ const ( ...@@ -719,39 +706,3 @@ const (
// "too_many_pings". // "too_many_pings".
GoAwayTooManyPings GoAwayReason = 2 GoAwayTooManyPings GoAwayReason = 2
) )
// loopyWriter is run in a separate go routine. It is the single code path that will
// write data on wire.
func loopyWriter(ctx context.Context, cbuf *controlBuffer, handler func(item) error) {
for {
select {
case i := <-cbuf.get():
cbuf.load()
if err := handler(i); err != nil {
errorf("transport: Error while handling item. Err: %v", err)
return
}
case <-ctx.Done():
return
}
hasData:
for {
select {
case i := <-cbuf.get():
cbuf.load()
if err := handler(i); err != nil {
errorf("transport: Error while handling item. Err: %v", err)
return
}
case <-ctx.Done():
return
default:
if err := handler(&flushIO{}); err != nil {
errorf("transport: Error while flushing. Err: %v", err)
return
}
break hasData
}
}
}
}
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
// Version is the current grpc version.
const Version = "1.13.0"
#!/bin/bash #!/bin/bash
if [[ `uname -a` = *"Darwin"* ]]; then
echo "It seems you are running on Mac. This script does not work on Mac. See https://github.com/grpc/grpc-go/issues/2047"
exit 1
fi
set -ex # Exit on error; debugging enabled. set -ex # Exit on error; debugging enabled.
set -o pipefail # Fail a pipe if any sub-command fails. set -o pipefail # Fail a pipe if any sub-command fails.
...@@ -49,6 +54,8 @@ if git status --porcelain | read; then ...@@ -49,6 +54,8 @@ if git status --porcelain | read; then
fi fi
git ls-files "*.go" | xargs grep -L "\(Copyright [0-9]\{4,\} gRPC authors\)\|DO NOT EDIT" 2>&1 | tee /dev/stderr | (! read) git ls-files "*.go" | xargs grep -L "\(Copyright [0-9]\{4,\} gRPC authors\)\|DO NOT EDIT" 2>&1 | tee /dev/stderr | (! read)
git ls-files "*.go" | xargs grep -l '"unsafe"' 2>&1 | (! grep -v '_test.go') | tee /dev/stderr | (! read)
git ls-files "*.go" | xargs grep -l '"math/rand"' 2>&1 | (! grep -v '^examples\|^stress\|grpcrand') | tee /dev/stderr | (! read)
gofmt -s -d -l . 2>&1 | tee /dev/stderr | (! read) gofmt -s -d -l . 2>&1 | tee /dev/stderr | (! read)
goimports -l . 2>&1 | tee /dev/stderr | (! read) goimports -l . 2>&1 | tee /dev/stderr | (! read)
golint ./... 2>&1 | (grep -vE "(_mock|\.pb)\.go:" || true) | tee /dev/stderr | (! read) golint ./... 2>&1 | (grep -vE "(_mock|\.pb)\.go:" || true) | tee /dev/stderr | (! read)
...@@ -80,5 +87,8 @@ google.golang.org/grpc/transport/transport_test.go:SA2002 ...@@ -80,5 +87,8 @@ google.golang.org/grpc/transport/transport_test.go:SA2002
google.golang.org/grpc/benchmark/benchmain/main.go:SA1019 google.golang.org/grpc/benchmark/benchmain/main.go:SA1019
google.golang.org/grpc/stats/stats_test.go:SA1019 google.golang.org/grpc/stats/stats_test.go:SA1019
google.golang.org/grpc/test/end2end_test.go:SA1019 google.golang.org/grpc/test/end2end_test.go:SA1019
google.golang.org/grpc/balancer_test.go:SA1019
google.golang.org/grpc/balancer.go:SA1019
google.golang.org/grpc/clientconn_test.go:SA1019
' ./... ' ./...
misspell -error . misspell -error .
...@@ -91,10 +91,12 @@ ...@@ -91,10 +91,12 @@
"revision": "e8f0f8aaa98dfb6586cbdf2978d511e3199a960a" "revision": "e8f0f8aaa98dfb6586cbdf2978d511e3199a960a"
}, },
{ {
"checksumSHA1": "s7J8PdKJFnoBFg6MowXLTvix7ug=", "checksumSHA1": "3iVD2sJv4uYnA8YgkR8yzZiUF7o=",
"path": "github.com/grpc-ecosystem/go-grpc-prometheus", "path": "github.com/grpc-ecosystem/go-grpc-prometheus",
"revision": "0dafe0d496ea71181bf2dd039e7e3f44b6bd11a7", "revision": "c225b8c3b01faf2899099b768856a9e916e5087b",
"revisionTime": "2017-08-26T09:06:48Z" "revisionTime": "2018-06-04T12:28:56Z",
"version": "v1.2.0",
"versionExact": "v1.2.0"
}, },
{ {
"checksumSHA1": "mb0MqzDyYEQMgh8+qwVm1RV4cxc=", "checksumSHA1": "mb0MqzDyYEQMgh8+qwVm1RV4cxc=",
...@@ -195,10 +197,12 @@ ...@@ -195,10 +197,12 @@
"revisionTime": "2018-07-25T12:39:19Z" "revisionTime": "2018-07-25T12:39:19Z"
}, },
{ {
"checksumSHA1": "UoEPi3qWhaKl6FW5AnTYDvlAIBg=", "checksumSHA1": "97d1x3kkzlOLFrtQ/NXL8j5sqz8=",
"path": "github.com/rafaeljusto/redigomock", "path": "github.com/rafaeljusto/redigomock",
"revision": "46f70867da7b79c74c21ef022c4a47f138af3d27", "revision": "7ae0511314e9946bb0c87d6d485169ab2467a290",
"revisionTime": "2017-01-16T09:20:13Z" "revisionTime": "2017-07-20T13:15:24Z",
"version": "v2.1",
"versionExact": "v2.1"
}, },
{ {
"checksumSHA1": "bQ+Wb430AXpen54AYtrR1Igfh18=", "checksumSHA1": "bQ+Wb430AXpen54AYtrR1Igfh18=",
...@@ -207,12 +211,12 @@ ...@@ -207,12 +211,12 @@
"revisionTime": "2016-09-10T04:38:05Z" "revisionTime": "2016-09-10T04:38:05Z"
}, },
{ {
"checksumSHA1": "ySaT8G3I3y4MmnoXOYAAX0rC+p8=", "checksumSHA1": "vRcu8DLpEnhOuaZ/M8iGl2CRG8Y=",
"path": "github.com/sirupsen/logrus", "path": "github.com/sirupsen/logrus",
"revision": "d682213848ed68c0a260ca37d6dd5ace8423f5ba", "revision": "3e01752db0189b9157070a0e1668a620f9a85da2",
"revisionTime": "2017-12-05T20:32:29Z", "revisionTime": "2018-07-21T07:00:01Z",
"version": "=v1.0.4", "version": "v1.0.6",
"versionExact": "v1.0.4" "versionExact": "v1.0.6"
}, },
{ {
"checksumSHA1": "hIEmcd7hIDqO/xWSp1rJJHd0TpE=", "checksumSHA1": "hIEmcd7hIDqO/xWSp1rJJHd0TpE=",
...@@ -269,12 +273,6 @@ ...@@ -269,12 +273,6 @@
"path": "golang.org/x/net/context", "path": "golang.org/x/net/context",
"revision": "f2499483f923065a842d38eb4c7f1927e6fc6e6d" "revision": "f2499483f923065a842d38eb4c7f1927e6fc6e6d"
}, },
{
"checksumSHA1": "WHc3uByvGaMcnSoI21fhzYgbOgg=",
"path": "golang.org/x/net/context/ctxhttp",
"revision": "a6577fac2d73be281a500b310739095313165611",
"revisionTime": "2017-03-08T20:54:49Z"
},
{ {
"checksumSHA1": "SHTyxlWxNjRwA7o3AiBM87PawSA=", "checksumSHA1": "SHTyxlWxNjRwA7o3AiBM87PawSA=",
"path": "golang.org/x/net/http2", "path": "golang.org/x/net/http2",
...@@ -325,180 +323,204 @@ ...@@ -325,180 +323,204 @@
"revisionTime": "2017-10-02T23:26:14Z" "revisionTime": "2017-10-02T23:26:14Z"
}, },
{ {
"checksumSHA1": "LXTQppZOmpZb8/zNBzfXmq3GDEg=", "checksumSHA1": "S+GGH5m4Njqwmndux5BlOjTmx8I=",
"path": "google.golang.org/grpc", "path": "google.golang.org/grpc",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "xBhmO0Vn4kzbmySioX+2gBImrkk=", "checksumSHA1": "xX1+b0/gjwxrjocYH5W/LyQPjs4=",
"path": "google.golang.org/grpc/balancer", "path": "google.golang.org/grpc/balancer",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "CPWX/IgaQSR3+78j4sPrvHNkW+U=", "checksumSHA1": "lw+L836hLeH8+//le+C+ycddCCU=",
"path": "google.golang.org/grpc/balancer/base", "path": "google.golang.org/grpc/balancer/base",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "DJ1AtOk4Pu7bqtUMob95Hw8HPNw=", "checksumSHA1": "DJ1AtOk4Pu7bqtUMob95Hw8HPNw=",
"path": "google.golang.org/grpc/balancer/roundrobin", "path": "google.golang.org/grpc/balancer/roundrobin",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "bfmh2m3qW8bb6qpfS/D4Wcl4hZE=", "checksumSHA1": "R3tuACGAPyK4lr+oSNt1saUzC0M=",
"path": "google.golang.org/grpc/codes", "path": "google.golang.org/grpc/codes",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "XH2WYcDNwVO47zYShREJjcYXm0Y=", "checksumSHA1": "XH2WYcDNwVO47zYShREJjcYXm0Y=",
"path": "google.golang.org/grpc/connectivity", "path": "google.golang.org/grpc/connectivity",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "4DnDX81AOSyVP3UJ5tQmlNcG1MI=", "checksumSHA1": "KthiDKNPHMeIu967enqtE4NaZzI=",
"path": "google.golang.org/grpc/credentials", "path": "google.golang.org/grpc/credentials",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "9DImIDqmAMPO24loHJ77UVJTDxQ=", "checksumSHA1": "cfLb+pzWB+Glwp82rgfcEST1mv8=",
"path": "google.golang.org/grpc/encoding", "path": "google.golang.org/grpc/encoding",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "H7SuPUqbPcdbNqgl+k3ohuwMAwE=", "checksumSHA1": "LKKkn7EYA+Do9Qwb2/SUKLFNxoo=",
"path": "google.golang.org/grpc/grpclb/grpc_lb_v1/messages", "path": "google.golang.org/grpc/encoding/proto",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "ntHev01vgZgeIh5VFRmbLx/BSTo=", "checksumSHA1": "ZPPSFisPDz2ANO4FBZIft+fRxyk=",
"path": "google.golang.org/grpc/grpclog", "path": "google.golang.org/grpc/grpclog",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "Qvf3zdmRCSsiM/VoBv0qB/naHtU=", "checksumSHA1": "cSdzm5GhbalJbWUNrN8pRdW0uks=",
"path": "google.golang.org/grpc/internal", "path": "google.golang.org/grpc/internal",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
},
{
"checksumSHA1": "uDJA7QK2iGnEwbd9TPqkLaM+xuU=",
"path": "google.golang.org/grpc/internal/backoff",
"revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.13.0",
"versionExact": "v1.13.0"
},
{
"checksumSHA1": "DpRAlo/UzTvErgcJ9SUQ+lmTxws=",
"path": "google.golang.org/grpc/internal/channelz",
"revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.13.0",
"versionExact": "v1.13.0"
},
{
"checksumSHA1": "70gndc/uHwyAl3D45zqp7vyHWlo=",
"path": "google.golang.org/grpc/internal/grpcrand",
"revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.13.0",
"versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "hcuHgKp8W0wIzoCnNfKI8NUss5o=", "checksumSHA1": "hcuHgKp8W0wIzoCnNfKI8NUss5o=",
"path": "google.golang.org/grpc/keepalive", "path": "google.golang.org/grpc/keepalive",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "KeUmTZV+2X46C49cKyjp+xM7fvw=", "checksumSHA1": "OjIAi5AzqlQ7kLtdAyjvdgMf6hc=",
"path": "google.golang.org/grpc/metadata", "path": "google.golang.org/grpc/metadata",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "5dwF592DPvhF2Wcex3m7iV6aGRQ=", "checksumSHA1": "VvGBoawND0urmYDy11FT+U1IHtU=",
"path": "google.golang.org/grpc/naming", "path": "google.golang.org/grpc/naming",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "n5EgDdBqFMa2KQFhtl+FF/4gIFo=", "checksumSHA1": "n5EgDdBqFMa2KQFhtl+FF/4gIFo=",
"path": "google.golang.org/grpc/peer", "path": "google.golang.org/grpc/peer",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "y8Ta+ctMP9CUTiPyPyxiD154d8w=", "checksumSHA1": "QOKwFz4Zdfxfjs8czgCCtzM5bk4=",
"path": "google.golang.org/grpc/resolver", "path": "google.golang.org/grpc/resolver",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "WpWF+bDzObsHf+bjoGpb/abeFxo=", "checksumSHA1": "30RAjcyNXLww43ikGOpiy3jg8WY=",
"path": "google.golang.org/grpc/resolver/dns", "path": "google.golang.org/grpc/resolver/dns",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "zs9M4xE8Lyg4wvuYvR00XoBxmuw=", "checksumSHA1": "zs9M4xE8Lyg4wvuYvR00XoBxmuw=",
"path": "google.golang.org/grpc/resolver/passthrough", "path": "google.golang.org/grpc/resolver/passthrough",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "G9lgXNi7qClo5sM2s6TbTHLFR3g=", "checksumSHA1": "YclPgme2gT3S0hTkHVdE1zAxJdo=",
"path": "google.golang.org/grpc/stats", "path": "google.golang.org/grpc/stats",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "tUo+M0Cb0W9ZEIt5BH30wJz/Kjc=", "checksumSHA1": "t/NhHuykWsxY0gEBd2WIv5RVBK8=",
"path": "google.golang.org/grpc/status", "path": "google.golang.org/grpc/status",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "qvArRhlrww5WvRmbyMF2mUfbJew=", "checksumSHA1": "qvArRhlrww5WvRmbyMF2mUfbJew=",
"path": "google.golang.org/grpc/tap", "path": "google.golang.org/grpc/tap",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
}, },
{ {
"checksumSHA1": "4PldZ/0JjX6SpJYaMByY1ozywnY=", "checksumSHA1": "FmV+Y3VY7iRchu5m38iQTPMNAKc=",
"path": "google.golang.org/grpc/transport", "path": "google.golang.org/grpc/transport",
"revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", "revision": "168a6198bcb0ef175f7dacec0b8691fc141dc9b8",
"revisionTime": "2018-01-08T22:01:35Z", "revisionTime": "2018-06-19T22:19:05Z",
"version": "v1.9.1", "version": "v1.13.0",
"versionExact": "v1.9.1" "versionExact": "v1.13.0"
} }
], ],
"rootPath": "gitlab.com/gitlab-org/gitlab-workhorse" "rootPath": "gitlab.com/gitlab-org/gitlab-workhorse"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment