From 5a6b765673b91d1a1efc9779c23efcb69a774271 Mon Sep 17 00:00:00 2001
From: ericdreeves <ericdreeves@users.noreply.github.com>
Date: Sat, 19 Nov 2016 10:05:29 -0600
Subject: [PATCH] Add connect_timeout and read_timeout to fastcgi. (#1257)

---
 caddyhttp/fastcgi/dialer.go          | 10 ++++--
 caddyhttp/fastcgi/fastcgi.go         |  5 +++
 caddyhttp/fastcgi/fastcgi_test.go    | 47 ++++++++++++++++++++++++++--
 caddyhttp/fastcgi/fcgiclient.go      | 43 +++++++++++++++++++------
 caddyhttp/fastcgi/fcgiclient_test.go |  2 +-
 caddyhttp/fastcgi/setup.go           | 27 ++++++++++++++--
 caddyhttp/fastcgi/setup_test.go      | 24 ++++++++++++++
 7 files changed, 140 insertions(+), 18 deletions(-)

diff --git a/caddyhttp/fastcgi/dialer.go b/caddyhttp/fastcgi/dialer.go
index 0afd8c0..be3cfe5 100644
--- a/caddyhttp/fastcgi/dialer.go
+++ b/caddyhttp/fastcgi/dialer.go
@@ -4,6 +4,7 @@ import (
 	"errors"
 	"sync"
 	"sync/atomic"
+	"time"
 )
 
 type dialer interface {
@@ -13,10 +14,12 @@ type dialer interface {
 
 // basicDialer is a basic dialer that wraps default fcgi functions.
 type basicDialer struct {
-	network, address string
+	network string
+	address string
+	timeout time.Duration
 }
 
-func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address) }
+func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address, b.timeout) }
 func (b basicDialer) Close(c Client) error  { return c.Close() }
 
 // persistentDialer keeps a pool of fcgi connections.
@@ -25,6 +28,7 @@ type persistentDialer struct {
 	size    int
 	network string
 	address string
+	timeout time.Duration
 	pool    []Client
 	sync.Mutex
 }
@@ -43,7 +47,7 @@ func (p *persistentDialer) Dial() (Client, error) {
 	p.Unlock()
 
 	// no connection available, create new one
-	return Dial(p.network, p.address)
+	return Dial(p.network, p.address, p.timeout)
 }
 
 func (p *persistentDialer) Close(client Client) error {
diff --git a/caddyhttp/fastcgi/fastcgi.go b/caddyhttp/fastcgi/fastcgi.go
index 9041794..7499959 100644
--- a/caddyhttp/fastcgi/fastcgi.go
+++ b/caddyhttp/fastcgi/fastcgi.go
@@ -12,6 +12,7 @@ import (
 	"path/filepath"
 	"strconv"
 	"strings"
+	"time"
 
 	"github.com/mholt/caddy/caddyhttp/httpserver"
 )
@@ -81,6 +82,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
 			if err != nil {
 				return http.StatusBadGateway, err
 			}
+			fcgiBackend.SetReadTimeout(rule.ReadTimeout)
 
 			var resp *http.Response
 			contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length"))
@@ -301,6 +303,9 @@ type Rule struct {
 	// Ignored paths
 	IgnoredSubPaths []string
 
+	// The duration used to set a deadline when reading from the FastCGI server.
+	ReadTimeout time.Duration
+
 	// FCGI dialer
 	dialer dialer
 }
diff --git a/caddyhttp/fastcgi/fastcgi_test.go b/caddyhttp/fastcgi/fastcgi_test.go
index 4b22ea9..2190bf2 100644
--- a/caddyhttp/fastcgi/fastcgi_test.go
+++ b/caddyhttp/fastcgi/fastcgi_test.go
@@ -10,6 +10,7 @@ import (
 	"strings"
 	"sync"
 	"testing"
+	"time"
 )
 
 func TestServeHTTP(t *testing.T) {
@@ -28,8 +29,14 @@ func TestServeHTTP(t *testing.T) {
 
 	network, address := parseAddress(listener.Addr().String())
 	handler := Handler{
-		Next:  nil,
-		Rules: []Rule{{Path: "/", Address: listener.Addr().String(), dialer: basicDialer{network, address}}},
+		Next: nil,
+		Rules: []Rule{
+			{
+				Path:    "/",
+				Address: listener.Addr().String(),
+				dialer:  basicDialer{network: network, address: address},
+			},
+		},
 	}
 	r, err := http.NewRequest("GET", "/", nil)
 	if err != nil {
@@ -318,3 +325,39 @@ func TestBuildEnv(t *testing.T) {
 	}
 
 }
+
+func TestReadTimeout(t *testing.T) {
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("Unable to create listener for test: %v", err)
+	}
+	defer listener.Close()
+	go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		time.Sleep(time.Second * 1)
+	}))
+
+	network, address := parseAddress(listener.Addr().String())
+	handler := Handler{
+		Next: nil,
+		Rules: []Rule{
+			{
+				Path:        "/",
+				Address:     listener.Addr().String(),
+				dialer:      basicDialer{network: network, address: address},
+				ReadTimeout: time.Millisecond * 100,
+			},
+		},
+	}
+	r, err := http.NewRequest("GET", "/", nil)
+	if err != nil {
+		t.Fatalf("Unable to create request: %v", err)
+	}
+	w := httptest.NewRecorder()
+
+	_, err = handler.ServeHTTP(w, r)
+	if err == nil {
+		t.Error("Expected i/o timeout error but had none")
+	} else if err, ok := err.(net.Error); !ok || !err.Timeout() {
+		t.Errorf("Expected i/o timeout error, got: '%s'", err.Error())
+	}
+}
diff --git a/caddyhttp/fastcgi/fcgiclient.go b/caddyhttp/fastcgi/fcgiclient.go
index 925a068..95d391d 100644
--- a/caddyhttp/fastcgi/fcgiclient.go
+++ b/caddyhttp/fastcgi/fcgiclient.go
@@ -15,6 +15,7 @@ import (
 	"bytes"
 	"encoding/binary"
 	"errors"
+	"fmt"
 	"io"
 	"io/ioutil"
 	"mime/multipart"
@@ -28,6 +29,7 @@ import (
 	"strconv"
 	"strings"
 	"sync"
+	"time"
 )
 
 // FCGIListenSockFileno describes listen socket file number.
@@ -114,6 +116,8 @@ type Client interface {
 	Post(pairs map[string]string, method string, bodyType string, body io.Reader, contentLength int) (response *http.Response, err error)
 	Close() error
 	StdErr() bytes.Buffer
+	ReadTimeout() time.Duration
+	SetReadTimeout(time.Duration) error
 }
 
 type header struct {
@@ -169,13 +173,14 @@ func (rec *record) read(r io.Reader) (buf []byte, err error) {
 // FCGIClient implements a FastCGI client, which is a standard for
 // interfacing external applications with Web servers.
 type FCGIClient struct {
-	mutex     sync.Mutex
-	rwc       io.ReadWriteCloser
-	h         header
-	buf       bytes.Buffer
-	stderr    bytes.Buffer
-	keepAlive bool
-	reqID     uint16
+	mutex       sync.Mutex
+	rwc         io.ReadWriteCloser
+	h           header
+	buf         bytes.Buffer
+	stderr      bytes.Buffer
+	keepAlive   bool
+	reqID       uint16
+	readTimeout time.Duration
 }
 
 // DialWithDialer connects to the fcgi responder at the specified network address, using custom net.Dialer.
@@ -198,8 +203,8 @@ func DialWithDialer(network, address string, dialer net.Dialer) (fcgi *FCGIClien
 
 // Dial connects to the fcgi responder at the specified network address, using default net.Dialer.
 // See func net.Dial for a description of the network and address parameters.
-func Dial(network, address string) (fcgi *FCGIClient, err error) {
-	return DialWithDialer(network, address, net.Dialer{})
+func Dial(network string, address string, timeout time.Duration) (fcgi *FCGIClient, err error) {
+	return DialWithDialer(network, address, net.Dialer{Timeout: timeout})
 }
 
 // Close closes fcgi connnection.
@@ -350,6 +355,15 @@ func (w *streamReader) Read(p []byte) (n int, err error) {
 			for {
 				rec := &record{}
 				var buf []byte
+				if readTimeout := w.c.ReadTimeout(); readTimeout > 0 {
+					conn, ok := w.c.rwc.(net.Conn)
+					if ok {
+						conn.SetReadDeadline(time.Now().Add(readTimeout))
+					} else {
+						err = fmt.Errorf("Could not set Client ReadTimeout")
+						return
+					}
+				}
 				buf, err = rec.read(w.c.rwc)
 				if err == errInvalidHeaderVersion {
 					continue
@@ -559,6 +573,17 @@ func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[str
 	return c.Post(p, "POST", bodyType, buf, buf.Len())
 }
 
+// ReadTimeout returns the read timeout for future calls that read from the
+// fcgi responder.
+func (c *FCGIClient) ReadTimeout() time.Duration { return c.readTimeout }
+
+// SetReadTimeout sets the read timeout for future calls that read from the
+// fcgi responder. A zero value for t means no timeout will be set.
+func (c *FCGIClient) SetReadTimeout(t time.Duration) error {
+	c.readTimeout = t
+	return nil
+}
+
 // Checks whether chunked is part of the encodings stack
 func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
 
diff --git a/caddyhttp/fastcgi/fcgiclient_test.go b/caddyhttp/fastcgi/fcgiclient_test.go
index 0048c25..bc2a224 100644
--- a/caddyhttp/fastcgi/fcgiclient_test.go
+++ b/caddyhttp/fastcgi/fcgiclient_test.go
@@ -103,7 +103,7 @@ func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
 }
 
 func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) {
-	fcgi, err := Dial("tcp", ipPort)
+	fcgi, err := Dial("tcp", ipPort, 0)
 	if err != nil {
 		log.Println("err:", err)
 		return
diff --git a/caddyhttp/fastcgi/setup.go b/caddyhttp/fastcgi/setup.go
index b4444f2..f37bc83 100644
--- a/caddyhttp/fastcgi/setup.go
+++ b/caddyhttp/fastcgi/setup.go
@@ -6,6 +6,7 @@ import (
 	"path/filepath"
 	"strconv"
 	"strings"
+	"time"
 
 	"github.com/mholt/caddy"
 	"github.com/mholt/caddy/caddyhttp/httpserver"
@@ -69,6 +70,9 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
 			}
 		}
 
+		var err error
+		var pool int
+		var timeout time.Duration
 		var dialers []dialer
 		var poolSize = -1
 
@@ -116,7 +120,7 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
 				if !c.NextArg() {
 					return rules, c.ArgErr()
 				}
-				pool, err := strconv.Atoi(c.Val())
+				pool, err = strconv.Atoi(c.Val())
 				if err != nil {
 					return rules, err
 				}
@@ -125,15 +129,32 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
 				} else {
 					return rules, c.Errf("positive integer expected, found %d", pool)
 				}
+			case "connect_timeout":
+				if !c.NextArg() {
+					return rules, c.ArgErr()
+				}
+				timeout, err = time.ParseDuration(c.Val())
+				if err != nil {
+					return rules, err
+				}
+			case "read_timeout":
+				if !c.NextArg() {
+					return rules, c.ArgErr()
+				}
+				readTimeout, err := time.ParseDuration(c.Val())
+				if err != nil {
+					return rules, err
+				}
+				rule.ReadTimeout = readTimeout
 			}
 		}
 
 		for _, rawAddress := range upstreams {
 			network, address := parseAddress(rawAddress)
 			if poolSize >= 0 {
-				dialers = append(dialers, &persistentDialer{size: poolSize, network: network, address: address})
+				dialers = append(dialers, &persistentDialer{size: poolSize, network: network, address: address, timeout: timeout})
 			} else {
-				dialers = append(dialers, basicDialer{network: network, address: address})
+				dialers = append(dialers, basicDialer{network: network, address: address, timeout: timeout})
 			}
 		}
 
diff --git a/caddyhttp/fastcgi/setup_test.go b/caddyhttp/fastcgi/setup_test.go
index 95d2d76..c5e0fd8 100644
--- a/caddyhttp/fastcgi/setup_test.go
+++ b/caddyhttp/fastcgi/setup_test.go
@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"reflect"
 	"testing"
+	"time"
 
 	"github.com/mholt/caddy"
 	"github.com/mholt/caddy/caddyhttp/httpserver"
@@ -159,6 +160,29 @@ func TestFastcgiParse(t *testing.T) {
 				dialer:     &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
 				IndexFiles: []string{},
 			}}},
+		{`fastcgi / ` + defaultAddress + ` {
+	              connect_timeout 5s
+	              }`,
+			false, []Rule{{
+				Path:       "/",
+				Address:    defaultAddress,
+				Ext:        "",
+				SplitPath:  "",
+				dialer:     &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 5 * time.Second}}},
+				IndexFiles: []string{},
+			}}},
+		{`fastcgi / ` + defaultAddress + ` {
+	              read_timeout 5s
+	              }`,
+			false, []Rule{{
+				Path:        "/",
+				Address:     defaultAddress,
+				Ext:         "",
+				SplitPath:   "",
+				dialer:      &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address}}},
+				IndexFiles:  []string{},
+				ReadTimeout: 5 * time.Second,
+			}}},
 		{`fastcgi / {
 
 		              }`,
-- 
2.30.9