Commit 4c700efb authored by Matt Holt's avatar Matt Holt Committed by GitHub

Merge pull request #1751 from zikes/header_policy

proxy: add Header load balancing policy
parents 9ad96b33 95366e41
...@@ -18,12 +18,13 @@ type Policy interface { ...@@ -18,12 +18,13 @@ type Policy interface {
} }
func init() { func init() {
RegisterPolicy("random", func() Policy { return &Random{} }) RegisterPolicy("random", func(arg string) Policy { return &Random{} })
RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) RegisterPolicy("least_conn", func(arg string) Policy { return &LeastConn{} })
RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) RegisterPolicy("round_robin", func(arg string) Policy { return &RoundRobin{} })
RegisterPolicy("ip_hash", func() Policy { return &IPHash{} }) RegisterPolicy("ip_hash", func(arg string) Policy { return &IPHash{} })
RegisterPolicy("first", func() Policy { return &First{} }) RegisterPolicy("first", func(arg string) Policy { return &First{} })
RegisterPolicy("uri_hash", func() Policy { return &URIHash{} }) RegisterPolicy("uri_hash", func(arg string) Policy { return &URIHash{} })
RegisterPolicy("header", func(arg string) Policy { return &Header{arg} })
} }
// Random is a policy that selects up hosts from a pool at random. // Random is a policy that selects up hosts from a pool at random.
...@@ -160,3 +161,22 @@ func (r *First) Select(pool HostPool, request *http.Request) *UpstreamHost { ...@@ -160,3 +161,22 @@ func (r *First) Select(pool HostPool, request *http.Request) *UpstreamHost {
} }
return nil return nil
} }
// Header is a policy that selects based on a hash of the given header
type Header struct {
// The name of the request header, the value of which will determine
// how the request is routed
Name string
}
// Select selects the host based on hashing the header value
func (r *Header) Select(pool HostPool, request *http.Request) *UpstreamHost {
if r.Name == "" {
return nil
}
val := request.Header.Get(r.Name)
if val == "" {
return nil
}
return hostByHashing(pool, val)
}
...@@ -302,3 +302,42 @@ func TestUriPolicy(t *testing.T) { ...@@ -302,3 +302,42 @@ func TestUriPolicy(t *testing.T) {
t.Error("Expected uri policy policy host to be nil.") t.Error("Expected uri policy policy host to be nil.")
} }
} }
func TestHeaderPolicy(t *testing.T) {
pool := testPool()
tests := []struct {
Policy *Header
RequestHeaderName string
RequestHeaderValue string
NilHost bool
HostIndex int
}{
{&Header{""}, "", "", true, 0},
{&Header{""}, "Affinity", "somevalue", true, 0},
{&Header{""}, "Affinity", "", true, 0},
{&Header{"Affinity"}, "", "", true, 0},
{&Header{"Affinity"}, "Affinity", "somevalue", false, 1},
{&Header{"Affinity"}, "Affinity", "somevalue2", false, 0},
{&Header{"Affinity"}, "Affinity", "somevalue3", false, 2},
{&Header{"Affinity"}, "Affinity", "", true, 0},
}
for idx, test := range tests {
request, _ := http.NewRequest("GET", "/", nil)
if test.RequestHeaderName != "" {
request.Header.Add(test.RequestHeaderName, test.RequestHeaderValue)
}
host := test.Policy.Select(pool, request)
if test.NilHost && host != nil {
t.Errorf("%d: Expected host to be nil", idx)
}
if !test.NilHost && host == nil {
t.Errorf("%d: Did not expect host to be nil", idx)
}
if !test.NilHost && host != pool[test.HostIndex] {
t.Errorf("%d: Expected Header policy to be host %d", idx, test.HostIndex)
}
}
}
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
) )
var ( var (
supportedPolicies = make(map[string]func() Policy) supportedPolicies = make(map[string]func(string) Policy)
) )
type staticUpstream struct { type staticUpstream struct {
...@@ -243,7 +243,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { ...@@ -243,7 +243,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
if !ok { if !ok {
return c.ArgErr() return c.ArgErr()
} }
u.Policy = policyCreateFunc() arg := ""
if c.NextArg() {
arg = c.Val()
}
u.Policy = policyCreateFunc(arg)
case "fail_timeout": case "fail_timeout":
if !c.NextArg() { if !c.NextArg() {
return c.ArgErr() return c.ArgErr()
...@@ -523,7 +527,7 @@ func (u *staticUpstream) Stop() error { ...@@ -523,7 +527,7 @@ func (u *staticUpstream) Stop() error {
} }
// RegisterPolicy adds a custom policy to the proxy. // RegisterPolicy adds a custom policy to the proxy.
func RegisterPolicy(name string, policy func() Policy) { func RegisterPolicy(name string, policy func(string) Policy) {
supportedPolicies[name] = policy supportedPolicies[name] = policy
} }
......
...@@ -106,7 +106,7 @@ func TestSelect(t *testing.T) { ...@@ -106,7 +106,7 @@ func TestSelect(t *testing.T) {
func TestRegisterPolicy(t *testing.T) { func TestRegisterPolicy(t *testing.T) {
name := "custom" name := "custom"
customPolicy := &customPolicy{} customPolicy := &customPolicy{}
RegisterPolicy(name, func() Policy { return customPolicy }) RegisterPolicy(name, func(string) Policy { return customPolicy })
if _, ok := supportedPolicies[name]; !ok { if _, ok := supportedPolicies[name]; !ok {
t.Error("Expected supportedPolicies to have a custom policy.") t.Error("Expected supportedPolicies to have a custom policy.")
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment