Commit 078c9915 authored by Martin Redmond's avatar Martin Redmond Committed by Matt Holt

proxy: custom upstream health check by body string, closes #324 (#1691)

parent bf7b2548
package proxy package proxy
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -38,12 +39,13 @@ type staticUpstream struct { ...@@ -38,12 +39,13 @@ type staticUpstream struct {
TryInterval time.Duration TryInterval time.Duration
MaxConns int64 MaxConns int64
HealthCheck struct { HealthCheck struct {
Client http.Client Client http.Client
Path string Path string
Interval time.Duration Interval time.Duration
Timeout time.Duration Timeout time.Duration
Host string Host string
Port string Port string
ContentString string
} }
WithoutPathPrefix string WithoutPathPrefix string
IgnoredSubPaths []string IgnoredSubPaths []string
...@@ -337,6 +339,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { ...@@ -337,6 +339,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
return c.Errf("invalid health_check_port '%s'", port) return c.Errf("invalid health_check_port '%s'", port)
} }
u.HealthCheck.Port = port u.HealthCheck.Port = port
case "health_check_contains":
if !c.NextArg() {
return c.ArgErr()
}
u.HealthCheck.ContentString = c.Val()
case "header_upstream": case "header_upstream":
var header, value string var header, value string
if !c.Args(&header, &value) { if !c.Args(&header, &value) {
...@@ -402,27 +409,42 @@ func (u *staticUpstream) healthCheck() { ...@@ -402,27 +409,42 @@ func (u *staticUpstream) healthCheck() {
} }
hostURL += u.HealthCheck.Path hostURL += u.HealthCheck.Path
var unhealthy bool unhealthy := func() bool {
// set up request, needed to be able to modify headers
// set up request, needed to be able to modify headers // possible errors are bad HTTP methods or un-parsable urls
// possible errors are bad HTTP methods or un-parsable urls req, err := http.NewRequest("GET", hostURL, nil)
req, err := http.NewRequest("GET", hostURL, nil) if err != nil {
if err != nil { return true
unhealthy = true }
} else {
// set host for request going upstream // set host for request going upstream
if u.HealthCheck.Host != "" { if u.HealthCheck.Host != "" {
req.Host = u.HealthCheck.Host req.Host = u.HealthCheck.Host
} }
r, err := u.HealthCheck.Client.Do(req)
if r, err := u.HealthCheck.Client.Do(req); err == nil { if err != nil {
return true
}
defer func() {
io.Copy(ioutil.Discard, r.Body) io.Copy(ioutil.Discard, r.Body)
r.Body.Close() r.Body.Close()
unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 }()
} else { if r.StatusCode < 200 || r.StatusCode >= 400 {
unhealthy = true return true
} }
} if u.HealthCheck.ContentString == "" { // don't check for content string
return false
}
// TODO ReadAll will be replaced if deemed necessary
// See https://github.com/mholt/caddy/pull/1691
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return true
}
if bytes.Contains(buf, []byte(u.HealthCheck.ContentString)) {
return false
}
return true
}()
if unhealthy { if unhealthy {
atomic.StoreInt32(&host.Unhealthy, 1) atomic.StoreInt32(&host.Unhealthy, 1)
} else { } else {
......
...@@ -448,3 +448,56 @@ func TestHealthCheckPort(t *testing.T) { ...@@ -448,3 +448,56 @@ func TestHealthCheckPort(t *testing.T) {
}) })
} }
func TestHealthCheckContentString(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "blablabla good blablabla")
r.Body.Close()
}))
_, port, err := net.SplitHostPort(server.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
defer server.Close()
tests := []struct {
config string
shouldContain bool
}{
{"proxy / localhost:" + port +
" { health_check /testhealth " +
" health_check_contains good\n}",
true,
},
{"proxy / localhost:" + port + " {\n health_check /testhealth health_check_port " + port +
" \n health_check_contains bad\n}",
false,
},
}
for i, test := range tests {
u, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "")
if err != nil {
t.Error("Expected no error. Test %d Got:", i, err.Error())
}
for _, upstream := range u {
staticUpstream, ok := upstream.(*staticUpstream)
if !ok {
t.Errorf("Type mismatch: %#v", upstream)
continue
}
staticUpstream.healthCheck()
for _, host := range staticUpstream.Hosts {
if test.shouldContain && atomic.LoadInt32(&host.Unhealthy) == 0 {
// healthcheck url was hit and the required test string was found
continue
}
if !test.shouldContain && atomic.LoadInt32(&host.Unhealthy) != 0 {
// healthcheck url was hit and the required string was not found
continue
}
t.Errorf("Health check bad response")
}
upstream.Stop()
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment