Commit 88d3dcae authored by Kris Hamoud's avatar Kris Hamoud

added ip_hash load balancing

updated tests

fixed comment format

fixed formatting, minor logic fix

added newline to EOF

updated logic, fixed tests

added comment

updated formatting

updated test output

fixed typo
parent 72af3f82
package proxy
import (
"hash/fnv"
"math"
"math/rand"
"net"
"net/http"
"sync"
)
......@@ -11,20 +14,21 @@ type HostPool []*UpstreamHost
// Policy decides how a host will be selected from a pool.
type Policy interface {
Select(pool HostPool) *UpstreamHost
Select(pool HostPool, r *http.Request) *UpstreamHost
}
func init() {
RegisterPolicy("random", func() Policy { return &Random{} })
RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
RegisterPolicy("ip_hash", func() Policy { return &IPHash{} })
}
// Random is a policy that selects up hosts from a pool at random.
type Random struct{}
// Select selects an up host at random from the specified pool.
func (r *Random) Select(pool HostPool) *UpstreamHost {
func (r *Random) Select(pool HostPool, request *http.Request) *UpstreamHost {
// Because the number of available hosts isn't known
// up front, the host is selected via reservoir sampling
......@@ -53,7 +57,7 @@ type LeastConn struct{}
// Select selects the up host with the least number of connections in the
// pool. If more than one host has the same least number of connections,
// one of the hosts is chosen at random.
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
func (r *LeastConn) Select(pool HostPool, request *http.Request) *UpstreamHost {
var bestHost *UpstreamHost
count := 0
leastConn := int64(math.MaxInt64)
......@@ -86,7 +90,7 @@ type RoundRobin struct {
}
// Select selects an up host from the pool using a round robin ordering scheme.
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
func (r *RoundRobin) Select(pool HostPool, request *http.Request) *UpstreamHost {
poolLen := uint32(len(pool))
r.mutex.Lock()
defer r.mutex.Unlock()
......@@ -100,3 +104,35 @@ func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
}
return nil
}
// IPHash is a policy that selects hosts based on hashing the request ip
type IPHash struct{}
func hash(s string) uint32 {
h := fnv.New32a()
h.Write([]byte(s))
return h.Sum32()
}
// Select selects an up host from the pool using a round robin ordering scheme.
func (r *IPHash) Select(pool HostPool, request *http.Request) *UpstreamHost {
poolLen := uint32(len(pool))
clientIP, _, err := net.SplitHostPort(request.RemoteAddr)
if err != nil {
clientIP = request.RemoteAddr
}
hash := hash(clientIP)
for {
if poolLen == 0 {
break
}
index := hash % poolLen
host := pool[index]
if host.Available() {
return host
}
pool = append(pool[:index], pool[index+1:]...)
poolLen--
}
return nil
}
......@@ -21,7 +21,7 @@ func TestMain(m *testing.M) {
type customPolicy struct{}
func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
func (r *customPolicy) Select(pool HostPool, request *http.Request) *UpstreamHost {
return pool[0]
}
......@@ -43,37 +43,39 @@ func testPool() HostPool {
func TestRoundRobinPolicy(t *testing.T) {
pool := testPool()
rrPolicy := &RoundRobin{}
h := rrPolicy.Select(pool)
request, _ := http.NewRequest("GET", "/", nil)
h := rrPolicy.Select(pool, request)
// First selected host is 1, because counter starts at 0
// and increments before host is selected
if h != pool[1] {
t.Error("Expected first round robin host to be second host in the pool.")
}
h = rrPolicy.Select(pool)
h = rrPolicy.Select(pool, request)
if h != pool[2] {
t.Error("Expected second round robin host to be third host in the pool.")
}
h = rrPolicy.Select(pool)
h = rrPolicy.Select(pool, request)
if h != pool[0] {
t.Error("Expected third round robin host to be first host in the pool.")
}
// mark host as down
pool[1].Unhealthy = true
h = rrPolicy.Select(pool)
h = rrPolicy.Select(pool, request)
if h != pool[2] {
t.Error("Expected to skip down host.")
}
// mark host as up
pool[1].Unhealthy = false
h = rrPolicy.Select(pool)
h = rrPolicy.Select(pool, request)
if h == pool[2] {
t.Error("Expected to balance evenly among healthy hosts")
}
// mark host as full
pool[1].Conns = 1
pool[1].MaxConns = 1
h = rrPolicy.Select(pool)
h = rrPolicy.Select(pool, request)
if h != pool[2] {
t.Error("Expected to skip full host.")
}
......@@ -82,14 +84,16 @@ func TestRoundRobinPolicy(t *testing.T) {
func TestLeastConnPolicy(t *testing.T) {
pool := testPool()
lcPolicy := &LeastConn{}
request, _ := http.NewRequest("GET", "/", nil)
pool[0].Conns = 10
pool[1].Conns = 10
h := lcPolicy.Select(pool)
h := lcPolicy.Select(pool, request)
if h != pool[2] {
t.Error("Expected least connection host to be third host.")
}
pool[2].Conns = 100
h = lcPolicy.Select(pool)
h = lcPolicy.Select(pool, request)
if h != pool[0] && h != pool[1] {
t.Error("Expected least connection host to be first or second host.")
}
......@@ -98,8 +102,127 @@ func TestLeastConnPolicy(t *testing.T) {
func TestCustomPolicy(t *testing.T) {
pool := testPool()
customPolicy := &customPolicy{}
h := customPolicy.Select(pool)
request, _ := http.NewRequest("GET", "/", nil)
h := customPolicy.Select(pool, request)
if h != pool[0] {
t.Error("Expected custom policy host to be the first host.")
}
}
func TestIPHashPolicy(t *testing.T) {
pool := testPool()
ipHash := &IPHash{}
request, _ := http.NewRequest("GET", "/", nil)
// We should be able to predict where every request is routed.
request.RemoteAddr = "172.0.0.1:80"
h := ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
request.RemoteAddr = "172.0.0.2:80"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
request.RemoteAddr = "172.0.0.3:80"
h = ipHash.Select(pool, request)
if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.")
}
request.RemoteAddr = "172.0.0.4:80"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
// we should get the same results without a port
request.RemoteAddr = "172.0.0.1"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
request.RemoteAddr = "172.0.0.2"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
request.RemoteAddr = "172.0.0.3"
h = ipHash.Select(pool, request)
if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.")
}
request.RemoteAddr = "172.0.0.4"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
// we should get a healthy host if the original host is unhealthy and a
// healthy host is available
request.RemoteAddr = "172.0.0.1"
pool[1].Unhealthy = true
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
}
request.RemoteAddr = "172.0.0.2"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
pool[1].Unhealthy = false
request.RemoteAddr = "172.0.0.3"
pool[2].Unhealthy = true
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
}
request.RemoteAddr = "172.0.0.4"
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
}
// We should be able to resize the host pool and still be able to predict
// where a request will be routed with the same IP's used above
pool = []*UpstreamHost{
{
Name: workableServer.URL, // this should resolve (healthcheck test)
},
{
Name: "http://localhost:99998", // this shouldn't
},
}
pool = HostPool(pool)
request.RemoteAddr = "172.0.0.1:80"
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
}
request.RemoteAddr = "172.0.0.2:80"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
request.RemoteAddr = "172.0.0.3:80"
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
}
request.RemoteAddr = "172.0.0.4:80"
h = ipHash.Select(pool, request)
if h != pool[1] {
t.Error("Expected ip hash policy host to be the second host.")
}
// We should get nil when there are no healthy hosts
pool[0].Unhealthy = true
pool[1].Unhealthy = true
h = ipHash.Select(pool, request)
if h != nil {
t.Error("Expected ip hash policy host to be nil.")
}
}
......@@ -27,7 +27,7 @@ type Upstream interface {
// The path this upstream host should be routed on
From() string
// Selects an upstream host to be routed to.
Select() *UpstreamHost
Select(*http.Request) *UpstreamHost
// Checks if subpath is not an ignored path
AllowedPath(string) bool
}
......@@ -93,7 +93,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// hosts until timeout (or until we get a nil host).
start := time.Now()
for time.Now().Sub(start) < tryDuration {
host := upstream.Select()
host := upstream.Select(r)
if host == nil {
return http.StatusBadGateway, errUnreachable
}
......
......@@ -736,7 +736,7 @@ func (u *fakeUpstream) From() string {
return u.from
}
func (u *fakeUpstream) Select() *UpstreamHost {
func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
if u.host == nil {
uri, err := url.Parse(u.name)
if err != nil {
......@@ -781,7 +781,7 @@ func (u *fakeWsUpstream) From() string {
return "/"
}
func (u *fakeWsUpstream) Select() *UpstreamHost {
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
uri, _ := url.Parse(u.name)
return &UpstreamHost{
Name: u.name,
......
......@@ -346,7 +346,7 @@ func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
}
}
func (u *staticUpstream) Select() *UpstreamHost {
func (u *staticUpstream) Select(r *http.Request) *UpstreamHost {
pool := u.Hosts
if len(pool) == 1 {
if !pool[0].Available() {
......@@ -364,11 +364,10 @@ func (u *staticUpstream) Select() *UpstreamHost {
if allUnavailable {
return nil
}
if u.Policy == nil {
return (&Random{}).Select(pool)
return (&Random{}).Select(pool, r)
}
return u.Policy.Select(pool)
return u.Policy.Select(pool, r)
}
func (u *staticUpstream) AllowedPath(requestPath string) bool {
......
package proxy
import (
"github.com/mholt/caddy/caddyfile"
"net/http"
"strings"
"testing"
"time"
"github.com/mholt/caddy/caddyfile"
)
func TestNewHost(t *testing.T) {
......@@ -72,14 +72,15 @@ func TestSelect(t *testing.T) {
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
r, _ := http.NewRequest("GET", "/", nil)
upstream.Hosts[0].Unhealthy = true
upstream.Hosts[1].Unhealthy = true
upstream.Hosts[2].Unhealthy = true
if h := upstream.Select(); h != nil {
if h := upstream.Select(r); h != nil {
t.Error("Expected select to return nil as all host are down")
}
upstream.Hosts[2].Unhealthy = false
if h := upstream.Select(); h == nil {
if h := upstream.Select(r); h == nil {
t.Error("Expected select to not return nil")
}
upstream.Hosts[0].Conns = 1
......@@ -88,11 +89,11 @@ func TestSelect(t *testing.T) {
upstream.Hosts[1].MaxConns = 1
upstream.Hosts[2].Conns = 1
upstream.Hosts[2].MaxConns = 1
if h := upstream.Select(); h != nil {
if h := upstream.Select(r); h != nil {
t.Error("Expected select to return nil as all hosts are full")
}
upstream.Hosts[2].Conns = 0
if h := upstream.Select(); h == nil {
if h := upstream.Select(r); h == nil {
t.Error("Expected select to not return nil")
}
}
......@@ -188,6 +189,7 @@ func TestParseBlockHealthCheck(t *testing.T) {
}
func TestParseBlock(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
tests := []struct {
config string
}{
......@@ -207,7 +209,7 @@ func TestParseBlock(t *testing.T) {
t.Error("Expected no error. Got:", err.Error())
}
for _, upstream := range upstreams {
headers := upstream.Select().UpstreamHeaders
headers := upstream.Select(r).UpstreamHeaders
if _, ok := headers["Host"]; !ok {
t.Errorf("Test %d: Could not find the Host header", i+1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment