Commit c972ea39 authored by Mateusz Gajewski's avatar Mateusz Gajewski Committed by Matt Holt

Fastcgi upstreams (#1264)

* Make fastcgi load balanceable too

* Address one more corner case - invalid configuration fastcgi /

* After review fixes

* Simplify conditions

* Error message

* New fastcgi syntax

* golint will be happy

* Change syntax
parent 12fd3499
package fastcgi
import "sync"
import (
"errors"
"sync"
"sync/atomic"
)
type dialer interface {
Dial() (*FCGIClient, error)
Close(*FCGIClient) error
Dial() (Client, error)
Close(Client) error
}
// basicDialer is a basic dialer that wraps default fcgi functions.
......@@ -12,8 +16,8 @@ type basicDialer struct {
network, address string
}
func (b basicDialer) Dial() (*FCGIClient, error) { return Dial(b.network, b.address) }
func (b basicDialer) Close(c *FCGIClient) error { return c.Close() }
func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address) }
func (b basicDialer) Close(c Client) error { return c.Close() }
// persistentDialer keeps a pool of fcgi connections.
// connections are not closed after use, rather added back to the pool for reuse.
......@@ -21,11 +25,11 @@ type persistentDialer struct {
size int
network string
address string
pool []*FCGIClient
pool []Client
sync.Mutex
}
func (p *persistentDialer) Dial() (*FCGIClient, error) {
func (p *persistentDialer) Dial() (Client, error) {
p.Lock()
// connection is available, return first one.
if len(p.pool) > 0 {
......@@ -42,7 +46,7 @@ func (p *persistentDialer) Dial() (*FCGIClient, error) {
return Dial(p.network, p.address)
}
func (p *persistentDialer) Close(client *FCGIClient) error {
func (p *persistentDialer) Close(client Client) error {
p.Lock()
if len(p.pool) < p.size {
// pool is not full yet, add connection for reuse
......@@ -57,3 +61,35 @@ func (p *persistentDialer) Close(client *FCGIClient) error {
// otherwise, close the connection.
return client.Close()
}
type loadBalancingDialer struct {
dialers []dialer
current int64
}
func (m *loadBalancingDialer) Dial() (Client, error) {
nextDialerIndex := atomic.AddInt64(&m.current, 1) % int64(len(m.dialers))
currentDialer := m.dialers[nextDialerIndex]
client, err := currentDialer.Dial()
if err != nil {
return nil, err
}
return &dialerAwareClient{Client: client, dialer: currentDialer}, nil
}
func (m *loadBalancingDialer) Close(c Client) error {
// Close the client according to dialer behaviour
if da, ok := c.(*dialerAwareClient); ok {
return da.dialer.Close(c)
}
return errors.New("Cannot close client")
}
type dialerAwareClient struct {
Client
dialer dialer
}
package fastcgi
import (
"errors"
"testing"
)
func TestLoadbalancingDialer(t *testing.T) {
// given
runs := 100
mockDialer1 := new(mockDialer)
mockDialer2 := new(mockDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1, mockDialer2}}
// when
for i := 0; i < runs; i++ {
client, err := dialer.Dial()
dialer.Close(client)
if err != nil {
t.Errorf("Expected error to be nil")
}
}
// then
if mockDialer1.dialCalled != mockDialer2.dialCalled && mockDialer1.dialCalled != 50 {
t.Errorf("Expected dialer to call Dial() on multiple backend dialers %d times [actual: %d, %d]", 50, mockDialer1.dialCalled, mockDialer2.dialCalled)
}
if mockDialer1.closeCalled != mockDialer2.closeCalled && mockDialer1.closeCalled != 50 {
t.Errorf("Expected dialer to call Close() on multiple backend dialers %d times [actual: %d, %d]", 50, mockDialer1.closeCalled, mockDialer2.closeCalled)
}
}
func TestLoadBalancingDialerShouldReturnDialerAwareClient(t *testing.T) {
// given
mockDialer1 := new(mockDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1}}
// when
client, err := dialer.Dial()
// then
if err != nil {
t.Errorf("Expected error to be nil")
}
if awareClient, ok := client.(*dialerAwareClient); !ok {
t.Error("Expected dialer to wrap client")
} else {
if awareClient.dialer != mockDialer1 {
t.Error("Expected wrapped client to have reference to dialer")
}
}
}
func TestLoadBalancingDialerShouldUnderlyingReturnDialerError(t *testing.T) {
// given
mockDialer1 := new(errorReturningDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1}}
// when
_, err := dialer.Dial()
// then
if err.Error() != "Error during dial" {
t.Errorf("Expected 'Error during dial', got: '%s'", err.Error())
}
}
func TestLoadBalancingDialerShouldCloseClient(t *testing.T) {
// given
mockDialer1 := new(mockDialer)
mockDialer2 := new(mockDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1, mockDialer2}}
client, _ := dialer.Dial()
// when
err := dialer.Close(client)
// then
if err != nil {
t.Error("Expected error not to occur")
}
// load balancing starts from index 1
if mockDialer2.client != client {
t.Errorf("Expected Close() to be called on referenced dialer")
}
}
type mockDialer struct {
dialCalled int
closeCalled int
client Client
}
type mockClient struct {
Client
}
func (m *mockDialer) Dial() (Client, error) {
m.dialCalled++
return mockClient{Client: &FCGIClient{}}, nil
}
func (m *mockDialer) Close(c Client) error {
m.client = c
m.closeCalled++
return nil
}
type errorReturningDialer struct {
client Client
}
func (m *errorReturningDialer) Dial() (Client, error) {
return mockClient{Client: &FCGIClient{}}, errors.New("Error during dial")
}
func (m *errorReturningDialer) Close(c Client) error {
m.client = c
return errors.New("Error during close")
}
......@@ -111,9 +111,9 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
defer rule.dialer.Close(fcgiBackend)
// Log any stderr output from upstream
if fcgiBackend.stderr.Len() != 0 {
if stderr := fcgiBackend.StdErr(); stderr.Len() != 0 {
// Remove trailing newline, error logger already does this.
err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
err = LogError(strings.TrimSuffix(stderr.String(), "\n"))
}
// Normally we would return the status code if it is an error status (>= 400),
......
......@@ -106,6 +106,16 @@ const (
maxPad = 255
)
// Client interface
type Client interface {
Get(pair map[string]string) (response *http.Response, err error)
Head(pair map[string]string) (response *http.Response, err error)
Options(pairs map[string]string) (response *http.Response, err error)
Post(pairs map[string]string, method string, bodyType string, body io.Reader, contentLength int) (response *http.Response, err error)
Close() error
StdErr() bytes.Buffer
}
type header struct {
Version uint8
Type uint8
......@@ -197,22 +207,29 @@ func (c *FCGIClient) Close() error {
return c.rwc.Close()
}
func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) {
func (c *FCGIClient) writeRecord(recType uint8, content []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.buf.Reset()
c.h.init(recType, c.reqID, len(content))
if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
return err
}
if _, err := c.buf.Write(content); err != nil {
return err
}
if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
return err
}
_, err = c.rwc.Write(c.buf.Bytes())
return err
if _, err := c.rwc.Write(c.buf.Bytes()); err != nil {
return err
}
return nil
}
func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error {
......@@ -360,6 +377,11 @@ func (w *streamReader) Read(p []byte) (n int, err error) {
return
}
// StdErr returns stderr stream
func (c *FCGIClient) StdErr() bytes.Buffer {
return c.stderr
}
// Do made the request and returns a io.Reader that translates the data read
// from fcgi responder out of fcgi packet before returning it.
func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) {
......
......@@ -5,6 +5,7 @@ import (
"net/http"
"path/filepath"
"strconv"
"strings"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
......@@ -55,26 +56,21 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
args := c.RemainingArgs()
switch len(args) {
case 0:
if len(args) < 2 || len(args) > 3 {
return rules, c.ArgErr()
case 1:
rule.Path = "/"
rule.Address = args[0]
case 2:
rule.Path = args[0]
rule.Address = args[1]
case 3:
rule.Path = args[0]
rule.Address = args[1]
err := fastcgiPreset(args[2], &rule)
if err != nil {
return rules, c.Err("Invalid fastcgi rule preset '" + args[2] + "'")
}
rule.Path = args[0]
upstreams := []string{args[1]}
if len(args) == 3 {
if err := fastcgiPreset(args[2], &rule); err != nil {
return rules, err
}
}
network, address := parseAddress(rule.Address)
rule.dialer = basicDialer{network: network, address: address}
var dialers []dialer
var poolSize = -1
for c.NextBlock() {
switch c.Val() {
......@@ -94,6 +90,15 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
return rules, c.ArgErr()
}
rule.IndexFiles = args
case "upstream":
args := c.RemainingArgs()
if len(args) != 1 {
return rules, c.ArgErr()
}
upstreams = append(upstreams, args[0])
case "env":
envArgs := c.RemainingArgs()
if len(envArgs) < 2 {
......@@ -106,6 +111,7 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
return rules, c.ArgErr()
}
rule.IgnoredSubPaths = ignoredPaths
case "pool":
if !c.NextArg() {
return rules, c.ArgErr()
......@@ -115,13 +121,24 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
return rules, err
}
if pool >= 0 {
rule.dialer = &persistentDialer{size: pool, network: network, address: address}
poolSize = pool
} else {
return rules, c.Errf("positive integer expected, found %d", pool)
}
}
}
for _, rawAddress := range upstreams {
network, address := parseAddress(rawAddress)
if poolSize >= 0 {
dialers = append(dialers, &persistentDialer{size: poolSize, network: network, address: address})
} else {
dialers = append(dialers, basicDialer{network: network, address: address})
}
}
rule.dialer = &loadBalancingDialer{dialers: dialers}
rule.Address = strings.Join(upstreams, ",")
rules = append(rules, rule)
}
......
......@@ -76,9 +76,31 @@ func TestFastcgiParse(t *testing.T) {
Address: "127.0.0.1:9000",
Ext: ".php",
SplitPath: ".php",
dialer: basicDialer{network: "tcp", address: "127.0.0.1:9000"},
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000"}}},
IndexFiles: []string{"index.php"},
}}},
{`fastcgi /blog 127.0.0.1:9000 php {
upstream 127.0.0.1:9001
}`,
false, []Rule{{
Path: "/blog",
Address: "127.0.0.1:9000,127.0.0.1:9001",
Ext: ".php",
SplitPath: ".php",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000"}, basicDialer{network: "tcp", address: "127.0.0.1:9001"}}},
IndexFiles: []string{"index.php"},
}}},
{`fastcgi /blog 127.0.0.1:9000 {
upstream 127.0.0.1:9001
}`,
false, []Rule{{
Path: "/blog",
Address: "127.0.0.1:9000,127.0.0.1:9001",
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000"}, basicDialer{network: "tcp", address: "127.0.0.1:9001"}}},
IndexFiles: []string{},
}}},
{`fastcgi / ` + defaultAddress + ` {
split .html
}`,
......@@ -87,7 +109,7 @@ func TestFastcgiParse(t *testing.T) {
Address: defaultAddress,
Ext: "",
SplitPath: ".html",
dialer: basicDialer{network: network, address: address},
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
IndexFiles: []string{},
}}},
{`fastcgi / ` + defaultAddress + ` {
......@@ -99,7 +121,7 @@ func TestFastcgiParse(t *testing.T) {
Address: "127.0.0.1:9001",
Ext: "",
SplitPath: ".html",
dialer: basicDialer{network: network, address: address},
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
IndexFiles: []string{},
IgnoredSubPaths: []string{"/admin", "/user"},
}}},
......@@ -111,18 +133,19 @@ func TestFastcgiParse(t *testing.T) {
Address: defaultAddress,
Ext: "",
SplitPath: "",
dialer: &persistentDialer{size: 0, network: network, address: address},
dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 0, network: network, address: address}}},
IndexFiles: []string{},
}}},
{`fastcgi / ` + defaultAddress + ` {
{`fastcgi / 127.0.0.1:8080 {
upstream 127.0.0.1:9000
pool 5
}`,
false, []Rule{{
Path: "/",
Address: defaultAddress,
Address: "127.0.0.1:8080,127.0.0.1:9000",
Ext: "",
SplitPath: "",
dialer: &persistentDialer{size: 5, network: network, address: address},
dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:8080"}, &persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:9000"}}},
IndexFiles: []string{},
}}},
{`fastcgi / ` + defaultAddress + ` {
......@@ -133,9 +156,14 @@ func TestFastcgiParse(t *testing.T) {
Address: defaultAddress,
Ext: "",
SplitPath: ".php",
dialer: basicDialer{network: network, address: address},
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
IndexFiles: []string{},
}}},
{`fastcgi / {
}`,
true, []Rule{},
},
}
for i, test := range tests {
actualFastcgiConfigs, err := fastcgiParse(caddy.NewTestController("http", test.inputFastcgiConfig))
......@@ -175,20 +203,7 @@ func TestFastcgiParse(t *testing.T) {
t.Errorf("Test %d expected %dth FastCGI dialer to be of type %T, but got %T",
i, j, test.expectedFastcgiConfig[j].dialer, actualFastcgiConfig.dialer)
} else {
equal := true
switch actual := actualFastcgiConfig.dialer.(type) {
case basicDialer:
equal = actualFastcgiConfig.dialer == test.expectedFastcgiConfig[j].dialer
case *persistentDialer:
if expected, ok := test.expectedFastcgiConfig[j].dialer.(*persistentDialer); ok {
equal = actual.Equals(expected)
} else {
equal = false
}
default:
t.Errorf("Unkonw dialer type %T", actualFastcgiConfig.dialer)
}
if !equal {
if !areDialersEqual(actualFastcgiConfig.dialer, test.expectedFastcgiConfig[j].dialer, t) {
t.Errorf("Test %d expected %dth FastCGI dialer to be %v, but got %v",
i, j, test.expectedFastcgiConfig[j].dialer, actualFastcgiConfig.dialer)
}
......@@ -205,5 +220,31 @@ func TestFastcgiParse(t *testing.T) {
}
}
}
}
func areDialersEqual(current, expected dialer, t *testing.T) bool {
switch actual := current.(type) {
case *loadBalancingDialer:
if expected, ok := expected.(*loadBalancingDialer); ok {
for i := 0; i < len(actual.dialers); i++ {
if !areDialersEqual(actual.dialers[i], expected.dialers[i], t) {
return false
}
}
return true
}
case basicDialer:
return current == expected
case *persistentDialer:
if expected, ok := expected.(*persistentDialer); ok {
return actual.Equals(expected)
}
default:
t.Errorf("Unknown dialer type %T", current)
}
return false
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment