Commit 9fe2ef41 authored by Abiola Ibrahim's avatar Abiola Ibrahim Committed by Matt Holt

rewrite: Regular expression support for simple rule (#2082)

* Regexp support for simple rewrite rule

* Add negate option for simplicity

* ascertain explicit regexp char
parent 88edca65
...@@ -63,22 +63,38 @@ type Rule interface { ...@@ -63,22 +63,38 @@ type Rule interface {
// SimpleRule is a simple rewrite rule. // SimpleRule is a simple rewrite rule.
type SimpleRule struct { type SimpleRule struct {
From, To string Regexp *regexp.Regexp
To string
Negate bool
} }
// NewSimpleRule creates a new Simple Rule // NewSimpleRule creates a new Simple Rule
func NewSimpleRule(from, to string) SimpleRule { func NewSimpleRule(from, to string, negate bool) (*SimpleRule, error) {
return SimpleRule{from, to} r, err := regexp.Compile(from)
if err != nil {
return nil, err
}
return &SimpleRule{
Regexp: r,
To: to,
Negate: negate,
}, nil
} }
// BasePath satisfies httpserver.Config // BasePath satisfies httpserver.Config
func (s SimpleRule) BasePath() string { return s.From } func (s SimpleRule) BasePath() string { return "/" }
// Match satisfies httpserver.Config // Match satisfies httpserver.Config
func (s SimpleRule) Match(r *http.Request) bool { return s.From == r.URL.Path } func (s *SimpleRule) Match(r *http.Request) bool {
matches := regexpMatches(s.Regexp, "/", r.URL.Path)
if s.Negate {
return len(matches) == 0
}
return len(matches) > 0
}
// Rewrite rewrites the internal location of the current request. // Rewrite rewrites the internal location of the current request.
func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result { func (s *SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
// attempt rewrite // attempt rewrite
return To(fs, r, s.To, newReplacer(r)) return To(fs, r, s.To, newReplacer(r))
...@@ -165,7 +181,7 @@ func (r ComplexRule) Match(req *http.Request) bool { ...@@ -165,7 +181,7 @@ func (r ComplexRule) Match(req *http.Request) bool {
return true return true
} }
// otherwise validate regex // otherwise validate regex
return r.regexpMatches(req.URL.Path) != nil return regexpMatches(r.Regexp, r.Base, req.URL.Path) != nil
} }
// Rewrite rewrites the internal location of the current request. // Rewrite rewrites the internal location of the current request.
...@@ -174,7 +190,7 @@ func (r ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result) ...@@ -174,7 +190,7 @@ func (r ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result)
// validate regexp if present // validate regexp if present
if r.Regexp != nil { if r.Regexp != nil {
matches := r.regexpMatches(req.URL.Path) matches := regexpMatches(r.Regexp, r.Base, req.URL.Path)
switch len(matches) { switch len(matches) {
case 0: case 0:
// no match // no match
...@@ -230,14 +246,14 @@ func (r ComplexRule) matchExt(rPath string) bool { ...@@ -230,14 +246,14 @@ func (r ComplexRule) matchExt(rPath string) bool {
return !mustUse return !mustUse
} }
func (r ComplexRule) regexpMatches(rPath string) []string { func regexpMatches(regexp *regexp.Regexp, base, rPath string) []string {
if r.Regexp != nil { if regexp != nil {
// include trailing slash in regexp if present // include trailing slash in regexp if present
start := len(r.Base) start := len(base)
if strings.HasSuffix(r.Base, "/") { if strings.HasSuffix(base, "/") {
start-- start--
} }
return r.Regexp.FindStringSubmatch(rPath[start:]) return regexp.FindStringSubmatch(rPath[start:])
} }
return nil return nil
} }
......
...@@ -29,9 +29,9 @@ func TestRewrite(t *testing.T) { ...@@ -29,9 +29,9 @@ func TestRewrite(t *testing.T) {
rw := Rewrite{ rw := Rewrite{
Next: httpserver.HandlerFunc(urlPrinter), Next: httpserver.HandlerFunc(urlPrinter),
Rules: []httpserver.HandlerConfig{ Rules: []httpserver.HandlerConfig{
NewSimpleRule("/from", "/to"), newSimpleRule(t, "^/from$", "/to"),
NewSimpleRule("/a", "/b"), newSimpleRule(t, "^/a$", "/b"),
NewSimpleRule("/b", "/b{uri}"), newSimpleRule(t, "^/b$", "/b{uri}"),
}, },
FileSys: http.Dir("."), FileSys: http.Dir("."),
} }
...@@ -131,6 +131,45 @@ func TestRewrite(t *testing.T) { ...@@ -131,6 +131,45 @@ func TestRewrite(t *testing.T) {
} }
} }
// TestWordpress is a test for wordpress usecase.
func TestWordpress(t *testing.T) {
rw := Rewrite{
Next: httpserver.HandlerFunc(urlPrinter),
Rules: []httpserver.HandlerConfig{
// both rules are same, thanks to Go regexp (confusion).
newSimpleRule(t, "^/wp-admin", "{path} {path}/ /index.php?{query}", true),
newSimpleRule(t, "^\\/wp-admin", "{path} {path}/ /index.php?{query}", true),
},
FileSys: http.Dir("."),
}
tests := []struct {
from string
expectedTo string
}{
{"/wp-admin", "/wp-admin"},
{"/wp-admin/login.php", "/wp-admin/login.php"},
{"/not-wp-admin/login.php?not=admin", "/index.php?not=admin"},
{"/loophole", "/index.php"},
{"/user?name=john", "/index.php?name=john"},
}
for i, test := range tests {
req, err := http.NewRequest("GET", test.from, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
ctx := context.WithValue(req.Context(), httpserver.OriginalURLCtxKey, *req.URL)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
rw.ServeHTTP(rec, req)
if got, want := rec.Body.String(), test.expectedTo; got != want {
t.Errorf("Test %d: Expected URL to be '%s' but was '%s'", i, want, got)
}
}
}
func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) { func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprint(w, r.URL.String()) fmt.Fprint(w, r.URL.String())
return 0, nil return 0, nil
......
...@@ -58,6 +58,7 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) { ...@@ -58,6 +58,7 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) {
var base = "/" var base = "/"
var pattern, to string var pattern, to string
var ext []string var ext []string
var negate bool
args := c.RemainingArgs() args := c.RemainingArgs()
...@@ -111,7 +112,14 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) { ...@@ -111,7 +112,14 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) {
// the only unhandled case is 2 and above // the only unhandled case is 2 and above
default: default:
rule = NewSimpleRule(args[0], strings.Join(args[1:], " ")) if args[0] == "not" {
negate = true
args = args[1:]
}
rule, err = NewSimpleRule(args[0], strings.Join(args[1:], " "), negate)
if err != nil {
return nil, err
}
rules = append(rules, rule) rules = append(rules, rule)
} }
......
...@@ -50,6 +50,19 @@ func TestSetup(t *testing.T) { ...@@ -50,6 +50,19 @@ func TestSetup(t *testing.T) {
} }
} }
// newSimpleRule is convenience test function for SimpleRule.
func newSimpleRule(t *testing.T, from, to string, negate ...bool) Rule {
var n bool
if len(negate) > 0 {
n = negate[0]
}
rule, err := NewSimpleRule(from, to, n)
if err != nil {
t.Fatal(err)
}
return rule
}
func TestRewriteParse(t *testing.T) { func TestRewriteParse(t *testing.T) {
simpleTests := []struct { simpleTests := []struct {
input string input string
...@@ -57,17 +70,20 @@ func TestRewriteParse(t *testing.T) { ...@@ -57,17 +70,20 @@ func TestRewriteParse(t *testing.T) {
expected []Rule expected []Rule
}{ }{
{`rewrite /from /to`, false, []Rule{ {`rewrite /from /to`, false, []Rule{
SimpleRule{From: "/from", To: "/to"}, newSimpleRule(t, "/from", "/to"),
}}, }},
{`rewrite /from /to {`rewrite /from /to
rewrite a b`, false, []Rule{ rewrite a b`, false, []Rule{
SimpleRule{From: "/from", To: "/to"}, newSimpleRule(t, "/from", "/to"),
SimpleRule{From: "a", To: "b"}, newSimpleRule(t, "a", "b"),
}}, }},
{`rewrite a`, true, []Rule{}}, {`rewrite a`, true, []Rule{}},
{`rewrite`, true, []Rule{}}, {`rewrite`, true, []Rule{}},
{`rewrite a b c`, false, []Rule{ {`rewrite a b c`, false, []Rule{
SimpleRule{From: "a", To: "b c"}, newSimpleRule(t, "a", "b c"),
}},
{`rewrite not a b c`, false, []Rule{
newSimpleRule(t, "a", "b c", true),
}}, }},
} }
...@@ -88,17 +104,22 @@ func TestRewriteParse(t *testing.T) { ...@@ -88,17 +104,22 @@ func TestRewriteParse(t *testing.T) {
} }
for j, e := range test.expected { for j, e := range test.expected {
actualRule := actual[j].(SimpleRule) actualRule := actual[j].(*SimpleRule)
expectedRule := e.(SimpleRule) expectedRule := e.(*SimpleRule)
if actualRule.From != expectedRule.From { if actualRule.Regexp.String() != expectedRule.Regexp.String() {
t.Errorf("Test %d, rule %d: Expected From=%s, got %s", t.Errorf("Test %d, rule %d: Expected From=%s, got %s",
i, j, expectedRule.From, actualRule.From) i, j, expectedRule.Regexp.String(), actualRule.Regexp.String())
} }
if actualRule.To != expectedRule.To { if actualRule.To != expectedRule.To {
t.Errorf("Test %d, rule %d: Expected To=%s, got %s", t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
i, j, expectedRule.To, actualRule.To) i, j, expectedRule.Regexp.String(), actualRule.Regexp.String())
}
if actualRule.Negate != expectedRule.Negate {
t.Errorf("Test %d, rule %d: Expected Negate=%v, got %v",
i, j, expectedRule.Negate, actualRule.Negate)
} }
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment