Commit 0a0d2cc1 authored by ericdreeves's avatar ericdreeves Committed by Matt Holt

Use RequestURI when redirecting to canonical path. (#1331)

* Use RequestURI when redirecting to canonical path.

Caddy may trim a request's URL path when it starts with the path that's
associated with the virtual host. This change uses the path from the request's
RequestURI when performing a redirect.

Fix issue #1327.

* Rename redirurl to redirURL.

* Redirect to the full URL.

The scheme and host from the virtual host's site configuration is used
in order to redirect to the full URL.

* Add comment and remove redundant check.

* Store the original URL path in request context.

By storing the original URL path as a value in the request context,
middlewares can access both it and the sanitized path. The default
default FileServer handler will use the original URL on redirects.

* Replace contextKey type with CtxKey.

In addition to moving the CtxKey definition to the caddy package, this
change updates the CtxKey references in the httpserver, fastcgi, and
basicauth packages.

* httpserver: Fix reference to CtxKey
parent 50749b4e
...@@ -869,3 +869,11 @@ var ( ...@@ -869,3 +869,11 @@ var (
// by default if no other file is specified. // by default if no other file is specified.
DefaultConfigFile = "Caddyfile" DefaultConfigFile = "Caddyfile"
) )
// CtxKey is a value for use with context.WithValue.
type CtxKey string
// URLPathCtxKey is a context key. It can be used in HTTP handlers with
// context.WithValue to access the original request URI that accompanied the
// server request. The associated value will be of type string.
const URLPathCtxKey CtxKey = "url_path"
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
"sync" "sync"
"github.com/jimstudt/http-authentication/basic" "github.com/jimstudt/http-authentication/basic"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
...@@ -65,7 +66,7 @@ func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error ...@@ -65,7 +66,7 @@ func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error
// let upstream middleware (e.g. fastcgi and cgi) know about authenticated // let upstream middleware (e.g. fastcgi and cgi) know about authenticated
// user; this replaces the request with a wrapped instance // user; this replaces the request with a wrapped instance
r = r.WithContext(context.WithValue(r.Context(), r = r.WithContext(context.WithValue(r.Context(),
httpserver.CtxKey("remote_user"), username)) caddy.CtxKey("remote_user"), username))
} }
} }
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
...@@ -18,7 +19,7 @@ func TestBasicAuth(t *testing.T) { ...@@ -18,7 +19,7 @@ func TestBasicAuth(t *testing.T) {
// This handler is registered for tests in which the only authorized user is // This handler is registered for tests in which the only authorized user is
// "okuser" // "okuser"
upstreamHandler := func(w http.ResponseWriter, r *http.Request) (int, error) { upstreamHandler := func(w http.ResponseWriter, r *http.Request) (int, error) {
remoteUser, _ := r.Context().Value(httpserver.CtxKey("remote_user")).(string) remoteUser, _ := r.Context().Value(caddy.CtxKey("remote_user")).(string)
if remoteUser != "okuser" { if remoteUser != "okuser" {
t.Errorf("Test %d: expecting remote user 'okuser', got '%s'", i, remoteUser) t.Errorf("Test %d: expecting remote user 'okuser', got '%s'", i, remoteUser)
} }
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
...@@ -222,7 +223,7 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string] ...@@ -222,7 +223,7 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
// Retrieve name of remote user that was set by some downstream middleware, // Retrieve name of remote user that was set by some downstream middleware,
// possibly basicauth. // possibly basicauth.
remoteUser, _ := r.Context().Value(httpserver.CtxKey("remote_user")).(string) // Blank if not set remoteUser, _ := r.Context().Value(caddy.CtxKey("remote_user")).(string) // Blank if not set
// Some variables are unused but cleared explicitly to prevent // Some variables are unused but cleared explicitly to prevent
// the parent environment from interfering. // the parent environment from interfering.
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"os" "os"
"github.com/mholt/caddy"
"github.com/russross/blackfriday" "github.com/russross/blackfriday"
) )
...@@ -325,10 +326,8 @@ func (c Context) Files(name string) ([]string, error) { ...@@ -325,10 +326,8 @@ func (c Context) Files(name string) ([]string, error) {
// IsMITM returns true if it seems likely that the TLS connection // IsMITM returns true if it seems likely that the TLS connection
// is being intercepted. // is being intercepted.
func (c Context) IsMITM() bool { func (c Context) IsMITM() bool {
if val, ok := c.Req.Context().Value(CtxKey("mitm")).(bool); ok { if val, ok := c.Req.Context().Value(caddy.CtxKey("mitm")).(bool); ok {
return val return val
} }
return false return false
} }
type CtxKey string
...@@ -9,6 +9,8 @@ import ( ...@@ -9,6 +9,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"github.com/mholt/caddy"
) )
// tlsHandler is a http.Handler that will inject a value // tlsHandler is a http.Handler that will inject a value
...@@ -72,7 +74,7 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -72,7 +74,7 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
if checked { if checked {
r = r.WithContext(context.WithValue(r.Context(), CtxKey("mitm"), mitm)) r = r.WithContext(context.WithValue(r.Context(), caddy.CtxKey("mitm"), mitm))
} }
if mitm && h.closeOnMITM { if mitm && h.closeOnMITM {
......
...@@ -7,6 +7,8 @@ import ( ...@@ -7,6 +7,8 @@ import (
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"testing" "testing"
"github.com/mholt/caddy"
) )
func TestParseClientHello(t *testing.T) { func TestParseClientHello(t *testing.T) {
...@@ -285,7 +287,7 @@ func TestHeuristicFunctionsAndHandler(t *testing.T) { ...@@ -285,7 +287,7 @@ func TestHeuristicFunctionsAndHandler(t *testing.T) {
want := ch.interception want := ch.interception
handler := &tlsHandler{ handler := &tlsHandler{
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got, checked = r.Context().Value(CtxKey("mitm")).(bool) got, checked = r.Context().Value(caddy.CtxKey("mitm")).(bool)
}), }),
listener: newTLSListener(nil, nil), listener: newTLSListener(nil, nil),
} }
......
...@@ -13,6 +13,8 @@ import ( ...@@ -13,6 +13,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/mholt/caddy"
) )
// requestReplacer is a strings.Replacer which is used to // requestReplacer is a strings.Replacer which is used to
...@@ -299,7 +301,7 @@ func (r *replacer) getSubstitution(key string) string { ...@@ -299,7 +301,7 @@ func (r *replacer) getSubstitution(key string) string {
} }
return requestReplacer.Replace(r.requestBody.String()) return requestReplacer.Replace(r.requestBody.String())
case "{mitm}": case "{mitm}":
if val, ok := r.request.Context().Value(CtxKey("mitm")).(bool); ok { if val, ok := r.request.Context().Value(caddy.CtxKey("mitm")).(bool); ok {
if val { if val {
return "likely" return "likely"
} else { } else {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"runtime" "runtime"
"strings" "strings"
...@@ -284,6 +285,8 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -284,6 +285,8 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}() }()
w.Header().Set("Server", "Caddy") w.Header().Set("Server", "Caddy")
c := context.WithValue(r.Context(), caddy.URLPathCtxKey, r.URL.Path)
r = r.WithContext(c)
sanitizePath(r) sanitizePath(r)
...@@ -340,6 +343,14 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -340,6 +343,14 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
} }
} }
// URL fields other than Path and RawQuery will be empty for most server
// requests. Hence, the request URL is updated with the scheme and host
// from the virtual host's site address.
if vhostURL, err := url.Parse(vhost.Addr.String()); err == nil {
r.URL.Scheme = vhostURL.Scheme
r.URL.Host = vhostURL.Host
}
// Apply the path-based request body size limit // Apply the path-based request body size limit
// The error returned by MaxBytesReader is meant to be handled // The error returned by MaxBytesReader is meant to be handled
// by whichever middleware/plugin that receives it when calling // by whichever middleware/plugin that receives it when calling
...@@ -398,10 +409,10 @@ func (s *Server) Stop() error { ...@@ -398,10 +409,10 @@ func (s *Server) Stop() error {
return nil return nil
} }
// sanitizePath collapses any ./ ../ /// madness // sanitizePath collapses any ./ ../ /// madness which helps prevent
// which helps prevent path traversal attacks. // path traversal attacks. Note to middleware: use the value within the
// Note to middleware: use URL.RawPath If you need // request's context at key caddy.URLPathContextKey to access the
// the "original" URL.Path value. // "original" URL.Path value.
func sanitizePath(r *http.Request) { func sanitizePath(r *http.Request) {
if r.URL.Path == "/" { if r.URL.Path == "/" {
return return
......
...@@ -3,12 +3,14 @@ package staticfiles ...@@ -3,12 +3,14 @@ package staticfiles
import ( import (
"math/rand" "math/rand"
"net/http" "net/http"
"net/url"
"os" "os"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"github.com/mholt/caddy"
) )
// FileServer implements a production-ready file server // FileServer implements a production-ready file server
...@@ -90,17 +92,34 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request, name stri ...@@ -90,17 +92,34 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request, name stri
} }
// redirect to canonical path // redirect to canonical path
url := r.URL.Path
if d.IsDir() { if d.IsDir() {
// Ensure / at end of directory url // Ensure / at end of directory url. If the original URL path is
if !strings.HasSuffix(url, "/") { // used then ensure / exists as well.
Redirect(w, r, path.Base(url)+"/", http.StatusMovedPermanently) if !strings.HasSuffix(r.URL.Path, "/") {
toURL, _ := url.Parse(r.URL.String())
path, ok := r.Context().Value(caddy.URLPathCtxKey).(string)
if ok && !strings.HasSuffix(path, "/") {
toURL.Path = path
}
toURL.Path += "/"
http.Redirect(w, r, toURL.String(), http.StatusMovedPermanently)
return http.StatusMovedPermanently, nil return http.StatusMovedPermanently, nil
} }
} else { } else {
// Ensure no / at end of file url // Ensure no / at end of file url. If the original URL path is
if strings.HasSuffix(url, "/") { // used then ensure no / exists as well.
Redirect(w, r, "../"+path.Base(url), http.StatusMovedPermanently) if strings.HasSuffix(r.URL.Path, "/") {
toURL, _ := url.Parse(r.URL.String())
path, ok := r.Context().Value(caddy.URLPathCtxKey).(string)
if ok && strings.HasSuffix(path, "/") {
toURL.Path = path
}
toURL.Path = strings.TrimSuffix(toURL.Path, "/")
http.Redirect(w, r, toURL.String(), http.StatusMovedPermanently)
return http.StatusMovedPermanently, nil return http.StatusMovedPermanently, nil
} }
} }
......
package staticfiles package staticfiles
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -10,6 +11,8 @@ import ( ...@@ -10,6 +11,8 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/mholt/caddy"
) )
var ( var (
...@@ -30,6 +33,7 @@ var ( ...@@ -30,6 +33,7 @@ var (
webrootSubBrotliHtml = filepath.Join("webroot", "sub", "brotli.html") webrootSubBrotliHtml = filepath.Join("webroot", "sub", "brotli.html")
webrootSubBrotliHtmlGz = filepath.Join("webroot", "sub", "brotli.html.gz") webrootSubBrotliHtmlGz = filepath.Join("webroot", "sub", "brotli.html.gz")
webrootSubBrotliHtmlBr = filepath.Join("webroot", "sub", "brotli.html.br") webrootSubBrotliHtmlBr = filepath.Join("webroot", "sub", "brotli.html.br")
webrootSubBarDirWithIndexIndexHTML = filepath.Join("webroot", "bar", "dirwithindex", "index.html")
) )
// testFiles is a map with relative paths to test files as keys and file content as values. // testFiles is a map with relative paths to test files as keys and file content as values.
...@@ -55,6 +59,7 @@ var testFiles = map[string]string{ ...@@ -55,6 +59,7 @@ var testFiles = map[string]string{
webrootSubBrotliHtml: "3.brotli.html", webrootSubBrotliHtml: "3.brotli.html",
webrootSubBrotliHtmlGz: "4.brotli.html.gz", webrootSubBrotliHtmlGz: "4.brotli.html.gz",
webrootSubBrotliHtmlBr: "5.brotli.html.br", webrootSubBrotliHtmlBr: "5.brotli.html.br",
webrootSubBarDirWithIndexIndexHTML: "<h1>bar/dirwithindex/index.html</h1>",
} }
// TestServeHTTP covers positive scenarios when serving files. // TestServeHTTP covers positive scenarios when serving files.
...@@ -72,8 +77,9 @@ func TestServeHTTP(t *testing.T) { ...@@ -72,8 +77,9 @@ func TestServeHTTP(t *testing.T) {
tests := []struct { tests := []struct {
url string url string
cleanedPath string
acceptEncoding string acceptEncoding string
expectedLocation string
expectedStatus int expectedStatus int
expectedBodyContent string expectedBodyContent string
expectedEtag string expectedEtag string
...@@ -108,6 +114,7 @@ func TestServeHTTP(t *testing.T) { ...@@ -108,6 +114,7 @@ func TestServeHTTP(t *testing.T) {
{ {
url: "https://foo/dirwithindex", url: "https://foo/dirwithindex",
expectedStatus: http.StatusMovedPermanently, expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/dirwithindex/",
expectedBodyContent: movedPermanently, expectedBodyContent: movedPermanently,
}, },
// Test 5 - access folder without index file // Test 5 - access folder without index file
...@@ -119,12 +126,14 @@ func TestServeHTTP(t *testing.T) { ...@@ -119,12 +126,14 @@ func TestServeHTTP(t *testing.T) {
{ {
url: "https://foo/dir", url: "https://foo/dir",
expectedStatus: http.StatusMovedPermanently, expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/dir/",
expectedBodyContent: movedPermanently, expectedBodyContent: movedPermanently,
}, },
// Test 7 - access file with trailing slash // Test 7 - access file with trailing slash
{ {
url: "https://foo/file1.html/", url: "https://foo/file1.html/",
expectedStatus: http.StatusMovedPermanently, expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/file1.html",
expectedBodyContent: movedPermanently, expectedBodyContent: movedPermanently,
}, },
// Test 8 - access not existing path // Test 8 - access not existing path
...@@ -148,6 +157,7 @@ func TestServeHTTP(t *testing.T) { ...@@ -148,6 +157,7 @@ func TestServeHTTP(t *testing.T) {
{ {
url: "https://foo/dir?param1=val", url: "https://foo/dir?param1=val",
expectedStatus: http.StatusMovedPermanently, expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/dir/?param1=val",
expectedBodyContent: movedPermanently, expectedBodyContent: movedPermanently,
}, },
// Test 12 - attempt to bypass hidden file // Test 12 - attempt to bypass hidden file
...@@ -216,11 +226,39 @@ func TestServeHTTP(t *testing.T) { ...@@ -216,11 +226,39 @@ func TestServeHTTP(t *testing.T) {
url: "https://foo/file1.html/other", url: "https://foo/file1.html/other",
expectedStatus: http.StatusNotFound, expectedStatus: http.StatusNotFound,
}, },
// Test 20 - access folder with index file without trailing slash, with
// cleaned path
{
url: "https://foo/bar/dirwithindex",
cleanedPath: "/dirwithindex",
expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/bar/dirwithindex/",
expectedBodyContent: movedPermanently,
},
// Test 21 - access folder with index file without trailing slash, with
// cleaned path and query params
{
url: "https://foo/bar/dirwithindex?param1=val",
cleanedPath: "/dirwithindex",
expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/bar/dirwithindex/?param1=val",
expectedBodyContent: movedPermanently,
},
// Test 22 - access file with trailing slash with cleaned path
{
url: "https://foo/bar/file1.html/",
cleanedPath: "file1.html/",
expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/bar/file1.html",
expectedBodyContent: movedPermanently,
},
} }
for i, test := range tests { for i, test := range tests {
responseRecorder := httptest.NewRecorder() responseRecorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.url, nil) request, err := http.NewRequest("GET", test.url, nil)
ctx := context.WithValue(request.Context(), caddy.URLPathCtxKey, request.URL.Path)
request = request.WithContext(ctx)
request.Header.Add("Accept-Encoding", test.acceptEncoding) request.Header.Add("Accept-Encoding", test.acceptEncoding)
...@@ -231,6 +269,12 @@ func TestServeHTTP(t *testing.T) { ...@@ -231,6 +269,12 @@ func TestServeHTTP(t *testing.T) {
if u, _ := url.Parse(test.url); u.RawPath != "" { if u, _ := url.Parse(test.url); u.RawPath != "" {
request.URL.Path = u.RawPath request.URL.Path = u.RawPath
} }
// Caddy may trim a request's URL path. Overwrite the path with
// the cleanedPath to test redirects when the path has been
// modified.
if test.cleanedPath != "" {
request.URL.Path = test.cleanedPath
}
status, err := fileserver.ServeHTTP(responseRecorder, request) status, err := fileserver.ServeHTTP(responseRecorder, request)
etag := responseRecorder.Header().Get("Etag") etag := responseRecorder.Header().Get("Etag")
body := responseRecorder.Body.String() body := responseRecorder.Body.String()
...@@ -266,6 +310,13 @@ func TestServeHTTP(t *testing.T) { ...@@ -266,6 +310,13 @@ func TestServeHTTP(t *testing.T) {
if !strings.Contains(body, test.expectedBodyContent) { if !strings.Contains(body, test.expectedBodyContent) {
t.Errorf("Test %d: Expected body to contain %q, found %q", i, test.expectedBodyContent, body) t.Errorf("Test %d: Expected body to contain %q, found %q", i, test.expectedBodyContent, body)
} }
if test.expectedLocation != "" {
l := responseRecorder.Header().Get("Location")
if test.expectedLocation != l {
t.Errorf("Test %d: Expected Location header %q, found %q", i, test.expectedLocation, l)
}
}
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment