From 63fd264043ed6913c5ae785783dde687f1ec34e2 Mon Sep 17 00:00:00 2001 From: Mohammad Gufran <mohammad.gufran@kayako.com> Date: Mon, 6 Nov 2017 11:31:10 +0530 Subject: [PATCH] proxy: Add SRV support for proxy upstream (#1915) * Simplify parseUpstream function * Add SRV support for proxy upstream --- caddyhttp/proxy/proxy.go | 3 +- caddyhttp/proxy/reverseproxy.go | 46 +++++- caddyhttp/proxy/reverseproxy_test.go | 94 +++++++++++ caddyhttp/proxy/upstream.go | 237 ++++++++++++++++++--------- caddyhttp/proxy/upstream_test.go | 220 ++++++++++++++++++++++++- 5 files changed, 518 insertions(+), 82 deletions(-) create mode 100644 caddyhttp/proxy/reverseproxy_test.go diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index 3d88af21..e66ad776 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -82,7 +82,8 @@ type UpstreamHost struct { // This is an int32 so that we can use atomic operations to do concurrent // reads & writes to this value. The default value of 0 indicates that it // is healthy and any non-zero value indicates unhealthy. - Unhealthy int32 + Unhealthy int32 + HealthCheckResult atomic.Value } // Down checks whether the upstream host is down or not. diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 78fbb6bd..2fac5aab 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -26,7 +26,9 @@ package proxy import ( + "context" "crypto/tls" + "fmt" "io" "net" "net/http" @@ -91,6 +93,8 @@ type ReverseProxy struct { // response body. // If zero, no periodic flushing is done. FlushInterval time.Duration + + srvResolver srvResolver } // Though the relevant directive prefix is just "unix:", url.Parse @@ -105,6 +109,23 @@ func socketDial(hostName string) func(network, addr string) (conn net.Conn, err } } +func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) { + service := locator + if strings.HasPrefix(locator, "srv://") { + service = locator[6:] + } else if strings.HasPrefix(locator, "srv+https://") { + service = locator[12:] + } + + return func(network, addr string) (conn net.Conn, err error) { + _, addrs, err := rp.srvResolver.LookupSRV(context.Background(), "", "", service) + if err != nil { + return nil, err + } + return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)) + } +} + func singleJoiningSlash(a, b string) string { aslash := strings.HasSuffix(a, "/") bslash := strings.HasPrefix(b, "/") @@ -131,6 +152,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * // scheme and host have to be faked req.URL.Scheme = "http" req.URL.Host = "socket" + } else if target.Scheme == "srv" { + req.URL.Scheme = "http" + req.URL.Host = target.Host + } else if target.Scheme == "srv+https" { + req.URL.Scheme = "https" + req.URL.Host = target.Host } else { req.URL.Scheme = target.Scheme req.URL.Host = target.Host @@ -199,7 +226,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * } } - rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events + rp := &ReverseProxy{ + Director: director, + FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events + srvResolver: net.DefaultResolver, + } + if target.Scheme == "unix" { rp.Transport = &http.Transport{ Dial: socketDial(target.String()), @@ -210,13 +242,15 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * HandshakeTimeout: defaultCryptoHandshakeTimeout, }, } - } else if keepalive != http.DefaultMaxIdleConnsPerHost { - // if keepalive is equal to the default, - // just use default transport, to avoid creating - // a brand new transport + } else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") { + dialFunc := defaultDialer.Dial + if strings.HasPrefix(target.Scheme, "srv") { + dialFunc = rp.srvDialerFunc(target.String()) + } + transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, - Dial: defaultDialer.Dial, + Dial: dialFunc, TLSHandshakeTimeout: defaultCryptoHandshakeTimeout, ExpectContinueTimeout: 1 * time.Second, } diff --git a/caddyhttp/proxy/reverseproxy_test.go b/caddyhttp/proxy/reverseproxy_test.go new file mode 100644 index 00000000..2d1d80df --- /dev/null +++ b/caddyhttp/proxy/reverseproxy_test.go @@ -0,0 +1,94 @@ +// Copyright 2015 Light Code Labs, LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "net" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" +) + +const ( + expectedResponse = "response from request proxied to upstream" + expectedStatus = http.StatusOK +) + +var upstreamHost *httptest.Server + +func setupTest() { + upstreamHost = httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test-path" { + w.WriteHeader(expectedStatus) + w.Write([]byte(expectedResponse)) + } else { + w.WriteHeader(404) + w.Write([]byte("Not found")) + } + })) +} + +func tearDownTest() { + upstreamHost.Close() +} + +func TestSingleSRVHostReverseProxy(t *testing.T) { + setupTest() + defer tearDownTest() + + target, err := url.Parse("srv://test.upstream.service") + if err != nil { + t.Errorf("Failed to parse target URL. %s", err.Error()) + } + + upstream, err := url.Parse(upstreamHost.URL) + if err != nil { + t.Errorf("Failed to parse test server URL [%s]. %s", upstreamHost.URL, err.Error()) + } + pp, err := strconv.Atoi(upstream.Port()) + if err != nil { + t.Errorf("Failed to parse upstream server port [%s]. %s", upstream.Port(), err.Error()) + } + port := uint16(pp) + + rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost) + rp.srvResolver = testResolver{ + result: []*net.SRV{ + {Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1}, + }, + } + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "http://test.host/test-path", nil) + if err != nil { + t.Errorf("Failed to create new request. %s", err.Error()) + } + + err = rp.ServeHTTP(resp, req, nil) + if err != nil { + t.Errorf("Failed to perform reverse proxy to upstream host. %s", err.Error()) + } + + if resp.Body.String() != expectedResponse { + t.Errorf("Unexpected proxy response received. Expected: '%s', Got: '%s'", expectedResponse, resp.Body.String()) + } + + if resp.Code != expectedStatus { + t.Errorf("Unexpected proxy status. Expected: '%d', Got: '%d'", expectedStatus, resp.Code) + } +} diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index bab8b462..ae15a6dc 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -16,6 +16,7 @@ package proxy import ( "bytes" + "context" "fmt" "io" "io/ioutil" @@ -65,6 +66,11 @@ type staticUpstream struct { IgnoredSubPaths []string insecureSkipVerify bool MaxFails int32 + resolver srvResolver +} + +type srvResolver interface { + LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) } // NewStaticUpstreams parses the configuration input and sets up @@ -86,6 +92,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) TryInterval: 250 * time.Millisecond, MaxConns: 0, KeepAlive: http.DefaultMaxIdleConnsPerHost, + resolver: net.DefaultResolver, } if !c.Args(&upstream.from) { @@ -93,7 +100,21 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) } var to []string + hasSrv := false + for _, t := range c.RemainingArgs() { + if len(to) > 0 && hasSrv { + return upstreams, c.Err("only one upstream is supported when using SRV locator") + } + + if strings.HasPrefix(t, "srv://") || strings.HasPrefix(t, "srv+https://") { + if len(to) > 0 { + return upstreams, c.Err("service locator upstreams can not be mixed with host names") + } + + hasSrv = true + } + parsed, err := parseUpstream(t) if err != nil { return upstreams, err @@ -107,13 +128,18 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) if !c.NextArg() { return upstreams, c.ArgErr() } + + if hasSrv { + return upstreams, c.Err("upstream directive is not supported when backend is service locator") + } + parsed, err := parseUpstream(c.Val()) if err != nil { return upstreams, err } to = append(to, parsed...) default: - if err := parseBlock(&c, upstream); err != nil { + if err := parseBlock(&c, upstream, hasSrv); err != nil { return upstreams, err } } @@ -165,7 +191,9 @@ func (u *staticUpstream) From() string { func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { if !strings.HasPrefix(host, "http") && !strings.HasPrefix(host, "unix:") && - !strings.HasPrefix(host, "quic:") { + !strings.HasPrefix(host, "quic:") && + !strings.HasPrefix(host, "srv://") && + !strings.HasPrefix(host, "srv+https://") { host = "http://" + host } uh := &UpstreamHost{ @@ -189,6 +217,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { }(u), WithoutPathPrefix: u.WithoutPathPrefix, MaxConns: u.MaxConns, + HealthCheckResult: atomic.Value{}, } baseURL, err := url.Parse(uh.Name) @@ -205,50 +234,65 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { } func parseUpstream(u string) ([]string, error) { - if !strings.HasPrefix(u, "unix:") { - colonIdx := strings.LastIndex(u, ":") - protoIdx := strings.Index(u, "://") - - if colonIdx != -1 && colonIdx != protoIdx { - us := u[:colonIdx] - ue := "" - portsEnd := len(u) - if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 { - portsEnd = colonIdx + nextSlash - ue = u[portsEnd:] - } - ports := u[len(us)+1 : portsEnd] + if strings.HasPrefix(u, "unix:") { + return []string{u}, nil + } - if separators := strings.Count(ports, "-"); separators == 1 { - portsStr := strings.Split(ports, "-") - pIni, err := strconv.Atoi(portsStr[0]) - if err != nil { - return nil, err - } + isSrv := strings.HasPrefix(u, "srv://") || strings.HasPrefix(u, "srv+https://") + colonIdx := strings.LastIndex(u, ":") + protoIdx := strings.Index(u, "://") - pEnd, err := strconv.Atoi(portsStr[1]) - if err != nil { - return nil, err - } + if colonIdx == -1 || colonIdx == protoIdx { + return []string{u}, nil + } - if pEnd <= pIni { - return nil, fmt.Errorf("port range [%s] is invalid", ports) - } + if isSrv { + return nil, fmt.Errorf("service locator %s can not have port specified", u) + } - hosts := []string{} - for p := pIni; p <= pEnd; p++ { - hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue)) - } - return hosts, nil - } - } + us := u[:colonIdx] + ue := "" + portsEnd := len(u) + if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 { + portsEnd = colonIdx + nextSlash + ue = u[portsEnd:] + } + + ports := u[len(us)+1 : portsEnd] + separators := strings.Count(ports, "-") + + if separators == 0 { + return []string{u}, nil + } + + if separators > 1 { + return nil, fmt.Errorf("port range [%s] has %d separators", ports, separators) + } + + portsStr := strings.Split(ports, "-") + pIni, err := strconv.Atoi(portsStr[0]) + if err != nil { + return nil, err } - return []string{u}, nil + pEnd, err := strconv.Atoi(portsStr[1]) + if err != nil { + return nil, err + } + if pEnd <= pIni { + return nil, fmt.Errorf("port range [%s] is invalid", ports) + } + + hosts := []string{} + for p := pIni; p <= pEnd; p++ { + hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue)) + } + + return hosts, nil } -func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { +func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { switch c.Val() { case "policy": if !c.NextArg() { @@ -348,6 +392,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { if !c.NextArg() { return c.ArgErr() } + + if hasSrv { + return c.Err("health_check_port directive is not allowed when upstream is SRV locator") + } + port := c.Val() n, err := strconv.Atoi(port) if err != nil { @@ -420,54 +469,94 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { return nil } +func (u *staticUpstream) resolveHost(h string) ([]string, bool, error) { + names := []string{} + proto := "http" + if !strings.HasPrefix(h, "srv://") && !strings.HasPrefix(h, "srv+https://") { + return []string{h}, false, nil + } + + if strings.HasPrefix(h, "srv+https://") { + proto = "https" + } + + _, addrs, err := u.resolver.LookupSRV(context.Background(), "", "", h) + if err != nil { + return names, true, err + } + + for _, addr := range addrs { + names = append(names, fmt.Sprintf("%s://%s:%d", proto, addr.Target, addr.Port)) + } + + return names, true, nil +} + func (u *staticUpstream) healthCheck() { for _, host := range u.Hosts { - hostURL := host.Name - if u.HealthCheck.Port != "" { - hostURL = replacePort(host.Name, u.HealthCheck.Port) + candidates, isSrv, err := u.resolveHost(host.Name) + if err != nil { + host.HealthCheckResult.Store(err.Error()) + atomic.StoreInt32(&host.Unhealthy, 1) + continue } - hostURL += u.HealthCheck.Path - unhealthy := func() bool { - // set up request, needed to be able to modify headers - // possible errors are bad HTTP methods or un-parsable urls - req, err := http.NewRequest("GET", hostURL, nil) - if err != nil { - return true + unhealthyCount := 0 + for _, addr := range candidates { + hostURL := addr + if !isSrv && u.HealthCheck.Port != "" { + hostURL = replacePort(hostURL, u.HealthCheck.Port) } - // set host for request going upstream - if u.HealthCheck.Host != "" { - req.Host = u.HealthCheck.Host - } - r, err := u.HealthCheck.Client.Do(req) - if err != nil { + hostURL += u.HealthCheck.Path + + unhealthy := func() bool { + // set up request, needed to be able to modify headers + // possible errors are bad HTTP methods or un-parsable urls + req, err := http.NewRequest("GET", hostURL, nil) + if err != nil { + return true + } + // set host for request going upstream + if u.HealthCheck.Host != "" { + req.Host = u.HealthCheck.Host + } + r, err := u.HealthCheck.Client.Do(req) + if err != nil { + return true + } + defer func() { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + }() + if r.StatusCode < 200 || r.StatusCode >= 400 { + return true + } + if u.HealthCheck.ContentString == "" { // don't check for content string + return false + } + // TODO ReadAll will be replaced if deemed necessary + // See https://github.com/mholt/caddy/pull/1691 + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return true + } + if bytes.Contains(buf, []byte(u.HealthCheck.ContentString)) { + return false + } return true - } - defer func() { - io.Copy(ioutil.Discard, r.Body) - r.Body.Close() }() - if r.StatusCode < 200 || r.StatusCode >= 400 { - return true - } - if u.HealthCheck.ContentString == "" { // don't check for content string - return false - } - // TODO ReadAll will be replaced if deemed necessary - // See https://github.com/mholt/caddy/pull/1691 - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return true - } - if bytes.Contains(buf, []byte(u.HealthCheck.ContentString)) { - return false + + if unhealthy { + unhealthyCount++ } - return true - }() - if unhealthy { + } + + if unhealthyCount == len(candidates) { atomic.StoreInt32(&host.Unhealthy, 1) + host.HealthCheckResult.Store("Failed") } else { atomic.StoreInt32(&host.Unhealthy, 0) + host.HealthCheckResult.Store("OK") } } } diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go index ce662d19..23fd4831 100644 --- a/caddyhttp/proxy/upstream_test.go +++ b/caddyhttp/proxy/upstream_test.go @@ -15,10 +15,15 @@ package proxy import ( + "context" + "errors" "fmt" "net" "net/http" "net/http/httptest" + "net/url" + "reflect" + "strconv" "strings" "sync/atomic" "testing" @@ -187,7 +192,7 @@ func TestParseBlockHealthCheck(t *testing.T) { u := staticUpstream{} c := caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)) for c.Next() { - parseBlock(&c, &u) + parseBlock(&c, &u, false) } if u.HealthCheck.Interval.String() != test.interval { t.Errorf( @@ -551,3 +556,216 @@ func TestQuicHost(t *testing.T) { } } } + +func TestParseSRVBlock(t *testing.T) { + tests := []struct { + config string + shouldErr bool + }{ + {"proxy / srv://bogus.service", false}, + {"proxy / srv://bogus.service:80", true}, + {"proxy / srv://bogus.service srv://bogus.service.fallback", true}, + {"proxy / srv://bogus.service http://bogus.service.fallback", true}, + {"proxy / http://bogus.service srv://bogus.service.fallback", true}, + {"proxy / srv://bogus.service bogus.service.fallback", true}, + {`proxy / srv://bogus.service { + upstream srv://bogus.service + }`, true}, + {"proxy / srv+https://bogus.service", false}, + {"proxy / srv+https://bogus.service:80", true}, + {"proxy / srv+https://bogus.service srv://bogus.service.fallback", true}, + {"proxy / srv+https://bogus.service http://bogus.service.fallback", true}, + {"proxy / http://bogus.service srv+https://bogus.service.fallback", true}, + {"proxy / srv+https://bogus.service bogus.service.fallback", true}, + {`proxy / srv+https://bogus.service { + upstream srv://bogus.service + }`, true}, + {`proxy / srv+https://bogus.service { + health_check_port 96 + }`, true}, + } + + for i, test := range tests { + _, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") + if err == nil && test.shouldErr { + t.Errorf("Case %d - Expected an error. got nothing", i) + } + + if err != nil && !test.shouldErr { + t.Errorf("Case %d - Expected no error. got %s", i, err.Error()) + } + } +} + +type testResolver struct { + errOn string + result []*net.SRV +} + +func (r testResolver) LookupSRV(ctx context.Context, _, _, service string) (string, []*net.SRV, error) { + if service == r.errOn { + return "", nil, errors.New("an error occurred") + } + + return "", r.result, nil +} + +func TestResolveHost(t *testing.T) { + upstream := &staticUpstream{ + resolver: testResolver{ + errOn: "srv://problematic.service.name", + result: []*net.SRV{ + {Target: "target-1.fqdn", Port: 85, Priority: 1, Weight: 1}, + {Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1}, + {Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1}, + }, + }, + } + + tests := []struct { + host string + expect []string + isSrv bool + shouldErr bool + }{ + // Static DNS records + {"http://subdomain.domain.service", + []string{"http://subdomain.domain.service"}, + false, + false}, + {"https://subdomain.domain.service", + []string{"https://subdomain.domain.service"}, + false, + false}, + {"http://subdomain.domain.service:76", + []string{"http://subdomain.domain.service:76"}, + false, + false}, + {"https://subdomain.domain.service:65", + []string{"https://subdomain.domain.service:65"}, + false, + false}, + + // SRV lookups + {"srv://service.name", []string{ + "http://target-1.fqdn:85", + "http://target-2.fqdn:33", + "http://target-3.fqdn:94", + }, true, false}, + {"srv+https://service.name", []string{ + "https://target-1.fqdn:85", + "https://target-2.fqdn:33", + "https://target-3.fqdn:94", + }, true, false}, + {"srv://problematic.service.name", []string{}, true, true}, + } + + for i, test := range tests { + results, isSrv, err := upstream.resolveHost(test.host) + if err == nil && test.shouldErr { + t.Errorf("Test %d - expected an error, got none", i) + } + + if err != nil && !test.shouldErr { + t.Errorf("Test %d - unexpected error %s", i, err.Error()) + } + + if test.isSrv && !isSrv { + t.Errorf("Test %d - expecting resolution to be SRV lookup but it isn't", i) + } + + if isSrv && !test.isSrv { + t.Errorf("Test %d - expecting resolution to be normal lookup, got SRV", i) + } + + if !reflect.DeepEqual(results, test.expect) { + t.Errorf("Test %d - resolution result %#v does not match expected value %#v", i, results, test.expect) + } + } +} + +func TestSRVHealthCheck(t *testing.T) { + serverURL, err := url.Parse(workableServer.URL) + if err != nil { + t.Errorf("Failed to parse test server URL: %s", err.Error()) + } + + pp, err := strconv.Atoi(serverURL.Port()) + if err != nil { + t.Errorf("Failed to parse test server port [%s]: %s", serverURL.Port(), err.Error()) + } + + port := uint16(pp) + + allGoodResolver := testResolver{ + result: []*net.SRV{ + {Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1}, + }, + } + + partialFailureResolver := testResolver{ + result: []*net.SRV{ + {Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1}, + {Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1}, + {Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1}, + }, + } + + fullFailureResolver := testResolver{ + result: []*net.SRV{ + {Target: "target-1.fqdn", Port: 876, Priority: 1, Weight: 1}, + {Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1}, + {Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1}, + }, + } + + resolutionErrorResolver := testResolver{ + errOn: "srv://tag.service.consul", + result: []*net.SRV{}, + } + + upstream := &staticUpstream{ + Hosts: []*UpstreamHost{ + {Name: "srv://tag.service.consul"}, + }, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + + tests := []struct { + resolver testResolver + shouldFail bool + shouldErr bool + }{ + {allGoodResolver, false, false}, + {partialFailureResolver, false, false}, + {fullFailureResolver, true, false}, + {resolutionErrorResolver, true, true}, + } + + for i, test := range tests { + upstream.resolver = test.resolver + upstream.healthCheck() + if upstream.Hosts[0].Down() && !test.shouldFail { + t.Errorf("Test %d - expected all healthchecks to pass, all failing", i) + } + + if test.shouldFail && !upstream.Hosts[0].Down() { + t.Errorf("Test %d - expected all healthchecks to fail, all passing", i) + } + + status := fmt.Sprintf("%s", upstream.Hosts[0].HealthCheckResult.Load()) + + if test.shouldFail && !test.shouldErr && status != "Failed" { + t.Errorf("Test %d - Expected health check result to be 'Failed', got '%s'", i, status) + } + + if !test.shouldFail && status != "OK" { + t.Errorf("Test %d - Expected health check result to be 'OK', got '%s'", i, status) + } + + if test.shouldErr && status != "an error occurred" { + t.Errorf("Test %d - Expected health check result to be 'an error occured', got '%s'", i, status) + } + } +} -- 2.30.9