Commit 6490ff62 authored by Matthew Holt's avatar Matthew Holt

Adjust proxy headers properly (fixes #916)

parent 57710e8b
...@@ -84,7 +84,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -84,7 +84,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
} }
// this replacer is used to fill in header field values // this replacer is used to fill in header field values
var replacer httpserver.Replacer replacer := httpserver.NewReplacer(r, nil, "")
// outreq is the request that makes a roundtrip to the backend // outreq is the request that makes a roundtrip to the backend
outreq := createUpstreamRequest(r) outreq := createUpstreamRequest(r)
...@@ -119,16 +119,10 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -119,16 +119,10 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// set headers for request going upstream // set headers for request going upstream
if host.UpstreamHeaders != nil { if host.UpstreamHeaders != nil {
if replacer == nil {
replacer = httpserver.NewReplacer(r, nil, "")
}
if v, ok := host.UpstreamHeaders["Host"]; ok {
outreq.Host = replacer.Replace(v[len(v)-1])
}
// modify headers for request that will be sent to the upstream host // modify headers for request that will be sent to the upstream host
upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer) mutateHeadersByRules(outreq.Header, host.UpstreamHeaders, replacer)
for k, v := range upHeaders { if hostHeaders, ok := outreq.Header["Host"]; ok && len(hostHeaders) > 0 {
outreq.Header[k] = v outreq.Host = hostHeaders[len(hostHeaders)-1]
} }
} }
...@@ -136,9 +130,6 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -136,9 +130,6 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// headers coming back downstream // headers coming back downstream
var downHeaderUpdateFn respUpdateFn var downHeaderUpdateFn respUpdateFn
if host.DownstreamHeaders != nil { if host.DownstreamHeaders != nil {
if replacer == nil {
replacer = httpserver.NewReplacer(r, nil, "")
}
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
} }
...@@ -185,6 +176,8 @@ func (p Proxy) match(r *http.Request) Upstream { ...@@ -185,6 +176,8 @@ func (p Proxy) match(r *http.Request) Upstream {
// createUpstremRequest shallow-copies r into a new request // createUpstremRequest shallow-copies r into a new request
// that can be sent upstream. // that can be sent upstream.
//
// Derived from reverseproxy.go in the standard Go httputil package.
func createUpstreamRequest(r *http.Request) *http.Request { func createUpstreamRequest(r *http.Request) *http.Request {
outreq := new(http.Request) outreq := new(http.Request)
*outreq = *r // includes shallow copies of maps, but okay *outreq = *r // includes shallow copies of maps, but okay
...@@ -199,10 +192,14 @@ func createUpstreamRequest(r *http.Request) *http.Request { ...@@ -199,10 +192,14 @@ func createUpstreamRequest(r *http.Request) *http.Request {
// connection, regardless of what the client sent to us. This // connection, regardless of what the client sent to us. This
// is modifying the same underlying map from r (shallow // is modifying the same underlying map from r (shallow
// copied above) so we only copy it if necessary. // copied above) so we only copy it if necessary.
var copiedHeaders bool
for _, h := range hopHeaders { for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" { if outreq.Header.Get(h) != "" {
if !copiedHeaders {
outreq.Header = make(http.Header) outreq.Header = make(http.Header)
copyHeader(outreq.Header, r.Header) copyHeader(outreq.Header, r.Header)
copiedHeaders = true
}
outreq.Header.Del(h) outreq.Header.Del(h)
} }
} }
...@@ -222,45 +219,20 @@ func createUpstreamRequest(r *http.Request) *http.Request { ...@@ -222,45 +219,20 @@ func createUpstreamRequest(r *http.Request) *http.Request {
func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn { func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn {
return func(resp *http.Response) { return func(resp *http.Response) {
newHeaders := createHeadersByRules(rules, resp.Header, replacer) mutateHeadersByRules(resp.Header, rules, replacer)
for h, v := range newHeaders {
resp.Header[h] = v
}
} }
} }
func createHeadersByRules(rules http.Header, base http.Header, repl httpserver.Replacer) http.Header { func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer) {
newHeaders := make(http.Header) for ruleField, ruleValues := range rules {
for header, values := range rules { if strings.HasPrefix(ruleField, "+") {
if strings.HasPrefix(header, "+") { for _, ruleValue := range ruleValues {
header = strings.TrimLeft(header, "+") headers.Add(strings.TrimPrefix(ruleField, "+"), repl.Replace(ruleValue))
add(newHeaders, header, base[header])
applyEach(values, repl.Replace)
add(newHeaders, header, values)
} else if strings.HasPrefix(header, "-") {
base.Del(strings.TrimLeft(header, "-"))
} else if _, ok := base[header]; ok {
applyEach(values, repl.Replace)
for _, v := range values {
newHeaders.Set(header, v)
}
} else {
applyEach(values, repl.Replace)
add(newHeaders, header, values)
add(newHeaders, header, base[header])
}
} }
return newHeaders } else if strings.HasPrefix(ruleField, "-") {
} headers.Del(strings.TrimPrefix(ruleField, "-"))
} else if len(ruleValues) > 0 {
func applyEach(values []string, mapFn func(string) string) { headers.Set(ruleField, repl.Replace(ruleValues[len(ruleValues)-1]))
for i, v := range values {
values[i] = mapFn(v)
} }
}
func add(base http.Header, header string, values []string) {
for _, v := range values {
base.Add(header, v)
} }
} }
...@@ -177,10 +177,11 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r ...@@ -177,10 +177,11 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r
res, err := transport.RoundTrip(outreq) res, err := transport.RoundTrip(outreq)
if err != nil { if err != nil {
return err return err
} else if respUpdateFn != nil {
respUpdateFn(res)
} }
if respUpdateFn != nil {
respUpdateFn(res)
}
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
res.Body.Close() res.Body.Close()
hj, ok := rw.(http.Hijacker) hj, ok := rw.(http.Hijacker)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment