From 63fd264043ed6913c5ae785783dde687f1ec34e2 Mon Sep 17 00:00:00 2001
From: Mohammad Gufran <mohammad.gufran@kayako.com>
Date: Mon, 6 Nov 2017 11:31:10 +0530
Subject: [PATCH] proxy: Add SRV support for proxy upstream (#1915)

* Simplify parseUpstream function

* Add SRV support for proxy upstream
---
 caddyhttp/proxy/proxy.go             |   3 +-
 caddyhttp/proxy/reverseproxy.go      |  46 +++++-
 caddyhttp/proxy/reverseproxy_test.go |  94 +++++++++++
 caddyhttp/proxy/upstream.go          | 237 ++++++++++++++++++---------
 caddyhttp/proxy/upstream_test.go     | 220 ++++++++++++++++++++++++-
 5 files changed, 518 insertions(+), 82 deletions(-)
 create mode 100644 caddyhttp/proxy/reverseproxy_test.go

diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go
index 3d88af21..e66ad776 100644
--- a/caddyhttp/proxy/proxy.go
+++ b/caddyhttp/proxy/proxy.go
@@ -82,7 +82,8 @@ type UpstreamHost struct {
 	// This is an int32 so that we can use atomic operations to do concurrent
 	// reads & writes to this value.  The default value of 0 indicates that it
 	// is healthy and any non-zero value indicates unhealthy.
-	Unhealthy int32
+	Unhealthy         int32
+	HealthCheckResult atomic.Value
 }
 
 // Down checks whether the upstream host is down or not.
diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go
index 78fbb6bd..2fac5aab 100644
--- a/caddyhttp/proxy/reverseproxy.go
+++ b/caddyhttp/proxy/reverseproxy.go
@@ -26,7 +26,9 @@
 package proxy
 
 import (
+	"context"
 	"crypto/tls"
+	"fmt"
 	"io"
 	"net"
 	"net/http"
@@ -91,6 +93,8 @@ type ReverseProxy struct {
 	// response body.
 	// If zero, no periodic flushing is done.
 	FlushInterval time.Duration
+
+	srvResolver srvResolver
 }
 
 // Though the relevant directive prefix is just "unix:", url.Parse
@@ -105,6 +109,23 @@ func socketDial(hostName string) func(network, addr string) (conn net.Conn, err
 	}
 }
 
+func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) {
+	service := locator
+	if strings.HasPrefix(locator, "srv://") {
+		service = locator[6:]
+	} else if strings.HasPrefix(locator, "srv+https://") {
+		service = locator[12:]
+	}
+
+	return func(network, addr string) (conn net.Conn, err error) {
+		_, addrs, err := rp.srvResolver.LookupSRV(context.Background(), "", "", service)
+		if err != nil {
+			return nil, err
+		}
+		return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port))
+	}
+}
+
 func singleJoiningSlash(a, b string) string {
 	aslash := strings.HasSuffix(a, "/")
 	bslash := strings.HasPrefix(b, "/")
@@ -131,6 +152,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
 			// scheme and host have to be faked
 			req.URL.Scheme = "http"
 			req.URL.Host = "socket"
+		} else if target.Scheme == "srv" {
+			req.URL.Scheme = "http"
+			req.URL.Host = target.Host
+		} else if target.Scheme == "srv+https" {
+			req.URL.Scheme = "https"
+			req.URL.Host = target.Host
 		} else {
 			req.URL.Scheme = target.Scheme
 			req.URL.Host = target.Host
@@ -199,7 +226,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
 		}
 	}
 
-	rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events
+	rp := &ReverseProxy{
+		Director:      director,
+		FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
+		srvResolver:   net.DefaultResolver,
+	}
+
 	if target.Scheme == "unix" {
 		rp.Transport = &http.Transport{
 			Dial: socketDial(target.String()),
@@ -210,13 +242,15 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
 				HandshakeTimeout: defaultCryptoHandshakeTimeout,
 			},
 		}
-	} else if keepalive != http.DefaultMaxIdleConnsPerHost {
-		// if keepalive is equal to the default,
-		// just use default transport, to avoid creating
-		// a brand new transport
+	} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
+		dialFunc := defaultDialer.Dial
+		if strings.HasPrefix(target.Scheme, "srv") {
+			dialFunc = rp.srvDialerFunc(target.String())
+		}
+
 		transport := &http.Transport{
 			Proxy:                 http.ProxyFromEnvironment,
-			Dial:                  defaultDialer.Dial,
+			Dial:                  dialFunc,
 			TLSHandshakeTimeout:   defaultCryptoHandshakeTimeout,
 			ExpectContinueTimeout: 1 * time.Second,
 		}
diff --git a/caddyhttp/proxy/reverseproxy_test.go b/caddyhttp/proxy/reverseproxy_test.go
new file mode 100644
index 00000000..2d1d80df
--- /dev/null
+++ b/caddyhttp/proxy/reverseproxy_test.go
@@ -0,0 +1,94 @@
+// Copyright 2015 Light Code Labs, LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proxy
+
+import (
+	"net"
+	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"strconv"
+	"testing"
+)
+
+const (
+	expectedResponse = "response from request proxied to upstream"
+	expectedStatus   = http.StatusOK
+)
+
+var upstreamHost *httptest.Server
+
+func setupTest() {
+	upstreamHost = httptest.NewServer(http.HandlerFunc(
+		func(w http.ResponseWriter, r *http.Request) {
+			if r.URL.Path == "/test-path" {
+				w.WriteHeader(expectedStatus)
+				w.Write([]byte(expectedResponse))
+			} else {
+				w.WriteHeader(404)
+				w.Write([]byte("Not found"))
+			}
+		}))
+}
+
+func tearDownTest() {
+	upstreamHost.Close()
+}
+
+func TestSingleSRVHostReverseProxy(t *testing.T) {
+	setupTest()
+	defer tearDownTest()
+
+	target, err := url.Parse("srv://test.upstream.service")
+	if err != nil {
+		t.Errorf("Failed to parse target URL. %s", err.Error())
+	}
+
+	upstream, err := url.Parse(upstreamHost.URL)
+	if err != nil {
+		t.Errorf("Failed to parse test server URL [%s]. %s", upstreamHost.URL, err.Error())
+	}
+	pp, err := strconv.Atoi(upstream.Port())
+	if err != nil {
+		t.Errorf("Failed to parse upstream server port [%s]. %s", upstream.Port(), err.Error())
+	}
+	port := uint16(pp)
+
+	rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost)
+	rp.srvResolver = testResolver{
+		result: []*net.SRV{
+			{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
+		},
+	}
+
+	resp := httptest.NewRecorder()
+	req, err := http.NewRequest("GET", "http://test.host/test-path", nil)
+	if err != nil {
+		t.Errorf("Failed to create new request. %s", err.Error())
+	}
+
+	err = rp.ServeHTTP(resp, req, nil)
+	if err != nil {
+		t.Errorf("Failed to perform reverse proxy to upstream host. %s", err.Error())
+	}
+
+	if resp.Body.String() != expectedResponse {
+		t.Errorf("Unexpected proxy response received. Expected: '%s', Got: '%s'", expectedResponse, resp.Body.String())
+	}
+
+	if resp.Code != expectedStatus {
+		t.Errorf("Unexpected proxy status. Expected: '%d', Got: '%d'", expectedStatus, resp.Code)
+	}
+}
diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go
index bab8b462..ae15a6dc 100644
--- a/caddyhttp/proxy/upstream.go
+++ b/caddyhttp/proxy/upstream.go
@@ -16,6 +16,7 @@ package proxy
 
 import (
 	"bytes"
+	"context"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -65,6 +66,11 @@ type staticUpstream struct {
 	IgnoredSubPaths    []string
 	insecureSkipVerify bool
 	MaxFails           int32
+	resolver           srvResolver
+}
+
+type srvResolver interface {
+	LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error)
 }
 
 // NewStaticUpstreams parses the configuration input and sets up
@@ -86,6 +92,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
 			TryInterval:       250 * time.Millisecond,
 			MaxConns:          0,
 			KeepAlive:         http.DefaultMaxIdleConnsPerHost,
+			resolver:          net.DefaultResolver,
 		}
 
 		if !c.Args(&upstream.from) {
@@ -93,7 +100,21 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
 		}
 
 		var to []string
+		hasSrv := false
+
 		for _, t := range c.RemainingArgs() {
+			if len(to) > 0 && hasSrv {
+				return upstreams, c.Err("only one upstream is supported when using SRV locator")
+			}
+
+			if strings.HasPrefix(t, "srv://") || strings.HasPrefix(t, "srv+https://") {
+				if len(to) > 0 {
+					return upstreams, c.Err("service locator upstreams can not be mixed with host names")
+				}
+
+				hasSrv = true
+			}
+
 			parsed, err := parseUpstream(t)
 			if err != nil {
 				return upstreams, err
@@ -107,13 +128,18 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
 				if !c.NextArg() {
 					return upstreams, c.ArgErr()
 				}
+
+				if hasSrv {
+					return upstreams, c.Err("upstream directive is not supported when backend is service locator")
+				}
+
 				parsed, err := parseUpstream(c.Val())
 				if err != nil {
 					return upstreams, err
 				}
 				to = append(to, parsed...)
 			default:
-				if err := parseBlock(&c, upstream); err != nil {
+				if err := parseBlock(&c, upstream, hasSrv); err != nil {
 					return upstreams, err
 				}
 			}
@@ -165,7 +191,9 @@ func (u *staticUpstream) From() string {
 func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
 	if !strings.HasPrefix(host, "http") &&
 		!strings.HasPrefix(host, "unix:") &&
-		!strings.HasPrefix(host, "quic:") {
+		!strings.HasPrefix(host, "quic:") &&
+		!strings.HasPrefix(host, "srv://") &&
+		!strings.HasPrefix(host, "srv+https://") {
 		host = "http://" + host
 	}
 	uh := &UpstreamHost{
@@ -189,6 +217,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
 		}(u),
 		WithoutPathPrefix: u.WithoutPathPrefix,
 		MaxConns:          u.MaxConns,
+		HealthCheckResult: atomic.Value{},
 	}
 
 	baseURL, err := url.Parse(uh.Name)
@@ -205,50 +234,65 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
 }
 
 func parseUpstream(u string) ([]string, error) {
-	if !strings.HasPrefix(u, "unix:") {
-		colonIdx := strings.LastIndex(u, ":")
-		protoIdx := strings.Index(u, "://")
-
-		if colonIdx != -1 && colonIdx != protoIdx {
-			us := u[:colonIdx]
-			ue := ""
-			portsEnd := len(u)
-			if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 {
-				portsEnd = colonIdx + nextSlash
-				ue = u[portsEnd:]
-			}
-			ports := u[len(us)+1 : portsEnd]
+	if strings.HasPrefix(u, "unix:") {
+		return []string{u}, nil
+	}
 
-			if separators := strings.Count(ports, "-"); separators == 1 {
-				portsStr := strings.Split(ports, "-")
-				pIni, err := strconv.Atoi(portsStr[0])
-				if err != nil {
-					return nil, err
-				}
+	isSrv := strings.HasPrefix(u, "srv://") || strings.HasPrefix(u, "srv+https://")
+	colonIdx := strings.LastIndex(u, ":")
+	protoIdx := strings.Index(u, "://")
 
-				pEnd, err := strconv.Atoi(portsStr[1])
-				if err != nil {
-					return nil, err
-				}
+	if colonIdx == -1 || colonIdx == protoIdx {
+		return []string{u}, nil
+	}
 
-				if pEnd <= pIni {
-					return nil, fmt.Errorf("port range [%s] is invalid", ports)
-				}
+	if isSrv {
+		return nil, fmt.Errorf("service locator %s can not have port specified", u)
+	}
 
-				hosts := []string{}
-				for p := pIni; p <= pEnd; p++ {
-					hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue))
-				}
-				return hosts, nil
-			}
-		}
+	us := u[:colonIdx]
+	ue := ""
+	portsEnd := len(u)
+	if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 {
+		portsEnd = colonIdx + nextSlash
+		ue = u[portsEnd:]
+	}
+
+	ports := u[len(us)+1 : portsEnd]
+	separators := strings.Count(ports, "-")
+
+	if separators == 0 {
+		return []string{u}, nil
+	}
+
+	if separators > 1 {
+		return nil, fmt.Errorf("port range [%s] has %d separators", ports, separators)
+	}
+
+	portsStr := strings.Split(ports, "-")
+	pIni, err := strconv.Atoi(portsStr[0])
+	if err != nil {
+		return nil, err
 	}
 
-	return []string{u}, nil
+	pEnd, err := strconv.Atoi(portsStr[1])
+	if err != nil {
+		return nil, err
+	}
 
+	if pEnd <= pIni {
+		return nil, fmt.Errorf("port range [%s] is invalid", ports)
+	}
+
+	hosts := []string{}
+	for p := pIni; p <= pEnd; p++ {
+		hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue))
+	}
+
+	return hosts, nil
 }
 
-func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
+func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
 	switch c.Val() {
 	case "policy":
 		if !c.NextArg() {
@@ -348,6 +392,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
 		if !c.NextArg() {
 			return c.ArgErr()
 		}
+
+		if hasSrv {
+			return c.Err("health_check_port directive is not allowed when upstream is SRV locator")
+		}
+
 		port := c.Val()
 		n, err := strconv.Atoi(port)
 		if err != nil {
@@ -420,54 +469,94 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
 	return nil
 }
 
+func (u *staticUpstream) resolveHost(h string) ([]string, bool, error) {
+	names := []string{}
+	proto := "http"
+	if !strings.HasPrefix(h, "srv://") && !strings.HasPrefix(h, "srv+https://") {
+		return []string{h}, false, nil
+	}
+
+	if strings.HasPrefix(h, "srv+https://") {
+		proto = "https"
+	}
+
+	_, addrs, err := u.resolver.LookupSRV(context.Background(), "", "", h)
+	if err != nil {
+		return names, true, err
+	}
+
+	for _, addr := range addrs {
+		names = append(names, fmt.Sprintf("%s://%s:%d", proto, addr.Target, addr.Port))
+	}
+
+	return names, true, nil
+}
+
 func (u *staticUpstream) healthCheck() {
 	for _, host := range u.Hosts {
-		hostURL := host.Name
-		if u.HealthCheck.Port != "" {
-			hostURL = replacePort(host.Name, u.HealthCheck.Port)
+		candidates, isSrv, err := u.resolveHost(host.Name)
+		if err != nil {
+			host.HealthCheckResult.Store(err.Error())
+			atomic.StoreInt32(&host.Unhealthy, 1)
+			continue
 		}
-		hostURL += u.HealthCheck.Path
 
-		unhealthy := func() bool {
-			// set up request, needed to be able to modify headers
-			// possible errors are bad HTTP methods or un-parsable urls
-			req, err := http.NewRequest("GET", hostURL, nil)
-			if err != nil {
-				return true
+		unhealthyCount := 0
+		for _, addr := range candidates {
+			hostURL := addr
+			if !isSrv && u.HealthCheck.Port != "" {
+				hostURL = replacePort(hostURL, u.HealthCheck.Port)
 			}
-			// set host for request going upstream
-			if u.HealthCheck.Host != "" {
-				req.Host = u.HealthCheck.Host
-			}
-			r, err := u.HealthCheck.Client.Do(req)
-			if err != nil {
+			hostURL += u.HealthCheck.Path
+
+			unhealthy := func() bool {
+				// set up request, needed to be able to modify headers
+				// possible errors are bad HTTP methods or un-parsable urls
+				req, err := http.NewRequest("GET", hostURL, nil)
+				if err != nil {
+					return true
+				}
+				// set host for request going upstream
+				if u.HealthCheck.Host != "" {
+					req.Host = u.HealthCheck.Host
+				}
+				r, err := u.HealthCheck.Client.Do(req)
+				if err != nil {
+					return true
+				}
+				defer func() {
+					io.Copy(ioutil.Discard, r.Body)
+					r.Body.Close()
+				}()
+				if r.StatusCode < 200 || r.StatusCode >= 400 {
+					return true
+				}
+				if u.HealthCheck.ContentString == "" { // don't check for content string
+					return false
+				}
+				// TODO ReadAll will be replaced if deemed necessary
+				//      See https://github.com/mholt/caddy/pull/1691
+				buf, err := ioutil.ReadAll(r.Body)
+				if err != nil {
+					return true
+				}
+				if bytes.Contains(buf, []byte(u.HealthCheck.ContentString)) {
+					return false
+				}
 				return true
-			}
-			defer func() {
-				io.Copy(ioutil.Discard, r.Body)
-				r.Body.Close()
 			}()
-			if r.StatusCode < 200 || r.StatusCode >= 400 {
-				return true
-			}
-			if u.HealthCheck.ContentString == "" { // don't check for content string
-				return false
-			}
-			// TODO ReadAll will be replaced if deemed necessary
-			//      See https://github.com/mholt/caddy/pull/1691
-			buf, err := ioutil.ReadAll(r.Body)
-			if err != nil {
-				return true
-			}
-			if bytes.Contains(buf, []byte(u.HealthCheck.ContentString)) {
-				return false
+
+			if unhealthy {
+				unhealthyCount++
 			}
-			return true
-		}()
-		if unhealthy {
+		}
+
+		if unhealthyCount == len(candidates) {
 			atomic.StoreInt32(&host.Unhealthy, 1)
+			host.HealthCheckResult.Store("Failed")
 		} else {
 			atomic.StoreInt32(&host.Unhealthy, 0)
+			host.HealthCheckResult.Store("OK")
 		}
 	}
 }
diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go
index ce662d19..23fd4831 100644
--- a/caddyhttp/proxy/upstream_test.go
+++ b/caddyhttp/proxy/upstream_test.go
@@ -15,10 +15,15 @@
 package proxy
 
 import (
+	"context"
+	"errors"
 	"fmt"
 	"net"
 	"net/http"
 	"net/http/httptest"
+	"net/url"
+	"reflect"
+	"strconv"
 	"strings"
 	"sync/atomic"
 	"testing"
@@ -187,7 +192,7 @@ func TestParseBlockHealthCheck(t *testing.T) {
 		u := staticUpstream{}
 		c := caddyfile.NewDispenser("Testfile", strings.NewReader(test.config))
 		for c.Next() {
-			parseBlock(&c, &u)
+			parseBlock(&c, &u, false)
 		}
 		if u.HealthCheck.Interval.String() != test.interval {
 			t.Errorf(
@@ -551,3 +556,216 @@ func TestQuicHost(t *testing.T) {
 		}
 	}
 }
+
+func TestParseSRVBlock(t *testing.T) {
+	tests := []struct {
+		config    string
+		shouldErr bool
+	}{
+		{"proxy / srv://bogus.service", false},
+		{"proxy / srv://bogus.service:80", true},
+		{"proxy / srv://bogus.service srv://bogus.service.fallback", true},
+		{"proxy / srv://bogus.service http://bogus.service.fallback", true},
+		{"proxy / http://bogus.service srv://bogus.service.fallback", true},
+		{"proxy / srv://bogus.service bogus.service.fallback", true},
+		{`proxy / srv://bogus.service {
+		    upstream srv://bogus.service
+		 }`, true},
+		{"proxy / srv+https://bogus.service", false},
+		{"proxy / srv+https://bogus.service:80", true},
+		{"proxy / srv+https://bogus.service srv://bogus.service.fallback", true},
+		{"proxy / srv+https://bogus.service http://bogus.service.fallback", true},
+		{"proxy / http://bogus.service srv+https://bogus.service.fallback", true},
+		{"proxy / srv+https://bogus.service bogus.service.fallback", true},
+		{`proxy / srv+https://bogus.service {
+		    upstream srv://bogus.service
+		 }`, true},
+		{`proxy / srv+https://bogus.service {
+			health_check_port 96
+		 }`, true},
+	}
+
+	for i, test := range tests {
+		_, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "")
+		if err == nil && test.shouldErr {
+			t.Errorf("Case %d - Expected an error. got nothing", i)
+		}
+
+		if err != nil && !test.shouldErr {
+			t.Errorf("Case %d - Expected no error. got %s", i, err.Error())
+		}
+	}
+}
+
+type testResolver struct {
+	errOn  string
+	result []*net.SRV
+}
+
+func (r testResolver) LookupSRV(ctx context.Context, _, _, service string) (string, []*net.SRV, error) {
+	if service == r.errOn {
+		return "", nil, errors.New("an error occurred")
+	}
+
+	return "", r.result, nil
+}
+
+func TestResolveHost(t *testing.T) {
+	upstream := &staticUpstream{
+		resolver: testResolver{
+			errOn: "srv://problematic.service.name",
+			result: []*net.SRV{
+				{Target: "target-1.fqdn", Port: 85, Priority: 1, Weight: 1},
+				{Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1},
+				{Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1},
+			},
+		},
+	}
+
+	tests := []struct {
+		host      string
+		expect    []string
+		isSrv     bool
+		shouldErr bool
+	}{
+		// Static DNS records
+		{"http://subdomain.domain.service",
+			[]string{"http://subdomain.domain.service"},
+			false,
+			false},
+		{"https://subdomain.domain.service",
+			[]string{"https://subdomain.domain.service"},
+			false,
+			false},
+		{"http://subdomain.domain.service:76",
+			[]string{"http://subdomain.domain.service:76"},
+			false,
+			false},
+		{"https://subdomain.domain.service:65",
+			[]string{"https://subdomain.domain.service:65"},
+			false,
+			false},
+
+		// SRV lookups
+		{"srv://service.name", []string{
+			"http://target-1.fqdn:85",
+			"http://target-2.fqdn:33",
+			"http://target-3.fqdn:94",
+		}, true, false},
+		{"srv+https://service.name", []string{
+			"https://target-1.fqdn:85",
+			"https://target-2.fqdn:33",
+			"https://target-3.fqdn:94",
+		}, true, false},
+		{"srv://problematic.service.name", []string{}, true, true},
+	}
+
+	for i, test := range tests {
+		results, isSrv, err := upstream.resolveHost(test.host)
+		if err == nil && test.shouldErr {
+			t.Errorf("Test %d - expected an error, got none", i)
+		}
+
+		if err != nil && !test.shouldErr {
+			t.Errorf("Test %d - unexpected error %s", i, err.Error())
+		}
+
+		if test.isSrv && !isSrv {
+			t.Errorf("Test %d - expecting resolution to be SRV lookup but it isn't", i)
+		}
+
+		if isSrv && !test.isSrv {
+			t.Errorf("Test %d - expecting resolution to be normal lookup, got SRV", i)
+		}
+
+		if !reflect.DeepEqual(results, test.expect) {
+			t.Errorf("Test %d - resolution result %#v does not match expected value %#v", i, results, test.expect)
+		}
+	}
+}
+
+func TestSRVHealthCheck(t *testing.T) {
+	serverURL, err := url.Parse(workableServer.URL)
+	if err != nil {
+		t.Errorf("Failed to parse test server URL: %s", err.Error())
+	}
+
+	pp, err := strconv.Atoi(serverURL.Port())
+	if err != nil {
+		t.Errorf("Failed to parse test server port [%s]: %s", serverURL.Port(), err.Error())
+	}
+
+	port := uint16(pp)
+
+	allGoodResolver := testResolver{
+		result: []*net.SRV{
+			{Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1},
+		},
+	}
+
+	partialFailureResolver := testResolver{
+		result: []*net.SRV{
+			{Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1},
+			{Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1},
+			{Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1},
+		},
+	}
+
+	fullFailureResolver := testResolver{
+		result: []*net.SRV{
+			{Target: "target-1.fqdn", Port: 876, Priority: 1, Weight: 1},
+			{Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1},
+			{Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1},
+		},
+	}
+
+	resolutionErrorResolver := testResolver{
+		errOn:  "srv://tag.service.consul",
+		result: []*net.SRV{},
+	}
+
+	upstream := &staticUpstream{
+		Hosts: []*UpstreamHost{
+			{Name: "srv://tag.service.consul"},
+		},
+		FailTimeout: 10 * time.Second,
+		MaxFails:    1,
+	}
+
+	tests := []struct {
+		resolver   testResolver
+		shouldFail bool
+		shouldErr  bool
+	}{
+		{allGoodResolver, false, false},
+		{partialFailureResolver, false, false},
+		{fullFailureResolver, true, false},
+		{resolutionErrorResolver, true, true},
+	}
+
+	for i, test := range tests {
+		upstream.resolver = test.resolver
+		upstream.healthCheck()
+		if upstream.Hosts[0].Down() && !test.shouldFail {
+			t.Errorf("Test %d - expected all healthchecks to pass, all failing", i)
+		}
+
+		if test.shouldFail && !upstream.Hosts[0].Down() {
+			t.Errorf("Test %d - expected all healthchecks to fail, all passing", i)
+		}
+
+		status := fmt.Sprintf("%s", upstream.Hosts[0].HealthCheckResult.Load())
+
+		if test.shouldFail && !test.shouldErr && status != "Failed" {
+			t.Errorf("Test %d - Expected health check result to be 'Failed', got '%s'", i, status)
+		}
+
+		if !test.shouldFail && status != "OK" {
+			t.Errorf("Test %d - Expected health check result to be 'OK', got '%s'", i, status)
+		}
+
+		if test.shouldErr && status != "an error occurred" {
+			t.Errorf("Test %d - Expected health check result to be 'an error occured', got '%s'", i, status)
+		}
+	}
+}
-- 
2.30.9