Commit ae645ef2 authored by Tw's avatar Tw

Introduce `limits` middleware

1. Replace original `maxrequestbody` directive.
2. Add request header limit.

fix issue #1587
Signed-off-by: default avatarTw <tw19881113@gmail.com>
parent 90efff68
......@@ -16,9 +16,9 @@ import (
_ "github.com/mholt/caddy/caddyhttp/header"
_ "github.com/mholt/caddy/caddyhttp/index"
_ "github.com/mholt/caddy/caddyhttp/internalsrv"
_ "github.com/mholt/caddy/caddyhttp/limits"
_ "github.com/mholt/caddy/caddyhttp/log"
_ "github.com/mholt/caddy/caddyhttp/markdown"
_ "github.com/mholt/caddy/caddyhttp/maxrequestbody"
_ "github.com/mholt/caddy/caddyhttp/mime"
_ "github.com/mholt/caddy/caddyhttp/pprof"
_ "github.com/mholt/caddy/caddyhttp/proxy"
......
......@@ -436,7 +436,7 @@ var directives = []string{
"root",
"index",
"bind",
"maxrequestbody", // TODO: 'limits'
"limits",
"timeouts",
"tls",
......
......@@ -302,7 +302,7 @@ func (r *replacer) getSubstitution(key string) string {
}
_, err := ioutil.ReadAll(r.request.Body)
if err != nil {
if _, ok := err.(MaxBytesExceeded); ok {
if err == MaxBytesExceededErr {
return r.emptyValue
}
}
......
......@@ -4,8 +4,8 @@ package httpserver
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
......@@ -66,6 +66,7 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
sites: group,
connTimeout: GracefulTimeout,
}
s.Server = makeHTTPServerWithHeaderLimit(s.Server, group)
s.Server.Handler = s // this is weird, but whatever
// extract TLS settings from each site config to build
......@@ -127,6 +128,32 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
return s, nil
}
// makeHTTPServerWithHeaderLimit apply minimum header limit within a group to given http.Server
func makeHTTPServerWithHeaderLimit(s *http.Server, group []*SiteConfig) *http.Server {
var min int64
for _, cfg := range group {
limit := cfg.Limits.MaxRequestHeaderSize
if limit == 0 {
continue
}
// not set yet
if min == 0 {
min = limit
}
// find a better one
if limit < min {
min = limit
}
}
if min > 0 {
s.MaxHeaderBytes = int(min)
}
return s
}
// makeHTTPServerWithTimeouts makes an http.Server from the group of
// configs in a way that configures timeouts (or, if not set, it uses
// the default timeouts) by combining the configuration of each
......@@ -359,20 +386,6 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
}
}
// Apply the path-based request body size limit
// The error returned by MaxBytesReader is meant to be handled
// by whichever middleware/plugin that receives it when calling
// .Read() or a similar method on the request body
// TODO: Make this middleware instead?
if r.Body != nil {
for _, pathlimit := range vhost.MaxRequestBodySizes {
if Path(r.URL.Path).Matches(pathlimit.Path) {
r.Body = MaxBytesReader(w, r.Body, pathlimit.Limit)
break
}
}
}
return vhost.middlewareChain.ServeHTTP(w, r)
}
......@@ -465,73 +478,9 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) {
return ln.TCPListener.File()
}
// MaxBytesExceeded is the error type returned by MaxBytesReader
// MaxBytesExceeded is the error returned by MaxBytesReader
// when the request body exceeds the limit imposed
type MaxBytesExceeded struct{}
func (err MaxBytesExceeded) Error() string {
return "http: request body too large"
}
// MaxBytesReader and its associated methods are borrowed from the
// Go Standard library (comments intact). The only difference is that
// it returns a MaxBytesExceeded error instead of a generic error message
// when the request body has exceeded the requested limit
func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
return &maxBytesReader{w: w, r: r, n: n}
}
type maxBytesReader struct {
w http.ResponseWriter
r io.ReadCloser // underlying reader
n int64 // max bytes remaining
err error // sticky error
}
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
if l.err != nil {
return 0, l.err
}
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
if int64(len(p)) > l.n+1 {
p = p[:l.n+1]
}
n, err = l.r.Read(p)
if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}
n = int(l.n)
l.n = 0
// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
l.err = MaxBytesExceeded{}
return n, l.err
}
func (l *maxBytesReader) Close() error {
return l.r.Close()
}
var MaxBytesExceededErr = errors.New("http: request body too large")
// DefaultErrorFunc responds to an HTTP request with a simple description
// of the specified HTTP status code.
......
......@@ -15,7 +15,7 @@ func TestAddress(t *testing.T) {
}
}
func TestMakeHTTPServer(t *testing.T) {
func TestMakeHTTPServerWithTimeouts(t *testing.T) {
for i, tc := range []struct {
group []*SiteConfig
expected Timeouts
......@@ -111,3 +111,36 @@ func TestMakeHTTPServer(t *testing.T) {
}
}
}
func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
for name, c := range map[string]struct {
group []*SiteConfig
expect int
}{
"disable": {
group: []*SiteConfig{{}},
expect: 0,
},
"oneSite": {
group: []*SiteConfig{{Limits: Limits{
MaxRequestHeaderSize: 100,
}}},
expect: 100,
},
"multiSites": {
group: []*SiteConfig{
{Limits: Limits{MaxRequestHeaderSize: 100}},
{Limits: Limits{MaxRequestHeaderSize: 50}},
},
expect: 50,
},
} {
c := c
t.Run(name, func(t *testing.T) {
actual := makeHTTPServerWithHeaderLimit(&http.Server{}, c.group)
if got := actual.MaxHeaderBytes; got != c.expect {
t.Errorf("Expect %d, but got %d", c.expect, got)
}
})
}
}
......@@ -38,8 +38,8 @@ type SiteConfig struct {
// for a request.
HiddenFiles []string
// Max amount of bytes a request can send on a given path
MaxRequestBodySizes []PathLimit
// Max request's header/body size
Limits Limits
// The path to the Caddyfile used to generate this site config
originCaddyfile string
......@@ -71,6 +71,12 @@ type Timeouts struct {
IdleTimeoutSet bool
}
// Limits specify size limit of request's header and body.
type Limits struct {
MaxRequestHeaderSize int64
MaxRequestBodySizes []PathLimit
}
// PathLimit is a mapping from a site's path to its corresponding
// maximum request body size (in bytes)
type PathLimit struct {
......
package limits
import (
"io"
"net/http"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// Limit is a middleware to control request body size
type Limit struct {
Next httpserver.Handler
BodyLimits []httpserver.PathLimit
}
func (l Limit) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if r.Body == nil {
return l.Next.ServeHTTP(w, r)
}
// apply the path-based request body size limit.
for _, bl := range l.BodyLimits {
if httpserver.Path(r.URL.Path).Matches(bl.Path) {
r.Body = MaxBytesReader(w, r.Body, bl.Limit)
break
}
}
return l.Next.ServeHTTP(w, r)
}
// MaxBytesReader and its associated methods are borrowed from the
// Go Standard library (comments intact). The only difference is that
// it returns a MaxBytesExceeded error instead of a generic error message
// when the request body has exceeded the requested limit
func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
return &maxBytesReader{w: w, r: r, n: n}
}
type maxBytesReader struct {
w http.ResponseWriter
r io.ReadCloser // underlying reader
n int64 // max bytes remaining
err error // sticky error
}
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
if l.err != nil {
return 0, l.err
}
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
if int64(len(p)) > l.n+1 {
p = p[:l.n+1]
}
n, err = l.r.Read(p)
if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}
n = int(l.n)
l.n = 0
// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
l.err = httpserver.MaxBytesExceededErr
return n, l.err
}
func (l *maxBytesReader) Close() error {
return l.r.Close()
}
package limits
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestBodySizeLimit(t *testing.T) {
var (
gotContent []byte
gotError error
expectContent = "hello"
)
l := Limit{
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
gotContent, gotError = ioutil.ReadAll(r.Body)
return 0, nil
}),
BodyLimits: []httpserver.PathLimit{{Path: "/", Limit: int64(len(expectContent))}},
}
r := httptest.NewRequest("GET", "/", strings.NewReader(expectContent+expectContent))
l.ServeHTTP(httptest.NewRecorder(), r)
if got := string(gotContent); got != expectContent {
t.Errorf("expected content[%s], got[%s]", expectContent, got)
}
if gotError != httpserver.MaxBytesExceededErr {
t.Errorf("expect error %v, got %v", httpserver.MaxBytesExceededErr, gotError)
}
}
package maxrequestbody
package limits
import (
"errors"
......@@ -12,13 +12,13 @@ import (
const (
serverType = "http"
pluginName = "maxrequestbody"
pluginName = "limits"
)
func init() {
caddy.RegisterPlugin(pluginName, caddy.Plugin{
ServerType: serverType,
Action: setupMaxRequestBody,
Action: setupLimits,
})
}
......@@ -28,56 +28,97 @@ type pathLimitUnparsed struct {
Limit string
}
func setupMaxRequestBody(c *caddy.Controller) error {
func setupLimits(c *caddy.Controller) error {
bls, err := parseLimits(c)
if err != nil {
return err
}
httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Limit{Next: next, BodyLimits: bls}
})
return nil
}
func parseLimits(c *caddy.Controller) ([]httpserver.PathLimit, error) {
config := httpserver.GetConfig(c)
if !c.Next() {
return c.ArgErr()
return nil, c.ArgErr()
}
args := c.RemainingArgs()
argList := []pathLimitUnparsed{}
headerLimit := ""
switch len(args) {
case 0:
// Format: { <path> <limit> ... }
// Format: limits {
// header <limit>
// body <path> <limit>
// body <limit>
// ...
// }
for c.NextBlock() {
path := c.Val()
if !c.NextArg() {
// Uneven pairing of path/limit
return c.ArgErr()
kind := c.Val()
pathOrLimit := c.RemainingArgs()
switch kind {
case "header":
if len(pathOrLimit) != 1 {
return nil, c.ArgErr()
}
headerLimit = pathOrLimit[0]
case "body":
if len(pathOrLimit) == 1 {
argList = append(argList, pathLimitUnparsed{
Path: path,
Limit: c.Val(),
Path: "/",
Limit: pathOrLimit[0],
})
break
}
if len(pathOrLimit) == 2 {
argList = append(argList, pathLimitUnparsed{
Path: pathOrLimit[0],
Limit: pathOrLimit[1],
})
break
}
fallthrough
default:
return nil, c.ArgErr()
}
}
case 1:
// Format: <limit>
// Format: limits <limit>
headerLimit = args[0]
argList = []pathLimitUnparsed{{
Path: "/",
Limit: args[0],
}}
case 2:
// Format: <path> <limit>
argList = []pathLimitUnparsed{{
Path: args[0],
Limit: args[1],
}}
default:
return c.ArgErr()
return nil, c.ArgErr()
}
if headerLimit != "" {
size := parseSize(headerLimit)
if size < 1 { // also disallow size = 0
return nil, c.ArgErr()
}
config.Limits.MaxRequestHeaderSize = size
}
if len(argList) > 0 {
pathLimit, err := parseArguments(argList)
if err != nil {
return c.ArgErr()
return nil, c.ArgErr()
}
SortPathLimits(pathLimit)
config.Limits.MaxRequestBodySizes = pathLimit
}
config.MaxRequestBodySizes = pathLimit
return nil
return config.Limits.MaxRequestBodySizes, nil
}
func parseArguments(args []pathLimitUnparsed) ([]httpserver.PathLimit, error) {
......
package maxrequestbody
package limits
import (
"reflect"
......@@ -14,32 +14,98 @@ const (
GB = 1024 * 1024 * 1024
)
func TestSetupMaxRequestBody(t *testing.T) {
cases := []struct {
func TestParseLimits(t *testing.T) {
for name, c := range map[string]struct {
input string
hasError bool
shouldErr bool
expect httpserver.Limits
}{
// Format: { <path> <limit> ... }
{input: "maxrequestbody / 20MB", hasError: false},
// Format: <limit>
{input: "maxrequestbody 999KB", hasError: false},
// Format: { <path> <limit> ... }
{input: "maxrequestbody { /images 50MB /upload 10MB\n/test 10KB }", hasError: false},
// Wrong formats
{input: "maxrequestbody typo { /images 50MB }", hasError: true},
{input: "maxrequestbody 999MB /home 20KB", hasError: true},
}
for caseNum, c := range cases {
"catchAll": {
input: `limits 2kb`,
expect: httpserver.Limits{
MaxRequestHeaderSize: 2 * KB,
MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/", Limit: 2 * KB}},
},
},
"onlyHeader": {
input: `limits {
header 2kb
}`,
expect: httpserver.Limits{
MaxRequestHeaderSize: 2 * KB,
},
},
"onlyBody": {
input: `limits {
body 2kb
}`,
expect: httpserver.Limits{
MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/", Limit: 2 * KB}},
},
},
"onlyBodyWithPath": {
input: `limits {
body /test 2kb
}`,
expect: httpserver.Limits{
MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/test", Limit: 2 * KB}},
},
},
"mixture": {
input: `limits {
header 1kb
body 2kb
body /bar 3kb
}`,
expect: httpserver.Limits{
MaxRequestHeaderSize: 1 * KB,
MaxRequestBodySizes: []httpserver.PathLimit{
{Path: "/bar", Limit: 3 * KB},
{Path: "/", Limit: 2 * KB},
},
},
},
"invalidFormat": {
input: `limits a b`,
shouldErr: true,
},
"invalidHeaderFormat": {
input: `limits {
header / 100
}`,
shouldErr: true,
},
"invalidBodyFormat": {
input: `limits {
body / 100 200
}`,
shouldErr: true,
},
"invalidKind": {
input: `limits {
head 100
}`,
shouldErr: true,
},
"invalidLimitSize": {
input: `limits 10bk`,
shouldErr: true,
},
} {
c := c
t.Run(name, func(t *testing.T) {
controller := caddy.NewTestController("", c.input)
err := setupMaxRequestBody(controller)
if c.hasError && (err == nil) {
t.Errorf("Expecting error for case %v but none encountered", caseNum)
_, err := parseLimits(controller)
if c.shouldErr && err == nil {
t.Error("failed to get expected error")
}
if !c.hasError && (err != nil) {
t.Errorf("Expecting no error for case %v but encountered %v", caseNum, err)
if !c.shouldErr && err != nil {
t.Errorf("got unexpected error: %v", err)
}
if got := httpserver.GetConfig(controller).Limits; !reflect.DeepEqual(got, c.expect) {
t.Errorf("expect %#v, but got %#v", c.expect, got)
}
})
}
}
......
......@@ -228,7 +228,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
return 0, nil
}
if _, ok := backendErr.(httpserver.MaxBytesExceeded); ok {
if backendErr == httpserver.MaxBytesExceededErr {
return http.StatusRequestEntityTooLarge, backendErr
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment