Commit 4babe4b2 authored by Leonard Hecker's avatar Leonard Hecker

proxy: Added support for HTTP trailers

parent 533039e6
...@@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request { ...@@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request {
outreq.URL.Opaque = outreq.URL.RawPath outreq.URL.Opaque = outreq.URL.RawPath
} }
// We are modifying the same underlying map from req (shallow
// copied above) so we only copy it if necessary.
copiedHeaders := false
// Remove hop-by-hop headers listed in the "Connection" header.
// See RFC 2616, section 14.10.
if c := outreq.Header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
if !copiedHeaders {
outreq.Header = make(http.Header)
copyHeader(outreq.Header, r.Header)
copiedHeaders = true
}
outreq.Header.Del(f)
}
}
}
// Remove hop-by-hop headers to the backend. Especially // Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent // important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us. This // connection, regardless of what the client sent to us.
// is modifying the same underlying map from r (shallow
// copied above) so we only copy it if necessary.
var copiedHeaders bool
for _, h := range hopHeaders { for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" { if outreq.Header.Get(h) != "" {
if !copiedHeaders { if !copiedHeaders {
......
...@@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) { ...@@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) {
log.SetOutput(ioutil.Discard) log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr) defer log.SetOutput(os.Stderr)
verifyHeaders := func(headers http.Header, trailers http.Header) {
if headers.Get("X-Header") != "header-value" {
t.Error("Expected header 'X-Header' to be proxied properly")
}
if trailers == nil {
t.Error("Expected to receive trailers")
}
if trailers.Get("X-Trailer") != "trailer-value" {
t.Error("Expected header 'X-Trailer' to be proxied properly")
}
}
var requestReceived bool var requestReceived bool
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// read the body (even if it's empty) to make Go parse trailers
io.Copy(ioutil.Discard, r.Body)
verifyHeaders(r.Header, r.Trailer)
requestReceived = true requestReceived = true
w.Header().Set("Trailer", "X-Trailer")
w.Header().Set("X-Header", "header-value")
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, client")) w.Write([]byte("Hello, client"))
w.Header().Set("X-Trailer", "trailer-value")
})) }))
defer backend.Close() defer backend.Close()
...@@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) { ...@@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
r.ContentLength = -1 // force chunked encoding (required for trailers)
r.Header.Set("X-Header", "header-value")
r.Trailer = map[string][]string{
"X-Trailer": {"trailer-value"},
}
p.ServeHTTP(w, r) p.ServeHTTP(w, r)
if !requestReceived { if !requestReceived {
t.Error("Expected backend to receive request, but it didn't") t.Error("Expected backend to receive request, but it didn't")
} }
res := w.Result()
verifyHeaders(res.Header, res.Trailer)
// Make sure {upstream} placeholder is set // Make sure {upstream} placeholder is set
rr := httpserver.NewResponseRecorder(httptest.NewRecorder()) rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
rr.Replacer = httpserver.NewReplacer(r, rr, "-") rr.Replacer = httpserver.NewReplacer(r, rr, "-")
......
...@@ -211,10 +211,27 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, ...@@ -211,10 +211,27 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
return err return err
} }
isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket"
// Remove hop-by-hop headers listed in the
// "Connection" header of the response.
if c := res.Header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
res.Header.Del(f)
}
}
}
for _, h := range hopHeaders {
res.Header.Del(h)
}
if respUpdateFn != nil { if respUpdateFn != nil {
respUpdateFn(res) respUpdateFn(res)
} }
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
if isWebsocket {
res.Body.Close() res.Body.Close()
hj, ok := rw.(http.Hijacker) hj, ok := rw.(http.Hijacker)
if !ok { if !ok {
...@@ -246,13 +263,30 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, ...@@ -246,13 +263,30 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
go pooledIoCopy(backendConn, conn) // write tcp stream to backend go pooledIoCopy(backendConn, conn) // write tcp stream to backend
pooledIoCopy(conn, backendConn) // read tcp stream from backend pooledIoCopy(conn, backendConn) // read tcp stream from backend
} else { } else {
defer res.Body.Close()
for _, h := range hopHeaders {
res.Header.Del(h)
}
copyHeader(rw.Header(), res.Header) copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
if len(res.Trailer) > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode) rw.WriteHeader(res.StatusCode)
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok {
fl.Flush()
}
}
rp.copyResponse(rw, res.Body) rp.copyResponse(rw, res.Body)
res.Body.Close() // close now, instead of defer, to populate res.Trailer
copyHeader(rw.Header(), res.Trailer)
} }
return nil return nil
...@@ -305,16 +339,17 @@ func copyHeader(dst, src http.Header) { ...@@ -305,16 +339,17 @@ func copyHeader(dst, src http.Header) {
// Hop-by-hop headers. These are removed when sent to the backend. // Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{ var hopHeaders = []string{
"Alt-Svc",
"Alternate-Protocol",
"Connection", "Connection",
"Keep-Alive", "Keep-Alive",
"Proxy-Authenticate", "Proxy-Authenticate",
"Proxy-Authorization", "Proxy-Authorization",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Te", // canonicalized version of "TE" "Te", // canonicalized version of "TE"
"Trailers", "Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding", "Transfer-Encoding",
"Upgrade", "Upgrade",
"Alternate-Protocol",
"Alt-Svc",
} }
type respUpdateFn func(resp *http.Response) type respUpdateFn func(resp *http.Response)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment