Commit 463c9d9d authored by Augusto Roman's avatar Augusto Roman Committed by Matt Holt

Fix data race for max connection limiting in proxy directive. (#1438)

* Fix data race for max connection limiting in proxy directive.

The Conns and Unhealthy fields are updated concurrently across all active
requests.  Because of this, they must use atomic operations for reads and
writes.

Prior to this change, Conns was incremented atomically, but read unsafely.
Unhealthly was updated & read unsafely.  The new test
TestReverseProxyMaxConnLimit exposes this race when run with -race.

Switching to atomic operations makes the race detector happy.

* oops, remove leftover dead code.
parent 1bd9e9e5
...@@ -60,13 +60,13 @@ func TestRoundRobinPolicy(t *testing.T) { ...@@ -60,13 +60,13 @@ func TestRoundRobinPolicy(t *testing.T) {
t.Error("Expected third round robin host to be first host in the pool.") t.Error("Expected third round robin host to be first host in the pool.")
} }
// mark host as down // mark host as down
pool[1].Unhealthy = true pool[1].Unhealthy = 1
h = rrPolicy.Select(pool, request) h = rrPolicy.Select(pool, request)
if h != pool[2] { if h != pool[2] {
t.Error("Expected to skip down host.") t.Error("Expected to skip down host.")
} }
// mark host as up // mark host as up
pool[1].Unhealthy = false pool[1].Unhealthy = 0
h = rrPolicy.Select(pool, request) h = rrPolicy.Select(pool, request)
if h == pool[2] { if h == pool[2] {
...@@ -161,7 +161,7 @@ func TestIPHashPolicy(t *testing.T) { ...@@ -161,7 +161,7 @@ func TestIPHashPolicy(t *testing.T) {
// we should get a healthy host if the original host is unhealthy and a // we should get a healthy host if the original host is unhealthy and a
// healthy host is available // healthy host is available
request.RemoteAddr = "172.0.0.1" request.RemoteAddr = "172.0.0.1"
pool[1].Unhealthy = true pool[1].Unhealthy = 1
h = ipHash.Select(pool, request) h = ipHash.Select(pool, request)
if h != pool[2] { if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.") t.Error("Expected ip hash policy host to be the third host.")
...@@ -172,10 +172,10 @@ func TestIPHashPolicy(t *testing.T) { ...@@ -172,10 +172,10 @@ func TestIPHashPolicy(t *testing.T) {
if h != pool[2] { if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.") t.Error("Expected ip hash policy host to be the third host.")
} }
pool[1].Unhealthy = false pool[1].Unhealthy = 0
request.RemoteAddr = "172.0.0.3" request.RemoteAddr = "172.0.0.3"
pool[2].Unhealthy = true pool[2].Unhealthy = 1
h = ipHash.Select(pool, request) h = ipHash.Select(pool, request)
if h != pool[0] { if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.") t.Error("Expected ip hash policy host to be the first host.")
...@@ -219,8 +219,8 @@ func TestIPHashPolicy(t *testing.T) { ...@@ -219,8 +219,8 @@ func TestIPHashPolicy(t *testing.T) {
} }
// We should get nil when there are no healthy hosts // We should get nil when there are no healthy hosts
pool[0].Unhealthy = true pool[0].Unhealthy = 1
pool[1].Unhealthy = true pool[1].Unhealthy = 1
h = ipHash.Select(pool, request) h = ipHash.Select(pool, request)
if h != nil { if h != nil {
t.Error("Expected ip hash policy host to be nil.") t.Error("Expected ip hash policy host to be nil.")
......
...@@ -49,6 +49,8 @@ type UpstreamHostDownFunc func(*UpstreamHost) bool ...@@ -49,6 +49,8 @@ type UpstreamHostDownFunc func(*UpstreamHost) bool
// UpstreamHost represents a single proxy upstream // UpstreamHost represents a single proxy upstream
type UpstreamHost struct { type UpstreamHost struct {
// This field is read & written to concurrently, so all access must use
// atomic operations.
Conns int64 // must be first field to be 64-bit aligned on 32-bit systems Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
MaxConns int64 MaxConns int64
Name string // hostname of this upstream host Name string // hostname of this upstream host
...@@ -59,7 +61,10 @@ type UpstreamHost struct { ...@@ -59,7 +61,10 @@ type UpstreamHost struct {
WithoutPathPrefix string WithoutPathPrefix string
ReverseProxy *ReverseProxy ReverseProxy *ReverseProxy
Fails int32 Fails int32
Unhealthy bool // This is an int32 so that we can use atomic operations to do concurrent
// reads & writes to this value. The default value of 0 indicates that it
// is healthy and any non-zero value indicates unhealthy.
Unhealthy int32
} }
// Down checks whether the upstream host is down or not. // Down checks whether the upstream host is down or not.
...@@ -68,14 +73,14 @@ type UpstreamHost struct { ...@@ -68,14 +73,14 @@ type UpstreamHost struct {
func (uh *UpstreamHost) Down() bool { func (uh *UpstreamHost) Down() bool {
if uh.CheckDown == nil { if uh.CheckDown == nil {
// Default settings // Default settings
return uh.Unhealthy || uh.Fails > 0 return atomic.LoadInt32(&uh.Unhealthy) != 0 || atomic.LoadInt32(&uh.Fails) > 0
} }
return uh.CheckDown(uh) return uh.CheckDown(uh)
} }
// Full checks whether the upstream host has reached its maximum connections // Full checks whether the upstream host has reached its maximum connections
func (uh *UpstreamHost) Full() bool { func (uh *UpstreamHost) Full() bool {
return uh.MaxConns > 0 && uh.Conns >= uh.MaxConns return uh.MaxConns > 0 && atomic.LoadInt64(&uh.Conns) >= uh.MaxConns
} }
// Available checks whether the upstream host is available for proxying to // Available checks whether the upstream host is available for proxying to
......
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
...@@ -143,6 +144,74 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) { ...@@ -143,6 +144,74 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
} }
} }
// This test will fail when using the race detector without atomic reads &
// writes of UpstreamHost.Conns and UpstreamHost.Unhealthy.
func TestReverseProxyMaxConnLimit(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
const MaxTestConns = 2
connReceived := make(chan bool, MaxTestConns)
connContinue := make(chan bool)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
connReceived <- true
<-connContinue
}))
defer backend.Close()
su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`
proxy / `+backend.URL+` {
max_conns `+fmt.Sprint(MaxTestConns)+`
}
`)))
if err != nil {
t.Fatal(err)
}
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: su,
}
var jobs sync.WaitGroup
for i := 0; i < MaxTestConns; i++ {
jobs.Add(1)
go func(i int) {
defer jobs.Done()
w := httptest.NewRecorder()
code, err := p.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
if err != nil {
t.Errorf("Request %d failed: %v", i, err)
} else if code != 0 {
t.Errorf("Bad return code for request %d: %d", i, code)
} else if w.Code != 200 {
t.Errorf("Bad statuc code for request %d: %d", i, w.Code)
}
}(i)
}
// Wait for all the requests to hit the backend.
for i := 0; i < MaxTestConns; i++ {
<-connReceived
}
// Now we should have MaxTestConns requests connected and sitting on the backend
// server. Verify that the next request is rejected.
w := httptest.NewRecorder()
code, err := p.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
if code != http.StatusBadGateway {
t.Errorf("Expected request to be rejected, but got: %d [%v]\nStatus code: %d",
code, err, w.Code)
}
// Now let all the requests complete and verify the status codes for those:
close(connContinue)
// Wait for the initial requests to finish and check their results.
jobs.Wait()
}
func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) { func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
// Capture the expected panic // Capture the expected panic
defer func() { defer func() {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"path" "path"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/mholt/caddy/caddyfile" "github.com/mholt/caddy/caddyfile"
...@@ -128,15 +129,15 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { ...@@ -128,15 +129,15 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
Conns: 0, Conns: 0,
Fails: 0, Fails: 0,
FailTimeout: u.FailTimeout, FailTimeout: u.FailTimeout,
Unhealthy: false, Unhealthy: 0,
UpstreamHeaders: u.upstreamHeaders, UpstreamHeaders: u.upstreamHeaders,
DownstreamHeaders: u.downstreamHeaders, DownstreamHeaders: u.downstreamHeaders,
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc { CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool { return func(uh *UpstreamHost) bool {
if uh.Unhealthy { if atomic.LoadInt32(&uh.Unhealthy) != 0 {
return true return true
} }
if uh.Fails >= u.MaxFails { if atomic.LoadInt32(&uh.Fails) >= u.MaxFails {
return true return true
} }
return false return false
...@@ -355,12 +356,18 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { ...@@ -355,12 +356,18 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
func (u *staticUpstream) healthCheck() { func (u *staticUpstream) healthCheck() {
for _, host := range u.Hosts { for _, host := range u.Hosts {
hostURL := host.Name + u.HealthCheck.Path hostURL := host.Name + u.HealthCheck.Path
var unhealthy bool
if r, err := u.HealthCheck.Client.Get(hostURL); err == nil { if r, err := u.HealthCheck.Client.Get(hostURL); err == nil {
io.Copy(ioutil.Discard, r.Body) io.Copy(ioutil.Discard, r.Body)
r.Body.Close() r.Body.Close()
host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
} else { } else {
host.Unhealthy = true unhealthy = true
}
if unhealthy {
atomic.StoreInt32(&host.Unhealthy, 1)
} else {
atomic.StoreInt32(&host.Unhealthy, 0)
} }
} }
} }
......
...@@ -36,12 +36,12 @@ func TestNewHost(t *testing.T) { ...@@ -36,12 +36,12 @@ func TestNewHost(t *testing.T) {
t.Error("Expected new host not to be down.") t.Error("Expected new host not to be down.")
} }
// mark Unhealthy // mark Unhealthy
uh.Unhealthy = true uh.Unhealthy = 1
if !uh.CheckDown(uh) { if !uh.CheckDown(uh) {
t.Error("Expected unhealthy host to be down.") t.Error("Expected unhealthy host to be down.")
} }
// mark with Fails // mark with Fails
uh.Unhealthy = false uh.Unhealthy = 0
uh.Fails = 1 uh.Fails = 1
if !uh.CheckDown(uh) { if !uh.CheckDown(uh) {
t.Error("Expected failed host to be down.") t.Error("Expected failed host to be down.")
...@@ -74,13 +74,13 @@ func TestSelect(t *testing.T) { ...@@ -74,13 +74,13 @@ func TestSelect(t *testing.T) {
MaxFails: 1, MaxFails: 1,
} }
r, _ := http.NewRequest("GET", "/", nil) r, _ := http.NewRequest("GET", "/", nil)
upstream.Hosts[0].Unhealthy = true upstream.Hosts[0].Unhealthy = 1
upstream.Hosts[1].Unhealthy = true upstream.Hosts[1].Unhealthy = 1
upstream.Hosts[2].Unhealthy = true upstream.Hosts[2].Unhealthy = 1
if h := upstream.Select(r); h != nil { if h := upstream.Select(r); h != nil {
t.Error("Expected select to return nil as all host are down") t.Error("Expected select to return nil as all host are down")
} }
upstream.Hosts[2].Unhealthy = false upstream.Hosts[2].Unhealthy = 0
if h := upstream.Select(r); h == nil { if h := upstream.Select(r); h == nil {
t.Error("Expected select to not return nil") t.Error("Expected select to not return nil")
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment