Commit a1a8d0f6 authored by Matthew Holt's avatar Matthew Holt

Merge branch 'master' of github.com:mholt/caddy

parents 5d813a1b 04bee0f3
package httpserver
import (
"math/rand"
"path"
"strings"
"time"
)
// CleanMaskedPath prevents one or more of the path cleanup operations:
// - collapse multiple slashes into one
// - eliminate "/." (current directory)
// - eliminate "<parent_directory>/.."
// by masking certain patterns in the path with a temporary random string.
// This could be helpful when certain patterns in the path are desired to be preserved
// that would otherwise be changed by path.Clean().
// One such use case is the presence of the double slashes as protocol separator
// (e.g., /api/endpoint/http://example.com).
// This is a common pattern in many applications to allow passing URIs as path argument.
func CleanMaskedPath(reqPath string, masks ...string) string {
var replacerVal string
maskMap := make(map[string]string)
// Iterate over supplied masks and create temporary replacement strings
// only for the masks that are present in the path, then replace all occurrences
for _, mask := range masks {
if strings.Index(reqPath, mask) >= 0 {
replacerVal = "/_caddy" + generateRandomString() + "__"
maskMap[mask] = replacerVal
reqPath = strings.Replace(reqPath, mask, replacerVal, -1)
}
}
reqPath = path.Clean(reqPath)
// Revert the replaced masks after path cleanup
for mask, replacerVal := range maskMap {
reqPath = strings.Replace(reqPath, replacerVal, mask, -1)
}
return reqPath
}
// CleanPath calls CleanMaskedPath() with the default mask of "://"
// to preserve double slashes of protocols
// such as "http://", "https://", and "ftp://" etc.
func CleanPath(reqPath string) string {
return CleanMaskedPath(reqPath, "://")
}
// An efficient and fast method for random string generation.
// Inspired by http://stackoverflow.com/a/31832326.
const randomStringLength = 4
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const (
letterIdxBits = 6
letterIdxMask = 1<<letterIdxBits - 1
letterIdxMax = 63 / letterIdxBits
)
var src = rand.NewSource(time.Now().UnixNano())
func generateRandomString() string {
b := make([]byte, randomStringLength)
for i, cache, remain := randomStringLength-1, src.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}
return string(b)
}
package httpserver
import (
"path"
"testing"
)
var paths = map[string]map[string]string{
"/../a/b/../././/c": {
"preserve_all": "/../a/b/../././/c",
"preserve_protocol": "/a/c",
"preserve_slashes": "/a//c",
"preserve_dots": "/../a/b/../././c",
"clean_all": "/a/c",
},
"/path/https://www.google.com": {
"preserve_all": "/path/https://www.google.com",
"preserve_protocol": "/path/https://www.google.com",
"preserve_slashes": "/path/https://www.google.com",
"preserve_dots": "/path/https:/www.google.com",
"clean_all": "/path/https:/www.google.com",
},
"/a/b/../././/c/http://example.com/foo//bar/../blah": {
"preserve_all": "/a/b/../././/c/http://example.com/foo//bar/../blah",
"preserve_protocol": "/a/c/http://example.com/foo/blah",
"preserve_slashes": "/a//c/http://example.com/foo/blah",
"preserve_dots": "/a/b/../././c/http:/example.com/foo/bar/../blah",
"clean_all": "/a/c/http:/example.com/foo/blah",
},
}
func assertEqual(t *testing.T, expected, received string) {
if expected != received {
t.Errorf("\tExpected: %s\n\t\t\tRecieved: %s", expected, received)
}
}
func maskedTestRunner(t *testing.T, variation string, masks ...string) {
for reqPath, transformation := range paths {
assertEqual(t, transformation[variation], CleanMaskedPath(reqPath, masks...))
}
}
// No need to test the built-in path.Clean() function.
// However, it could be useful to cross-examine the test dataset.
func TestPathClean(t *testing.T) {
for reqPath, transformation := range paths {
assertEqual(t, transformation["clean_all"], path.Clean(reqPath))
}
}
func TestCleanAll(t *testing.T) {
maskedTestRunner(t, "clean_all")
}
func TestPreserveAll(t *testing.T) {
maskedTestRunner(t, "preserve_all", "//", "/..", "/.")
}
func TestPreserveProtocol(t *testing.T) {
maskedTestRunner(t, "preserve_protocol", "://")
}
func TestPreserveSlashes(t *testing.T) {
maskedTestRunner(t, "preserve_slashes", "//")
}
func TestPreserveDots(t *testing.T) {
maskedTestRunner(t, "preserve_dots", "/..", "/.")
}
func TestDefaultMask(t *testing.T) {
for reqPath, transformation := range paths {
assertEqual(t, transformation["preserve_protocol"], CleanPath(reqPath))
}
}
func maskedBenchmarkRunner(b *testing.B, masks ...string) {
for n := 0; n < b.N; n++ {
for reqPath := range paths {
CleanMaskedPath(reqPath, masks...)
}
}
}
func BenchmarkPathClean(b *testing.B) {
for n := 0; n < b.N; n++ {
for reqPath := range paths {
path.Clean(reqPath)
}
}
}
func BenchmarkCleanAll(b *testing.B) {
maskedBenchmarkRunner(b)
}
func BenchmarkPreserveAll(b *testing.B) {
maskedBenchmarkRunner(b, "//", "/..", "/.")
}
func BenchmarkPreserveProtocol(b *testing.B) {
maskedBenchmarkRunner(b, "://")
}
func BenchmarkPreserveSlashes(b *testing.B) {
maskedBenchmarkRunner(b, "//")
}
func BenchmarkPreserveDots(b *testing.B) {
maskedBenchmarkRunner(b, "/..", "/.")
}
func BenchmarkDefaultMask(b *testing.B) {
for n := 0; n < b.N; n++ {
for reqPath := range paths {
CleanPath(reqPath)
}
}
}
......@@ -9,7 +9,6 @@ import (
"net"
"net/http"
"os"
"path"
"runtime"
"strings"
"sync"
......@@ -351,7 +350,7 @@ func sanitizePath(r *http.Request) {
if r.URL.Path == "/" {
return
}
cleanedPath := path.Clean(r.URL.Path)
cleanedPath := CleanPath(r.URL.Path)
if cleanedPath == "." {
r.URL.Path = "/"
} else {
......
......@@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request {
outreq.URL.Opaque = outreq.URL.RawPath
}
// We are modifying the same underlying map from req (shallow
// copied above) so we only copy it if necessary.
copiedHeaders := false
// Remove hop-by-hop headers listed in the "Connection" header.
// See RFC 2616, section 14.10.
if c := outreq.Header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
if !copiedHeaders {
outreq.Header = make(http.Header)
copyHeader(outreq.Header, r.Header)
copiedHeaders = true
}
outreq.Header.Del(f)
}
}
}
// Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us. This
// is modifying the same underlying map from r (shallow
// copied above) so we only copy it if necessary.
var copiedHeaders bool
// connection, regardless of what the client sent to us.
for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" {
if !copiedHeaders {
......
......@@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
verifyHeaders := func(headers http.Header, trailers http.Header) {
if headers.Get("X-Header") != "header-value" {
t.Error("Expected header 'X-Header' to be proxied properly")
}
if trailers == nil {
t.Error("Expected to receive trailers")
}
if trailers.Get("X-Trailer") != "trailer-value" {
t.Error("Expected header 'X-Trailer' to be proxied properly")
}
}
var requestReceived bool
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// read the body (even if it's empty) to make Go parse trailers
io.Copy(ioutil.Discard, r.Body)
verifyHeaders(r.Header, r.Trailer)
requestReceived = true
w.Header().Set("Trailer", "X-Trailer")
w.Header().Set("X-Header", "header-value")
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, client"))
w.Header().Set("X-Trailer", "trailer-value")
}))
defer backend.Close()
......@@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
r.ContentLength = -1 // force chunked encoding (required for trailers)
r.Header.Set("X-Header", "header-value")
r.Trailer = map[string][]string{
"X-Trailer": {"trailer-value"},
}
p.ServeHTTP(w, r)
if !requestReceived {
t.Error("Expected backend to receive request, but it didn't")
}
res := w.Result()
verifyHeaders(res.Header, res.Trailer)
// Make sure {upstream} placeholder is set
rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
rr.Replacer = httpserver.NewReplacer(r, rr, "-")
......@@ -123,7 +154,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL)
p := newWebSocketTestProxy(wsNop.URL, false)
// Create client request
r := httptest.NewRequest("GET", "/", nil)
......@@ -148,7 +179,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL)
p := newWebSocketTestProxy(wsNop.URL, false)
// Create client request
r := httptest.NewRequest("GET", "/", nil)
......@@ -189,7 +220,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
defer wsEcho.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsEcho.URL)
p := newWebSocketTestProxy(wsEcho.URL, false)
// This is a full end-end test, so the proxy handler
// has to be part of a server listening on a port. Our
......@@ -228,6 +259,52 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
}
}
func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
wsEcho := newTLSServer(websocket.Handler(func(ws *websocket.Conn) {
io.Copy(ws, ws)
}))
defer wsEcho.Close()
p := newWebSocketTestProxy(wsEcho.URL, true)
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
// Set up WebSocket client
url := strings.Replace(echoProxy.URL, "https://", "wss://", 1)
wsCfg, err := websocket.NewConfig(url, echoProxy.URL)
if err != nil {
t.Fatal(err)
}
wsCfg.TlsConfig = &tls.Config{InsecureSkipVerify: true}
ws, err := websocket.DialConfig(wsCfg)
if err != nil {
t.Fatal(err)
}
defer ws.Close()
// Send test message
trialMsg := "Is it working?"
if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil {
t.Fatal(sendErr)
}
// It should be echoed back to us
var actualMsg string
if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil {
t.Fatal(rcvErr)
}
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func TestUnixSocketProxy(t *testing.T) {
if runtime.GOOS == "windows" {
return
......@@ -264,7 +341,7 @@ func TestUnixSocketProxy(t *testing.T) {
defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url)
p := newWebSocketTestProxy(url, false)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
......@@ -982,10 +1059,14 @@ func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.
// redirect to the specified backendAddr. The function
// also sets up the rules/environment for testing WebSocket
// proxy.
func newWebSocketTestProxy(backendAddr string) *Proxy {
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
return &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}},
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{
name: backendAddr,
without: "",
insecure: insecure,
}},
}
}
......@@ -997,8 +1078,9 @@ func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
}
type fakeWsUpstream struct {
name string
without string
name string
without string
insecure bool
}
func (u *fakeWsUpstream) From() string {
......@@ -1007,13 +1089,17 @@ func (u *fakeWsUpstream) From() string {
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
uri, _ := url.Parse(u.name)
return &UpstreamHost{
host := &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
UpstreamHeaders: http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}},
}
if u.insecure {
host.ReverseProxy.UseInsecureTransport()
}
return host
}
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment