Commit 264e5b79 authored by Nimi Wariboko Jr's avatar Nimi Wariboko Jr

Use the provided Replacer tools in order to proxy string interpolation.

parent a28d5585
...@@ -4,11 +4,8 @@ package proxy ...@@ -4,11 +4,8 @@ package proxy
import ( import (
"errors" "errors"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
"net"
"net/http" "net/http"
"net/url" "net/url"
"regexp"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
) )
...@@ -34,6 +31,7 @@ type UpstreamHostDownFunc func(*UpstreamHost) bool ...@@ -34,6 +31,7 @@ type UpstreamHostDownFunc func(*UpstreamHost) bool
// An UpstreamHost represents a single proxy upstream // An UpstreamHost represents a single proxy upstream
type UpstreamHost struct { type UpstreamHost struct {
// The hostname of this upstream host
Name string Name string
ReverseProxy *ReverseProxy ReverseProxy *ReverseProxy
Conns int64 Conns int64
...@@ -52,81 +50,27 @@ func (uh *UpstreamHost) Down() bool { ...@@ -52,81 +50,27 @@ func (uh *UpstreamHost) Down() bool {
return uh.CheckDown(uh) return uh.CheckDown(uh)
} }
//https://github.com/mgutz/str
var tRe = regexp.MustCompile(`([\-\[\]()*\s])`)
var tRe2 = regexp.MustCompile(`\$`)
var openDelim = tRe2.ReplaceAllString(tRe.ReplaceAllString("{{", "\\$1"), "\\$")
var closDelim = tRe2.ReplaceAllString(tRe.ReplaceAllString("}}", "\\$1"), "\\$")
var templateDelim = regexp.MustCompile(openDelim + `(.+?)` + closDelim)
type requestVars struct {
Host string
RemoteIp string
Scheme string
Upstream string
UpstreamHost string
}
func templateWithDelimiters(s string, vars requestVars) string {
matches := templateDelim.FindAllStringSubmatch(s, -1)
for _, submatches := range matches {
match := submatches[0]
key := submatches[1]
found := true
repl := ""
switch key {
case "http_host":
repl = vars.Host
case "remote_addr":
repl = vars.RemoteIp
case "scheme":
repl = vars.Scheme
case "upstream":
repl = vars.Upstream
case "upstream_host":
repl = vars.UpstreamHost
default:
found = false
}
if found {
s = strings.Replace(s, match, repl, -1)
}
}
return s
}
// ServeHTTP satisfies the middleware.Handler interface. // ServeHTTP satisfies the middleware.Handler interface.
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, upstream := range p.Upstreams { for _, upstream := range p.Upstreams {
if middleware.Path(r.URL.Path).Matches(upstream.From()) { if middleware.Path(r.URL.Path).Matches(upstream.From()) {
vars := requestVars{ var replacer middleware.Replacer
Host: r.Host, start := time.Now()
Scheme: "http", requestHost := r.Host
}
if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
vars.RemoteIp = clientIP
}
if fFor := r.Header.Get("X-Forwarded-For"); fFor != "" {
vars.RemoteIp = fFor
}
if r.TLS != nil {
vars.Scheme = "https"
}
// Since Select() should give us "up" hosts, keep retrying // Since Select() should give us "up" hosts, keep retrying
// hosts until timeout (or until we get a nil host). // hosts until timeout (or until we get a nil host).
start := time.Now()
for time.Now().Sub(start) < (60 * time.Second) { for time.Now().Sub(start) < (60 * time.Second) {
host := upstream.Select() host := upstream.Select()
if host == nil { if host == nil {
return http.StatusBadGateway, errUnreachable return http.StatusBadGateway, errUnreachable
} }
proxy := host.ReverseProxy proxy := host.ReverseProxy
vars.Upstream = host.Name
r.Host = host.Name r.Host = host.Name
if baseUrl, err := url.Parse(host.Name); err == nil { if baseUrl, err := url.Parse(host.Name); err == nil {
vars.UpstreamHost = baseUrl.Host r.Host = baseUrl.Host
if proxy == nil { if proxy == nil {
proxy = NewSingleHostReverseProxy(baseUrl) proxy = NewSingleHostReverseProxy(baseUrl)
} }
...@@ -136,12 +80,18 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -136,12 +80,18 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
var extraHeaders http.Header var extraHeaders http.Header
if host.ExtraHeaders != nil { if host.ExtraHeaders != nil {
extraHeaders = make(http.Header) extraHeaders = make(http.Header)
if replacer == nil {
rHost := r.Host
r.Host = requestHost
replacer = middleware.NewReplacer(r, nil)
r.Host = rHost
}
for header, values := range host.ExtraHeaders { for header, values := range host.ExtraHeaders {
for _, value := range values { for _, value := range values {
extraHeaders.Add(header, extraHeaders.Add(header,
templateWithDelimiters(value, vars)) replacer.Replace(value))
if header == "Host" { if header == "Host" {
r.Host = templateWithDelimiters(value, vars) r.Host = replacer.Replace(value)
} }
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment