Commit 5a6b7656 authored by ericdreeves's avatar ericdreeves Committed by Matt Holt

Add connect_timeout and read_timeout to fastcgi. (#1257)

parent 8acf0432
......@@ -4,6 +4,7 @@ import (
"errors"
"sync"
"sync/atomic"
"time"
)
type dialer interface {
......@@ -13,10 +14,12 @@ type dialer interface {
// basicDialer is a basic dialer that wraps default fcgi functions.
type basicDialer struct {
network, address string
network string
address string
timeout time.Duration
}
func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address) }
func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address, b.timeout) }
func (b basicDialer) Close(c Client) error { return c.Close() }
// persistentDialer keeps a pool of fcgi connections.
......@@ -25,6 +28,7 @@ type persistentDialer struct {
size int
network string
address string
timeout time.Duration
pool []Client
sync.Mutex
}
......@@ -43,7 +47,7 @@ func (p *persistentDialer) Dial() (Client, error) {
p.Unlock()
// no connection available, create new one
return Dial(p.network, p.address)
return Dial(p.network, p.address, p.timeout)
}
func (p *persistentDialer) Close(client Client) error {
......
......@@ -12,6 +12,7 @@ import (
"path/filepath"
"strconv"
"strings"
"time"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
......@@ -81,6 +82,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
if err != nil {
return http.StatusBadGateway, err
}
fcgiBackend.SetReadTimeout(rule.ReadTimeout)
var resp *http.Response
contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length"))
......@@ -301,6 +303,9 @@ type Rule struct {
// Ignored paths
IgnoredSubPaths []string
// The duration used to set a deadline when reading from the FastCGI server.
ReadTimeout time.Duration
// FCGI dialer
dialer dialer
}
......
......@@ -10,6 +10,7 @@ import (
"strings"
"sync"
"testing"
"time"
)
func TestServeHTTP(t *testing.T) {
......@@ -28,8 +29,14 @@ func TestServeHTTP(t *testing.T) {
network, address := parseAddress(listener.Addr().String())
handler := Handler{
Next: nil,
Rules: []Rule{{Path: "/", Address: listener.Addr().String(), dialer: basicDialer{network, address}}},
Next: nil,
Rules: []Rule{
{
Path: "/",
Address: listener.Addr().String(),
dialer: basicDialer{network: network, address: address},
},
},
}
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
......@@ -318,3 +325,39 @@ func TestBuildEnv(t *testing.T) {
}
}
func TestReadTimeout(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to create listener for test: %v", err)
}
defer listener.Close()
go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second * 1)
}))
network, address := parseAddress(listener.Addr().String())
handler := Handler{
Next: nil,
Rules: []Rule{
{
Path: "/",
Address: listener.Addr().String(),
dialer: basicDialer{network: network, address: address},
ReadTimeout: time.Millisecond * 100,
},
},
}
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Unable to create request: %v", err)
}
w := httptest.NewRecorder()
_, err = handler.ServeHTTP(w, r)
if err == nil {
t.Error("Expected i/o timeout error but had none")
} else if err, ok := err.(net.Error); !ok || !err.Timeout() {
t.Errorf("Expected i/o timeout error, got: '%s'", err.Error())
}
}
......@@ -15,6 +15,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
......@@ -28,6 +29,7 @@ import (
"strconv"
"strings"
"sync"
"time"
)
// FCGIListenSockFileno describes listen socket file number.
......@@ -114,6 +116,8 @@ type Client interface {
Post(pairs map[string]string, method string, bodyType string, body io.Reader, contentLength int) (response *http.Response, err error)
Close() error
StdErr() bytes.Buffer
ReadTimeout() time.Duration
SetReadTimeout(time.Duration) error
}
type header struct {
......@@ -169,13 +173,14 @@ func (rec *record) read(r io.Reader) (buf []byte, err error) {
// FCGIClient implements a FastCGI client, which is a standard for
// interfacing external applications with Web servers.
type FCGIClient struct {
mutex sync.Mutex
rwc io.ReadWriteCloser
h header
buf bytes.Buffer
stderr bytes.Buffer
keepAlive bool
reqID uint16
mutex sync.Mutex
rwc io.ReadWriteCloser
h header
buf bytes.Buffer
stderr bytes.Buffer
keepAlive bool
reqID uint16
readTimeout time.Duration
}
// DialWithDialer connects to the fcgi responder at the specified network address, using custom net.Dialer.
......@@ -198,8 +203,8 @@ func DialWithDialer(network, address string, dialer net.Dialer) (fcgi *FCGIClien
// Dial connects to the fcgi responder at the specified network address, using default net.Dialer.
// See func net.Dial for a description of the network and address parameters.
func Dial(network, address string) (fcgi *FCGIClient, err error) {
return DialWithDialer(network, address, net.Dialer{})
func Dial(network string, address string, timeout time.Duration) (fcgi *FCGIClient, err error) {
return DialWithDialer(network, address, net.Dialer{Timeout: timeout})
}
// Close closes fcgi connnection.
......@@ -350,6 +355,15 @@ func (w *streamReader) Read(p []byte) (n int, err error) {
for {
rec := &record{}
var buf []byte
if readTimeout := w.c.ReadTimeout(); readTimeout > 0 {
conn, ok := w.c.rwc.(net.Conn)
if ok {
conn.SetReadDeadline(time.Now().Add(readTimeout))
} else {
err = fmt.Errorf("Could not set Client ReadTimeout")
return
}
}
buf, err = rec.read(w.c.rwc)
if err == errInvalidHeaderVersion {
continue
......@@ -559,6 +573,17 @@ func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[str
return c.Post(p, "POST", bodyType, buf, buf.Len())
}
// ReadTimeout returns the read timeout for future calls that read from the
// fcgi responder.
func (c *FCGIClient) ReadTimeout() time.Duration { return c.readTimeout }
// SetReadTimeout sets the read timeout for future calls that read from the
// fcgi responder. A zero value for t means no timeout will be set.
func (c *FCGIClient) SetReadTimeout(t time.Duration) error {
c.readTimeout = t
return nil
}
// Checks whether chunked is part of the encodings stack
func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
......
......@@ -103,7 +103,7 @@ func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
}
func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) {
fcgi, err := Dial("tcp", ipPort)
fcgi, err := Dial("tcp", ipPort, 0)
if err != nil {
log.Println("err:", err)
return
......
......@@ -6,6 +6,7 @@ import (
"path/filepath"
"strconv"
"strings"
"time"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
......@@ -69,6 +70,9 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
}
}
var err error
var pool int
var timeout time.Duration
var dialers []dialer
var poolSize = -1
......@@ -116,7 +120,7 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
if !c.NextArg() {
return rules, c.ArgErr()
}
pool, err := strconv.Atoi(c.Val())
pool, err = strconv.Atoi(c.Val())
if err != nil {
return rules, err
}
......@@ -125,15 +129,32 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
} else {
return rules, c.Errf("positive integer expected, found %d", pool)
}
case "connect_timeout":
if !c.NextArg() {
return rules, c.ArgErr()
}
timeout, err = time.ParseDuration(c.Val())
if err != nil {
return rules, err
}
case "read_timeout":
if !c.NextArg() {
return rules, c.ArgErr()
}
readTimeout, err := time.ParseDuration(c.Val())
if err != nil {
return rules, err
}
rule.ReadTimeout = readTimeout
}
}
for _, rawAddress := range upstreams {
network, address := parseAddress(rawAddress)
if poolSize >= 0 {
dialers = append(dialers, &persistentDialer{size: poolSize, network: network, address: address})
dialers = append(dialers, &persistentDialer{size: poolSize, network: network, address: address, timeout: timeout})
} else {
dialers = append(dialers, basicDialer{network: network, address: address})
dialers = append(dialers, basicDialer{network: network, address: address, timeout: timeout})
}
}
......
......@@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"testing"
"time"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
......@@ -159,6 +160,29 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
IndexFiles: []string{},
}}},
{`fastcgi / ` + defaultAddress + ` {
connect_timeout 5s
}`,
false, []Rule{{
Path: "/",
Address: defaultAddress,
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 5 * time.Second}}},
IndexFiles: []string{},
}}},
{`fastcgi / ` + defaultAddress + ` {
read_timeout 5s
}`,
false, []Rule{{
Path: "/",
Address: defaultAddress,
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
IndexFiles: []string{},
ReadTimeout: 5 * time.Second,
}}},
{`fastcgi / {
}`,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment