Commit 463c9d9d authored by Augusto Roman's avatar Augusto Roman Committed by Matt Holt

Fix data race for max connection limiting in proxy directive. (#1438)

* Fix data race for max connection limiting in proxy directive.

The Conns and Unhealthy fields are updated concurrently across all active
requests.  Because of this, they must use atomic operations for reads and
writes.

Prior to this change, Conns was incremented atomically, but read unsafely.
Unhealthly was updated & read unsafely.  The new test
TestReverseProxyMaxConnLimit exposes this race when run with -race.

Switching to atomic operations makes the race detector happy.

* oops, remove leftover dead code.
parent 1bd9e9e5
......@@ -60,13 +60,13 @@ func TestRoundRobinPolicy(t *testing.T) {
t.Error("Expected third round robin host to be first host in the pool.")
}
// mark host as down
pool[1].Unhealthy = true
pool[1].Unhealthy = 1
h = rrPolicy.Select(pool, request)
if h != pool[2] {
t.Error("Expected to skip down host.")
}
// mark host as up
pool[1].Unhealthy = false
pool[1].Unhealthy = 0
h = rrPolicy.Select(pool, request)
if h == pool[2] {
......@@ -161,7 +161,7 @@ func TestIPHashPolicy(t *testing.T) {
// we should get a healthy host if the original host is unhealthy and a
// healthy host is available
request.RemoteAddr = "172.0.0.1"
pool[1].Unhealthy = true
pool[1].Unhealthy = 1
h = ipHash.Select(pool, request)
if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.")
......@@ -172,10 +172,10 @@ func TestIPHashPolicy(t *testing.T) {
if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.")
}
pool[1].Unhealthy = false
pool[1].Unhealthy = 0
request.RemoteAddr = "172.0.0.3"
pool[2].Unhealthy = true
pool[2].Unhealthy = 1
h = ipHash.Select(pool, request)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
......@@ -219,8 +219,8 @@ func TestIPHashPolicy(t *testing.T) {
}
// We should get nil when there are no healthy hosts
pool[0].Unhealthy = true
pool[1].Unhealthy = true
pool[0].Unhealthy = 1
pool[1].Unhealthy = 1
h = ipHash.Select(pool, request)
if h != nil {
t.Error("Expected ip hash policy host to be nil.")
......
......@@ -49,6 +49,8 @@ type UpstreamHostDownFunc func(*UpstreamHost) bool
// UpstreamHost represents a single proxy upstream
type UpstreamHost struct {
// This field is read & written to concurrently, so all access must use
// atomic operations.
Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
MaxConns int64
Name string // hostname of this upstream host
......@@ -59,7 +61,10 @@ type UpstreamHost struct {
WithoutPathPrefix string
ReverseProxy *ReverseProxy
Fails int32
Unhealthy bool
// This is an int32 so that we can use atomic operations to do concurrent
// reads & writes to this value. The default value of 0 indicates that it
// is healthy and any non-zero value indicates unhealthy.
Unhealthy int32
}
// Down checks whether the upstream host is down or not.
......@@ -68,14 +73,14 @@ type UpstreamHost struct {
func (uh *UpstreamHost) Down() bool {
if uh.CheckDown == nil {
// Default settings
return uh.Unhealthy || uh.Fails > 0
return atomic.LoadInt32(&uh.Unhealthy) != 0 || atomic.LoadInt32(&uh.Fails) > 0
}
return uh.CheckDown(uh)
}
// Full checks whether the upstream host has reached its maximum connections
func (uh *UpstreamHost) Full() bool {
return uh.MaxConns > 0 && uh.Conns >= uh.MaxConns
return uh.MaxConns > 0 && atomic.LoadInt64(&uh.Conns) >= uh.MaxConns
}
// Available checks whether the upstream host is available for proxying to
......
......@@ -19,6 +19,7 @@ import (
"reflect"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
......@@ -143,6 +144,74 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
}
}
// This test will fail when using the race detector without atomic reads &
// writes of UpstreamHost.Conns and UpstreamHost.Unhealthy.
func TestReverseProxyMaxConnLimit(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
const MaxTestConns = 2
connReceived := make(chan bool, MaxTestConns)
connContinue := make(chan bool)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
connReceived <- true
<-connContinue
}))
defer backend.Close()
su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`
proxy / `+backend.URL+` {
max_conns `+fmt.Sprint(MaxTestConns)+`
}
`)))
if err != nil {
t.Fatal(err)
}
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: su,
}
var jobs sync.WaitGroup
for i := 0; i < MaxTestConns; i++ {
jobs.Add(1)
go func(i int) {
defer jobs.Done()
w := httptest.NewRecorder()
code, err := p.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
if err != nil {
t.Errorf("Request %d failed: %v", i, err)
} else if code != 0 {
t.Errorf("Bad return code for request %d: %d", i, code)
} else if w.Code != 200 {
t.Errorf("Bad statuc code for request %d: %d", i, w.Code)
}
}(i)
}
// Wait for all the requests to hit the backend.
for i := 0; i < MaxTestConns; i++ {
<-connReceived
}
// Now we should have MaxTestConns requests connected and sitting on the backend
// server. Verify that the next request is rejected.
w := httptest.NewRecorder()
code, err := p.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
if code != http.StatusBadGateway {
t.Errorf("Expected request to be rejected, but got: %d [%v]\nStatus code: %d",
code, err, w.Code)
}
// Now let all the requests complete and verify the status codes for those:
close(connContinue)
// Wait for the initial requests to finish and check their results.
jobs.Wait()
}
func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
// Capture the expected panic
defer func() {
......
......@@ -9,6 +9,7 @@ import (
"path"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/mholt/caddy/caddyfile"
......@@ -128,15 +129,15 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
Conns: 0,
Fails: 0,
FailTimeout: u.FailTimeout,
Unhealthy: false,
Unhealthy: 0,
UpstreamHeaders: u.upstreamHeaders,
DownstreamHeaders: u.downstreamHeaders,
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool {
if uh.Unhealthy {
if atomic.LoadInt32(&uh.Unhealthy) != 0 {
return true
}
if uh.Fails >= u.MaxFails {
if atomic.LoadInt32(&uh.Fails) >= u.MaxFails {
return true
}
return false
......@@ -355,12 +356,18 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
func (u *staticUpstream) healthCheck() {
for _, host := range u.Hosts {
hostURL := host.Name + u.HealthCheck.Path
var unhealthy bool
if r, err := u.HealthCheck.Client.Get(hostURL); err == nil {
io.Copy(ioutil.Discard, r.Body)
r.Body.Close()
host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
} else {
host.Unhealthy = true
unhealthy = true
}
if unhealthy {
atomic.StoreInt32(&host.Unhealthy, 1)
} else {
atomic.StoreInt32(&host.Unhealthy, 0)
}
}
}
......
......@@ -36,12 +36,12 @@ func TestNewHost(t *testing.T) {
t.Error("Expected new host not to be down.")
}
// mark Unhealthy
uh.Unhealthy = true
uh.Unhealthy = 1
if !uh.CheckDown(uh) {
t.Error("Expected unhealthy host to be down.")
}
// mark with Fails
uh.Unhealthy = false
uh.Unhealthy = 0
uh.Fails = 1
if !uh.CheckDown(uh) {
t.Error("Expected failed host to be down.")
......@@ -74,13 +74,13 @@ func TestSelect(t *testing.T) {
MaxFails: 1,
}
r, _ := http.NewRequest("GET", "/", nil)
upstream.Hosts[0].Unhealthy = true
upstream.Hosts[1].Unhealthy = true
upstream.Hosts[2].Unhealthy = true
upstream.Hosts[0].Unhealthy = 1
upstream.Hosts[1].Unhealthy = 1
upstream.Hosts[2].Unhealthy = 1
if h := upstream.Select(r); h != nil {
t.Error("Expected select to return nil as all host are down")
}
upstream.Hosts[2].Unhealthy = false
upstream.Hosts[2].Unhealthy = 0
if h := upstream.Select(r); h == nil {
t.Error("Expected select to not return nil")
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment