Commit 416af05a authored by Matthew Holt's avatar Matthew Holt

Migrating more middleware packages

parent 2f92443d
// Package basicauth implements HTTP Basic Authentication.
package basicauth
import (
"bufio"
"crypto/subtle"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"github.com/jimstudt/http-authentication/basic"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// BasicAuth is middleware to protect resources with a username and password.
// Note that HTTP Basic Authentication is not secure by itself and should
// not be used to protect important assets without HTTPS. Even then, the
// security of HTTP Basic Auth is disputed. Use discretion when deciding
// what to protect with BasicAuth.
type BasicAuth struct {
Next httpserver.Handler
SiteRoot string
Rules []Rule
}
// ServeHTTP implements the httpserver.Handler interface.
func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
var hasAuth bool
var isAuthenticated bool
for _, rule := range a.Rules {
for _, res := range rule.Resources {
if !httpserver.Path(r.URL.Path).Matches(res) {
continue
}
// Path matches; parse auth header
username, password, ok := r.BasicAuth()
hasAuth = true
// Check credentials
if !ok ||
username != rule.Username ||
!rule.Password(password) {
//subtle.ConstantTimeCompare([]byte(password), []byte(rule.Password)) != 1 {
continue
}
// Flag set only on successful authentication
isAuthenticated = true
}
}
if hasAuth {
if !isAuthenticated {
w.Header().Set("WWW-Authenticate", "Basic")
return http.StatusUnauthorized, nil
}
// "It's an older code, sir, but it checks out. I was about to clear them."
return a.Next.ServeHTTP(w, r)
}
// Pass-thru when no paths match
return a.Next.ServeHTTP(w, r)
}
// Rule represents a BasicAuth rule. A username and password
// combination protect the associated resources, which are
// file or directory paths.
type Rule struct {
Username string
Password func(string) bool
Resources []string
}
// PasswordMatcher determines whether a password matches a rule.
type PasswordMatcher func(pw string) bool
var (
htpasswords map[string]map[string]PasswordMatcher
htpasswordsMu sync.Mutex
)
// GetHtpasswdMatcher matches password rules.
func GetHtpasswdMatcher(filename, username, siteRoot string) (PasswordMatcher, error) {
filename = filepath.Join(siteRoot, filename)
htpasswordsMu.Lock()
if htpasswords == nil {
htpasswords = make(map[string]map[string]PasswordMatcher)
}
pm := htpasswords[filename]
if pm == nil {
fh, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("open %q: %v", filename, err)
}
defer fh.Close()
pm = make(map[string]PasswordMatcher)
if err = parseHtpasswd(pm, fh); err != nil {
return nil, fmt.Errorf("parsing htpasswd %q: %v", fh.Name(), err)
}
htpasswords[filename] = pm
}
htpasswordsMu.Unlock()
if pm[username] == nil {
return nil, fmt.Errorf("username %q not found in %q", username, filename)
}
return pm[username], nil
}
func parseHtpasswd(pm map[string]PasswordMatcher, r io.Reader) error {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.IndexByte(line, '#') == 0 {
continue
}
i := strings.IndexByte(line, ':')
if i <= 0 {
return fmt.Errorf("malformed line, no color: %q", line)
}
user, encoded := line[:i], line[i+1:]
for _, p := range basic.DefaultSystems {
matcher, err := p(encoded)
if err != nil {
return err
}
if matcher != nil {
pm[user] = matcher.MatchesPassword
break
}
}
}
return scanner.Err()
}
// PlainMatcher returns a PasswordMatcher that does a constant-time
// byte-wise comparison.
func PlainMatcher(passw string) PasswordMatcher {
return func(pw string) bool {
return subtle.ConstantTimeCompare([]byte(pw), []byte(passw)) == 1
}
}
package basicauth
import (
"encoding/base64"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestBasicAuth(t *testing.T) {
rw := BasicAuth{
Next: httpserver.HandlerFunc(contentHandler),
Rules: []Rule{
{Username: "test", Password: PlainMatcher("ttest"), Resources: []string{"/testing"}},
},
}
tests := []struct {
from string
result int
cred string
}{
{"/testing", http.StatusUnauthorized, "ttest:test"},
{"/testing", http.StatusOK, "test:ttest"},
{"/testing", http.StatusUnauthorized, ""},
}
for i, test := range tests {
req, err := http.NewRequest("GET", test.from, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request %v", i, err)
}
auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred))
req.Header.Set("Authorization", auth)
rec := httptest.NewRecorder()
result, err := rw.ServeHTTP(rec, req)
if err != nil {
t.Fatalf("Test %d: Could not ServeHTTP %v", i, err)
}
if result != test.result {
t.Errorf("Test %d: Expected Header '%d' but was '%d'",
i, test.result, result)
}
if result == http.StatusUnauthorized {
headers := rec.Header()
if val, ok := headers["Www-Authenticate"]; ok {
if val[0] != "Basic" {
t.Errorf("Test %d, Www-Authenticate should be %s provided %s", i, "Basic", val[0])
}
} else {
t.Errorf("Test %d, should provide a header Www-Authenticate", i)
}
}
}
}
func TestMultipleOverlappingRules(t *testing.T) {
rw := BasicAuth{
Next: httpserver.HandlerFunc(contentHandler),
Rules: []Rule{
{Username: "t", Password: PlainMatcher("p1"), Resources: []string{"/t"}},
{Username: "t1", Password: PlainMatcher("p2"), Resources: []string{"/t/t"}},
},
}
tests := []struct {
from string
result int
cred string
}{
{"/t", http.StatusOK, "t:p1"},
{"/t/t", http.StatusOK, "t:p1"},
{"/t/t", http.StatusOK, "t1:p2"},
{"/a", http.StatusOK, "t1:p2"},
{"/t/t", http.StatusUnauthorized, "t1:p3"},
{"/t", http.StatusUnauthorized, "t1:p2"},
}
for i, test := range tests {
req, err := http.NewRequest("GET", test.from, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request %v", i, err)
}
auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred))
req.Header.Set("Authorization", auth)
rec := httptest.NewRecorder()
result, err := rw.ServeHTTP(rec, req)
if err != nil {
t.Fatalf("Test %d: Could not ServeHTTP %v", i, err)
}
if result != test.result {
t.Errorf("Test %d: Expected Header '%d' but was '%d'",
i, test.result, result)
}
}
}
func contentHandler(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprintf(w, r.URL.String())
return http.StatusOK, nil
}
func TestHtpasswd(t *testing.T) {
htpasswdPasswd := "IedFOuGmTpT8"
htpasswdFile := `sha1:{SHA}dcAUljwz99qFjYR0YLTXx0RqLww=
md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61`
htfh, err := ioutil.TempFile("", "basicauth-")
if err != nil {
t.Skipf("Error creating temp file (%v), will skip htpassword test")
return
}
defer os.Remove(htfh.Name())
if _, err = htfh.Write([]byte(htpasswdFile)); err != nil {
t.Fatalf("write htpasswd file %q: %v", htfh.Name(), err)
}
htfh.Close()
for i, username := range []string{"sha1", "md5"} {
rule := Rule{Username: username, Resources: []string{"/testing"}}
siteRoot := filepath.Dir(htfh.Name())
filename := filepath.Base(htfh.Name())
if rule.Password, err = GetHtpasswdMatcher(filename, rule.Username, siteRoot); err != nil {
t.Fatalf("GetHtpasswdMatcher(%q, %q): %v", htfh.Name(), rule.Username, err)
}
t.Logf("%d. username=%q", i, rule.Username)
if !rule.Password(htpasswdPasswd) || rule.Password(htpasswdPasswd+"!") {
t.Errorf("%d (%s) password does not match.", i, rule.Username)
}
}
}
package basicauth
import (
"strings"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func init() {
caddy.RegisterPlugin(caddy.Plugin{
Name: "basicauth",
ServerType: "http",
Action: setup,
})
}
// setup configures a new BasicAuth middleware instance.
func setup(c *caddy.Controller) error {
cfg := httpserver.GetConfig(c.Key)
root := cfg.Root
rules, err := basicAuthParse(c)
if err != nil {
return err
}
basic := BasicAuth{Rules: rules}
cfg.AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
basic.Next = next
basic.SiteRoot = root
return basic
})
return nil
}
func basicAuthParse(c *caddy.Controller) ([]Rule, error) {
var rules []Rule
cfg := httpserver.GetConfig(c.Key)
var err error
for c.Next() {
var rule Rule
args := c.RemainingArgs()
switch len(args) {
case 2:
rule.Username = args[0]
if rule.Password, err = passwordMatcher(rule.Username, args[1], cfg.Root); err != nil {
return rules, c.Errf("Get password matcher from %s: %v", c.Val(), err)
}
for c.NextBlock() {
rule.Resources = append(rule.Resources, c.Val())
if c.NextArg() {
return rules, c.Errf("Expecting only one resource per line (extra '%s')", c.Val())
}
}
case 3:
rule.Resources = append(rule.Resources, args[0])
rule.Username = args[1]
if rule.Password, err = passwordMatcher(rule.Username, args[2], cfg.Root); err != nil {
return rules, c.Errf("Get password matcher from %s: %v", c.Val(), err)
}
default:
return rules, c.ArgErr()
}
rules = append(rules, rule)
}
return rules, nil
}
func passwordMatcher(username, passw, siteRoot string) (PasswordMatcher, error) {
if !strings.HasPrefix(passw, "htpasswd=") {
return PlainMatcher(passw), nil
}
return GetHtpasswdMatcher(passw[9:], username, siteRoot)
}
package basicauth
import (
"fmt"
"io/ioutil"
"os"
"strings"
"testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestSetup(t *testing.T) {
err := setup(caddy.NewTestController(`basicauth user pwd`))
if err != nil {
t.Errorf("Expected no errors, but got: %v", err)
}
mids := httpserver.GetConfig("").Middleware()
if len(mids) == 0 {
t.Fatal("Expected middleware, got 0 instead")
}
handler := mids[0](httpserver.EmptyNext)
myHandler, ok := handler.(BasicAuth)
if !ok {
t.Fatalf("Expected handler to be type BasicAuth, got: %#v", handler)
}
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
t.Error("'Next' field of handler was not set properly")
}
}
func TestBasicAuthParse(t *testing.T) {
htpasswdPasswd := "IedFOuGmTpT8"
htpasswdFile := `sha1:{SHA}dcAUljwz99qFjYR0YLTXx0RqLww=
md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61`
var skipHtpassword bool
htfh, err := ioutil.TempFile(".", "basicauth-")
if err != nil {
t.Logf("Error creating temp file (%v), will skip htpassword test", err)
skipHtpassword = true
} else {
if _, err = htfh.Write([]byte(htpasswdFile)); err != nil {
t.Fatalf("write htpasswd file %q: %v", htfh.Name(), err)
}
htfh.Close()
defer os.Remove(htfh.Name())
}
tests := []struct {
input string
shouldErr bool
password string
expected []Rule
}{
{`basicauth user pwd`, false, "pwd", []Rule{
{Username: "user"},
}},
{`basicauth user pwd {
}`, false, "pwd", []Rule{
{Username: "user"},
}},
{`basicauth user pwd {
/resource1
/resource2
}`, false, "pwd", []Rule{
{Username: "user", Resources: []string{"/resource1", "/resource2"}},
}},
{`basicauth /resource user pwd`, false, "pwd", []Rule{
{Username: "user", Resources: []string{"/resource"}},
}},
{`basicauth /res1 user1 pwd1
basicauth /res2 user2 pwd2`, false, "pwd", []Rule{
{Username: "user1", Resources: []string{"/res1"}},
{Username: "user2", Resources: []string{"/res2"}},
}},
{`basicauth user`, true, "", []Rule{}},
{`basicauth`, true, "", []Rule{}},
{`basicauth /resource user pwd asdf`, true, "", []Rule{}},
{`basicauth sha1 htpasswd=` + htfh.Name(), false, htpasswdPasswd, []Rule{
{Username: "sha1"},
}},
}
for i, test := range tests {
actual, err := basicAuthParse(caddy.NewTestController(test.input))
if err == nil && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if err != nil && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
}
if len(actual) != len(test.expected) {
t.Fatalf("Test %d expected %d rules, but got %d",
i, len(test.expected), len(actual))
}
for j, expectedRule := range test.expected {
actualRule := actual[j]
if actualRule.Username != expectedRule.Username {
t.Errorf("Test %d, rule %d: Expected username '%s', got '%s'",
i, j, expectedRule.Username, actualRule.Username)
}
if strings.Contains(test.input, "htpasswd=") && skipHtpassword {
continue
}
pwd := test.password
if len(actual) > 1 {
pwd = fmt.Sprintf("%s%d", pwd, j+1)
}
if !actualRule.Password(pwd) || actualRule.Password(test.password+"!") {
t.Errorf("Test %d, rule %d: Expected password '%v', got '%v'",
i, j, test.password, actualRule.Password(""))
}
expectedRes := fmt.Sprintf("%v", expectedRule.Resources)
actualRes := fmt.Sprintf("%v", actualRule.Resources)
if actualRes != expectedRes {
t.Errorf("Test %d, rule %d: Expected resource list %s, but got %s",
i, j, expectedRes, actualRes)
}
}
}
}
package expvar
import (
"expvar"
"fmt"
"net/http"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// ExpVar is a simple struct to hold expvar's configuration
type ExpVar struct {
Next httpserver.Handler
Resource Resource
}
// ServeHTTP handles requests to expvar's configured entry point with
// expvar, or passes all other requests up the chain.
func (e ExpVar) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if httpserver.Path(r.URL.Path).Matches(string(e.Resource)) {
expvarHandler(w, r)
return 0, nil
}
return e.Next.ServeHTTP(w, r)
}
// expvarHandler returns a JSON object will all the published variables.
//
// This is lifted straight from the expvar package.
func expvarHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
fmt.Fprintf(w, "{\n")
first := true
expvar.Do(func(kv expvar.KeyValue) {
if !first {
fmt.Fprintf(w, ",\n")
}
first = false
fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value)
})
fmt.Fprintf(w, "\n}\n")
}
// Resource contains the path to the expvar entry point
type Resource string
package expvar
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestExpVar(t *testing.T) {
rw := ExpVar{
Next: httpserver.HandlerFunc(contentHandler),
Resource: "/d/v",
}
tests := []struct {
from string
result int
}{
{"/d/v", 0},
{"/x/y", http.StatusOK},
}
for i, test := range tests {
req, err := http.NewRequest("GET", test.from, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request %v", i, err)
}
rec := httptest.NewRecorder()
result, err := rw.ServeHTTP(rec, req)
if err != nil {
t.Fatalf("Test %d: Could not ServeHTTP %v", i, err)
}
if result != test.result {
t.Errorf("Test %d: Expected Header '%d' but was '%d'",
i, test.result, result)
}
}
}
func contentHandler(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprintf(w, r.URL.String())
return http.StatusOK, nil
}
package expvar
import (
"expvar"
"runtime"
"sync"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func init() {
caddy.RegisterPlugin(caddy.Plugin{
Name: "expvar",
ServerType: "http",
Action: setup,
})
}
// setup configures a new ExpVar middleware instance.
func setup(c *caddy.Controller) error {
resource, err := expVarParse(c)
if err != nil {
return err
}
// publish any extra information/metrics we may want to capture
publishExtraVars()
ev := ExpVar{Resource: resource}
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
ev.Next = next
return ev
})
return nil
}
func expVarParse(c *caddy.Controller) (Resource, error) {
var resource Resource
var err error
for c.Next() {
args := c.RemainingArgs()
switch len(args) {
case 0:
resource = Resource(defaultExpvarPath)
case 1:
resource = Resource(args[0])
default:
return resource, c.ArgErr()
}
}
return resource, err
}
func publishExtraVars() {
// By using sync.Once instead of an init() function, we don't clutter
// the app's expvar export unnecessarily, or risk colliding with it.
publishOnce.Do(func() {
expvar.Publish("Goroutines", expvar.Func(func() interface{} {
return runtime.NumGoroutine()
}))
})
}
var publishOnce sync.Once // publishing variables should only be done once
var defaultExpvarPath = "/debug/vars"
package expvar
import (
"testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestSetup(t *testing.T) {
err := setup(caddy.NewTestController(`expvar`))
if err != nil {
t.Errorf("Expected no errors, got: %v", err)
}
mids := httpserver.GetConfig("").Middleware()
if len(mids) == 0 {
t.Fatal("Expected middleware, got 0 instead")
}
err = setup(caddy.NewTestController(`expvar /d/v`))
if err != nil {
t.Errorf("Expected no errors, got: %v", err)
}
mids = httpserver.GetConfig("").Middleware()
if len(mids) == 0 {
t.Fatal("Expected middleware, got 0 instead")
}
handler := mids[1](httpserver.EmptyNext)
myHandler, ok := handler.(ExpVar)
if !ok {
t.Fatalf("Expected handler to be type ExpVar, got: %#v", handler)
}
if myHandler.Resource != "/d/v" {
t.Errorf("Expected /d/v as expvar resource")
}
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
t.Error("'Next' field of handler was not set properly")
}
}
// Package extensions contains middleware for clean URLs.
//
// The root path of the site is passed in as well as possible extensions
// to try internally for paths requested that don't match an existing
// resource. The first path+ext combination that matches a valid file
// will be used.
package extensions
import (
"net/http"
"os"
"path"
"strings"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// Ext can assume an extension from clean URLs.
// It tries extensions in the order listed in Extensions.
type Ext struct {
// Next handler in the chain
Next httpserver.Handler
// Path to ther root of the site
Root string
// List of extensions to try
Extensions []string
}
// ServeHTTP implements the httpserver.Handler interface.
func (e Ext) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
urlpath := strings.TrimSuffix(r.URL.Path, "/")
if path.Ext(urlpath) == "" && len(r.URL.Path) > 0 && r.URL.Path[len(r.URL.Path)-1] != '/' {
for _, ext := range e.Extensions {
if resourceExists(e.Root, urlpath+ext) {
r.URL.Path = urlpath + ext
break
}
}
}
return e.Next.ServeHTTP(w, r)
}
// resourceExists returns true if the file specified at
// root + path exists; false otherwise.
func resourceExists(root, path string) bool {
_, err := os.Stat(root + path)
// technically we should use os.IsNotExist(err)
// but we don't handle any other kinds of errors anyway
return err == nil
}
package extensions
import (
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func init() {
caddy.RegisterPlugin(caddy.Plugin{
Name: "ext",
ServerType: "http",
Action: setup,
})
}
// setup configures a new instance of 'extensions' middleware for clean URLs.
func setup(c *caddy.Controller) error {
cfg := httpserver.GetConfig(c.Key)
root := cfg.Root
exts, err := extParse(c)
if err != nil {
return err
}
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Ext{
Next: next,
Extensions: exts,
Root: root,
}
})
return nil
}
// extParse sets up an instance of extension middleware
// from a middleware controller and returns a list of extensions.
func extParse(c *caddy.Controller) ([]string, error) {
var exts []string
for c.Next() {
// At least one extension is required
if !c.NextArg() {
return exts, c.ArgErr()
}
exts = append(exts, c.Val())
// Tack on any other extensions that may have been listed
exts = append(exts, c.RemainingArgs()...)
}
return exts, nil
}
package extensions
import (
"testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestSetup(t *testing.T) {
err := setup(caddy.NewTestController(`ext .html .htm .php`))
if err != nil {
t.Fatalf("Expected no errors, got: %v", err)
}
mids := httpserver.GetConfig("").Middleware()
if len(mids) == 0 {
t.Fatal("Expected middleware, had 0 instead")
}
handler := mids[0](httpserver.EmptyNext)
myHandler, ok := handler.(Ext)
if !ok {
t.Fatalf("Expected handler to be type Ext, got: %#v", handler)
}
if myHandler.Extensions[0] != ".html" {
t.Errorf("Expected .html in the list of Extensions")
}
if myHandler.Extensions[1] != ".htm" {
t.Errorf("Expected .htm in the list of Extensions")
}
if myHandler.Extensions[2] != ".php" {
t.Errorf("Expected .php in the list of Extensions")
}
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
t.Error("'Next' field of handler was not set properly")
}
}
func TestExtParse(t *testing.T) {
tests := []struct {
inputExts string
shouldErr bool
expectedExts []string
}{
{`ext .html .htm .php`, false, []string{".html", ".htm", ".php"}},
{`ext .php .html .xml`, false, []string{".php", ".html", ".xml"}},
{`ext .txt .php .xml`, false, []string{".txt", ".php", ".xml"}},
}
for i, test := range tests {
actualExts, err := extParse(caddy.NewTestController(test.inputExts))
if err == nil && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if err != nil && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
}
if len(actualExts) != len(test.expectedExts) {
t.Fatalf("Test %d expected %d rules, but got %d",
i, len(test.expectedExts), len(actualExts))
}
for j, actualExt := range actualExts {
if actualExt != test.expectedExts[j] {
t.Fatalf("Test %d expected %dth extension to be %s , but got %s",
i, j, test.expectedExts[j], actualExt)
}
}
}
}
// Package header provides middleware that appends headers to
// requests based on a set of configuration rules that define
// which routes receive which headers.
package header
import (
"net/http"
"strings"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// Headers is middleware that adds headers to the responses
// for requests matching a certain path.
type Headers struct {
Next httpserver.Handler
Rules []Rule
}
// ServeHTTP implements the httpserver.Handler interface and serves requests,
// setting headers on the response according to the configured rules.
func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
replacer := httpserver.NewReplacer(r, nil, "")
for _, rule := range h.Rules {
if httpserver.Path(r.URL.Path).Matches(rule.Path) {
for _, header := range rule.Headers {
if strings.HasPrefix(header.Name, "-") {
w.Header().Del(strings.TrimLeft(header.Name, "-"))
} else {
w.Header().Set(header.Name, replacer.Replace(header.Value))
}
}
}
}
return h.Next.ServeHTTP(w, r)
}
type (
// Rule groups a slice of HTTP headers by a URL pattern.
// TODO: use http.Header type instead?
Rule struct {
Path string
Headers []Header
}
// Header represents a single HTTP header, simply a name and value.
Header struct {
Name string
Value string
}
)
package header
import (
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestHeader(t *testing.T) {
hostname, err := os.Hostname()
if err != nil {
t.Fatalf("Could not determine hostname: %v", err)
}
for i, test := range []struct {
from string
name string
value string
}{
{"/a", "Foo", "Bar"},
{"/a", "Bar", ""},
{"/a", "Baz", ""},
{"/a", "ServerName", hostname},
{"/b", "Foo", ""},
{"/b", "Bar", "Removed in /a"},
} {
he := Headers{
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
return 0, nil
}),
Rules: []Rule{
{Path: "/a", Headers: []Header{
{Name: "Foo", Value: "Bar"},
{Name: "ServerName", Value: "{hostname}"},
{Name: "-Bar"},
}},
},
}
req, err := http.NewRequest("GET", test.from, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
rec := httptest.NewRecorder()
rec.Header().Set("Bar", "Removed in /a")
he.ServeHTTP(rec, req)
if got := rec.Header().Get(test.name); got != test.value {
t.Errorf("Test %d: Expected %s header to be %q but was %q",
i, test.name, test.value, got)
}
}
}
package header
import (
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func init() {
caddy.RegisterPlugin(caddy.Plugin{
Name: "header",
ServerType: "http",
Action: setup,
})
}
// setup configures a new Headers middleware instance.
func setup(c *caddy.Controller) error {
rules, err := headersParse(c)
if err != nil {
return err
}
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Headers{Next: next, Rules: rules}
})
return nil
}
func headersParse(c *caddy.Controller) ([]Rule, error) {
var rules []Rule
for c.NextLine() {
var head Rule
var isNewPattern bool
if !c.NextArg() {
return rules, c.ArgErr()
}
pattern := c.Val()
// See if we already have a definition for this Path pattern...
for _, h := range rules {
if h.Path == pattern {
head = h
break
}
}
// ...otherwise, this is a new pattern
if head.Path == "" {
head.Path = pattern
isNewPattern = true
}
for c.NextBlock() {
// A block of headers was opened...
h := Header{Name: c.Val()}
if c.NextArg() {
h.Value = c.Val()
}
head.Headers = append(head.Headers, h)
}
if c.NextArg() {
// ... or single header was defined as an argument instead.
h := Header{Name: c.Val()}
h.Value = c.Val()
if c.NextArg() {
h.Value = c.Val()
}
head.Headers = append(head.Headers, h)
}
if isNewPattern {
rules = append(rules, head)
} else {
for i := 0; i < len(rules); i++ {
if rules[i].Path == pattern {
rules[i] = head
break
}
}
}
}
return rules, nil
}
package header
import (
"fmt"
"testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestSetup(t *testing.T) {
err := setup(caddy.NewTestController(`header / Foo Bar`))
if err != nil {
t.Errorf("Expected no errors, but got: %v", err)
}
mids := httpserver.GetConfig("").Middleware()
if len(mids) == 0 {
t.Fatal("Expected middleware, had 0 instead")
}
handler := mids[0](httpserver.EmptyNext)
myHandler, ok := handler.(Headers)
if !ok {
t.Fatalf("Expected handler to be type Headers, got: %#v", handler)
}
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
t.Error("'Next' field of handler was not set properly")
}
}
func TestHeadersParse(t *testing.T) {
tests := []struct {
input string
shouldErr bool
expected []Rule
}{
{`header /foo Foo "Bar Baz"`,
false, []Rule{
{Path: "/foo", Headers: []Header{
{Name: "Foo", Value: "Bar Baz"},
}},
}},
{`header /bar { Foo "Bar Baz" Baz Qux }`,
false, []Rule{
{Path: "/bar", Headers: []Header{
{Name: "Foo", Value: "Bar Baz"},
{Name: "Baz", Value: "Qux"},
}},
}},
}
for i, test := range tests {
actual, err := headersParse(caddy.NewTestController(test.input))
if err == nil && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if err != nil && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
}
if len(actual) != len(test.expected) {
t.Fatalf("Test %d expected %d rules, but got %d",
i, len(test.expected), len(actual))
}
for j, expectedRule := range test.expected {
actualRule := actual[j]
if actualRule.Path != expectedRule.Path {
t.Errorf("Test %d, rule %d: Expected path %s, but got %s",
i, j, expectedRule.Path, actualRule.Path)
}
expectedHeaders := fmt.Sprintf("%v", expectedRule.Headers)
actualHeaders := fmt.Sprintf("%v", actualRule.Headers)
if actualHeaders != expectedHeaders {
t.Errorf("Test %d, rule %d: Expected headers %s, but got %s",
i, j, expectedHeaders, actualHeaders)
}
}
}
}
// Package internalsrv provides a simple middleware that (a) prevents access
// to internal locations and (b) allows to return files from internal location
// by setting a special header, e.g. in a proxy response.
//
// The package is named internalsrv so as not to conflict with Go tooling
// convention which treats folders called "internal" differently.
package internalsrv
import (
"net/http"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// Internal middleware protects internal locations from external requests -
// but allows access from the inside by using a special HTTP header.
type Internal struct {
Next httpserver.Handler
Paths []string
}
const (
redirectHeader string = "X-Accel-Redirect"
maxRedirectCount int = 10
)
func isInternalRedirect(w http.ResponseWriter) bool {
return w.Header().Get(redirectHeader) != ""
}
// ServeHTTP implements the httpserver.Handler interface.
func (i Internal) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// Internal location requested? -> Not found.
for _, prefix := range i.Paths {
if httpserver.Path(r.URL.Path).Matches(prefix) {
return http.StatusNotFound, nil
}
}
// Use internal response writer to ignore responses that will be
// redirected to internal locations
iw := internalResponseWriter{ResponseWriter: w}
status, err := i.Next.ServeHTTP(iw, r)
for c := 0; c < maxRedirectCount && isInternalRedirect(iw); c++ {
// Redirect - adapt request URL path and send it again
// "down the chain"
r.URL.Path = iw.Header().Get(redirectHeader)
iw.ClearHeader()
status, err = i.Next.ServeHTTP(iw, r)
}
if isInternalRedirect(iw) {
// Too many redirect cycles
iw.ClearHeader()
return http.StatusInternalServerError, nil
}
return status, err
}
// internalResponseWriter wraps the underlying http.ResponseWriter and ignores
// calls to Write and WriteHeader if the response should be redirected to an
// internal location.
type internalResponseWriter struct {
http.ResponseWriter
}
// ClearHeader removes all header fields that are already set.
func (w internalResponseWriter) ClearHeader() {
for k := range w.Header() {
w.Header().Del(k)
}
}
// WriteHeader ignores the call if the response should be redirected to an
// internal location.
func (w internalResponseWriter) WriteHeader(code int) {
if !isInternalRedirect(w) {
w.ResponseWriter.WriteHeader(code)
}
}
// Write ignores the call if the response should be redirected to an internal
// location.
func (w internalResponseWriter) Write(b []byte) (int, error) {
if isInternalRedirect(w) {
return 0, nil
}
return w.ResponseWriter.Write(b)
}
package internalsrv
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestInternal(t *testing.T) {
im := Internal{
Next: httpserver.HandlerFunc(internalTestHandlerFunc),
Paths: []string{"/internal"},
}
tests := []struct {
url string
expectedCode int
expectedBody string
}{
{"/internal", http.StatusNotFound, ""},
{"/public", 0, "/public"},
{"/public/internal", 0, "/public/internal"},
{"/redirect", 0, "/internal"},
{"/cycle", http.StatusInternalServerError, ""},
}
for i, test := range tests {
req, err := http.NewRequest("GET", test.url, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
rec := httptest.NewRecorder()
code, _ := im.ServeHTTP(rec, req)
if code != test.expectedCode {
t.Errorf("Test %d: Expected status code %d for %s, but got %d",
i, test.expectedCode, test.url, code)
}
if rec.Body.String() != test.expectedBody {
t.Errorf("Test %d: Expected body '%s' for %s, but got '%s'",
i, test.expectedBody, test.url, rec.Body.String())
}
}
}
func internalTestHandlerFunc(w http.ResponseWriter, r *http.Request) (int, error) {
switch r.URL.Path {
case "/redirect":
w.Header().Set("X-Accel-Redirect", "/internal")
case "/cycle":
w.Header().Set("X-Accel-Redirect", "/cycle")
}
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, r.URL.String())
return 0, nil
}
package internalsrv
import (
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func init() {
caddy.RegisterPlugin(caddy.Plugin{
Name: "internal",
ServerType: "http",
Action: setup,
})
}
// Internal configures a new Internal middleware instance.
func setup(c *caddy.Controller) error {
paths, err := internalParse(c)
if err != nil {
return err
}
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Internal{Next: next, Paths: paths}
})
return nil
}
func internalParse(c *caddy.Controller) ([]string, error) {
var paths []string
for c.Next() {
if !c.NextArg() {
return paths, c.ArgErr()
}
paths = append(paths, c.Val())
}
return paths, nil
}
package internalsrv
import (
"testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestSetup(t *testing.T) {
err := setup(caddy.NewTestController(`internal /internal`))
if err != nil {
t.Errorf("Expected no errors, got: %v", err)
}
mids := httpserver.GetConfig("").Middleware()
if len(mids) == 0 {
t.Fatal("Expected middleware, got 0 instead")
}
handler := mids[0](httpserver.EmptyNext)
myHandler, ok := handler.(Internal)
if !ok {
t.Fatalf("Expected handler to be type Internal, got: %#v", handler)
}
if myHandler.Paths[0] != "/internal" {
t.Errorf("Expected internal in the list of internal Paths")
}
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
t.Error("'Next' field of handler was not set properly")
}
}
func TestInternalParse(t *testing.T) {
tests := []struct {
inputInternalPaths string
shouldErr bool
expectedInternalPaths []string
}{
{`internal /internal`, false, []string{"/internal"}},
{`internal /internal1
internal /internal2`, false, []string{"/internal1", "/internal2"}},
}
for i, test := range tests {
actualInternalPaths, err := internalParse(caddy.NewTestController(test.inputInternalPaths))
if err == nil && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if err != nil && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
}
if len(actualInternalPaths) != len(test.expectedInternalPaths) {
t.Fatalf("Test %d expected %d InternalPaths, but got %d",
i, len(test.expectedInternalPaths), len(actualInternalPaths))
}
for j, actualInternalPath := range actualInternalPaths {
if actualInternalPath != test.expectedInternalPaths[j] {
t.Fatalf("Test %d expected %dth Internal Path to be %s , but got %s",
i, j, test.expectedInternalPaths[j], actualInternalPath)
}
}
}
}
package mime
import (
"net/http"
"path"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// Config represent a mime config. Map from extension to mime-type.
// Note, this should be safe with concurrent read access, as this is
// not modified concurrently.
type Config map[string]string
// Mime sets Content-Type header of requests based on configurations.
type Mime struct {
Next httpserver.Handler
Configs Config
}
// ServeHTTP implements the httpserver.Handler interface.
func (e Mime) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// Get a clean /-path, grab the extension
ext := path.Ext(path.Clean(r.URL.Path))
if contentType, ok := e.Configs[ext]; ok {
w.Header().Set("Content-Type", contentType)
}
return e.Next.ServeHTTP(w, r)
}
package mime
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestMimeHandler(t *testing.T) {
mimes := Config{
".html": "text/html",
".txt": "text/plain",
".swf": "application/x-shockwave-flash",
}
m := Mime{Configs: mimes}
w := httptest.NewRecorder()
exts := []string{
".html", ".txt", ".swf",
}
for _, e := range exts {
url := "/file" + e
r, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Error(err)
}
m.Next = nextFunc(true, mimes[e])
_, err = m.ServeHTTP(w, r)
if err != nil {
t.Error(err)
}
}
w = httptest.NewRecorder()
exts = []string{
".htm1", ".abc", ".mdx",
}
for _, e := range exts {
url := "/file" + e
r, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Error(err)
}
m.Next = nextFunc(false, "")
_, err = m.ServeHTTP(w, r)
if err != nil {
t.Error(err)
}
}
}
func nextFunc(shouldMime bool, contentType string) httpserver.Handler {
return httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
if shouldMime {
if w.Header().Get("Content-Type") != contentType {
return 0, fmt.Errorf("expected Content-Type: %v, found %v", contentType, r.Header.Get("Content-Type"))
}
return 0, nil
}
if w.Header().Get("Content-Type") != "" {
return 0, fmt.Errorf("Content-Type header not expected")
}
return 0, nil
})
}
package mime
import (
"fmt"
"strings"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func init() {
caddy.RegisterPlugin(caddy.Plugin{
Name: "mime",
ServerType: "http",
Action: setup,
})
}
// setup configures a new mime middleware instance.
func setup(c *caddy.Controller) error {
configs, err := mimeParse(c)
if err != nil {
return err
}
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Mime{Next: next, Configs: configs}
})
return nil
}
func mimeParse(c *caddy.Controller) (Config, error) {
configs := Config{}
for c.Next() {
// At least one extension is required
args := c.RemainingArgs()
switch len(args) {
case 2:
if err := validateExt(configs, args[0]); err != nil {
return configs, err
}
configs[args[0]] = args[1]
case 1:
return configs, c.ArgErr()
case 0:
for c.NextBlock() {
ext := c.Val()
if err := validateExt(configs, ext); err != nil {
return configs, err
}
if !c.NextArg() {
return configs, c.ArgErr()
}
configs[ext] = c.Val()
}
}
}
return configs, nil
}
// validateExt checks for valid file name extension.
func validateExt(configs Config, ext string) error {
if !strings.HasPrefix(ext, ".") {
return fmt.Errorf(`mime: invalid extension "%v" (must start with dot)`, ext)
}
if _, ok := configs[ext]; ok {
return fmt.Errorf(`mime: duplicate extension "%v" found`, ext)
}
return nil
}
package mime
import (
"testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestSetup(t *testing.T) {
err := setup(caddy.NewTestController(`mime .txt text/plain`))
if err != nil {
t.Errorf("Expected no errors, but got: %v", err)
}
mids := httpserver.GetConfig("").Middleware()
if len(mids) == 0 {
t.Fatal("Expected middleware, but had 0 instead")
}
handler := mids[0](httpserver.EmptyNext)
myHandler, ok := handler.(Mime)
if !ok {
t.Fatalf("Expected handler to be type Mime, got: %#v", handler)
}
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
t.Error("'Next' field of handler was not set properly")
}
tests := []struct {
input string
shouldErr bool
}{
{`mime {`, true},
{`mime {}`, true},
{`mime a b`, true},
{`mime a {`, true},
{`mime { txt f } `, true},
{`mime { html } `, true},
{`mime {
.html text/html
.txt text/plain
} `, false},
{`mime {
.foo text/foo
.bar text/bar
.foo text/foobar
} `, true},
{`mime { .html text/html } `, false},
{`mime { .html
} `, true},
{`mime .txt text/plain`, false},
}
for i, test := range tests {
m, err := mimeParse(caddy.NewTestController(test.input))
if test.shouldErr && err == nil {
t.Errorf("Test %v: Expected error but found nil %v", i, m)
} else if !test.shouldErr && err != nil {
t.Errorf("Test %v: Expected no error but found error: %v", i, err)
}
}
}
package pprof
import (
"net/http"
pp "net/http/pprof"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// BasePath is the base path to match for all pprof requests.
const BasePath = "/debug/pprof"
// Handler is a simple struct whose ServeHTTP will delegate pprof
// endpoints to their equivalent net/http/pprof handlers.
type Handler struct {
Next httpserver.Handler
Mux *http.ServeMux
}
// ServeHTTP handles requests to BasePath with pprof, or passes
// all other requests up the chain.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if httpserver.Path(r.URL.Path).Matches(BasePath) {
h.Mux.ServeHTTP(w, r)
return 0, nil
}
return h.Next.ServeHTTP(w, r)
}
// NewMux returns a new http.ServeMux that routes pprof requests.
// It pretty much copies what the std lib pprof does on init:
// https://golang.org/src/net/http/pprof/pprof.go#L67
func NewMux() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc(BasePath+"/", pp.Index)
mux.HandleFunc(BasePath+"/cmdline", pp.Cmdline)
mux.HandleFunc(BasePath+"/profile", pp.Profile)
mux.HandleFunc(BasePath+"/symbol", pp.Symbol)
mux.HandleFunc(BasePath+"/trace", pp.Trace)
return mux
}
package pprof
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestServeHTTP(t *testing.T) {
h := Handler{
Next: httpserver.HandlerFunc(nextHandler),
Mux: NewMux(),
}
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/debug/pprof", nil)
if err != nil {
t.Fatal(err)
}
status, err := h.ServeHTTP(w, r)
if status != 0 {
t.Errorf("Expected status %d but got %d", 0, status)
}
if err != nil {
t.Errorf("Expected nil error, but got: %v", err)
}
if w.Body.String() == "content" {
t.Errorf("Expected pprof to handle request, but it didn't")
}
w = httptest.NewRecorder()
r, err = http.NewRequest("GET", "/foo", nil)
if err != nil {
t.Fatal(err)
}
status, err = h.ServeHTTP(w, r)
if status != http.StatusNotFound {
t.Errorf("Test two: Expected status %d but got %d", http.StatusNotFound, status)
}
if err != nil {
t.Errorf("Test two: Expected nil error, but got: %v", err)
}
if w.Body.String() != "content" {
t.Errorf("Expected pprof to pass the request thru, but it didn't; got: %s", w.Body.String())
}
}
func nextHandler(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprintf(w, "content")
return http.StatusNotFound, nil
}
package pprof
import (
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func init() {
caddy.RegisterPlugin(caddy.Plugin{
Name: "pprof",
ServerType: "http",
Action: setup,
})
}
// setup returns a new instance of a pprof handler. It accepts no arguments or options.
func setup(c *caddy.Controller) error {
found := false
for c.Next() {
if found {
return c.Err("pprof can only be specified once")
}
if len(c.RemainingArgs()) != 0 {
return c.ArgErr()
}
if c.NextBlock() {
return c.ArgErr()
}
found = true
}
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return &Handler{Next: next, Mux: NewMux()}
})
return nil
}
package pprof
import (
"testing"
"github.com/mholt/caddy"
)
func TestSetup(t *testing.T) {
tests := []struct {
input string
shouldErr bool
}{
{`pprof`, false},
{`pprof {}`, true},
{`pprof /foo`, true},
{`pprof {
a b
}`, true},
{`pprof
pprof`, true},
}
for i, test := range tests {
err := setup(caddy.NewTestController(test.input))
if test.shouldErr && err == nil {
t.Errorf("Test %v: Expected error but found nil", i)
} else if !test.shouldErr && err != nil {
t.Errorf("Test %v: Expected no error but found error: %v", i, err)
}
}
}
package proxy
import (
"math/rand"
"sync/atomic"
)
// HostPool is a collection of UpstreamHosts.
type HostPool []*UpstreamHost
// Policy decides how a host will be selected from a pool.
type Policy interface {
Select(pool HostPool) *UpstreamHost
}
func init() {
RegisterPolicy("random", func() Policy { return &Random{} })
RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
}
// Random is a policy that selects up hosts from a pool at random.
type Random struct{}
// Select selects an up host at random from the specified pool.
func (r *Random) Select(pool HostPool) *UpstreamHost {
// instead of just generating a random index
// this is done to prevent selecting a unavailable host
var randHost *UpstreamHost
count := 0
for _, host := range pool {
if !host.Available() {
continue
}
count++
if count == 1 {
randHost = host
} else {
r := rand.Int() % count
if r == (count - 1) {
randHost = host
}
}
}
return randHost
}
// LeastConn is a policy that selects the host with the least connections.
type LeastConn struct{}
// Select selects the up host with the least number of connections in the
// pool. If more than one host has the same least number of connections,
// one of the hosts is chosen at random.
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
var bestHost *UpstreamHost
count := 0
leastConn := int64(1<<63 - 1)
for _, host := range pool {
if !host.Available() {
continue
}
hostConns := host.Conns
if hostConns < leastConn {
bestHost = host
leastConn = hostConns
count = 1
} else if hostConns == leastConn {
// randomly select host among hosts with least connections
count++
if count == 1 {
bestHost = host
} else {
r := rand.Int() % count
if r == (count - 1) {
bestHost = host
}
}
}
}
return bestHost
}
// RoundRobin is a policy that selects hosts based on round robin ordering.
type RoundRobin struct {
Robin uint32
}
// Select selects an up host from the pool using a round robin ordering scheme.
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
poolLen := uint32(len(pool))
selection := atomic.AddUint32(&r.Robin, 1) % poolLen
host := pool[selection]
// if the currently selected host is not available, just ffwd to up host
for i := uint32(1); !host.Available() && i < poolLen; i++ {
host = pool[(selection+i)%poolLen]
}
if !host.Available() {
return nil
}
return host
}
package proxy
import (
"net/http"
"net/http/httptest"
"os"
"testing"
)
var workableServer *httptest.Server
func TestMain(m *testing.M) {
workableServer = httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// do nothing
}))
r := m.Run()
workableServer.Close()
os.Exit(r)
}
type customPolicy struct{}
func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
return pool[0]
}
func testPool() HostPool {
pool := []*UpstreamHost{
{
Name: workableServer.URL, // this should resolve (healthcheck test)
},
{
Name: "http://shouldnot.resolve", // this shouldn't
},
{
Name: "http://C",
},
}
return HostPool(pool)
}
func TestRoundRobinPolicy(t *testing.T) {
pool := testPool()
rrPolicy := &RoundRobin{}
h := rrPolicy.Select(pool)
// First selected host is 1, because counter starts at 0
// and increments before host is selected
if h != pool[1] {
t.Error("Expected first round robin host to be second host in the pool.")
}
h = rrPolicy.Select(pool)
if h != pool[2] {
t.Error("Expected second round robin host to be third host in the pool.")
}
h = rrPolicy.Select(pool)
if h != pool[0] {
t.Error("Expected third round robin host to be first host in the pool.")
}
// mark host as down
pool[1].Unhealthy = true
h = rrPolicy.Select(pool)
if h != pool[2] {
t.Error("Expected to skip down host.")
}
// mark host as full
pool[2].Conns = 1
pool[2].MaxConns = 1
h = rrPolicy.Select(pool)
if h != pool[0] {
t.Error("Expected to skip full host.")
}
}
func TestLeastConnPolicy(t *testing.T) {
pool := testPool()
lcPolicy := &LeastConn{}
pool[0].Conns = 10
pool[1].Conns = 10
h := lcPolicy.Select(pool)
if h != pool[2] {
t.Error("Expected least connection host to be third host.")
}
pool[2].Conns = 100
h = lcPolicy.Select(pool)
if h != pool[0] && h != pool[1] {
t.Error("Expected least connection host to be first or second host.")
}
}
func TestCustomPolicy(t *testing.T) {
pool := testPool()
customPolicy := &customPolicy{}
h := customPolicy.Select(pool)
if h != pool[0] {
t.Error("Expected custom policy host to be the first host.")
}
}
// Package proxy is middleware that proxies HTTP requests.
package proxy
import (
"errors"
"net"
"net/http"
"net/url"
"strings"
"sync/atomic"
"time"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
var errUnreachable = errors.New("unreachable backend")
// Proxy represents a middleware instance that can proxy requests.
type Proxy struct {
Next httpserver.Handler
Upstreams []Upstream
}
// Upstream manages a pool of proxy upstream hosts. Select should return a
// suitable upstream host, or nil if no such hosts are available.
type Upstream interface {
// The path this upstream host should be routed on
From() string
// Selects an upstream host to be routed to.
Select() *UpstreamHost
// Checks if subpath is not an ignored path
AllowedPath(string) bool
}
// UpstreamHostDownFunc can be used to customize how Down behaves.
type UpstreamHostDownFunc func(*UpstreamHost) bool
// UpstreamHost represents a single proxy upstream
type UpstreamHost struct {
Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
Name string // hostname of this upstream host
ReverseProxy *ReverseProxy
Fails int32
FailTimeout time.Duration
Unhealthy bool
UpstreamHeaders http.Header
DownstreamHeaders http.Header
CheckDown UpstreamHostDownFunc
WithoutPathPrefix string
MaxConns int64
}
// Down checks whether the upstream host is down or not.
// Down will try to use uh.CheckDown first, and will fall
// back to some default criteria if necessary.
func (uh *UpstreamHost) Down() bool {
if uh.CheckDown == nil {
// Default settings
return uh.Unhealthy || uh.Fails > 0
}
return uh.CheckDown(uh)
}
// Full checks whether the upstream host has reached its maximum connections
func (uh *UpstreamHost) Full() bool {
return uh.MaxConns > 0 && uh.Conns >= uh.MaxConns
}
// Available checks whether the upstream host is available for proxying to
func (uh *UpstreamHost) Available() bool {
return !uh.Down() && !uh.Full()
}
// tryDuration is how long to try upstream hosts; failures result in
// immediate retries until this duration ends or we get a nil host.
var tryDuration = 60 * time.Second
// ServeHTTP satisfies the httpserver.Handler interface.
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, upstream := range p.Upstreams {
if !httpserver.Path(r.URL.Path).Matches(upstream.From()) ||
!upstream.AllowedPath(r.URL.Path) {
continue
}
var replacer httpserver.Replacer
start := time.Now()
outreq := createUpstreamRequest(r)
// Since Select() should give us "up" hosts, keep retrying
// hosts until timeout (or until we get a nil host).
for time.Now().Sub(start) < tryDuration {
host := upstream.Select()
if host == nil {
return http.StatusBadGateway, errUnreachable
}
if rr, ok := w.(*httpserver.ResponseRecorder); ok && rr.Replacer != nil {
rr.Replacer.Set("upstream", host.Name)
}
outreq.Host = host.Name
if host.UpstreamHeaders != nil {
if replacer == nil {
rHost := r.Host
replacer = httpserver.NewReplacer(r, nil, "")
outreq.Host = rHost
}
if v, ok := host.UpstreamHeaders["Host"]; ok {
outreq.Host = replacer.Replace(v[len(v)-1])
}
// Modify headers for request that will be sent to the upstream host
upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer)
for k, v := range upHeaders {
outreq.Header[k] = v
}
}
var downHeaderUpdateFn respUpdateFn
if host.DownstreamHeaders != nil {
if replacer == nil {
rHost := r.Host
replacer = httpserver.NewReplacer(r, nil, "")
outreq.Host = rHost
}
//Creates a function that is used to update headers the response received by the reverse proxy
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
}
proxy := host.ReverseProxy
if baseURL, err := url.Parse(host.Name); err == nil {
r.Host = baseURL.Host
if proxy == nil {
proxy = NewSingleHostReverseProxy(baseURL, host.WithoutPathPrefix)
}
} else if proxy == nil {
return http.StatusInternalServerError, err
}
atomic.AddInt64(&host.Conns, 1)
backendErr := proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
atomic.AddInt64(&host.Conns, -1)
if backendErr == nil {
return 0, nil
}
timeout := host.FailTimeout
if timeout == 0 {
timeout = 10 * time.Second
}
atomic.AddInt32(&host.Fails, 1)
go func(host *UpstreamHost, timeout time.Duration) {
time.Sleep(timeout)
atomic.AddInt32(&host.Fails, -1)
}(host, timeout)
}
return http.StatusBadGateway, errUnreachable
}
return p.Next.ServeHTTP(w, r)
}
// createUpstremRequest shallow-copies r into a new request
// that can be sent upstream.
func createUpstreamRequest(r *http.Request) *http.Request {
outreq := new(http.Request)
*outreq = *r // includes shallow copies of maps, but okay
// Restore URL Path if it has been modified
if outreq.URL.RawPath != "" {
outreq.URL.Opaque = outreq.URL.RawPath
}
// Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us. This
// is modifying the same underlying map from r (shallow
// copied above) so we only copy it if necessary.
for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" {
outreq.Header = make(http.Header)
copyHeader(outreq.Header, r.Header)
outreq.Header.Del(h)
}
}
if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
// If we aren't the first proxy, retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
outreq.Header.Set("X-Forwarded-For", clientIP)
}
return outreq
}
func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn {
return func(resp *http.Response) {
newHeaders := createHeadersByRules(rules, resp.Header, replacer)
for h, v := range newHeaders {
resp.Header[h] = v
}
}
}
func createHeadersByRules(rules http.Header, base http.Header, repl httpserver.Replacer) http.Header {
newHeaders := make(http.Header)
for header, values := range rules {
if strings.HasPrefix(header, "+") {
header = strings.TrimLeft(header, "+")
add(newHeaders, header, base[header])
applyEach(values, repl.Replace)
add(newHeaders, header, values)
} else if strings.HasPrefix(header, "-") {
base.Del(strings.TrimLeft(header, "-"))
} else if _, ok := base[header]; ok {
applyEach(values, repl.Replace)
for _, v := range values {
newHeaders.Set(header, v)
}
} else {
applyEach(values, repl.Replace)
add(newHeaders, header, values)
add(newHeaders, header, base[header])
}
}
return newHeaders
}
func applyEach(values []string, mapFn func(string) string) {
for i, v := range values {
values[i] = mapFn(v)
}
}
func add(base http.Header, header string, values []string) {
for _, v := range values {
base.Add(header, v)
}
}
package proxy
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/mholt/caddy/caddyhttp/httpserver"
"golang.org/x/net/websocket"
)
func init() {
tryDuration = 50 * time.Millisecond // prevent tests from hanging
}
func TestReverseProxy(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var requestReceived bool
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestReceived = true
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
// set up proxy
p := &Proxy{
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
if !requestReceived {
t.Error("Expected backend to receive request, but it didn't")
}
// Make sure {upstream} placeholder is set
rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
rr.Replacer = httpserver.NewReplacer(r, rr, "-")
p.ServeHTTP(rr, r)
if got, want := rr.Replacer.Replace("{upstream}"), backend.URL; got != want {
t.Errorf("Expected custom placeholder {upstream} to be set (%s), but it wasn't; got: %s", want, got)
}
}
func TestReverseProxyInsecureSkipVerify(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var requestReceived bool
backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestReceived = true
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
// set up proxy
p := &Proxy{
Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
if !requestReceived {
t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't")
}
}
func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
// No-op websocket backend simply allows the WS connection to be
// accepted then it will be immediately closed. Perfect for testing.
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {}))
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL)
// Create client request
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
r.Header = http.Header{
"Connection": {"Upgrade"},
"Upgrade": {"websocket"},
"Origin": {wsNop.URL},
"Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="},
"Sec-WebSocket-Version": {"13"},
}
// Capture the request
w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)}
// Booya! Do the test.
p.ServeHTTP(w, r)
// Make sure the backend accepted the WS connection.
// Mostly interested in the Upgrade and Connection response headers
// and the 101 status code.
expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n")
actual := w.fakeConn.writeBuf.Bytes()
if !bytes.Equal(actual, expected) {
t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual)
}
}
func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
// Echo server allows us to test that socket bytes are properly
// being proxied.
wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
io.Copy(ws, ws)
}))
defer wsEcho.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsEcho.URL)
// This is a full end-end test, so the proxy handler
// has to be part of a server listening on a port. Our
// WS client will connect to this test server, not
// the echo client directly.
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
// Set up WebSocket client
url := strings.Replace(echoProxy.URL, "http://", "ws://", 1)
ws, err := websocket.Dial(url, "", echoProxy.URL)
if err != nil {
t.Fatal(err)
}
defer ws.Close()
// Send test message
trialMsg := "Is it working?"
websocket.Message.Send(ws, trialMsg)
// It should be echoed back to us
var actualMsg string
websocket.Message.Receive(ws, &actualMsg)
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func TestUnixSocketProxy(t *testing.T) {
if runtime.GOOS == "windows" {
return
}
trialMsg := "Is it working?"
var proxySuccess bool
// This is our fake "application" we want to proxy to
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Request was proxied when this is called
proxySuccess = true
fmt.Fprint(w, trialMsg)
}))
// Get absolute path for unix: socket
socketPath, err := filepath.Abs("./test_socket")
if err != nil {
t.Fatalf("Unable to get absolute path: %v", err)
}
// Change httptest.Server listener to listen to unix: socket
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Unable to listen: %v", err)
}
ts.Listener = ln
ts.Start()
defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
res, err := http.Get(echoProxy.URL)
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
actualMsg := fmt.Sprintf("%s", greeting)
if !proxySuccess {
t.Errorf("Expected request to be proxied, but it wasn't")
}
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func GetHTTPProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, messageFormat, r.URL.String())
}))
return newPrefixedWebSocketTestProxy(ts.URL, prefix), ts
}
func GetSocketProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server, error) {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, messageFormat, r.URL.String())
}))
socketPath, err := filepath.Abs("./test_socket")
if err != nil {
return nil, nil, fmt.Errorf("Unable to get absolute path: %v", err)
}
ln, err := net.Listen("unix", socketPath)
if err != nil {
return nil, nil, fmt.Errorf("Unable to listen: %v", err)
}
ts.Listener = ln
ts.Start()
tsURL := strings.Replace(ts.URL, "http://", "unix:", 1)
return newPrefixedWebSocketTestProxy(tsURL, prefix), ts, nil
}
func GetTestServerMessage(p *Proxy, ts *httptest.Server, path string) (string, error) {
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
// *httptest.Server is passed so it can be `defer`red properly
defer ts.Close()
defer echoProxy.Close()
res, err := http.Get(echoProxy.URL + path)
if err != nil {
return "", fmt.Errorf("Unable to GET: %v", err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
return "", fmt.Errorf("Unable to read body: %v", err)
}
return fmt.Sprintf("%s", greeting), nil
}
func TestUnixSocketProxyPaths(t *testing.T) {
greeting := "Hello route %s"
tests := []struct {
url string
prefix string
expected string
}{
{"", "", fmt.Sprintf(greeting, "/")},
{"/hello", "", fmt.Sprintf(greeting, "/hello")},
{"/foo/bar", "", fmt.Sprintf(greeting, "/foo/bar")},
{"/foo?bar", "", fmt.Sprintf(greeting, "/foo?bar")},
{"/greet?name=john", "", fmt.Sprintf(greeting, "/greet?name=john")},
{"/world?wonderful&colorful", "", fmt.Sprintf(greeting, "/world?wonderful&colorful")},
{"/proxy/hello", "/proxy", fmt.Sprintf(greeting, "/hello")},
{"/proxy/foo/bar", "/proxy", fmt.Sprintf(greeting, "/foo/bar")},
{"/proxy/?foo=bar", "/proxy", fmt.Sprintf(greeting, "/?foo=bar")},
{"/queues/%2F/fetchtasks", "", fmt.Sprintf(greeting, "/queues/%2F/fetchtasks")},
{"/queues/%2F/fetchtasks?foo=bar", "", fmt.Sprintf(greeting, "/queues/%2F/fetchtasks?foo=bar")},
}
for _, test := range tests {
p, ts := GetHTTPProxy(greeting, test.prefix)
actualMsg, err := GetTestServerMessage(p, ts, test.url)
if err != nil {
t.Fatalf("Getting server message failed - %v", err)
}
if actualMsg != test.expected {
t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg)
}
}
if runtime.GOOS == "windows" {
return
}
for _, test := range tests {
p, ts, err := GetSocketProxy(greeting, test.prefix)
if err != nil {
t.Fatalf("Getting socket proxy failed - %v", err)
}
actualMsg, err := GetTestServerMessage(p, ts, test.url)
if err != nil {
t.Fatalf("Getting server message failed - %v", err)
}
if actualMsg != test.expected {
t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg)
}
}
}
func TestUpstreamHeadersUpdate(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var actualHeaders http.Header
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, client"))
actualHeaders = r.Header
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.UpstreamHeaders = http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"},
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
"-Remove-Me": {""},
"Replace-Me": {"{hostname}"},
}
// set up proxy
p := &Proxy{
Upstreams: []Upstream{upstream},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
//add initial headers
r.Header.Add("Merge-Me", "Initial")
r.Header.Add("Remove-Me", "Remove-Value")
r.Header.Add("Replace-Me", "Replace-Value")
p.ServeHTTP(w, r)
replacer := httpserver.NewReplacer(r, nil, "")
headerKey := "Merge-Me"
values, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Request sent to upstream backend does not contain expected %v header. Expected header to be added", headerKey)
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
t.Errorf("Values for proxy header `+Merge-Me` should be merged. Got %v", values)
}
headerKey = "Add-Me"
if _, ok := actualHeaders[headerKey]; !ok {
t.Errorf("Request sent to upstream backend does not contain expected %v header", headerKey)
}
headerKey = "Remove-Me"
if _, ok := actualHeaders[headerKey]; ok {
t.Errorf("Request sent to upstream backend should not contain %v header", headerKey)
}
headerKey = "Replace-Me"
headerValue := replacer.Replace("{hostname}")
value, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Request sent to upstream backend should not remove %v header", headerKey)
} else if len(value) > 0 && headerValue != value[0] {
t.Errorf("Request sent to upstream backend should replace value of %v header with %v. Instead value was %v", headerKey, headerValue, value)
}
}
func TestDownstreamHeadersUpdate(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Merge-Me", "Initial")
w.Header().Add("Remove-Me", "Remove-Value")
w.Header().Add("Replace-Me", "Replace-Value")
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.DownstreamHeaders = http.Header{
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
"-Remove-Me": {""},
"Replace-Me": {"{hostname}"},
}
// set up proxy
p := &Proxy{
Upstreams: []Upstream{upstream},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
replacer := httpserver.NewReplacer(r, nil, "")
actualHeaders := w.Header()
headerKey := "Merge-Me"
values, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Downstream response does not contain expected %v header. Expected header should be added", headerKey)
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
t.Errorf("Values for header `+Merge-Me` should be merged. Got %v", values)
}
headerKey = "Add-Me"
if _, ok := actualHeaders[headerKey]; !ok {
t.Errorf("Downstream response does not contain expected %v header", headerKey)
}
headerKey = "Remove-Me"
if _, ok := actualHeaders[headerKey]; ok {
t.Errorf("Downstream response should not contain %v header received from upstream", headerKey)
}
headerKey = "Replace-Me"
headerValue := replacer.Replace("{hostname}")
value, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Downstream response should contain %v header and not remove it", headerKey)
} else if len(value) > 0 && headerValue != value[0] {
t.Errorf("Downstream response should have header %v with value %v. Instead value was %v", headerKey, headerValue, value)
}
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name)
u := &fakeUpstream{
name: name,
host: &UpstreamHost{
Name: name,
ReverseProxy: NewSingleHostReverseProxy(uri, ""),
},
}
if insecure {
u.host.ReverseProxy.Transport = InsecureTransport
}
return u
}
type fakeUpstream struct {
name string
host *UpstreamHost
}
func (u *fakeUpstream) From() string {
return "/"
}
func (u *fakeUpstream) Select() *UpstreamHost {
return u.host
}
func (u *fakeUpstream) AllowedPath(requestPath string) bool {
return true
}
// newWebSocketTestProxy returns a test proxy that will
// redirect to the specified backendAddr. The function
// also sets up the rules/environment for testing WebSocket
// proxy.
func newWebSocketTestProxy(backendAddr string) *Proxy {
return &Proxy{
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}},
}
}
func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
return &Proxy{
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}},
}
}
type fakeWsUpstream struct {
name string
without string
}
func (u *fakeWsUpstream) From() string {
return "/"
}
func (u *fakeWsUpstream) Select() *UpstreamHost {
uri, _ := url.Parse(u.name)
return &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without),
UpstreamHeaders: http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}},
}
}
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool {
return true
}
// recorderHijacker is a ResponseRecorder that can
// be hijacked.
type recorderHijacker struct {
*httptest.ResponseRecorder
fakeConn *fakeConn
}
func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return rh.fakeConn, nil, nil
}
type fakeConn struct {
readBuf bytes.Buffer
writeBuf bytes.Buffer
}
func (c *fakeConn) LocalAddr() net.Addr { return nil }
func (c *fakeConn) RemoteAddr() net.Addr { return nil }
func (c *fakeConn) SetDeadline(t time.Time) error { return nil }
func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil }
func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
// This file is adapted from code in the net/http/httputil
// package of the Go standard library, which is by the
// Go Authors, and bears this copyright and license info:
//
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// This file has been modified from the standard lib to
// meet the needs of the application.
package proxy
import (
"crypto/tls"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// onExitFlushLoop is a callback set by tests to detect the state of the
// flushLoop() goroutine.
var onExitFlushLoop func()
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
type ReverseProxy struct {
// Director must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
Director func(*http.Request)
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// FlushInterval specifies the flush interval
// to flush to the client while copying the
// response body.
// If zero, no periodic flushing is done.
FlushInterval time.Duration
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
// Though the relevant directive prefix is just "unix:", url.Parse
// will - assuming the regular URL scheme - add additional slashes
// as if "unix" was a request protocol.
// What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):])
}
}
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
// Without logic: target's path is "/", incoming is "/api/messages",
// without is "/api", then the target request will be for /messages.
func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
if target.Scheme == "unix" {
// to make Dial work with unix URL,
// scheme and host have to be faked
req.URL.Scheme = "http"
req.URL.Host = "socket"
} else {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
}
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
// Trims the path of the socket from the URL path.
// This is done because req.URL passed to your proxied service
// will have the full path of the socket file prefixed to it.
// Calling /test on a server that proxies requests to
// unix:/var/run/www.socket will thus set the requested path
// to /var/run/www.socket/test, rendering paths useless.
if target.Scheme == "unix" {
// See comment on socketDial for the trim
socketPrefix := target.String()[len("unix://"):]
req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix)
}
// We are then safe to remove the `without` prefix.
if without != "" {
req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
}
}
rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events
if target.Scheme == "unix" {
rp.Transport = &http.Transport{
Dial: socketDial(target.String()),
}
}
return rp
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
// InsecureTransport is used to facilitate HTTPS proxying
// when it is OK for upstream to be using a bad certificate,
// since this transport skips verification.
var InsecureTransport http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
type respUpdateFn func(resp *http.Response)
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
}
p.Director(outreq)
outreq.Proto = "HTTP/1.1"
outreq.ProtoMajor = 1
outreq.ProtoMinor = 1
outreq.Close = false
res, err := transport.RoundTrip(outreq)
if err != nil {
return err
} else if respUpdateFn != nil {
respUpdateFn(res)
}
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
res.Body.Close()
hj, ok := rw.(http.Hijacker)
if !ok {
return nil
}
conn, _, err := hj.Hijack()
if err != nil {
return err
}
defer conn.Close()
backendConn, err := net.Dial("tcp", outreq.URL.Host)
if err != nil {
return err
}
defer backendConn.Close()
outreq.Write(backendConn)
go func() {
io.Copy(backendConn, conn) // write tcp stream to backend.
}()
io.Copy(conn, backendConn) // read tcp stream from backend.
} else {
defer res.Body.Close()
for _, h := range hopHeaders {
res.Header.Del(h)
}
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
p.copyResponse(rw, res.Body)
}
return nil
}
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
if p.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: p.FlushInterval,
done: make(chan bool),
}
go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
}
io.Copy(dst, src)
}
type writeFlusher interface {
io.Writer
http.Flusher
}
type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
lk sync.Mutex // protects Write + Flush
done chan bool
}
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.lk.Lock()
defer m.lk.Unlock()
return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
t := time.NewTicker(m.latency)
defer t.Stop()
for {
select {
case <-m.done:
if onExitFlushLoop != nil {
onExitFlushLoop()
}
return
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
}
}
}
func (m *maxLatencyWriter) stop() { m.done <- true }
package proxy
import (
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func init() {
caddy.RegisterPlugin(caddy.Plugin{
Name: "proxy",
ServerType: "http",
Action: setup,
})
}
// setup configures a new Proxy middleware instance.
func setup(c *caddy.Controller) error {
upstreams, err := NewStaticUpstreams(c.Dispenser)
if err != nil {
return err
}
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Proxy{Next: next, Upstreams: upstreams}
})
return nil
}
package proxy
import (
"reflect"
"testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestSetup(t *testing.T) {
for i, test := range []struct {
input string
shouldErr bool
expectedHosts map[string]struct{}
}{
// test #0 test usual to destination still works normally
{
"proxy / localhost:80",
false,
map[string]struct{}{
"http://localhost:80": {},
},
},
// test #1 test usual to destination with port range
{
"proxy / localhost:8080-8082",
false,
map[string]struct{}{
"http://localhost:8080": {},
"http://localhost:8081": {},
"http://localhost:8082": {},
},
},
// test #2 test upstream directive
{
"proxy / {\n upstream localhost:8080\n}",
false,
map[string]struct{}{
"http://localhost:8080": {},
},
},
// test #3 test upstream directive with port range
{
"proxy / {\n upstream localhost:8080-8081\n}",
false,
map[string]struct{}{
"http://localhost:8080": {},
"http://localhost:8081": {},
},
},
// test #4 test to destination with upstream directive
{
"proxy / localhost:8080 {\n upstream localhost:8081-8082\n}",
false,
map[string]struct{}{
"http://localhost:8080": {},
"http://localhost:8081": {},
"http://localhost:8082": {},
},
},
// test #5 test with unix sockets
{
"proxy / localhost:8080 {\n upstream unix:/var/foo\n}",
false,
map[string]struct{}{
"http://localhost:8080": {},
"unix:/var/foo": {},
},
},
// test #6 test fail on malformed port range
{
"proxy / localhost:8090-8080",
true,
nil,
},
// test #7 test fail on malformed port range 2
{
"proxy / {\n upstream localhost:80-A\n}",
true,
nil,
},
// test #8 test upstreams without ports work correctly
{
"proxy / http://localhost {\n upstream testendpoint\n}",
false,
map[string]struct{}{
"http://localhost": {},
"http://testendpoint": {},
},
},
// test #9 test several upstream directives
{
"proxy / localhost:8080 {\n upstream localhost:8081-8082\n upstream localhost:8083-8085\n}",
false,
map[string]struct{}{
"http://localhost:8080": {},
"http://localhost:8081": {},
"http://localhost:8082": {},
"http://localhost:8083": {},
"http://localhost:8084": {},
"http://localhost:8085": {},
},
},
} {
err := setup(caddy.NewTestController(test.input))
if err != nil && !test.shouldErr {
t.Errorf("Test case #%d received an error of %v", i, err)
} else if test.shouldErr {
continue
}
mids := httpserver.GetConfig("").Middleware()
mid := mids[len(mids)-1]
upstreams := mid(nil).(Proxy).Upstreams
for _, upstream := range upstreams {
val := reflect.ValueOf(upstream).Elem()
hosts := val.FieldByName("Hosts").Interface().(HostPool)
if len(hosts) != len(test.expectedHosts) {
t.Errorf("Test case #%d expected %d hosts but received %d", i, len(test.expectedHosts), len(hosts))
} else {
for _, host := range hosts {
if _, found := test.expectedHosts[host.Name]; !found {
t.Errorf("Test case #%d has an unexpected host %s", i, host.Name)
}
}
}
}
}
}
package proxy
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/mholt/caddy/caddyfile"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
var (
supportedPolicies = make(map[string]func() Policy)
)
type staticUpstream struct {
from string
upstreamHeaders http.Header
downstreamHeaders http.Header
Hosts HostPool
Policy Policy
insecureSkipVerify bool
FailTimeout time.Duration
MaxFails int32
MaxConns int64
HealthCheck struct {
Path string
Interval time.Duration
}
WithoutPathPrefix string
IgnoredSubPaths []string
}
// NewStaticUpstreams parses the configuration input and sets up
// static upstreams for the proxy middleware.
func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) {
var upstreams []Upstream
for c.Next() {
upstream := &staticUpstream{
from: "",
upstreamHeaders: make(http.Header),
downstreamHeaders: make(http.Header),
Hosts: nil,
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
MaxConns: 0,
}
if !c.Args(&upstream.from) {
return upstreams, c.ArgErr()
}
var to []string
for _, t := range c.RemainingArgs() {
parsed, err := parseUpstream(t)
if err != nil {
return upstreams, err
}
to = append(to, parsed...)
}
for c.NextBlock() {
switch c.Val() {
case "upstream":
if !c.NextArg() {
return upstreams, c.ArgErr()
}
parsed, err := parseUpstream(c.Val())
if err != nil {
return upstreams, err
}
to = append(to, parsed...)
default:
if err := parseBlock(&c, upstream); err != nil {
return upstreams, err
}
}
}
if len(to) == 0 {
return upstreams, c.ArgErr()
}
upstream.Hosts = make([]*UpstreamHost, len(to))
for i, host := range to {
uh, err := upstream.NewHost(host)
if err != nil {
return upstreams, err
}
upstream.Hosts[i] = uh
}
if upstream.HealthCheck.Path != "" {
go upstream.HealthCheckWorker(nil)
}
upstreams = append(upstreams, upstream)
}
return upstreams, nil
}
// RegisterPolicy adds a custom policy to the proxy.
func RegisterPolicy(name string, policy func() Policy) {
supportedPolicies[name] = policy
}
func (u *staticUpstream) From() string {
return u.from
}
func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
if !strings.HasPrefix(host, "http") &&
!strings.HasPrefix(host, "unix:") {
host = "http://" + host
}
uh := &UpstreamHost{
Name: host,
Conns: 0,
Fails: 0,
FailTimeout: u.FailTimeout,
Unhealthy: false,
UpstreamHeaders: u.upstreamHeaders,
DownstreamHeaders: u.downstreamHeaders,
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool {
if uh.Unhealthy {
return true
}
if uh.Fails >= u.MaxFails &&
u.MaxFails != 0 {
return true
}
return false
}
}(u),
WithoutPathPrefix: u.WithoutPathPrefix,
MaxConns: u.MaxConns,
}
baseURL, err := url.Parse(uh.Name)
if err != nil {
return nil, err
}
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix)
if u.insecureSkipVerify {
uh.ReverseProxy.Transport = InsecureTransport
}
return uh, nil
}
func parseUpstream(u string) ([]string, error) {
if !strings.HasPrefix(u, "unix:") {
colonIdx := strings.LastIndex(u, ":")
protoIdx := strings.Index(u, "://")
if colonIdx != -1 && colonIdx != protoIdx {
us := u[:colonIdx]
ports := u[len(us)+1:]
if separators := strings.Count(ports, "-"); separators > 1 {
return nil, fmt.Errorf("port range [%s] is invalid", ports)
} else if separators == 1 {
portsStr := strings.Split(ports, "-")
pIni, err := strconv.Atoi(portsStr[0])
if err != nil {
return nil, err
}
pEnd, err := strconv.Atoi(portsStr[1])
if err != nil {
return nil, err
}
if pEnd <= pIni {
return nil, fmt.Errorf("port range [%s] is invalid", ports)
}
hosts := []string{}
for p := pIni; p <= pEnd; p++ {
hosts = append(hosts, fmt.Sprintf("%s:%d", us, p))
}
return hosts, nil
}
}
}
return []string{u}, nil
}
func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
switch c.Val() {
case "policy":
if !c.NextArg() {
return c.ArgErr()
}
policyCreateFunc, ok := supportedPolicies[c.Val()]
if !ok {
return c.ArgErr()
}
u.Policy = policyCreateFunc()
case "fail_timeout":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
u.FailTimeout = dur
case "max_fails":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.Atoi(c.Val())
if err != nil {
return err
}
u.MaxFails = int32(n)
case "max_conns":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.ParseInt(c.Val(), 10, 64)
if err != nil {
return err
}
u.MaxConns = n
case "health_check":
if !c.NextArg() {
return c.ArgErr()
}
u.HealthCheck.Path = c.Val()
u.HealthCheck.Interval = 30 * time.Second
if c.NextArg() {
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
u.HealthCheck.Interval = dur
}
case "header_upstream":
fallthrough
case "proxy_header":
var header, value string
if !c.Args(&header, &value) {
return c.ArgErr()
}
u.upstreamHeaders.Add(header, value)
case "header_downstream":
var header, value string
if !c.Args(&header, &value) {
return c.ArgErr()
}
u.downstreamHeaders.Add(header, value)
case "websocket":
u.upstreamHeaders.Add("Connection", "{>Connection}")
u.upstreamHeaders.Add("Upgrade", "{>Upgrade}")
case "without":
if !c.NextArg() {
return c.ArgErr()
}
u.WithoutPathPrefix = c.Val()
case "except":
ignoredPaths := c.RemainingArgs()
if len(ignoredPaths) == 0 {
return c.ArgErr()
}
u.IgnoredSubPaths = ignoredPaths
case "insecure_skip_verify":
u.insecureSkipVerify = true
default:
return c.Errf("unknown property '%s'", c.Val())
}
return nil
}
func (u *staticUpstream) healthCheck() {
for _, host := range u.Hosts {
hostURL := host.Name + u.HealthCheck.Path
if r, err := http.Get(hostURL); err == nil {
io.Copy(ioutil.Discard, r.Body)
r.Body.Close()
host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
} else {
host.Unhealthy = true
}
}
}
func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
ticker := time.NewTicker(u.HealthCheck.Interval)
u.healthCheck()
for {
select {
case <-ticker.C:
u.healthCheck()
case <-stop:
// TODO: the library should provide a stop channel and global
// waitgroup to allow goroutines started by plugins a chance
// to clean themselves up.
}
}
}
func (u *staticUpstream) Select() *UpstreamHost {
pool := u.Hosts
if len(pool) == 1 {
if !pool[0].Available() {
return nil
}
return pool[0]
}
allUnavailable := true
for _, host := range pool {
if host.Available() {
allUnavailable = false
break
}
}
if allUnavailable {
return nil
}
if u.Policy == nil {
return (&Random{}).Select(pool)
}
return u.Policy.Select(pool)
}
func (u *staticUpstream) AllowedPath(requestPath string) bool {
for _, ignoredSubPath := range u.IgnoredSubPaths {
if httpserver.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
return false
}
}
return true
}
package proxy
import (
"testing"
"time"
)
func TestNewHost(t *testing.T) {
upstream := &staticUpstream{
FailTimeout: 10 * time.Second,
MaxConns: 1,
MaxFails: 1,
}
uh, err := upstream.NewHost("example.com")
if err != nil {
t.Error("Expected no error")
}
if uh.Name != "http://example.com" {
t.Error("Expected default schema to be added to Name.")
}
if uh.FailTimeout != upstream.FailTimeout {
t.Error("Expected default FailTimeout to be set.")
}
if uh.MaxConns != upstream.MaxConns {
t.Error("Expected default MaxConns to be set.")
}
if uh.CheckDown == nil {
t.Error("Expected default CheckDown to be set.")
}
if uh.CheckDown(uh) {
t.Error("Expected new host not to be down.")
}
// mark Unhealthy
uh.Unhealthy = true
if !uh.CheckDown(uh) {
t.Error("Expected unhealthy host to be down.")
}
// mark with Fails
uh.Unhealthy = false
uh.Fails = 1
if !uh.CheckDown(uh) {
t.Error("Expected failed host to be down.")
}
}
func TestHealthCheck(t *testing.T) {
upstream := &staticUpstream{
from: "",
Hosts: testPool(),
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
upstream.healthCheck()
if upstream.Hosts[0].Down() {
t.Error("Expected first host in testpool to not fail healthcheck.")
}
if !upstream.Hosts[1].Down() {
t.Error("Expected second host in testpool to fail healthcheck.")
}
}
func TestSelect(t *testing.T) {
upstream := &staticUpstream{
from: "",
Hosts: testPool()[:3],
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
upstream.Hosts[0].Unhealthy = true
upstream.Hosts[1].Unhealthy = true
upstream.Hosts[2].Unhealthy = true
if h := upstream.Select(); h != nil {
t.Error("Expected select to return nil as all host are down")
}
upstream.Hosts[2].Unhealthy = false
if h := upstream.Select(); h == nil {
t.Error("Expected select to not return nil")
}
upstream.Hosts[0].Conns = 1
upstream.Hosts[0].MaxConns = 1
upstream.Hosts[1].Conns = 1
upstream.Hosts[1].MaxConns = 1
upstream.Hosts[2].Conns = 1
upstream.Hosts[2].MaxConns = 1
if h := upstream.Select(); h != nil {
t.Error("Expected select to return nil as all hosts are full")
}
upstream.Hosts[2].Conns = 0
if h := upstream.Select(); h == nil {
t.Error("Expected select to not return nil")
}
}
func TestRegisterPolicy(t *testing.T) {
name := "custom"
customPolicy := &customPolicy{}
RegisterPolicy(name, func() Policy { return customPolicy })
if _, ok := supportedPolicies[name]; !ok {
t.Error("Expected supportedPolicies to have a custom policy.")
}
}
func TestAllowedPaths(t *testing.T) {
upstream := &staticUpstream{
from: "/proxy",
IgnoredSubPaths: []string{"/download", "/static"},
}
tests := []struct {
url string
expected bool
}{
{"/proxy", true},
{"/proxy/dl", true},
{"/proxy/download", false},
{"/proxy/download/static", false},
{"/proxy/static", false},
{"/proxy/static/download", false},
{"/proxy/something/download", true},
{"/proxy/something/static", true},
{"/proxy//static", false},
{"/proxy//static//download", false},
{"/proxy//download", false},
}
for i, test := range tests {
allowed := upstream.AllowedPath(test.url)
if test.expected != allowed {
t.Errorf("Test %d: expected %v found %v", i+1, test.expected, allowed)
}
}
}
// Package redirect is middleware for redirecting certain requests
// to other locations.
package redirect
import (
"fmt"
"html"
"net/http"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// Redirect is middleware to respond with HTTP redirects
type Redirect struct {
Next httpserver.Handler
Rules []Rule
}
// ServeHTTP implements the httpserver.Handler interface.
func (rd Redirect) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range rd.Rules {
if (rule.FromPath == "/" || r.URL.Path == rule.FromPath) && schemeMatches(rule, r) {
to := httpserver.NewReplacer(r, nil, "").Replace(rule.To)
if rule.Meta {
safeTo := html.EscapeString(to)
fmt.Fprintf(w, metaRedir, safeTo, safeTo)
} else {
http.Redirect(w, r, to, rule.Code)
}
return 0, nil
}
}
return rd.Next.ServeHTTP(w, r)
}
func schemeMatches(rule Rule, req *http.Request) bool {
return (rule.FromScheme == "https" && req.TLS != nil) ||
(rule.FromScheme != "https" && req.TLS == nil)
}
// Rule describes an HTTP redirect rule.
type Rule struct {
FromScheme, FromPath, To string
Code int
Meta bool
}
// Script tag comes first since that will better imitate a redirect in the browser's
// history, but the meta tag is a fallback for most non-JS clients.
const metaRedir = `<!DOCTYPE html>
<html>
<head>
<script>window.location.replace("%s");</script>
<meta http-equiv="refresh" content="0; URL='%s'">
</head>
<body>Redirecting...</body>
</html>
`
package redirect
import (
"bytes"
"crypto/tls"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestRedirect(t *testing.T) {
for i, test := range []struct {
from string
expectedLocation string
expectedCode int
}{
{"http://localhost/from", "/to", http.StatusMovedPermanently},
{"http://localhost/a", "/b", http.StatusTemporaryRedirect},
{"http://localhost/aa", "", http.StatusOK},
{"http://localhost/", "", http.StatusOK},
{"http://localhost/a?foo=bar", "/b", http.StatusTemporaryRedirect},
{"http://localhost/asdf?foo=bar", "", http.StatusOK},
{"http://localhost/foo#bar", "", http.StatusOK},
{"http://localhost/a#foo", "/b", http.StatusTemporaryRedirect},
// The scheme checks that were added to this package don't actually
// help with redirects because of Caddy's design: a redirect middleware
// for http will always be different than the redirect middleware for
// https because they have to be on different listeners. These tests
// just go to show extra bulletproofing, I guess.
{"http://localhost/scheme", "https://localhost/scheme", http.StatusMovedPermanently},
{"https://localhost/scheme", "", http.StatusOK},
{"https://localhost/scheme2", "http://localhost/scheme2", http.StatusMovedPermanently},
{"http://localhost/scheme2", "", http.StatusOK},
{"http://localhost/scheme3", "https://localhost/scheme3", http.StatusMovedPermanently},
{"https://localhost/scheme3", "", http.StatusOK},
} {
var nextCalled bool
re := Redirect{
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
nextCalled = true
return 0, nil
}),
Rules: []Rule{
{FromPath: "/from", To: "/to", Code: http.StatusMovedPermanently},
{FromPath: "/a", To: "/b", Code: http.StatusTemporaryRedirect},
// These http and https schemes would never actually be mixed in the same
// redirect rule with Caddy because http and https schemes have different listeners,
// so they don't share a redirect rule. So although these tests prove something
// impossible with Caddy, it's extra bulletproofing at very little cost.
{FromScheme: "http", FromPath: "/scheme", To: "https://localhost/scheme", Code: http.StatusMovedPermanently},
{FromScheme: "https", FromPath: "/scheme2", To: "http://localhost/scheme2", Code: http.StatusMovedPermanently},
{FromScheme: "", FromPath: "/scheme3", To: "https://localhost/scheme3", Code: http.StatusMovedPermanently},
},
}
req, err := http.NewRequest("GET", test.from, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
if strings.HasPrefix(test.from, "https://") {
req.TLS = new(tls.ConnectionState) // faux HTTPS
}
rec := httptest.NewRecorder()
re.ServeHTTP(rec, req)
if rec.Header().Get("Location") != test.expectedLocation {
t.Errorf("Test %d: Expected Location header to be %q but was %q",
i, test.expectedLocation, rec.Header().Get("Location"))
}
if rec.Code != test.expectedCode {
t.Errorf("Test %d: Expected status code to be %d but was %d",
i, test.expectedCode, rec.Code)
}
if nextCalled && test.expectedLocation != "" {
t.Errorf("Test %d: Next handler was unexpectedly called", i)
}
}
}
func TestParametersRedirect(t *testing.T) {
re := Redirect{
Rules: []Rule{
{FromPath: "/", Meta: false, To: "http://example.com{uri}"},
},
}
req, err := http.NewRequest("GET", "/a?b=c", nil)
if err != nil {
t.Fatalf("Test: Could not create HTTP request: %v", err)
}
rec := httptest.NewRecorder()
re.ServeHTTP(rec, req)
if rec.Header().Get("Location") != "http://example.com/a?b=c" {
t.Fatalf("Test: expected location header %q but was %q", "http://example.com/a?b=c", rec.Header().Get("Location"))
}
re = Redirect{
Rules: []Rule{
{FromPath: "/", Meta: false, To: "http://example.com/a{path}?b=c&{query}"},
},
}
req, err = http.NewRequest("GET", "/d?e=f", nil)
if err != nil {
t.Fatalf("Test: Could not create HTTP request: %v", err)
}
re.ServeHTTP(rec, req)
if "http://example.com/a/d?b=c&e=f" != rec.Header().Get("Location") {
t.Fatalf("Test: expected location header %q but was %q", "http://example.com/a/d?b=c&e=f", rec.Header().Get("Location"))
}
}
func TestMetaRedirect(t *testing.T) {
re := Redirect{
Rules: []Rule{
{FromPath: "/whatever", Meta: true, To: "/something"},
{FromPath: "/", Meta: true, To: "https://example.com/"},
},
}
for i, test := range re.Rules {
req, err := http.NewRequest("GET", test.FromPath, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
rec := httptest.NewRecorder()
re.ServeHTTP(rec, req)
body, err := ioutil.ReadAll(rec.Body)
if err != nil {
t.Fatalf("Test %d: Could not read HTTP response body: %v", i, err)
}
expectedSnippet := `<meta http-equiv="refresh" content="0; URL='` + test.To + `'">`
if !bytes.Contains(body, []byte(expectedSnippet)) {
t.Errorf("Test %d: Expected Response Body to contain %q but was %q",
i, expectedSnippet, body)
}
}
}
package redirect
import (
"net/http"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func init() {
caddy.RegisterPlugin(caddy.Plugin{
Name: "redir",
ServerType: "http",
Action: setup,
})
}
// setup configures a new Redirect middleware instance.
func setup(c *caddy.Controller) error {
rules, err := redirParse(c)
if err != nil {
return err
}
httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Redirect{Next: next, Rules: rules}
})
return nil
}
func redirParse(c *caddy.Controller) ([]Rule, error) {
var redirects []Rule
cfg := httpserver.GetConfig(c.Key)
// setRedirCode sets the redirect code for rule if it can, or returns an error
setRedirCode := func(code string, rule *Rule) error {
if code == "meta" {
rule.Meta = true
} else if codeNumber, ok := httpRedirs[code]; ok {
rule.Code = codeNumber
} else {
return c.Errf("Invalid redirect code '%v'", code)
}
return nil
}
// checkAndSaveRule checks the rule for validity (except the redir code)
// and saves it if it's valid, or returns an error.
checkAndSaveRule := func(rule Rule) error {
if rule.FromPath == rule.To {
return c.Err("'from' and 'to' values of redirect rule cannot be the same")
}
for _, otherRule := range redirects {
if otherRule.FromPath == rule.FromPath {
return c.Errf("rule with duplicate 'from' value: %s -> %s", otherRule.FromPath, otherRule.To)
}
}
redirects = append(redirects, rule)
return nil
}
for c.Next() {
args := c.RemainingArgs()
var hadOptionalBlock bool
for c.NextBlock() {
hadOptionalBlock = true
var rule Rule
if cfg.TLS.Enabled {
rule.FromScheme = "https"
} else {
rule.FromScheme = "http"
}
// Set initial redirect code
// BUG: If the code is specified for a whole block and that code is invalid,
// the line number will appear on the first line inside the block, even if that
// line overwrites the block-level code with a valid redirect code. The program
// still functions correctly, but the line number in the error reporting is
// misleading to the user.
if len(args) == 1 {
err := setRedirCode(args[0], &rule)
if err != nil {
return redirects, err
}
} else {
rule.Code = http.StatusMovedPermanently // default code
}
// RemainingArgs only gets the values after the current token, but in our
// case we want to include the current token to get an accurate count.
insideArgs := append([]string{c.Val()}, c.RemainingArgs()...)
switch len(insideArgs) {
case 1:
// To specified (catch-all redirect)
// Not sure why user is doing this in a table, as it causes all other redirects to be ignored.
// As such, this feature remains undocumented.
rule.FromPath = "/"
rule.To = insideArgs[0]
case 2:
// From and To specified
rule.FromPath = insideArgs[0]
rule.To = insideArgs[1]
case 3:
// From, To, and Code specified
rule.FromPath = insideArgs[0]
rule.To = insideArgs[1]
err := setRedirCode(insideArgs[2], &rule)
if err != nil {
return redirects, err
}
default:
return redirects, c.ArgErr()
}
err := checkAndSaveRule(rule)
if err != nil {
return redirects, err
}
}
if !hadOptionalBlock {
var rule Rule
if cfg.TLS.Enabled {
rule.FromScheme = "https"
} else {
rule.FromScheme = "http"
}
rule.Code = http.StatusMovedPermanently // default
switch len(args) {
case 1:
// To specified (catch-all redirect)
rule.FromPath = "/"
rule.To = args[0]
case 2:
// To and Code specified (catch-all redirect)
rule.FromPath = "/"
rule.To = args[0]
err := setRedirCode(args[1], &rule)
if err != nil {
return redirects, err
}
case 3:
// From, To, and Code specified
rule.FromPath = args[0]
rule.To = args[1]
err := setRedirCode(args[2], &rule)
if err != nil {
return redirects, err
}
default:
return redirects, c.ArgErr()
}
err := checkAndSaveRule(rule)
if err != nil {
return redirects, err
}
}
}
return redirects, nil
}
// httpRedirs is a list of supported HTTP redirect codes.
var httpRedirs = map[string]int{
"300": http.StatusMultipleChoices,
"301": http.StatusMovedPermanently,
"302": http.StatusFound, // (NOT CORRECT for "Temporary Redirect", see 307)
"303": http.StatusSeeOther,
"304": http.StatusNotModified,
"305": http.StatusUseProxy,
"307": http.StatusTemporaryRedirect,
"308": 308, // Permanent Redirect
}
package redirect
import (
"testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestSetup(t *testing.T) {
for j, test := range []struct {
input string
shouldErr bool
expectedRules []Rule
}{
// test case #0 tests the recognition of a valid HTTP status code defined outside of block statement
{"redir 300 {\n/ /foo\n}", false, []Rule{{FromPath: "/", To: "/foo", Code: 300}}},
// test case #1 tests the recognition of an invalid HTTP status code defined outside of block statement
{"redir 9000 {\n/ /foo\n}", true, []Rule{{}}},
// test case #2 tests the detection of a valid HTTP status code outside of a block statement being overriden by an invalid HTTP status code inside statement of a block statement
{"redir 300 {\n/ /foo 9000\n}", true, []Rule{{}}},
// test case #3 tests the detection of an invalid HTTP status code outside of a block statement being overriden by a valid HTTP status code inside statement of a block statement
{"redir 9000 {\n/ /foo 300\n}", true, []Rule{{}}},
// test case #4 tests the recognition of a TO redirection in a block statement.The HTTP status code is set to the default of 301 - MovedPermanently
{"redir 302 {\n/foo\n}", false, []Rule{{FromPath: "/", To: "/foo", Code: 302}}},
// test case #5 tests the recognition of a TO and From redirection in a block statement
{"redir {\n/bar /foo 303\n}", false, []Rule{{FromPath: "/bar", To: "/foo", Code: 303}}},
// test case #6 tests the recognition of a TO redirection in a non-block statement. The HTTP status code is set to the default of 301 - MovedPermanently
{"redir /foo", false, []Rule{{FromPath: "/", To: "/foo", Code: 301}}},
// test case #7 tests the recognition of a TO and From redirection in a non-block statement
{"redir /bar /foo 303", false, []Rule{{FromPath: "/bar", To: "/foo", Code: 303}}},
// test case #8 tests the recognition of multiple redirections
{"redir {\n / /foo 304 \n} \n redir {\n /bar /foobar 305 \n}", false, []Rule{{FromPath: "/", To: "/foo", Code: 304}, {FromPath: "/bar", To: "/foobar", Code: 305}}},
// test case #9 tests the detection of duplicate redirections
{"redir {\n /bar /foo 304 \n} redir {\n /bar /foo 304 \n}", true, []Rule{{}}},
} {
err := setup(caddy.NewTestController(test.input))
if err != nil && !test.shouldErr {
t.Errorf("Test case #%d recieved an error of %v", j, err)
} else if test.shouldErr {
continue
}
mids := httpserver.GetConfig("").Middleware()
recievedRules := mids[len(mids)-1](nil).(Redirect).Rules
for i, recievedRule := range recievedRules {
if recievedRule.FromPath != test.expectedRules[i].FromPath {
t.Errorf("Test case #%d.%d expected a from path of %s, but recieved a from path of %s", j, i, test.expectedRules[i].FromPath, recievedRule.FromPath)
}
if recievedRule.To != test.expectedRules[i].To {
t.Errorf("Test case #%d.%d expected a TO path of %s, but recieved a TO path of %s", j, i, test.expectedRules[i].To, recievedRule.To)
}
if recievedRule.Code != test.expectedRules[i].Code {
t.Errorf("Test case #%d.%d expected a HTTP status code of %d, but recieved a code of %d", j, i, test.expectedRules[i].Code, recievedRule.Code)
}
}
}
}
package rewrite
import (
"fmt"
"net/http"
"regexp"
"strings"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// Operators
const (
Is = "is"
Not = "not"
Has = "has"
NotHas = "not_has"
StartsWith = "starts_with"
EndsWith = "ends_with"
Match = "match"
NotMatch = "not_match"
)
func operatorError(operator string) error {
return fmt.Errorf("Invalid operator %v", operator)
}
func newReplacer(r *http.Request) httpserver.Replacer {
return httpserver.NewReplacer(r, nil, "")
}
// condition is a rewrite condition.
type condition func(string, string) bool
var conditions = map[string]condition{
Is: isFunc,
Not: notFunc,
Has: hasFunc,
NotHas: notHasFunc,
StartsWith: startsWithFunc,
EndsWith: endsWithFunc,
Match: matchFunc,
NotMatch: notMatchFunc,
}
// isFunc is condition for Is operator.
// It checks for equality.
func isFunc(a, b string) bool {
return a == b
}
// notFunc is condition for Not operator.
// It checks for inequality.
func notFunc(a, b string) bool {
return a != b
}
// hasFunc is condition for Has operator.
// It checks if b is a substring of a.
func hasFunc(a, b string) bool {
return strings.Contains(a, b)
}
// notHasFunc is condition for NotHas operator.
// It checks if b is not a substring of a.
func notHasFunc(a, b string) bool {
return !strings.Contains(a, b)
}
// startsWithFunc is condition for StartsWith operator.
// It checks if b is a prefix of a.
func startsWithFunc(a, b string) bool {
return strings.HasPrefix(a, b)
}
// endsWithFunc is condition for EndsWith operator.
// It checks if b is a suffix of a.
func endsWithFunc(a, b string) bool {
return strings.HasSuffix(a, b)
}
// matchFunc is condition for Match operator.
// It does regexp matching of a against pattern in b
// and returns if they match.
func matchFunc(a, b string) bool {
matched, _ := regexp.MatchString(b, a)
return matched
}
// notMatchFunc is condition for NotMatch operator.
// It does regexp matching of a against pattern in b
// and returns if they do not match.
func notMatchFunc(a, b string) bool {
matched, _ := regexp.MatchString(b, a)
return !matched
}
// If is statement for a rewrite condition.
type If struct {
A string
Operator string
B string
}
// True returns true if the condition is true and false otherwise.
// If r is not nil, it replaces placeholders before comparison.
func (i If) True(r *http.Request) bool {
if c, ok := conditions[i.Operator]; ok {
a, b := i.A, i.B
if r != nil {
replacer := newReplacer(r)
a = replacer.Replace(i.A)
b = replacer.Replace(i.B)
}
return c(a, b)
}
return false
}
// NewIf creates a new If condition.
func NewIf(a, operator, b string) (If, error) {
if _, ok := conditions[operator]; !ok {
return If{}, operatorError(operator)
}
return If{
A: a,
Operator: operator,
B: b,
}, nil
}
package rewrite
import (
"net/http"
"strings"
"testing"
)
func TestConditions(t *testing.T) {
tests := []struct {
condition string
isTrue bool
}{
{"a is b", false},
{"a is a", true},
{"a not b", true},
{"a not a", false},
{"a has a", true},
{"a has b", false},
{"ba has b", true},
{"bab has b", true},
{"bab has bb", false},
{"a not_has a", false},
{"a not_has b", true},
{"ba not_has b", false},
{"bab not_has b", false},
{"bab not_has bb", true},
{"bab starts_with bb", false},
{"bab starts_with ba", true},
{"bab starts_with bab", true},
{"bab ends_with bb", false},
{"bab ends_with bab", true},
{"bab ends_with ab", true},
{"a match *", false},
{"a match a", true},
{"a match .*", true},
{"a match a.*", true},
{"a match b.*", false},
{"ba match b.*", true},
{"ba match b[a-z]", true},
{"b0 match b[a-z]", false},
{"b0a match b[a-z]", false},
{"b0a match b[a-z]+", false},
{"b0a match b[a-z0-9]+", true},
{"a not_match *", true},
{"a not_match a", false},
{"a not_match .*", false},
{"a not_match a.*", false},
{"a not_match b.*", true},
{"ba not_match b.*", false},
{"ba not_match b[a-z]", false},
{"b0 not_match b[a-z]", true},
{"b0a not_match b[a-z]", true},
{"b0a not_match b[a-z]+", true},
{"b0a not_match b[a-z0-9]+", false},
}
for i, test := range tests {
str := strings.Fields(test.condition)
ifCond, err := NewIf(str[0], str[1], str[2])
if err != nil {
t.Error(err)
}
isTrue := ifCond.True(nil)
if isTrue != test.isTrue {
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
}
}
invalidOperators := []string{"ss", "and", "if"}
for _, op := range invalidOperators {
_, err := NewIf("a", op, "b")
if err == nil {
t.Errorf("Invalid operator %v used, expected error.", op)
}
}
replaceTests := []struct {
url string
condition string
isTrue bool
}{
{"/home", "{uri} match /home", true},
{"/hom", "{uri} match /home", false},
{"/hom", "{uri} starts_with /home", false},
{"/hom", "{uri} starts_with /h", true},
{"/home/.hiddenfile", `{uri} match \/\.(.*)`, true},
{"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true},
}
for i, test := range replaceTests {
r, err := http.NewRequest("GET", test.url, nil)
if err != nil {
t.Error(err)
}
str := strings.Fields(test.condition)
ifCond, err := NewIf(str[0], str[1], str[2])
if err != nil {
t.Error(err)
}
isTrue := ifCond.True(r)
if isTrue != test.isTrue {
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
}
}
}
// Package rewrite is middleware for rewriting requests internally to
// a different path.
package rewrite
import (
"fmt"
"net/http"
"net/url"
"path"
"path/filepath"
"regexp"
"strings"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// Result is the result of a rewrite
type Result int
const (
// RewriteIgnored is returned when rewrite is not done on request.
RewriteIgnored Result = iota
// RewriteDone is returned when rewrite is done on request.
RewriteDone
// RewriteStatus is returned when rewrite is not needed and status code should be set
// for the request.
RewriteStatus
)
// Rewrite is middleware to rewrite request locations internally before being handled.
type Rewrite struct {
Next httpserver.Handler
FileSys http.FileSystem
Rules []Rule
}
// ServeHTTP implements the httpserver.Handler interface.
func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
outer:
for _, rule := range rw.Rules {
switch result := rule.Rewrite(rw.FileSys, r); result {
case RewriteDone:
break outer
case RewriteIgnored:
break
case RewriteStatus:
// only valid for complex rules.
if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 {
return cRule.Status, nil
}
}
}
return rw.Next.ServeHTTP(w, r)
}
// Rule describes an internal location rewrite rule.
type Rule interface {
// Rewrite rewrites the internal location of the current request.
Rewrite(http.FileSystem, *http.Request) Result
}
// SimpleRule is a simple rewrite rule.
type SimpleRule struct {
From, To string
}
// NewSimpleRule creates a new Simple Rule
func NewSimpleRule(from, to string) SimpleRule {
return SimpleRule{from, to}
}
// Rewrite rewrites the internal location of the current request.
func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
if s.From == r.URL.Path {
// take note of this rewrite for internal use by fastcgi
// all we need is the URI, not full URL
r.Header.Set(headerFieldName, r.URL.RequestURI())
// attempt rewrite
return To(fs, r, s.To, newReplacer(r))
}
return RewriteIgnored
}
// ComplexRule is a rewrite rule based on a regular expression
type ComplexRule struct {
// Path base. Request to this path and subpaths will be rewritten
Base string
// Path to rewrite to
To string
// If set, neither performs rewrite nor proceeds
// with request. Only returns code.
Status int
// Extensions to filter by
Exts []string
// Rewrite conditions
Ifs []If
*regexp.Regexp
}
// NewComplexRule creates a new RegexpRule. It returns an error if regexp
// pattern (pattern) or extensions (ext) are invalid.
func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) {
// validate regexp if present
var r *regexp.Regexp
if pattern != "" {
var err error
r, err = regexp.Compile(pattern)
if err != nil {
return nil, err
}
}
// validate extensions if present
for _, v := range ext {
if len(v) < 2 || (len(v) < 3 && v[0] == '!') {
// check if no extension is specified
if v != "/" && v != "!/" {
return nil, fmt.Errorf("invalid extension %v", v)
}
}
}
return &ComplexRule{
Base: base,
To: to,
Status: status,
Exts: ext,
Ifs: ifs,
Regexp: r,
}, nil
}
// Rewrite rewrites the internal location of the current request.
func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result) {
rPath := req.URL.Path
replacer := newReplacer(req)
// validate base
if !httpserver.Path(rPath).Matches(r.Base) {
return
}
// validate extensions
if !r.matchExt(rPath) {
return
}
// validate regexp if present
if r.Regexp != nil {
// include trailing slash in regexp if present
start := len(r.Base)
if strings.HasSuffix(r.Base, "/") {
start--
}
matches := r.FindStringSubmatch(rPath[start:])
switch len(matches) {
case 0:
// no match
return
default:
// set regexp match variables {1}, {2} ...
// url escaped values of ? and #.
q, f := url.QueryEscape("?"), url.QueryEscape("#")
for i := 1; i < len(matches); i++ {
// Special case of unescaped # and ? by stdlib regexp.
// Reverse the unescape.
if strings.ContainsAny(matches[i], "?#") {
matches[i] = strings.NewReplacer("?", q, "#", f).Replace(matches[i])
}
replacer.Set(fmt.Sprint(i), matches[i])
}
}
}
// validate rewrite conditions
for _, i := range r.Ifs {
if !i.True(req) {
return
}
}
// if status is present, stop rewrite and return it.
if r.Status != 0 {
return RewriteStatus
}
// attempt rewrite
return To(fs, req, r.To, replacer)
}
// matchExt matches rPath against registered file extensions.
// Returns true if a match is found and false otherwise.
func (r *ComplexRule) matchExt(rPath string) bool {
f := filepath.Base(rPath)
ext := path.Ext(f)
if ext == "" {
ext = "/"
}
mustUse := false
for _, v := range r.Exts {
use := true
if v[0] == '!' {
use = false
v = v[1:]
}
if use {
mustUse = true
}
if ext == v {
return use
}
}
if mustUse {
return false
}
return true
}
// When a rewrite is performed, this header is added to the request
// and is for internal use only, specifically the fastcgi middleware.
// It contains the original request URI before the rewrite.
const headerFieldName = "Caddy-Rewrite-Original-URI"
package rewrite
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestRewrite(t *testing.T) {
rw := Rewrite{
Next: httpserver.HandlerFunc(urlPrinter),
Rules: []Rule{
NewSimpleRule("/from", "/to"),
NewSimpleRule("/a", "/b"),
NewSimpleRule("/b", "/b{uri}"),
},
FileSys: http.Dir("."),
}
regexps := [][]string{
{"/reg/", ".*", "/to", ""},
{"/r/", "[a-z]+", "/toaz", "!.html|"},
{"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""},
{"/ab/", "ab", "/ab?{query}", ".txt|"},
{"/ab/", "ab", "/ab?type=html&{query}", ".html|"},
{"/abc/", "ab", "/abc/{file}", ".html|"},
{"/abcd/", "ab", "/a/{dir}/{file}", ".html|"},
{"/abcde/", "ab", "/a#{fragment}", ".html|"},
{"/ab/", `.*\.jpg`, "/ajpg", ""},
{"/reggrp", `/ad/([0-9]+)([a-z]*)`, "/a{1}/{2}", ""},
{"/reg2grp", `(.*)`, "/{1}", ""},
{"/reg3grp", `(.*)/(.*)/(.*)`, "/{1}{2}{3}", ""},
{"/hashtest", "(.*)", "/{1}", ""},
}
for _, regexpRule := range regexps {
var ext []string
if s := strings.Split(regexpRule[3], "|"); len(s) > 1 {
ext = s[:len(s)-1]
}
rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil)
if err != nil {
t.Fatal(err)
}
rw.Rules = append(rw.Rules, rule)
}
tests := []struct {
from string
expectedTo string
}{
{"/from", "/to"},
{"/a", "/b"},
{"/b", "/b/b"},
{"/aa", "/aa"},
{"/", "/"},
{"/a?foo=bar", "/b?foo=bar"},
{"/asdf?foo=bar", "/asdf?foo=bar"},
{"/foo#bar", "/foo#bar"},
{"/a#foo", "/b#foo"},
{"/reg/foo", "/to"},
{"/re", "/re"},
{"/r/", "/r/"},
{"/r/123", "/r/123"},
{"/r/a123", "/toaz"},
{"/r/abcz", "/toaz"},
{"/r/z", "/toaz"},
{"/r/z.html", "/r/z.html"},
{"/r/z.js", "/toaz"},
{"/url/asAB", "/to/url/asAB"},
{"/url/aBsAB", "/url/aBsAB"},
{"/url/a00sAB", "/to/url/a00sAB"},
{"/url/a0z0sAB", "/to/url/a0z0sAB"},
{"/ab/aa", "/ab/aa"},
{"/ab/ab", "/ab/ab"},
{"/ab/ab.txt", "/ab"},
{"/ab/ab.txt?name=name", "/ab?name=name"},
{"/ab/ab.html?name=name", "/ab?type=html&name=name"},
{"/abc/ab.html", "/abc/ab.html"},
{"/abcd/abcd.html", "/a/abcd/abcd.html"},
{"/abcde/abcde.html", "/a"},
{"/abcde/abcde.html#1234", "/a#1234"},
{"/ab/ab.jpg", "/ajpg"},
{"/reggrp/ad/12", "/a12"},
{"/reggrp/ad/124a", "/a124/a"},
{"/reggrp/ad/124abc", "/a124/abc"},
{"/reg2grp/ad/124abc", "/ad/124abc"},
{"/reg3grp/ad/aa/66", "/adaa66"},
{"/reg3grp/ad612/n1n/ab", "/ad612n1nab"},
{"/hashtest/a%20%23%20test", "/a%20%23%20test"},
{"/hashtest/a%20%3F%20test", "/a%20%3F%20test"},
{"/hashtest/a%20%3F%23test", "/a%20%3F%23test"},
}
for i, test := range tests {
req, err := http.NewRequest("GET", test.from, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
rec := httptest.NewRecorder()
rw.ServeHTTP(rec, req)
if rec.Body.String() != test.expectedTo {
t.Errorf("Test %d: Expected URL to be '%s' but was '%s'",
i, test.expectedTo, rec.Body.String())
}
}
statusTests := []struct {
status int
base string
to string
regexp string
statusExpected bool
}{
{400, "/status", "", "", true},
{400, "/ignore", "", "", false},
{400, "/", "", "^/ignore", false},
{400, "/", "", "(.*)", true},
{400, "/status", "", "", true},
}
for i, s := range statusTests {
urlPath := fmt.Sprintf("/status%d", i)
rule, err := NewComplexRule(s.base, s.regexp, s.to, s.status, nil, nil)
if err != nil {
t.Fatalf("Test %d: No error expected for rule but found %v", i, err)
}
rw.Rules = []Rule{rule}
req, err := http.NewRequest("GET", urlPath, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
rec := httptest.NewRecorder()
code, err := rw.ServeHTTP(rec, req)
if err != nil {
t.Fatalf("Test %d: No error expected for handler but found %v", i, err)
}
if s.statusExpected {
if rec.Body.String() != "" {
t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String())
}
if code != s.status {
t.Errorf("Test %d: Expected status code %d found %d", i, s.status, code)
}
} else {
if code != 0 {
t.Errorf("Test %d: Expected no status code found %d", i, code)
}
}
}
}
func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprint(w, r.URL.String())
return 0, nil
}
package rewrite
import (
"net/http"
"strconv"
"strings"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func init() {
caddy.RegisterPlugin(caddy.Plugin{
Name: "rewrite",
ServerType: "http",
Action: setup,
})
}
// setup configures a new Rewrite middleware instance.
func setup(c *caddy.Controller) error {
rewrites, err := rewriteParse(c)
if err != nil {
return err
}
cfg := httpserver.GetConfig(c.Key)
cfg.AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Rewrite{
Next: next,
FileSys: http.Dir(cfg.Root),
Rules: rewrites,
}
})
return nil
}
func rewriteParse(c *caddy.Controller) ([]Rule, error) {
var simpleRules []Rule
var regexpRules []Rule
for c.Next() {
var rule Rule
var err error
var base = "/"
var pattern, to string
var status int
var ext []string
args := c.RemainingArgs()
var ifs []If
switch len(args) {
case 1:
base = args[0]
fallthrough
case 0:
for c.NextBlock() {
switch c.Val() {
case "r", "regexp":
if !c.NextArg() {
return nil, c.ArgErr()
}
pattern = c.Val()
case "to":
args1 := c.RemainingArgs()
if len(args1) == 0 {
return nil, c.ArgErr()
}
to = strings.Join(args1, " ")
case "ext":
args1 := c.RemainingArgs()
if len(args1) == 0 {
return nil, c.ArgErr()
}
ext = args1
case "if":
args1 := c.RemainingArgs()
if len(args1) != 3 {
return nil, c.ArgErr()
}
ifCond, err := NewIf(args1[0], args1[1], args1[2])
if err != nil {
return nil, err
}
ifs = append(ifs, ifCond)
case "status":
if !c.NextArg() {
return nil, c.ArgErr()
}
status, _ = strconv.Atoi(c.Val())
if status < 200 || (status > 299 && status < 400) || status > 499 {
return nil, c.Err("status must be 2xx or 4xx")
}
default:
return nil, c.ArgErr()
}
}
// ensure to or status is specified
if to == "" && status == 0 {
return nil, c.ArgErr()
}
if rule, err = NewComplexRule(base, pattern, to, status, ext, ifs); err != nil {
return nil, err
}
regexpRules = append(regexpRules, rule)
// the only unhandled case is 2 and above
default:
rule = NewSimpleRule(args[0], strings.Join(args[1:], " "))
simpleRules = append(simpleRules, rule)
}
}
// put simple rules in front to avoid regexp computation for them
return append(simpleRules, regexpRules...), nil
}
package rewrite
import (
"fmt"
"regexp"
"testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestSetup(t *testing.T) {
err := setup(caddy.NewTestController(`rewrite /from /to`))
if err != nil {
t.Errorf("Expected no errors, but got: %v", err)
}
mids := httpserver.GetConfig("").Middleware()
if len(mids) == 0 {
t.Fatal("Expected middleware, had 0 instead")
}
handler := mids[0](httpserver.EmptyNext)
myHandler, ok := handler.(Rewrite)
if !ok {
t.Fatalf("Expected handler to be type Rewrite, got: %#v", handler)
}
if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) {
t.Error("'Next' field of handler was not set properly")
}
if len(myHandler.Rules) != 1 {
t.Errorf("Expected handler to have %d rule, has %d instead", 1, len(myHandler.Rules))
}
}
func TestRewriteParse(t *testing.T) {
simpleTests := []struct {
input string
shouldErr bool
expected []Rule
}{
{`rewrite /from /to`, false, []Rule{
SimpleRule{From: "/from", To: "/to"},
}},
{`rewrite /from /to
rewrite a b`, false, []Rule{
SimpleRule{From: "/from", To: "/to"},
SimpleRule{From: "a", To: "b"},
}},
{`rewrite a`, true, []Rule{}},
{`rewrite`, true, []Rule{}},
{`rewrite a b c`, false, []Rule{
SimpleRule{From: "a", To: "b c"},
}},
}
for i, test := range simpleTests {
actual, err := rewriteParse(caddy.NewTestController(test.input))
if err == nil && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if err != nil && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
} else if err != nil && test.shouldErr {
continue
}
if len(actual) != len(test.expected) {
t.Fatalf("Test %d expected %d rules, but got %d",
i, len(test.expected), len(actual))
}
for j, e := range test.expected {
actualRule := actual[j].(SimpleRule)
expectedRule := e.(SimpleRule)
if actualRule.From != expectedRule.From {
t.Errorf("Test %d, rule %d: Expected From=%s, got %s",
i, j, expectedRule.From, actualRule.From)
}
if actualRule.To != expectedRule.To {
t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
i, j, expectedRule.To, actualRule.To)
}
}
}
regexpTests := []struct {
input string
shouldErr bool
expected []Rule
}{
{`rewrite {
r .*
to /to /index.php?
}`, false, []Rule{
&ComplexRule{Base: "/", To: "/to /index.php?", Regexp: regexp.MustCompile(".*")},
}},
{`rewrite {
regexp .*
to /to
ext / html txt
}`, false, []Rule{
&ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")},
}},
{`rewrite /path {
r rr
to /dest
}
rewrite / {
regexp [a-z]+
to /to /to2
}
`, false, []Rule{
&ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")},
&ComplexRule{Base: "/", To: "/to /to2", Regexp: regexp.MustCompile("[a-z]+")},
}},
{`rewrite {
r .*
}`, true, []Rule{
&ComplexRule{},
}},
{`rewrite {
}`, true, []Rule{
&ComplexRule{},
}},
{`rewrite /`, true, []Rule{
&ComplexRule{},
}},
{`rewrite {
to /to
if {path} is a
}`, false, []Rule{
&ComplexRule{Base: "/", To: "/to", Ifs: []If{{A: "{path}", Operator: "is", B: "a"}}},
}},
{`rewrite {
status 500
}`, true, []Rule{
&ComplexRule{},
}},
{`rewrite {
status 400
}`, false, []Rule{
&ComplexRule{Base: "/", Status: 400},
}},
{`rewrite {
to /to
status 400
}`, false, []Rule{
&ComplexRule{Base: "/", To: "/to", Status: 400},
}},
{`rewrite {
status 399
}`, true, []Rule{
&ComplexRule{},
}},
{`rewrite {
status 200
}`, false, []Rule{
&ComplexRule{Base: "/", Status: 200},
}},
{`rewrite {
to /to
status 200
}`, false, []Rule{
&ComplexRule{Base: "/", To: "/to", Status: 200},
}},
{`rewrite {
status 199
}`, true, []Rule{
&ComplexRule{},
}},
{`rewrite {
status 0
}`, true, []Rule{
&ComplexRule{},
}},
{`rewrite {
to /to
status 0
}`, true, []Rule{
&ComplexRule{},
}},
}
for i, test := range regexpTests {
actual, err := rewriteParse(caddy.NewTestController(test.input))
if err == nil && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if err != nil && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
} else if err != nil && test.shouldErr {
continue
}
if len(actual) != len(test.expected) {
t.Fatalf("Test %d expected %d rules, but got %d",
i, len(test.expected), len(actual))
}
for j, e := range test.expected {
actualRule := actual[j].(*ComplexRule)
expectedRule := e.(*ComplexRule)
if actualRule.Base != expectedRule.Base {
t.Errorf("Test %d, rule %d: Expected Base=%s, got %s",
i, j, expectedRule.Base, actualRule.Base)
}
if actualRule.To != expectedRule.To {
t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
i, j, expectedRule.To, actualRule.To)
}
if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) {
t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v",
i, j, expectedRule.To, actualRule.To)
}
if actualRule.Regexp != nil {
if actualRule.String() != expectedRule.String() {
t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
i, j, expectedRule.String(), actualRule.String())
}
}
if fmt.Sprint(actualRule.Ifs) != fmt.Sprint(expectedRule.Ifs) {
t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
i, j, fmt.Sprint(expectedRule.Ifs), fmt.Sprint(actualRule.Ifs))
}
}
}
}
empty
\ No newline at end of file
package rewrite
import (
"log"
"net/http"
"net/url"
"path"
"strings"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// To attempts rewrite. It attempts to rewrite to first valid path
// or the last path if none of the paths are valid.
// Returns true if rewrite is successful and false otherwise.
func To(fs http.FileSystem, r *http.Request, to string, replacer httpserver.Replacer) Result {
tos := strings.Fields(to)
// try each rewrite paths
t := ""
for _, v := range tos {
t = path.Clean(replacer.Replace(v))
// add trailing slash for directories, if present
if strings.HasSuffix(v, "/") && !strings.HasSuffix(t, "/") {
t += "/"
}
// validate file
if isValidFile(fs, t) {
break
}
}
// validate resulting path
u, err := url.Parse(t)
if err != nil {
// Let the user know we got here. Rewrite is expected but
// the resulting url is invalid.
log.Printf("[ERROR] rewrite: resulting path '%v' is invalid. error: %v", t, err)
return RewriteIgnored
}
// take note of this rewrite for internal use by fastcgi
// all we need is the URI, not full URL
r.Header.Set(headerFieldName, r.URL.RequestURI())
// perform rewrite
r.URL.Path = u.Path
if u.RawQuery != "" {
// overwrite query string if present
r.URL.RawQuery = u.RawQuery
}
if u.Fragment != "" {
// overwrite fragment if present
r.URL.Fragment = u.Fragment
}
return RewriteDone
}
// isValidFile checks if file exists on the filesystem.
// if file ends with `/`, it is validated as a directory.
func isValidFile(fs http.FileSystem, file string) bool {
if fs == nil {
return false
}
f, err := fs.Open(file)
if err != nil {
return false
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return false
}
// directory
if strings.HasSuffix(file, "/") {
return stat.IsDir()
}
// file
return !stat.IsDir()
}
package rewrite
import (
"net/http"
"net/url"
"testing"
)
func TestTo(t *testing.T) {
fs := http.Dir("testdata")
tests := []struct {
url string
to string
expected string
}{
{"/", "/somefiles", "/somefiles"},
{"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"},
{"/somefiles", "/testfile /index.php{uri}", "/testfile"},
{"/somefiles", "/testfile/ /index.php{uri}", "/index.php/somefiles"},
{"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"},
{"/?a=b", "/somefiles /index.php?{query}", "/index.php?a=b"},
{"/?a=b", "/testfile /index.php?{query}", "/testfile?a=b"},
{"/?a=b", "/testdir /index.php?{query}", "/index.php?a=b"},
{"/?a=b", "/testdir/ /index.php?{query}", "/testdir/?a=b"},
}
uri := func(r *url.URL) string {
uri := r.Path
if r.RawQuery != "" {
uri += "?" + r.RawQuery
}
return uri
}
for i, test := range tests {
r, err := http.NewRequest("GET", test.url, nil)
if err != nil {
t.Error(err)
}
To(fs, r, test.to, newReplacer(r))
if uri(r.URL) != test.expected {
t.Errorf("Test %v: expected %v found %v", i, test.expected, uri(r.URL))
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment