Commit 37e3cf68 authored by Matt Holt's avatar Matt Holt

Merge pull request #343 from abiosoft/master

proxy: 'except' property to ignore subpaths
parents dd119e04 7949388d
...@@ -26,6 +26,8 @@ type Upstream interface { ...@@ -26,6 +26,8 @@ type Upstream interface {
From() string From() string
// Selects an upstream host to be routed to. // Selects an upstream host to be routed to.
Select() *UpstreamHost Select() *UpstreamHost
// Checks if subpath is not an ignored path
IsAllowedPath(string) bool
} }
// UpstreamHostDownFunc can be used to customize how Down behaves. // UpstreamHostDownFunc can be used to customize how Down behaves.
...@@ -59,7 +61,7 @@ func (uh *UpstreamHost) Down() bool { ...@@ -59,7 +61,7 @@ func (uh *UpstreamHost) Down() bool {
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, upstream := range p.Upstreams { for _, upstream := range p.Upstreams {
if middleware.Path(r.URL.Path).Matches(upstream.From()) { if middleware.Path(r.URL.Path).Matches(upstream.From()) && upstream.IsAllowedPath(r.URL.Path) {
var replacer middleware.Replacer var replacer middleware.Replacer
start := time.Now() start := time.Now()
requestHost := r.Host requestHost := r.Host
......
...@@ -125,6 +125,10 @@ func (u *fakeUpstream) Select() *UpstreamHost { ...@@ -125,6 +125,10 @@ func (u *fakeUpstream) Select() *UpstreamHost {
} }
} }
func (u *fakeUpstream) IsAllowedPath(requestPath string) bool {
return true
}
// recorderHijacker is a ResponseRecorder that can // recorderHijacker is a ResponseRecorder that can
// be hijacked. // be hijacked.
type recorderHijacker struct { type recorderHijacker struct {
......
...@@ -5,11 +5,13 @@ import ( ...@@ -5,11 +5,13 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"path"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/mholt/caddy/caddy/parse" "github.com/mholt/caddy/caddy/parse"
"github.com/mholt/caddy/middleware"
) )
var ( var (
...@@ -29,6 +31,7 @@ type staticUpstream struct { ...@@ -29,6 +31,7 @@ type staticUpstream struct {
Interval time.Duration Interval time.Duration
} }
WithoutPathPrefix string WithoutPathPrefix string
IgnoredSubPaths []string
} }
// NewStaticUpstreams parses the configuration input and sets up // NewStaticUpstreams parses the configuration input and sets up
...@@ -165,6 +168,12 @@ func parseBlock(c *parse.Dispenser, u *staticUpstream) error { ...@@ -165,6 +168,12 @@ func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
return c.ArgErr() return c.ArgErr()
} }
u.WithoutPathPrefix = c.Val() u.WithoutPathPrefix = c.Val()
case "except":
ignoredPaths := c.RemainingArgs()
if len(ignoredPaths) == 0 {
return c.ArgErr()
}
u.IgnoredSubPaths = ignoredPaths
default: default:
return c.Errf("unknown property '%s'", c.Val()) return c.Errf("unknown property '%s'", c.Val())
} }
...@@ -223,3 +232,12 @@ func (u *staticUpstream) Select() *UpstreamHost { ...@@ -223,3 +232,12 @@ func (u *staticUpstream) Select() *UpstreamHost {
} }
return u.Policy.Select(pool) return u.Policy.Select(pool)
} }
func (u *staticUpstream) IsAllowedPath(requestPath string) bool {
for _, ignoredSubPath := range u.IgnoredSubPaths {
if middleware.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
return false
}
}
return true
}
...@@ -51,3 +51,33 @@ func TestRegisterPolicy(t *testing.T) { ...@@ -51,3 +51,33 @@ func TestRegisterPolicy(t *testing.T) {
} }
} }
func TestAllowedPaths(t *testing.T) {
upstream := &staticUpstream{
from: "/proxy",
IgnoredSubPaths: []string{"/download", "/static"},
}
tests := []struct {
url string
expected bool
}{
{"/proxy", true},
{"/proxy/dl", true},
{"/proxy/download", false},
{"/proxy/download/static", false},
{"/proxy/static", false},
{"/proxy/static/download", false},
{"/proxy/something/download", true},
{"/proxy/something/static", true},
{"/proxy//static", false},
{"/proxy//static//download", false},
{"/proxy//download", false},
}
for i, test := range tests {
isAllowed := upstream.IsAllowedPath(test.url)
if test.expected != isAllowed {
t.Errorf("Test %d: expected %v found %v", i+1, test.expected, isAllowed)
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment