Commit 6aba4a31 authored by Abiola Ibrahim's avatar Abiola Ibrahim Committed by Matt Holt

fastcgi: Revert persistent connections (#1739)

* Revert fastcgi to emove persistent connections.

* Fix linting errors

* reintroduce timeout tests

* check for non-zero timeout

* ensure resp is not nil
parent 56153e0b
package fastcgi
import (
"errors"
"sync"
"sync/atomic"
"time"
)
type dialer interface {
Dial() (Client, error)
Close(Client) error
}
// basicDialer is a basic dialer that wraps default fcgi functions.
type basicDialer struct {
network string
address string
timeout time.Duration
}
func (b basicDialer) Dial() (Client, error) {
return DialTimeout(b.network, b.address, b.timeout)
}
func (b basicDialer) Close(c Client) error { return c.Close() }
// persistentDialer keeps a pool of fcgi connections.
// connections are not closed after use, rather added back to the pool for reuse.
type persistentDialer struct {
size int
network string
address string
timeout time.Duration
pool []Client
sync.Mutex
}
func (p *persistentDialer) Dial() (Client, error) {
p.Lock()
// connection is available, return first one.
if len(p.pool) > 0 {
client := p.pool[0]
p.pool = p.pool[1:]
p.Unlock()
return client, nil
}
p.Unlock()
// no connection available, create new one
return DialTimeout(p.network, p.address, p.timeout)
}
func (p *persistentDialer) Close(client Client) error {
p.Lock()
if len(p.pool) < p.size {
// pool is not full yet, add connection for reuse
p.pool = append(p.pool, client)
p.Unlock()
return nil
}
p.Unlock()
// otherwise, close the connection.
return client.Close()
}
type loadBalancingDialer struct {
current int64
dialers []dialer
}
func (m *loadBalancingDialer) Dial() (Client, error) {
nextDialerIndex := atomic.AddInt64(&m.current, 1) % int64(len(m.dialers))
currentDialer := m.dialers[nextDialerIndex]
client, err := currentDialer.Dial()
if err != nil {
return nil, err
}
return &dialerAwareClient{Client: client, dialer: currentDialer}, nil
}
func (m *loadBalancingDialer) Close(c Client) error {
// Close the client according to dialer behaviour
if da, ok := c.(*dialerAwareClient); ok {
return da.dialer.Close(c)
}
return errors.New("Cannot close client")
}
type dialerAwareClient struct {
Client
dialer dialer
}
package fastcgi
import (
"errors"
"testing"
)
func TestLoadbalancingDialer(t *testing.T) {
// given
runs := 100
mockDialer1 := new(mockDialer)
mockDialer2 := new(mockDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1, mockDialer2}}
// when
for i := 0; i < runs; i++ {
client, err := dialer.Dial()
dialer.Close(client)
if err != nil {
t.Errorf("Expected error to be nil")
}
}
// then
if mockDialer1.dialCalled != mockDialer2.dialCalled && mockDialer1.dialCalled != 50 {
t.Errorf("Expected dialer to call Dial() on multiple backend dialers %d times [actual: %d, %d]", 50, mockDialer1.dialCalled, mockDialer2.dialCalled)
}
if mockDialer1.closeCalled != mockDialer2.closeCalled && mockDialer1.closeCalled != 50 {
t.Errorf("Expected dialer to call Close() on multiple backend dialers %d times [actual: %d, %d]", 50, mockDialer1.closeCalled, mockDialer2.closeCalled)
}
}
func TestLoadBalancingDialerShouldReturnDialerAwareClient(t *testing.T) {
// given
mockDialer1 := new(mockDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1}}
// when
client, err := dialer.Dial()
// then
if err != nil {
t.Errorf("Expected error to be nil")
}
if awareClient, ok := client.(*dialerAwareClient); !ok {
t.Error("Expected dialer to wrap client")
} else {
if awareClient.dialer != mockDialer1 {
t.Error("Expected wrapped client to have reference to dialer")
}
}
}
func TestLoadBalancingDialerShouldUnderlyingReturnDialerError(t *testing.T) {
// given
mockDialer1 := new(errorReturningDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1}}
// when
_, err := dialer.Dial()
// then
if err.Error() != "Error during dial" {
t.Errorf("Expected 'Error during dial', got: '%s'", err.Error())
}
}
func TestLoadBalancingDialerShouldCloseClient(t *testing.T) {
// given
mockDialer1 := new(mockDialer)
mockDialer2 := new(mockDialer)
dialer := &loadBalancingDialer{dialers: []dialer{mockDialer1, mockDialer2}}
client, _ := dialer.Dial()
// when
err := dialer.Close(client)
// then
if err != nil {
t.Error("Expected error not to occur")
}
// load balancing starts from index 1
if mockDialer2.client != client {
t.Errorf("Expected Close() to be called on referenced dialer")
}
}
type mockDialer struct {
dialCalled int
closeCalled int
client Client
}
type mockClient struct {
Client
}
func (m *mockDialer) Dial() (Client, error) {
m.dialCalled++
return mockClient{Client: &FCGIClient{}}, nil
}
func (m *mockDialer) Close(c Client) error {
m.client = c
m.closeCalled++
return nil
}
type errorReturningDialer struct {
client Client
}
func (m *errorReturningDialer) Dial() (Client, error) {
return mockClient{Client: &FCGIClient{}}, errors.New("Error during dial")
}
func (m *errorReturningDialer) Close(c Client) error {
m.client = c
return errors.New("Error during close")
}
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
package fastcgi package fastcgi
import ( import (
"context"
"errors" "errors"
"io" "io"
"net" "net"
...@@ -14,6 +15,7 @@ import ( ...@@ -14,6 +15,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
...@@ -90,16 +92,28 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -90,16 +92,28 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
} }
// Connect to FastCGI gateway // Connect to FastCGI gateway
fcgiBackend, err := rule.dialer.Dial() network, address := parseAddress(rule.Address())
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() { ctx := context.Background()
return http.StatusGatewayTimeout, err if rule.ConnectTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, rule.ConnectTimeout)
defer cancel()
} }
fcgiBackend, err := DialContext(ctx, network, address)
if err != nil {
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
defer fcgiBackend.Close() defer fcgiBackend.Close()
fcgiBackend.SetReadTimeout(rule.ReadTimeout)
fcgiBackend.SetSendTimeout(rule.SendTimeout) // read/write timeouts
if err := fcgiBackend.SetReadTimeout(rule.ReadTimeout); err != nil {
return http.StatusInternalServerError, err
}
if err := fcgiBackend.SetSendTimeout(rule.SendTimeout); err != nil {
return http.StatusInternalServerError, err
}
var resp *http.Response var resp *http.Response
...@@ -121,6 +135,10 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -121,6 +135,10 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength) resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
} }
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil { if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() { if err, ok := err.(net.Error); ok && err.Timeout() {
return http.StatusGatewayTimeout, err return http.StatusGatewayTimeout, err
...@@ -139,9 +157,9 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -139,9 +157,9 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
} }
// Log any stderr output from upstream // Log any stderr output from upstream
if stderr := fcgiBackend.StdErr(); stderr.Len() != 0 { if fcgiBackend.stderr.Len() != 0 {
// Remove trailing newline, error logger already does this. // Remove trailing newline, error logger already does this.
err = LogError(strings.TrimSuffix(stderr.String(), "\n")) err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
} }
// Normally we would return the status code if it is an error status (>= 400), // Normally we would return the status code if it is an error status (>= 400),
...@@ -303,8 +321,8 @@ type Rule struct { ...@@ -303,8 +321,8 @@ type Rule struct {
// The base path to match. Required. // The base path to match. Required.
Path string Path string
// The address of the FastCGI server. Required. // upstream load balancer
Address string balancer
// Always process files with this extension with fastcgi. // Always process files with this extension with fastcgi.
Ext string Ext string
...@@ -329,14 +347,32 @@ type Rule struct { ...@@ -329,14 +347,32 @@ type Rule struct {
// Ignored paths // Ignored paths
IgnoredSubPaths []string IgnoredSubPaths []string
// The duration used to set a deadline when connecting to an upstream.
ConnectTimeout time.Duration
// The duration used to set a deadline when reading from the FastCGI server. // The duration used to set a deadline when reading from the FastCGI server.
ReadTimeout time.Duration ReadTimeout time.Duration
// The duration used to set a deadline when sending to the FastCGI server. // The duration used to set a deadline when sending to the FastCGI server.
SendTimeout time.Duration SendTimeout time.Duration
}
// balancer is a fastcgi upstream load balancer.
type balancer interface {
// Address picks an upstream address from the
// underlying load balancer.
Address() string
}
// roundRobin is a round robin balancer for fastcgi upstreams.
type roundRobin struct {
addresses []string
index int64
}
// FCGI dialer func (r *roundRobin) Address() string {
dialer dialer index := atomic.AddInt64(&r.index, 1) % int64(len(r.addresses))
return r.addresses[index]
} }
// canSplit checks if path can split into two based on rule.SplitPath. // canSplit checks if path can split into two based on rule.SplitPath.
......
...@@ -29,16 +29,9 @@ func TestServeHTTP(t *testing.T) { ...@@ -29,16 +29,9 @@ func TestServeHTTP(t *testing.T) {
w.Write([]byte(body)) w.Write([]byte(body))
})) }))
network, address := parseAddress(listener.Addr().String())
handler := Handler{ handler := Handler{
Next: nil, Next: nil,
Rules: []Rule{ Rules: []Rule{{Path: "/", balancer: address(listener.Addr().String())}},
{
Path: "/",
Address: listener.Addr().String(),
dialer: basicDialer{network: network, address: address},
},
},
} }
r, err := http.NewRequest("GET", "/", nil) r, err := http.NewRequest("GET", "/", nil)
if err != nil { if err != nil {
...@@ -62,122 +55,25 @@ func TestServeHTTP(t *testing.T) { ...@@ -62,122 +55,25 @@ func TestServeHTTP(t *testing.T) {
} }
} }
// connectionCounter in fact is a listener with an added counter to keep track
// of the number of accepted connections.
type connectionCounter struct {
net.Listener
sync.Mutex
counter int
}
func (l *connectionCounter) Accept() (net.Conn, error) {
l.Lock()
l.counter++
l.Unlock()
return l.Listener.Accept()
}
// TestPersistent ensures that persistent
// as well as the non-persistent fastCGI servers
// send the answers corresnponding to the correct request.
// It also checks the number of tcp connections used.
func TestPersistent(t *testing.T) {
numberOfRequests := 32
for _, poolsize := range []int{0, 1, 5, numberOfRequests} {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to create listener for test: %v", err)
}
listener := &connectionCounter{l, *new(sync.Mutex), 0}
// this fcgi server replies with the request URL
go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body := "This answers a request to " + r.URL.Path
bodyLenStr := strconv.Itoa(len(body))
w.Header().Set("Content-Length", bodyLenStr)
w.Write([]byte(body))
}))
network, address := parseAddress(listener.Addr().String())
handler := Handler{
Next: nil,
Rules: []Rule{{Path: "/", Address: listener.Addr().String(), dialer: &persistentDialer{size: poolsize, network: network, address: address}}},
}
var semaphore sync.WaitGroup
serialMutex := new(sync.Mutex)
serialCounter := 0
parallelCounter := 0
// make some serial followed by some
// parallel requests to challenge the handler
for _, serialize := range []bool{true, false, false, false} {
if serialize {
serialCounter++
} else {
parallelCounter++
}
semaphore.Add(numberOfRequests)
for i := 0; i < numberOfRequests; i++ {
go func(i int, serialize bool) {
defer semaphore.Done()
if serialize {
serialMutex.Lock()
defer serialMutex.Unlock()
}
r, err := http.NewRequest("GET", "/"+strconv.Itoa(i), nil)
if err != nil {
t.Errorf("Unable to create request: %v", err)
}
ctx := context.WithValue(r.Context(), httpserver.OriginalURLCtxKey, *r.URL)
r = r.WithContext(ctx)
w := httptest.NewRecorder()
status, err := handler.ServeHTTP(w, r)
if status != 0 {
t.Errorf("Handler(pool: %v) return status %v", poolsize, status)
}
if err != nil {
t.Errorf("Handler(pool: %v) Error: %v", poolsize, err)
}
want := "This answers a request to /" + strconv.Itoa(i)
if got := w.Body.String(); got != want {
t.Errorf("Expected response from handler(pool: %v) to be '%s', got: '%s'", poolsize, want, got)
}
}(i, serialize)
} //next request
semaphore.Wait()
} // next set of requests (serial/parallel)
listener.Close()
t.Logf("The pool: %v test used %v tcp connections to answer %v * %v serial and %v * %v parallel requests.", poolsize, listener.counter, serialCounter, numberOfRequests, parallelCounter, numberOfRequests)
} // next handler (persistent/non-persistent)
}
func TestRuleParseAddress(t *testing.T) { func TestRuleParseAddress(t *testing.T) {
getClientTestTable := []struct { getClientTestTable := []struct {
rule *Rule rule *Rule
expectednetwork string expectednetwork string
expectedaddress string expectedaddress string
}{ }{
{&Rule{Address: "tcp://172.17.0.1:9000"}, "tcp", "172.17.0.1:9000"}, {&Rule{balancer: address("tcp://172.17.0.1:9000")}, "tcp", "172.17.0.1:9000"},
{&Rule{Address: "fastcgi://localhost:9000"}, "tcp", "localhost:9000"}, {&Rule{balancer: address("fastcgi://localhost:9000")}, "tcp", "localhost:9000"},
{&Rule{Address: "172.17.0.15"}, "tcp", "172.17.0.15"}, {&Rule{balancer: address("172.17.0.15")}, "tcp", "172.17.0.15"},
{&Rule{Address: "/my/unix/socket"}, "unix", "/my/unix/socket"}, {&Rule{balancer: address("/my/unix/socket")}, "unix", "/my/unix/socket"},
{&Rule{Address: "unix:/second/unix/socket"}, "unix", "/second/unix/socket"}, {&Rule{balancer: address("unix:/second/unix/socket")}, "unix", "/second/unix/socket"},
} }
for _, entry := range getClientTestTable { for _, entry := range getClientTestTable {
if actualnetwork, _ := parseAddress(entry.rule.Address); actualnetwork != entry.expectednetwork { if actualnetwork, _ := parseAddress(entry.rule.Address()); actualnetwork != entry.expectednetwork {
t.Errorf("Unexpected network for address string %v. Got %v, expected %v", entry.rule.Address, actualnetwork, entry.expectednetwork) t.Errorf("Unexpected network for address string %v. Got %v, expected %v", entry.rule.Address(), actualnetwork, entry.expectednetwork)
} }
if _, actualaddress := parseAddress(entry.rule.Address); actualaddress != entry.expectedaddress { if _, actualaddress := parseAddress(entry.rule.Address()); actualaddress != entry.expectedaddress {
t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address, actualaddress, entry.expectedaddress) t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address(), actualaddress, entry.expectedaddress)
} }
} }
} }
...@@ -332,14 +228,12 @@ func TestReadTimeout(t *testing.T) { ...@@ -332,14 +228,12 @@ func TestReadTimeout(t *testing.T) {
} }
defer listener.Close() defer listener.Close()
network, address := parseAddress(listener.Addr().String())
handler := Handler{ handler := Handler{
Next: nil, Next: nil,
Rules: []Rule{ Rules: []Rule{
{ {
Path: "/", Path: "/",
Address: listener.Addr().String(), balancer: address(listener.Addr().String()),
dialer: basicDialer{network: network, address: address},
ReadTimeout: test.readTimeout, ReadTimeout: test.readTimeout,
}, },
}, },
...@@ -394,14 +288,12 @@ func TestSendTimeout(t *testing.T) { ...@@ -394,14 +288,12 @@ func TestSendTimeout(t *testing.T) {
} }
defer listener.Close() defer listener.Close()
network, address := parseAddress(listener.Addr().String())
handler := Handler{ handler := Handler{
Next: nil, Next: nil,
Rules: []Rule{ Rules: []Rule{
{ {
Path: "/", Path: "/",
Address: listener.Addr().String(), balancer: address(listener.Addr().String()),
dialer: basicDialer{network: network, address: address},
SendTimeout: test.sendTimeout, SendTimeout: test.sendTimeout,
}, },
}, },
...@@ -434,3 +326,28 @@ func TestSendTimeout(t *testing.T) { ...@@ -434,3 +326,28 @@ func TestSendTimeout(t *testing.T) {
} }
} }
} }
func TestBalancer(t *testing.T) {
tests := [][]string{
{"localhost", "host.local"},
{"localhost"},
{"localhost", "host.local", "example.com"},
{"localhost", "host.local", "example.com", "127.0.0.1"},
}
for i, test := range tests {
b := address(test...)
for _, host := range test {
a := b.Address()
if a != host {
t.Errorf("Test %d: expected %s, found %s", i, host, a)
}
}
}
}
func address(addresses ...string) balancer {
return &roundRobin{
addresses: addresses,
index: -1,
}
}
...@@ -13,6 +13,7 @@ package fastcgi ...@@ -13,6 +13,7 @@ package fastcgi
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
...@@ -107,18 +108,6 @@ const ( ...@@ -107,18 +108,6 @@ const (
maxPad = 255 maxPad = 255
) )
// Client interface
type Client interface {
Get(pair map[string]string) (response *http.Response, err error)
Head(pair map[string]string) (response *http.Response, err error)
Options(pairs map[string]string) (response *http.Response, err error)
Post(pairs map[string]string, method string, bodyType string, body io.Reader, contentLength int64) (response *http.Response, err error)
Close() error
StdErr() bytes.Buffer
SetReadTimeout(time.Duration) error
SetSendTimeout(time.Duration) error
}
type header struct { type header struct {
Version uint8 Version uint8
Type uint8 Type uint8
...@@ -150,7 +139,7 @@ func (rec *record) read(r io.Reader) (buf []byte, err error) { ...@@ -150,7 +139,7 @@ func (rec *record) read(r io.Reader) (buf []byte, err error) {
return return
} }
if rec.h.Version != 1 { if rec.h.Version != 1 {
err = errInvalidHeaderVersion err = errors.New("fcgi: invalid header version")
return return
} }
if rec.h.Type == EndRequest { if rec.h.Type == EndRequest {
...@@ -173,7 +162,7 @@ func (rec *record) read(r io.Reader) (buf []byte, err error) { ...@@ -173,7 +162,7 @@ func (rec *record) read(r io.Reader) (buf []byte, err error) {
// interfacing external applications with Web servers. // interfacing external applications with Web servers.
type FCGIClient struct { type FCGIClient struct {
mutex sync.Mutex mutex sync.Mutex
conn net.Conn rwc io.ReadWriteCloser
h header h header
buf bytes.Buffer buf bytes.Buffer
stderr bytes.Buffer stderr bytes.Buffer
...@@ -183,53 +172,57 @@ type FCGIClient struct { ...@@ -183,53 +172,57 @@ type FCGIClient struct {
sendTimeout time.Duration sendTimeout time.Duration
} }
// DialTimeout connects to the fcgi responder at the specified network address, using default net.Dialer. // DialWithDialerContext connects to the fcgi responder at the specified network address, using custom net.Dialer
// and a context.
// See func net.Dial for a description of the network and address parameters. // See func net.Dial for a description of the network and address parameters.
func DialTimeout(network string, address string, timeout time.Duration) (fcgi *FCGIClient, err error) { func DialWithDialerContext(ctx context.Context, network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) {
conn, err := net.DialTimeout(network, address, timeout) var conn net.Conn
conn, err = dialer.DialContext(ctx, network, address)
if err != nil { if err != nil {
return return
} }
fcgi = &FCGIClient{conn: conn, keepAlive: false, reqID: 1} fcgi = &FCGIClient{
rwc: conn,
keepAlive: false,
reqID: 1,
}
return
}
// DialContext is like Dial but passes ctx to dialer.Dial.
func DialContext(ctx context.Context, network, address string) (fcgi *FCGIClient, err error) {
return DialWithDialerContext(ctx, network, address, net.Dialer{})
}
return fcgi, nil // Dial connects to the fcgi responder at the specified network address, using default net.Dialer.
// See func net.Dial for a description of the network and address parameters.
func Dial(network, address string) (fcgi *FCGIClient, err error) {
return DialContext(context.Background(), network, address)
} }
// Close closes fcgi connnection. // Close closes fcgi connnection
func (c *FCGIClient) Close() error { func (c *FCGIClient) Close() {
return c.conn.Close() c.rwc.Close()
} }
func (c *FCGIClient) writeRecord(recType uint8, content []byte) error { func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
c.buf.Reset() c.buf.Reset()
c.h.init(recType, c.reqID, len(content)) c.h.init(recType, c.reqID, len(content))
if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
return err return err
} }
if _, err := c.buf.Write(content); err != nil { if _, err := c.buf.Write(content); err != nil {
return err return err
} }
if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
return err return err
} }
_, err = c.rwc.Write(c.buf.Bytes())
if c.sendTimeout != 0 {
if err := c.conn.SetWriteDeadline(time.Now().Add(c.sendTimeout)); err != nil {
return err
}
}
if _, err := c.conn.Write(c.buf.Bytes()); err != nil {
return err return err
}
return nil
} }
func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error { func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error {
...@@ -345,14 +338,13 @@ func (w *streamReader) Read(p []byte) (n int, err error) { ...@@ -345,14 +338,13 @@ func (w *streamReader) Read(p []byte) (n int, err error) {
if len(p) > 0 { if len(p) > 0 {
if len(w.buf) == 0 { if len(w.buf) == 0 {
// filter outputs for error log // filter outputs for error log
for { for {
rec := &record{} rec := &record{}
var buf []byte var buf []byte
buf, err = rec.read(w.c.conn) buf, err = rec.read(w.c.rwc)
if err == errInvalidHeaderVersion { if err != nil {
continue
} else if err != nil {
return return
} }
// standard error output // standard error output
...@@ -376,15 +368,10 @@ func (w *streamReader) Read(p []byte) (n int, err error) { ...@@ -376,15 +368,10 @@ func (w *streamReader) Read(p []byte) (n int, err error) {
return return
} }
// StdErr returns stderr stream
func (c *FCGIClient) StdErr() bytes.Buffer {
return c.stderr
}
// Do made the request and returns a io.Reader that translates the data read // Do made the request and returns a io.Reader that translates the data read
// from fcgi responder out of fcgi packet before returning it. // from fcgi responder out of fcgi packet before returning it.
func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) { func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) {
err = c.writeBeginRequest(uint16(Responder), FCGIKeepConn) err = c.writeBeginRequest(uint16(Responder), 0)
if err != nil { if err != nil {
return return
} }
...@@ -407,11 +394,11 @@ func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err er ...@@ -407,11 +394,11 @@ func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err er
// clientCloser is a io.ReadCloser. It wraps a io.Reader with a Closer // clientCloser is a io.ReadCloser. It wraps a io.Reader with a Closer
// that closes FCGIClient connection. // that closes FCGIClient connection.
type clientCloser struct { type clientCloser struct {
f *FCGIClient *FCGIClient
io.Reader io.Reader
} }
func (c clientCloser) Close() error { return c.f.Close() } func (f clientCloser) Close() error { return f.rwc.Close() }
// Request returns a HTTP Response with Header and Body // Request returns a HTTP Response with Header and Body
// from fcgi responder // from fcgi responder
...@@ -425,12 +412,6 @@ func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Res ...@@ -425,12 +412,6 @@ func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Res
tp := textproto.NewReader(rb) tp := textproto.NewReader(rb)
resp = new(http.Response) resp = new(http.Response)
if c.readTimeout != 0 {
if err = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil {
return
}
}
// Parse the response headers. // Parse the response headers.
mimeHeader, err := tp.ReadMIMEHeader() mimeHeader, err := tp.ReadMIMEHeader()
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
...@@ -566,18 +547,20 @@ func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[str ...@@ -566,18 +547,20 @@ func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[str
// SetReadTimeout sets the read timeout for future calls that read from the // SetReadTimeout sets the read timeout for future calls that read from the
// fcgi responder. A zero value for t means no timeout will be set. // fcgi responder. A zero value for t means no timeout will be set.
func (c *FCGIClient) SetReadTimeout(t time.Duration) error { func (c *FCGIClient) SetReadTimeout(t time.Duration) error {
c.readTimeout = t if conn, ok := c.rwc.(net.Conn); ok && t != 0 {
return conn.SetReadDeadline(time.Now().Add(t))
}
return nil return nil
} }
// SetSendTimeout sets the read timeout for future calls that send data to // SetSendTimeout sets the read timeout for future calls that send data to
// the fcgi responder. A zero value for t means no timeout will be set. // the fcgi responder. A zero value for t means no timeout will be set.
func (c *FCGIClient) SetSendTimeout(t time.Duration) error { func (c *FCGIClient) SetSendTimeout(t time.Duration) error {
c.sendTimeout = t if conn, ok := c.rwc.(net.Conn); ok && t != 0 {
return conn.SetWriteDeadline(time.Now().Add(t))
}
return nil return nil
} }
// Checks whether chunked is part of the encodings stack // Checks whether chunked is part of the encodings stack
func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
var errInvalidHeaderVersion = errors.New("fcgi: invalid header version")
...@@ -103,7 +103,7 @@ func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) { ...@@ -103,7 +103,7 @@ func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
} }
func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) { func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) {
fcgi, err := DialTimeout("tcp", ipPort, 0) fcgi, err := Dial("tcp", ipPort)
if err != nil { if err != nil {
log.Println("err:", err) log.Println("err:", err)
return return
...@@ -155,7 +155,7 @@ func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[ ...@@ -155,7 +155,7 @@ func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[
fcgi.Close() fcgi.Close()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
if bytes.Contains(content, []byte("FAILED")) { if bytes.Index(content, []byte("FAILED")) >= 0 {
globalt.Error("Server return failed message") globalt.Error("Server return failed message")
} }
......
...@@ -4,8 +4,6 @@ import ( ...@@ -4,8 +4,6 @@ import (
"errors" "errors"
"net/http" "net/http"
"path/filepath" "path/filepath"
"strconv"
"strings"
"time" "time"
"github.com/mholt/caddy" "github.com/mholt/caddy"
...@@ -63,8 +61,6 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -63,8 +61,6 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
rule := Rule{ rule := Rule{
Root: absRoot, Root: absRoot,
Path: args[0], Path: args[0],
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
} }
upstreams := []string{args[1]} upstreams := []string{args[1]}
...@@ -75,10 +71,6 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -75,10 +71,6 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
} }
var err error var err error
var pool int
var connectTimeout = 60 * time.Second
var dialers []dialer
var poolSize = -1
for c.NextBlock() { for c.NextBlock() {
switch c.Val() { switch c.Val() {
...@@ -126,24 +118,11 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -126,24 +118,11 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
} }
rule.IgnoredSubPaths = ignoredPaths rule.IgnoredSubPaths = ignoredPaths
case "pool":
if !c.NextArg() {
return rules, c.ArgErr()
}
pool, err = strconv.Atoi(c.Val())
if err != nil {
return rules, err
}
if pool >= 0 {
poolSize = pool
} else {
return rules, c.Errf("positive integer expected, found %d", pool)
}
case "connect_timeout": case "connect_timeout":
if !c.NextArg() { if !c.NextArg() {
return rules, c.ArgErr() return rules, c.ArgErr()
} }
connectTimeout, err = time.ParseDuration(c.Val()) rule.ConnectTimeout, err = time.ParseDuration(c.Val())
if err != nil { if err != nil {
return rules, err return rules, err
} }
...@@ -168,29 +147,10 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -168,29 +147,10 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
} }
} }
for _, rawAddress := range upstreams { rule.balancer = &roundRobin{addresses: upstreams, index: -1}
network, address := parseAddress(rawAddress)
if poolSize >= 0 {
dialers = append(dialers, &persistentDialer{
size: poolSize,
network: network,
address: address,
timeout: connectTimeout,
})
} else {
dialers = append(dialers, basicDialer{
network: network,
address: address,
timeout: connectTimeout,
})
}
}
rule.dialer = &loadBalancingDialer{dialers: dialers}
rule.Address = strings.Join(upstreams, ",")
rules = append(rules, rule) rules = append(rules, rule)
} }
return rules, nil return rules, nil
} }
......
...@@ -2,10 +2,7 @@ package fastcgi ...@@ -2,10 +2,7 @@ package fastcgi
import ( import (
"fmt" "fmt"
"os"
"reflect"
"testing" "testing"
"time"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
...@@ -32,45 +29,13 @@ func TestSetup(t *testing.T) { ...@@ -32,45 +29,13 @@ func TestSetup(t *testing.T) {
if myHandler.Rules[0].Path != "/" { if myHandler.Rules[0].Path != "/" {
t.Errorf("Expected / as the Path") t.Errorf("Expected / as the Path")
} }
if myHandler.Rules[0].Address != "127.0.0.1:9000" { if myHandler.Rules[0].Address() != "127.0.0.1:9000" {
t.Errorf("Expected 127.0.0.1:9000 as the Address") t.Errorf("Expected 127.0.0.1:9000 as the Address")
} }
} }
func (p *persistentDialer) Equals(q *persistentDialer) bool {
if p.size != q.size {
return false
}
if p.network != q.network {
return false
}
if p.address != q.address {
return false
}
if len(p.pool) != len(q.pool) {
return false
}
for i, client := range p.pool {
if client != q.pool[i] {
return false
}
}
// ignore mutex state
return true
}
func TestFastcgiParse(t *testing.T) { func TestFastcgiParse(t *testing.T) {
rootPath, err := os.Getwd()
if err != nil {
t.Errorf("Can't determine current working directory; got '%v'", err)
}
defaultAddress := "127.0.0.1:9001"
network, address := parseAddress(defaultAddress)
t.Logf("Address '%v' was parsed to network '%v' and address '%v'", defaultAddress, network, address)
tests := []struct { tests := []struct {
inputFastcgiConfig string inputFastcgiConfig string
shouldErr bool shouldErr bool
...@@ -79,193 +44,34 @@ func TestFastcgiParse(t *testing.T) { ...@@ -79,193 +44,34 @@ func TestFastcgiParse(t *testing.T) {
{`fastcgi /blog 127.0.0.1:9000 php`, {`fastcgi /blog 127.0.0.1:9000 php`,
false, []Rule{{ false, []Rule{{
Root: rootPath,
Path: "/blog",
Address: "127.0.0.1:9000",
Ext: ".php",
SplitPath: ".php",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}}},
IndexFiles: []string{"index.php"},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi /blog 127.0.0.1:9000 php {
root /tmp
}`,
false, []Rule{{
Root: "/tmp",
Path: "/blog",
Address: "127.0.0.1:9000",
Ext: ".php",
SplitPath: ".php",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}}},
IndexFiles: []string{"index.php"},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi /blog 127.0.0.1:9000 php {
upstream 127.0.0.1:9001
}`,
false, []Rule{{
Root: rootPath,
Path: "/blog", Path: "/blog",
Address: "127.0.0.1:9000,127.0.0.1:9001", balancer: &roundRobin{addresses: []string{"127.0.0.1:9000"}},
Ext: ".php", Ext: ".php",
SplitPath: ".php", SplitPath: ".php",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}, basicDialer{network: "tcp", address: "127.0.0.1:9001", timeout: 60 * time.Second}}},
IndexFiles: []string{"index.php"}, IndexFiles: []string{"index.php"},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi /blog 127.0.0.1:9000 {
upstream 127.0.0.1:9001
}`,
false, []Rule{{
Root: rootPath,
Path: "/blog",
Address: "127.0.0.1:9000,127.0.0.1:9001",
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}, basicDialer{network: "tcp", address: "127.0.0.1:9001", timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}}, }}},
{`fastcgi / ` + defaultAddress + ` { {`fastcgi / 127.0.0.1:9001 {
split .html split .html
}`, }`,
false, []Rule{{ false, []Rule{{
Root: rootPath,
Path: "/", Path: "/",
Address: defaultAddress, balancer: &roundRobin{addresses: []string{"127.0.0.1:9001"}},
Ext: "", Ext: "",
SplitPath: ".html", SplitPath: ".html",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{}, IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}}, }}},
{`fastcgi / ` + defaultAddress + ` { {`fastcgi / 127.0.0.1:9001 {
split .html split .html
except /admin /user except /admin /user
}`, }`,
false, []Rule{{ false, []Rule{{
Root: rootPath,
Path: "/", Path: "/",
Address: "127.0.0.1:9001", balancer: &roundRobin{addresses: []string{"127.0.0.1:9001"}},
Ext: "", Ext: "",
SplitPath: ".html", SplitPath: ".html",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{}, IndexFiles: []string{},
IgnoredSubPaths: []string{"/admin", "/user"}, IgnoredSubPaths: []string{"/admin", "/user"},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}}, }}},
{`fastcgi / ` + defaultAddress + ` {
pool 0
}`,
false, []Rule{{
Root: rootPath,
Path: "/",
Address: defaultAddress,
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 0, network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi / 127.0.0.1:8080 {
upstream 127.0.0.1:9000
pool 5
}`,
false, []Rule{{
Root: rootPath,
Path: "/",
Address: "127.0.0.1:8080,127.0.0.1:9000",
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:8080", timeout: 60 * time.Second}, &persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi / ` + defaultAddress + ` {
split .php
}`,
false, []Rule{{
Root: rootPath,
Path: "/",
Address: defaultAddress,
Ext: "",
SplitPath: ".php",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi / ` + defaultAddress + ` {
connect_timeout 5s
}`,
false, []Rule{{
Root: rootPath,
Path: "/",
Address: defaultAddress,
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 5 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{
`fastcgi / ` + defaultAddress + ` { connect_timeout BADVALUE }`,
true,
[]Rule{},
},
{`fastcgi / ` + defaultAddress + ` {
read_timeout 5s
}`,
false, []Rule{{
Root: rootPath,
Path: "/",
Address: defaultAddress,
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 5 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{
`fastcgi / ` + defaultAddress + ` { read_timeout BADVALUE }`,
true,
[]Rule{},
},
{`fastcgi / ` + defaultAddress + ` {
send_timeout 5s
}`,
false, []Rule{{
Root: rootPath,
Path: "/",
Address: defaultAddress,
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 5 * time.Second,
}}},
{
`fastcgi / ` + defaultAddress + ` { send_timeout BADVALUE }`,
true,
[]Rule{},
},
{`fastcgi / {
}`,
true, []Rule{},
},
} }
for i, test := range tests { for i, test := range tests {
actualFastcgiConfigs, err := fastcgiParse(caddy.NewTestController("http", test.inputFastcgiConfig)) actualFastcgiConfigs, err := fastcgiParse(caddy.NewTestController("http", test.inputFastcgiConfig))
...@@ -281,19 +87,14 @@ func TestFastcgiParse(t *testing.T) { ...@@ -281,19 +87,14 @@ func TestFastcgiParse(t *testing.T) {
} }
for j, actualFastcgiConfig := range actualFastcgiConfigs { for j, actualFastcgiConfig := range actualFastcgiConfigs {
if actualFastcgiConfig.Root != test.expectedFastcgiConfig[j].Root {
t.Errorf("Test %d expected %dth FastCGI Root to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].Root, actualFastcgiConfig.Root)
}
if actualFastcgiConfig.Path != test.expectedFastcgiConfig[j].Path { if actualFastcgiConfig.Path != test.expectedFastcgiConfig[j].Path {
t.Errorf("Test %d expected %dth FastCGI Path to be %s , but got %s", t.Errorf("Test %d expected %dth FastCGI Path to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].Path, actualFastcgiConfig.Path) i, j, test.expectedFastcgiConfig[j].Path, actualFastcgiConfig.Path)
} }
if actualFastcgiConfig.Address != test.expectedFastcgiConfig[j].Address { if actualFastcgiConfig.Address() != test.expectedFastcgiConfig[j].Address() {
t.Errorf("Test %d expected %dth FastCGI Address to be %s , but got %s", t.Errorf("Test %d expected %dth FastCGI Address to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].Address, actualFastcgiConfig.Address) i, j, test.expectedFastcgiConfig[j].Address(), actualFastcgiConfig.Address())
} }
if actualFastcgiConfig.Ext != test.expectedFastcgiConfig[j].Ext { if actualFastcgiConfig.Ext != test.expectedFastcgiConfig[j].Ext {
...@@ -306,16 +107,6 @@ func TestFastcgiParse(t *testing.T) { ...@@ -306,16 +107,6 @@ func TestFastcgiParse(t *testing.T) {
i, j, test.expectedFastcgiConfig[j].SplitPath, actualFastcgiConfig.SplitPath) i, j, test.expectedFastcgiConfig[j].SplitPath, actualFastcgiConfig.SplitPath)
} }
if reflect.TypeOf(actualFastcgiConfig.dialer) != reflect.TypeOf(test.expectedFastcgiConfig[j].dialer) {
t.Errorf("Test %d expected %dth FastCGI dialer to be of type %T, but got %T",
i, j, test.expectedFastcgiConfig[j].dialer, actualFastcgiConfig.dialer)
} else {
if !areDialersEqual(actualFastcgiConfig.dialer, test.expectedFastcgiConfig[j].dialer, t) {
t.Errorf("Test %d expected %dth FastCGI dialer to be %v, but got %v",
i, j, test.expectedFastcgiConfig[j].dialer, actualFastcgiConfig.dialer)
}
}
if fmt.Sprint(actualFastcgiConfig.IndexFiles) != fmt.Sprint(test.expectedFastcgiConfig[j].IndexFiles) { if fmt.Sprint(actualFastcgiConfig.IndexFiles) != fmt.Sprint(test.expectedFastcgiConfig[j].IndexFiles) {
t.Errorf("Test %d expected %dth FastCGI IndexFiles to be %s , but got %s", t.Errorf("Test %d expected %dth FastCGI IndexFiles to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].IndexFiles, actualFastcgiConfig.IndexFiles) i, j, test.expectedFastcgiConfig[j].IndexFiles, actualFastcgiConfig.IndexFiles)
...@@ -325,43 +116,7 @@ func TestFastcgiParse(t *testing.T) { ...@@ -325,43 +116,7 @@ func TestFastcgiParse(t *testing.T) {
t.Errorf("Test %d expected %dth FastCGI IgnoredSubPaths to be %s , but got %s", t.Errorf("Test %d expected %dth FastCGI IgnoredSubPaths to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].IgnoredSubPaths, actualFastcgiConfig.IgnoredSubPaths) i, j, test.expectedFastcgiConfig[j].IgnoredSubPaths, actualFastcgiConfig.IgnoredSubPaths)
} }
if fmt.Sprint(actualFastcgiConfig.ReadTimeout) != fmt.Sprint(test.expectedFastcgiConfig[j].ReadTimeout) {
t.Errorf("Test %d expected %dth FastCGI ReadTimeout to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].ReadTimeout, actualFastcgiConfig.ReadTimeout)
}
if fmt.Sprint(actualFastcgiConfig.SendTimeout) != fmt.Sprint(test.expectedFastcgiConfig[j].SendTimeout) {
t.Errorf("Test %d expected %dth FastCGI SendTimeout to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].SendTimeout, actualFastcgiConfig.SendTimeout)
}
}
}
}
func areDialersEqual(current, expected dialer, t *testing.T) bool {
switch actual := current.(type) {
case *loadBalancingDialer:
if expected, ok := expected.(*loadBalancingDialer); ok {
for i := 0; i < len(actual.dialers); i++ {
if !areDialersEqual(actual.dialers[i], expected.dialers[i], t) {
return false
} }
} }
return true
}
case basicDialer:
return current == expected
case *persistentDialer:
if expected, ok := expected.(*persistentDialer); ok {
return actual.Equals(expected)
}
default:
t.Errorf("Unknown dialer type %T", current)
}
return false
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment