Commit d5cc10f7 authored by Toby Allen's avatar Toby Allen Committed by Matt Holt

Added Const for use of CtxKeys (#1511)

* Added Const for CtxKeys

* Move CtxKey Const declarations

* Fixed tests

* fix test
parent 96bfb9f3
...@@ -870,16 +870,5 @@ var ( ...@@ -870,16 +870,5 @@ var (
DefaultConfigFile = "Caddyfile" DefaultConfigFile = "Caddyfile"
) )
// CtxKey is a value for use with context.WithValue. // CtxKey is a value type for use with context.WithValue.
// TODO: Ideally CtxKey and consts will be moved to httpserver package.
// currently blocked by circular import with staticfiles.
type CtxKey string type CtxKey string
// URLPathCtxKey is a context key. It can be used in HTTP handlers with
// context.WithValue to access the original request URI that accompanied the
// server request. The associated value will be of type string.
const URLPathCtxKey CtxKey = "url_path"
// URIxRewriteCtxKey is a context key used to store original unrewritten
// URI in context.WithValue
const URIxRewriteCtxKey CtxKey = "caddy_rewrite_original_uri"
...@@ -19,7 +19,6 @@ import ( ...@@ -19,7 +19,6 @@ import (
"sync" "sync"
"github.com/jimstudt/http-authentication/basic" "github.com/jimstudt/http-authentication/basic"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
...@@ -62,8 +61,10 @@ func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error ...@@ -62,8 +61,10 @@ func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error
// by this point, authentication was successful // by this point, authentication was successful
isAuthenticated = true isAuthenticated = true
// let upstream middleware (e.g. fastcgi and cgi) know about authenticated user // let upstream middleware (e.g. fastcgi and cgi) know about authenticated
r = r.WithContext(context.WithValue(r.Context(), caddy.CtxKey("remote_user"), username)) // user; this replaces the request with a wrapped instance
r = r.WithContext(context.WithValue(r.Context(),
httpserver.RemoteUserCtxKey, username))
} }
} }
......
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
...@@ -19,7 +18,7 @@ func TestBasicAuth(t *testing.T) { ...@@ -19,7 +18,7 @@ func TestBasicAuth(t *testing.T) {
// This handler is registered for tests in which the only authorized user is // This handler is registered for tests in which the only authorized user is
// "okuser" // "okuser"
upstreamHandler := func(w http.ResponseWriter, r *http.Request) (int, error) { upstreamHandler := func(w http.ResponseWriter, r *http.Request) (int, error) {
remoteUser, _ := r.Context().Value(caddy.CtxKey("remote_user")).(string) remoteUser, _ := r.Context().Value(httpserver.RemoteUserCtxKey).(string)
if remoteUser != "okuser" { if remoteUser != "okuser" {
t.Errorf("Test %d: expecting remote user 'okuser', got '%s'", i, remoteUser) t.Errorf("Test %d: expecting remote user 'okuser', got '%s'", i, remoteUser)
} }
......
...@@ -15,7 +15,6 @@ import ( ...@@ -15,7 +15,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
...@@ -219,13 +218,13 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string] ...@@ -219,13 +218,13 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
// If it was rewritten, there will be a context value with the original URL, // If it was rewritten, there will be a context value with the original URL,
// which is needed to get the correct RequestURI value for PHP apps. // which is needed to get the correct RequestURI value for PHP apps.
reqURI := r.URL.RequestURI() reqURI := r.URL.RequestURI()
if origURI, _ := r.Context().Value(caddy.URIxRewriteCtxKey).(string); origURI != "" { if origURI, _ := r.Context().Value(httpserver.URIxRewriteCtxKey).(string); origURI != "" {
reqURI = origURI reqURI = origURI
} }
// Retrieve name of remote user that was set by some downstream middleware, // Retrieve name of remote user that was set by some downstream middleware,
// possibly basicauth. // possibly basicauth.
remoteUser, _ := r.Context().Value(caddy.CtxKey("remote_user")).(string) // Blank if not set remoteUser, _ := r.Context().Value(httpserver.RemoteUserCtxKey).(string) // Blank if not set
// Some variables are unused but cleared explicitly to prevent // Some variables are unused but cleared explicitly to prevent
// the parent environment from interfering. // the parent environment from interfering.
......
...@@ -14,7 +14,6 @@ import ( ...@@ -14,7 +14,6 @@ import (
"os" "os"
"github.com/mholt/caddy"
"github.com/russross/blackfriday" "github.com/russross/blackfriday"
) )
...@@ -349,7 +348,7 @@ func (c Context) Files(name string) ([]string, error) { ...@@ -349,7 +348,7 @@ func (c Context) Files(name string) ([]string, error) {
// IsMITM returns true if it seems likely that the TLS connection // IsMITM returns true if it seems likely that the TLS connection
// is being intercepted. // is being intercepted.
func (c Context) IsMITM() bool { func (c Context) IsMITM() bool {
if val, ok := c.Req.Context().Value(caddy.CtxKey("mitm")).(bool); ok { if val, ok := c.Req.Context().Value(MitmCtxKey).(bool); ok {
return val return val
} }
return false return false
......
...@@ -197,3 +197,16 @@ var EmptyNext = HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, e ...@@ -197,3 +197,16 @@ var EmptyNext = HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, e
func SameNext(next1, next2 Handler) bool { func SameNext(next1, next2 Handler) bool {
return fmt.Sprintf("%v", next1) == fmt.Sprintf("%v", next2) return fmt.Sprintf("%v", next1) == fmt.Sprintf("%v", next2)
} }
// Context key constants
const (
// URIxRewriteCtxKey is a context key used to store original unrewritten
// URI in context.WithValue
URIxRewriteCtxKey caddy.CtxKey = "caddy_rewrite_original_uri"
// RemoteUserCtxKey is a context key used to store remote user for request
RemoteUserCtxKey caddy.CtxKey = "remote_user"
// MitmCtxKey stores Mitm result
MitmCtxKey caddy.CtxKey = "mitm"
)
...@@ -9,8 +9,6 @@ import ( ...@@ -9,8 +9,6 @@ import (
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"github.com/mholt/caddy"
) )
// tlsHandler is a http.Handler that will inject a value // tlsHandler is a http.Handler that will inject a value
...@@ -74,7 +72,7 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -74,7 +72,7 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
if checked { if checked {
r = r.WithContext(context.WithValue(r.Context(), caddy.CtxKey("mitm"), mitm)) r = r.WithContext(context.WithValue(r.Context(), MitmCtxKey, mitm))
} }
if mitm && h.closeOnMITM { if mitm && h.closeOnMITM {
......
...@@ -7,8 +7,6 @@ import ( ...@@ -7,8 +7,6 @@ import (
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"testing" "testing"
"github.com/mholt/caddy"
) )
func TestParseClientHello(t *testing.T) { func TestParseClientHello(t *testing.T) {
...@@ -287,7 +285,7 @@ func TestHeuristicFunctionsAndHandler(t *testing.T) { ...@@ -287,7 +285,7 @@ func TestHeuristicFunctionsAndHandler(t *testing.T) {
want := ch.interception want := ch.interception
handler := &tlsHandler{ handler := &tlsHandler{
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got, checked = r.Context().Value(caddy.CtxKey("mitm")).(bool) got, checked = r.Context().Value(MitmCtxKey).(bool)
}), }),
listener: newTLSListener(nil, nil), listener: newTLSListener(nil, nil),
} }
......
...@@ -241,7 +241,7 @@ func (r *replacer) getSubstitution(key string) string { ...@@ -241,7 +241,7 @@ func (r *replacer) getSubstitution(key string) string {
// if a rewrite has happened, the original URI should be used as the path // if a rewrite has happened, the original URI should be used as the path
// rather than the rewritten URI // rather than the rewritten URI
var path string var path string
origpath, _ := r.request.Context().Value(caddy.URIxRewriteCtxKey).(string) origpath, _ := r.request.Context().Value(URIxRewriteCtxKey).(string)
if origpath == "" { if origpath == "" {
path = r.request.URL.Path path = r.request.URL.Path
} else { } else {
...@@ -251,7 +251,7 @@ func (r *replacer) getSubstitution(key string) string { ...@@ -251,7 +251,7 @@ func (r *replacer) getSubstitution(key string) string {
return path return path
case "{path_escaped}": case "{path_escaped}":
var path string var path string
origpath, _ := r.request.Context().Value(caddy.URIxRewriteCtxKey).(string) origpath, _ := r.request.Context().Value(URIxRewriteCtxKey).(string)
if origpath == "" { if origpath == "" {
path = r.request.URL.Path path = r.request.URL.Path
} else { } else {
...@@ -284,13 +284,13 @@ func (r *replacer) getSubstitution(key string) string { ...@@ -284,13 +284,13 @@ func (r *replacer) getSubstitution(key string) string {
} }
return port return port
case "{uri}": case "{uri}":
uri, _ := r.request.Context().Value(caddy.URIxRewriteCtxKey).(string) uri, _ := r.request.Context().Value(URIxRewriteCtxKey).(string)
if uri == "" { if uri == "" {
uri = r.request.URL.RequestURI() uri = r.request.URL.RequestURI()
} }
return uri return uri
case "{uri_escaped}": case "{uri_escaped}":
uri, _ := r.request.Context().Value(caddy.URIxRewriteCtxKey).(string) uri, _ := r.request.Context().Value(URIxRewriteCtxKey).(string)
if uri == "" { if uri == "" {
uri = r.request.URL.RequestURI() uri = r.request.URL.RequestURI()
} }
......
...@@ -8,8 +8,6 @@ import ( ...@@ -8,8 +8,6 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/mholt/caddy"
) )
func TestNewReplacer(t *testing.T) { func TestNewReplacer(t *testing.T) {
...@@ -164,7 +162,7 @@ func TestPathRewrite(t *testing.T) { ...@@ -164,7 +162,7 @@ func TestPathRewrite(t *testing.T) {
t.Fatalf("Request Formation Failed: %s\n", err.Error()) t.Fatalf("Request Formation Failed: %s\n", err.Error())
} }
ctx := context.WithValue(request.Context(), caddy.URIxRewriteCtxKey, "a/custom/path.php?key=value") ctx := context.WithValue(request.Context(), URIxRewriteCtxKey, "a/custom/path.php?key=value")
request = request.WithContext(ctx) request = request.WithContext(ctx)
repl := NewReplacer(request, recordRequest, "") repl := NewReplacer(request, recordRequest, "")
......
...@@ -292,7 +292,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -292,7 +292,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}() }()
w.Header().Set("Server", "Caddy") w.Header().Set("Server", "Caddy")
c := context.WithValue(r.Context(), caddy.URLPathCtxKey, r.URL.Path) c := context.WithValue(r.Context(), staticfiles.URLPathCtxKey, r.URL.Path)
r = r.WithContext(c) r = r.WithContext(c)
sanitizePath(r) sanitizePath(r)
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"path" "path"
"strings" "strings"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
...@@ -51,7 +50,7 @@ func To(fs http.FileSystem, r *http.Request, to string, replacer httpserver.Repl ...@@ -51,7 +50,7 @@ func To(fs http.FileSystem, r *http.Request, to string, replacer httpserver.Repl
// take note of this rewrite for internal use by fastcgi // take note of this rewrite for internal use by fastcgi
// all we need is the URI, not full URL // all we need is the URI, not full URL
*r = *r.WithContext(context.WithValue(r.Context(), caddy.URIxRewriteCtxKey, r.URL.RequestURI())) *r = *r.WithContext(context.WithValue(r.Context(), httpserver.URIxRewriteCtxKey, r.URL.RequestURI()))
// perform rewrite // perform rewrite
r.URL.Path = u.Path r.URL.Path = u.Path
......
...@@ -98,7 +98,7 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request, name stri ...@@ -98,7 +98,7 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request, name stri
if !strings.HasSuffix(r.URL.Path, "/") { if !strings.HasSuffix(r.URL.Path, "/") {
toURL, _ := url.Parse(r.URL.String()) toURL, _ := url.Parse(r.URL.String())
path, ok := r.Context().Value(caddy.URLPathCtxKey).(string) path, ok := r.Context().Value(URLPathCtxKey).(string)
if ok && !strings.HasSuffix(path, "/") { if ok && !strings.HasSuffix(path, "/") {
toURL.Path = path toURL.Path = path
} }
...@@ -113,7 +113,7 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request, name stri ...@@ -113,7 +113,7 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request, name stri
if strings.HasSuffix(r.URL.Path, "/") { if strings.HasSuffix(r.URL.Path, "/") {
toURL, _ := url.Parse(r.URL.String()) toURL, _ := url.Parse(r.URL.String())
path, ok := r.Context().Value(caddy.URLPathCtxKey).(string) path, ok := r.Context().Value(URLPathCtxKey).(string)
if ok && strings.HasSuffix(path, "/") { if ok && strings.HasSuffix(path, "/") {
toURL.Path = path toURL.Path = path
} }
...@@ -300,3 +300,8 @@ func mapFSRootOpenErr(originalErr error) error { ...@@ -300,3 +300,8 @@ func mapFSRootOpenErr(originalErr error) error {
} }
return originalErr return originalErr
} }
// URLPathCtxKey is a context key. It can be used in HTTP handlers with
// context.WithValue to access the original request URI that accompanied the
// server request. The associated value will be of type string.
const URLPathCtxKey caddy.CtxKey = "url_path"
...@@ -12,8 +12,6 @@ import ( ...@@ -12,8 +12,6 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/mholt/caddy"
) )
var ( var (
...@@ -265,7 +263,7 @@ func TestServeHTTP(t *testing.T) { ...@@ -265,7 +263,7 @@ func TestServeHTTP(t *testing.T) {
for i, test := range tests { for i, test := range tests {
responseRecorder := httptest.NewRecorder() responseRecorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.url, nil) request, err := http.NewRequest("GET", test.url, nil)
ctx := context.WithValue(request.Context(), caddy.URLPathCtxKey, request.URL.Path) ctx := context.WithValue(request.Context(), URLPathCtxKey, request.URL.Path)
request = request.WithContext(ctx) request = request.WithContext(ctx)
request.Header.Add("Accept-Encoding", test.acceptEncoding) request.Header.Add("Accept-Encoding", test.acceptEncoding)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment