Commit 269a8b5f authored by Matthew Holt's avatar Matthew Holt

Merge branch 'master' into diagnostics

# Conflicts:
#	plugins.go
#	vendor/manifest
parents 5820356c 12014922
...@@ -16,4 +16,6 @@ Caddyfile ...@@ -16,4 +16,6 @@ Caddyfile
og_static/ og_static/
.vscode/ .vscode/
\ No newline at end of file
*.bat
\ No newline at end of file
<p align="center"> <p align="center">
<a href="https://caddyserver.com"><img src="https://cloud.githubusercontent.com/assets/1128849/25305033/12916fce-2731-11e7-86ec-580d4d31cb16.png" alt="Caddy" width="400"></a> <a href="https://caddyserver.com"><img src="https://user-images.githubusercontent.com/1128849/36137292-bebc223a-1051-11e8-9a81-4ea9054c96ac.png" alt="Caddy" width="400"></a>
</p> </p>
<h3 align="center">Every Site on HTTPS <!-- Serve Confidently --></h3> <h3 align="center">Every Site on HTTPS <!-- Serve Confidently --></h3>
<p align="center">Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.</p> <p align="center">Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.</p>
...@@ -59,7 +59,7 @@ customize your build in the browser ...@@ -59,7 +59,7 @@ customize your build in the browser
pre-built, vanilla binaries pre-built, vanilla binaries
## Build ## Build
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.8 or newer). Follow these instruction for fast building: To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.9 or newer). Follow these instruction for fast building:
- Get source `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds` - Get source `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
- Now `cd` to `$GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go` - Now `cd` to `$GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`
......
...@@ -78,8 +78,18 @@ var ( ...@@ -78,8 +78,18 @@ var (
mu sync.Mutex mu sync.Mutex
) )
func init() {
OnProcessExit = append(OnProcessExit, func() {
if PidFile != "" {
os.Remove(PidFile)
}
})
}
// Instance contains the state of servers created as a result of // Instance contains the state of servers created as a result of
// calling Start and can be used to access or control those servers. // calling Start and can be used to access or control those servers.
// It is literally an instance of a server type. Instance values
// should NOT be copied. Use *Instance for safety.
type Instance struct { type Instance struct {
// serverType is the name of the instance's server type // serverType is the name of the instance's server type
serverType string serverType string
...@@ -90,10 +100,11 @@ type Instance struct { ...@@ -90,10 +100,11 @@ type Instance struct {
// wg is used to wait for all servers to shut down // wg is used to wait for all servers to shut down
wg *sync.WaitGroup wg *sync.WaitGroup
// context is the context created for this instance. // context is the context created for this instance,
// used to coordinate the setting up of the server type
context Context context Context
// servers is the list of servers with their listeners. // servers is the list of servers with their listeners
servers []ServerListener servers []ServerListener
// these callbacks execute when certain events occur // these callbacks execute when certain events occur
...@@ -102,6 +113,18 @@ type Instance struct { ...@@ -102,6 +113,18 @@ type Instance struct {
onRestart []func() error // before restart commences onRestart []func() error // before restart commences
onShutdown []func() error // stopping, even as part of a restart onShutdown []func() error // stopping, even as part of a restart
onFinalShutdown []func() error // stopping, not as part of a restart onFinalShutdown []func() error // stopping, not as part of a restart
// storing values on an instance is preferable to
// global state because these will get garbage-
// collected after in-process reloads when the
// old instances are destroyed; use StorageMu
// to access this value safely
Storage map[interface{}]interface{}
StorageMu sync.RWMutex
}
func Instances() []*Instance {
return instances
} }
// Servers returns the ServerListeners in i. // Servers returns the ServerListeners in i.
...@@ -197,7 +220,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) { ...@@ -197,7 +220,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
} }
// create new instance; if the restart fails, it is simply discarded // create new instance; if the restart fails, it is simply discarded
newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg} newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg, Storage: make(map[interface{}]interface{})}
// attempt to start new instance // attempt to start new instance
err := startWithListenerFds(newCaddyfile, newInst, restartFds) err := startWithListenerFds(newCaddyfile, newInst, restartFds)
...@@ -456,7 +479,7 @@ func (i *Instance) Caddyfile() Input { ...@@ -456,7 +479,7 @@ func (i *Instance) Caddyfile() Input {
// //
// This function blocks until all the servers are listening. // This function blocks until all the servers are listening.
func Start(cdyfile Input) (*Instance, error) { func Start(cdyfile Input) (*Instance, error) {
inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)} inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
err := startWithListenerFds(cdyfile, inst, nil) err := startWithListenerFds(cdyfile, inst, nil)
if err != nil { if err != nil {
return inst, err return inst, err
...@@ -469,11 +492,34 @@ func Start(cdyfile Input) (*Instance, error) { ...@@ -469,11 +492,34 @@ func Start(cdyfile Input) (*Instance, error) {
} }
func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]restartTriple) error { func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]restartTriple) error {
// save this instance in the list now so that
// plugins can access it if need be, for example
// the caddytls package, so it can perform cert
// renewals while starting up; we just have to
// remove the instance from the list later if
// it fails
instancesMu.Lock()
instances = append(instances, inst)
instancesMu.Unlock()
var err error
defer func() {
if err != nil {
instancesMu.Lock()
for i, otherInst := range instances {
if otherInst == inst {
instances = append(instances[:i], instances[i+1:]...)
break
}
}
instancesMu.Unlock()
}
}()
if cdyfile == nil { if cdyfile == nil {
cdyfile = CaddyfileInput{} cdyfile = CaddyfileInput{}
} }
err := ValidateAndExecuteDirectives(cdyfile, inst, false) err = ValidateAndExecuteDirectives(cdyfile, inst, false)
if err != nil { if err != nil {
return err return err
} }
...@@ -505,10 +551,6 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r ...@@ -505,10 +551,6 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
return err return err
} }
instancesMu.Lock()
instances = append(instances, inst)
instancesMu.Unlock()
// run any AfterStartup callbacks if this is not // run any AfterStartup callbacks if this is not
// part of a restart; then show file descriptor notice // part of a restart; then show file descriptor notice
if restartFds == nil { if restartFds == nil {
...@@ -547,7 +589,7 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r ...@@ -547,7 +589,7 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bool) error { func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bool) error {
// If parsing only inst will be nil, create an instance for this function call only. // If parsing only inst will be nil, create an instance for this function call only.
if justValidate { if justValidate {
inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)} inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
} }
stypeName := cdyfile.ServerType() stypeName := cdyfile.ServerType()
...@@ -564,14 +606,14 @@ func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bo ...@@ -564,14 +606,14 @@ func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bo
return err return err
} }
inst.context = stype.NewContext() inst.context = stype.NewContext(inst)
if inst.context == nil { if inst.context == nil {
return fmt.Errorf("server type %s produced a nil Context", stypeName) return fmt.Errorf("server type %s produced a nil Context", stypeName)
} }
sblocks, err = inst.context.InspectServerBlocks(cdyfile.Path(), sblocks) sblocks, err = inst.context.InspectServerBlocks(cdyfile.Path(), sblocks)
if err != nil { if err != nil {
return err return fmt.Errorf("error inspecting server blocks: %v", err)
} }
diagnostics.Set("num_server_blocks", len(sblocks)) diagnostics.Set("num_server_blocks", len(sblocks))
......
...@@ -148,7 +148,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -148,7 +148,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
case "HEAD": case "HEAD":
resp, err = fcgiBackend.Head(env) resp, err = fcgiBackend.Head(env)
case "GET": case "GET":
resp, err = fcgiBackend.Get(env) resp, err = fcgiBackend.Get(env, r.Body, contentLength)
case "OPTIONS": case "OPTIONS":
resp, err = fcgiBackend.Options(env) resp, err = fcgiBackend.Options(env)
default: default:
......
...@@ -460,12 +460,12 @@ func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Res ...@@ -460,12 +460,12 @@ func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Res
} }
// Get issues a GET request to the fcgi responder. // Get issues a GET request to the fcgi responder.
func (c *FCGIClient) Get(p map[string]string) (resp *http.Response, err error) { func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) {
p["REQUEST_METHOD"] = "GET" p["REQUEST_METHOD"] = "GET"
p["CONTENT_LENGTH"] = "0" p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
return c.Request(p, nil) return c.Request(p, body)
} }
// Head issues a HEAD request to the fcgi responder. // Head issues a HEAD request to the fcgi responder.
......
...@@ -140,7 +140,8 @@ func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[ ...@@ -140,7 +140,8 @@ func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[
} }
resp, err = fcgi.PostForm(fcgiParams, values) resp, err = fcgi.PostForm(fcgiParams, values)
} else { } else {
resp, err = fcgi.Get(fcgiParams) rd := bytes.NewReader(data)
resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len()))
} }
default: default:
......
...@@ -27,7 +27,7 @@ func activateHTTPS(cctx caddy.Context) error { ...@@ -27,7 +27,7 @@ func activateHTTPS(cctx caddy.Context) error {
operatorPresent := !caddy.Started() operatorPresent := !caddy.Started()
if !caddy.Quiet && operatorPresent { if !caddy.Quiet && operatorPresent {
fmt.Print("Activating privacy features...") fmt.Print("Activating privacy features... ")
} }
ctx := cctx.(*httpContext) ctx := cctx.(*httpContext)
...@@ -69,7 +69,7 @@ func activateHTTPS(cctx caddy.Context) error { ...@@ -69,7 +69,7 @@ func activateHTTPS(cctx caddy.Context) error {
} }
if !caddy.Quiet && operatorPresent { if !caddy.Quiet && operatorPresent {
fmt.Println(" done.") fmt.Println("done.")
} }
return nil return nil
...@@ -160,23 +160,37 @@ func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort str ...@@ -160,23 +160,37 @@ func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort str
// to listen on HTTPPort. The TLS field of cfg must not be nil. // to listen on HTTPPort. The TLS field of cfg must not be nil.
func redirPlaintextHost(cfg *SiteConfig) *SiteConfig { func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
redirPort := cfg.Addr.Port redirPort := cfg.Addr.Port
if redirPort == DefaultHTTPSPort { if redirPort == HTTPSPort {
redirPort = "" // default port is redundant // By default, HTTPSPort should be DefaultHTTPSPort,
// which of course doesn't need to be explicitly stated
// in the Location header. Even if HTTPSPort is changed
// so that it is no longer DefaultHTTPSPort, we shouldn't
// append it to the URL in the Location because changing
// the HTTPS port is assumed to be an internal-only change
// (in other words, we assume port forwarding is going on);
// but redirects go back to a presumably-external client.
// (If redirect clients are also internal, that is more
// advanced, and the user should configure HTTP->HTTPS
// redirects themselves.)
redirPort = ""
} }
redirMiddleware := func(next Handler) Handler { redirMiddleware := func(next Handler) Handler {
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
// Construct the URL to which to redirect. Note that the Host in a request might // Construct the URL to which to redirect. Note that the Host in a
// contain a port, but we just need the hostname; we'll set the port if needed. // request might contain a port, but we just need the hostname from
// it; and we'll set the port if needed.
toURL := "https://" toURL := "https://"
requestHost, _, err := net.SplitHostPort(r.Host) requestHost, _, err := net.SplitHostPort(r.Host)
if err != nil { if err != nil {
requestHost = r.Host // Host did not contain a port; great requestHost = r.Host // Host did not contain a port, so use the whole value
} }
if redirPort == "" { if redirPort == "" {
toURL += requestHost toURL += requestHost
} else { } else {
toURL += net.JoinHostPort(requestHost, redirPort) toURL += net.JoinHostPort(requestHost, redirPort)
} }
toURL += r.URL.RequestURI() toURL += r.URL.RequestURI()
w.Header().Set("Connection", "close") w.Header().Set("Connection", "close")
...@@ -184,9 +198,11 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig { ...@@ -184,9 +198,11 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
return 0, nil return 0, nil
}) })
} }
host := cfg.Addr.Host host := cfg.Addr.Host
port := HTTPPort port := HTTPPort
addr := net.JoinHostPort(host, port) addr := net.JoinHostPort(host, port)
return &SiteConfig{ return &SiteConfig{
Addr: Address{Original: addr, Host: host, Port: port}, Addr: Address{Original: addr, Host: host, Port: port},
ListenHost: cfg.ListenHost, ListenHost: cfg.ListenHost,
......
...@@ -53,7 +53,7 @@ func TestRedirPlaintextHost(t *testing.T) { ...@@ -53,7 +53,7 @@ func TestRedirPlaintextHost(t *testing.T) {
}, },
{ {
Host: "foohost", Host: "foohost",
Port: "443", // since this is the default HTTPS port, should not be included in Location value Port: HTTPSPort, // since this is the 'default' HTTPS port, should not be included in Location value
}, },
{ {
Host: "*.example.com", Host: "*.example.com",
......
...@@ -91,11 +91,13 @@ func hideCaddyfile(cctx caddy.Context) error { ...@@ -91,11 +91,13 @@ func hideCaddyfile(cctx caddy.Context) error {
return nil return nil
} }
func newContext() caddy.Context { func newContext(inst *caddy.Instance) caddy.Context {
return &httpContext{keysToSiteConfigs: make(map[string]*SiteConfig)} return &httpContext{instance: inst, keysToSiteConfigs: make(map[string]*SiteConfig)}
} }
type httpContext struct { type httpContext struct {
instance *caddy.Instance
// keysToSiteConfigs maps an address at the top of a // keysToSiteConfigs maps an address at the top of a
// server block (a "key") to its SiteConfig. Not all // server block (a "key") to its SiteConfig. Not all
// SiteConfigs will be represented here, only ones // SiteConfigs will be represented here, only ones
...@@ -115,12 +117,14 @@ func (h *httpContext) saveConfig(key string, cfg *SiteConfig) { ...@@ -115,12 +117,14 @@ func (h *httpContext) saveConfig(key string, cfg *SiteConfig) {
// executing directives and otherwise prepares the directives to // executing directives and otherwise prepares the directives to
// be parsed and executed. // be parsed and executed.
func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) { func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
siteAddrs := make(map[string]string)
// For each address in each server block, make a new config // For each address in each server block, make a new config
for _, sb := range serverBlocks { for _, sb := range serverBlocks {
for _, key := range sb.Keys { for _, key := range sb.Keys {
key = strings.ToLower(key) key = strings.ToLower(key)
if _, dup := h.keysToSiteConfigs[key]; dup { if _, dup := h.keysToSiteConfigs[key]; dup {
return serverBlocks, fmt.Errorf("duplicate site address: %s", key) return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
} }
addr, err := standardizeAddress(key) addr, err := standardizeAddress(key)
if err != nil { if err != nil {
...@@ -136,6 +140,23 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd ...@@ -136,6 +140,23 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
addr.Port = Port addr.Port = Port
} }
// Make sure the adjusted site address is distinct
addrCopy := addr // make copy so we don't disturb the original, carefully-parsed address struct
if addrCopy.Port == "" && Port == DefaultPort {
addrCopy.Port = Port
}
addrStr := strings.ToLower(addrCopy.String())
if otherSiteKey, dup := siteAddrs[addrStr]; dup {
err := fmt.Errorf("duplicate site address: %s", addrStr)
if (addrCopy.Host == Host && Host != DefaultHost) ||
(addrCopy.Port == Port && Port != DefaultPort) {
err = fmt.Errorf("site defined as %s is a duplicate of %s because of modified "+
"default host and/or port values (usually via -host or -port flags)", key, otherSiteKey)
}
return serverBlocks, err
}
siteAddrs[addrStr] = key
// If default HTTP or HTTPS ports have been customized, // If default HTTP or HTTPS ports have been customized,
// make sure the ACME challenge ports match // make sure the ACME challenge ports match
var altHTTPPort, altTLSSNIPort string var altHTTPPort, altTLSSNIPort string
...@@ -146,15 +167,19 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd ...@@ -146,15 +167,19 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
altTLSSNIPort = HTTPSPort altTLSSNIPort = HTTPSPort
} }
// Make our caddytls.Config, which has a pointer to the
// instance's certificate cache and enough information
// to use automatic HTTPS when the time comes
caddytlsConfig := caddytls.NewConfig(h.instance)
caddytlsConfig.Hostname = addr.Host
caddytlsConfig.AltHTTPPort = altHTTPPort
caddytlsConfig.AltTLSSNIPort = altTLSSNIPort
// Save the config to our master list, and key it for lookups // Save the config to our master list, and key it for lookups
cfg := &SiteConfig{ cfg := &SiteConfig{
Addr: addr, Addr: addr,
Root: Root, Root: Root,
TLS: &caddytls.Config{ TLS: caddytlsConfig,
Hostname: addr.Host,
AltHTTPPort: altHTTPPort,
AltTLSSNIPort: altTLSSNIPort,
},
originCaddyfile: sourceFile, originCaddyfile: sourceFile,
IndexPages: staticfiles.DefaultIndexPages, IndexPages: staticfiles.DefaultIndexPages,
} }
......
...@@ -137,7 +137,7 @@ func TestAddressString(t *testing.T) { ...@@ -137,7 +137,7 @@ func TestAddressString(t *testing.T) {
func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) { func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
Port = "9999" Port = "9999"
filename := "Testfile" filename := "Testfile"
ctx := newContext().(*httpContext) ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader(`localhost`) input := strings.NewReader(`localhost`)
sblocks, err := caddyfile.Parse(filename, input, nil) sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil { if err != nil {
...@@ -153,9 +153,26 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) { ...@@ -153,9 +153,26 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
} }
} }
// See discussion on PR #2015
func TestInspectServerBlocksWithAdjustedAddress(t *testing.T) {
Port = DefaultPort
Host = "example.com"
filename := "Testfile"
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader("example.com {\n}\n:2015 {\n}")
sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil {
t.Fatalf("Expected no error setting up test, got: %v", err)
}
_, err = ctx.InspectServerBlocks(filename, sblocks)
if err == nil {
t.Fatalf("Expected an error because site definitions should overlap, got: %v", err)
}
}
func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) { func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
filename := "Testfile" filename := "Testfile"
ctx := newContext().(*httpContext) ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader("localhost {\n}\nLOCALHOST {\n}") input := strings.NewReader("localhost {\n}\nLOCALHOST {\n}")
sblocks, err := caddyfile.Parse(filename, input, nil) sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil { if err != nil {
...@@ -207,7 +224,7 @@ func TestDirectivesList(t *testing.T) { ...@@ -207,7 +224,7 @@ func TestDirectivesList(t *testing.T) {
} }
func TestContextSaveConfig(t *testing.T) { func TestContextSaveConfig(t *testing.T) {
ctx := newContext().(*httpContext) ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
ctx.saveConfig("foo", new(SiteConfig)) ctx.saveConfig("foo", new(SiteConfig))
if _, ok := ctx.keysToSiteConfigs["foo"]; !ok { if _, ok := ctx.keysToSiteConfigs["foo"]; !ok {
t.Error("Expected config to be saved, but it wasn't") t.Error("Expected config to be saved, but it wasn't")
...@@ -226,7 +243,7 @@ func TestContextSaveConfig(t *testing.T) { ...@@ -226,7 +243,7 @@ func TestContextSaveConfig(t *testing.T) {
// Test to make sure we are correctly hiding the Caddyfile // Test to make sure we are correctly hiding the Caddyfile
func TestHideCaddyfile(t *testing.T) { func TestHideCaddyfile(t *testing.T) {
ctx := newContext().(*httpContext) ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
ctx.saveConfig("test", &SiteConfig{ ctx.saveConfig("test", &SiteConfig{
Root: Root, Root: Root,
originCaddyfile: "Testfile", originCaddyfile: "Testfile",
......
...@@ -392,7 +392,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -392,7 +392,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
if vhost == nil { if vhost == nil {
// check for ACME challenge even if vhost is nil; // check for ACME challenge even if vhost is nil;
// could be a new host coming online soon // could be a new host coming online soon
if caddytls.HTTPChallengeHandler(w, r, "localhost", caddytls.DefaultHTTPAlternatePort) { if caddytls.HTTPChallengeHandler(w, r, "localhost") {
return 0, nil return 0, nil
} }
// otherwise, log the error and write a message to the client // otherwise, log the error and write a message to the client
...@@ -408,7 +408,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -408,7 +408,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// we still check for ACME challenge if the vhost exists, // we still check for ACME challenge if the vhost exists,
// because we must apply its HTTP challenge config settings // because we must apply its HTTP challenge config settings
if s.proxyHTTPChallenge(vhost, w, r) { if caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost) {
return 0, nil return 0, nil
} }
...@@ -416,31 +416,25 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -416,31 +416,25 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// the URL path, so a request to example.com/foo/blog on the site // the URL path, so a request to example.com/foo/blog on the site
// defined as example.com/foo appears as /blog instead of /foo/blog. // defined as example.com/foo appears as /blog instead of /foo/blog.
if pathPrefix != "/" { if pathPrefix != "/" {
r.URL.Path = strings.TrimPrefix(r.URL.Path, pathPrefix) r.URL = trimPathPrefix(r.URL, pathPrefix)
if !strings.HasPrefix(r.URL.Path, "/") {
r.URL.Path = "/" + r.URL.Path
}
} }
return vhost.middlewareChain.ServeHTTP(w, r) return vhost.middlewareChain.ServeHTTP(w, r)
} }
// proxyHTTPChallenge solves the ACME HTTP challenge if r is the HTTP func trimPathPrefix(u *url.URL, prefix string) *url.URL {
// request for the challenge. If it is, and if the request has been // We need to use URL.EscapedPath() when trimming the pathPrefix as
// fulfilled (response written), true is returned; false otherwise. // URL.Path is ambiguous about / or %2f - see docs. See #1927
// If you don't have a vhost, just call the challenge handler directly. trimmed := strings.TrimPrefix(u.EscapedPath(), prefix)
func (s *Server) proxyHTTPChallenge(vhost *SiteConfig, w http.ResponseWriter, r *http.Request) bool { if !strings.HasPrefix(trimmed, "/") {
if vhost.Addr.Port != caddytls.HTTPChallengePort { trimmed = "/" + trimmed
return false
}
if vhost.TLS != nil && vhost.TLS.Manual {
return false
} }
altPort := caddytls.DefaultHTTPAlternatePort trimmedURL, err := url.Parse(trimmed)
if vhost.TLS != nil && vhost.TLS.AltHTTPPort != "" { if err != nil {
altPort = vhost.TLS.AltHTTPPort log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmed, err)
return u
} }
return caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost, altPort) return trimmedURL
} }
// Address returns the address s was assigned to listen on. // Address returns the address s was assigned to listen on.
......
...@@ -16,6 +16,7 @@ package httpserver ...@@ -16,6 +16,7 @@ package httpserver
import ( import (
"net/http" "net/http"
"net/url"
"testing" "testing"
"time" "time"
) )
...@@ -126,6 +127,94 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) { ...@@ -126,6 +127,94 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) {
} }
} }
func TestTrimPathPrefix(t *testing.T) {
for i, pt := range []struct {
path string
prefix string
expected string
shouldFail bool
}{
{
path: "/my/path",
prefix: "/my",
expected: "/path",
shouldFail: false,
},
{
path: "/my/%2f/path",
prefix: "/my",
expected: "/%2f/path",
shouldFail: false,
},
{
path: "/my/path",
prefix: "/my/",
expected: "/path",
shouldFail: false,
},
{
path: "/my///path",
prefix: "/my",
expected: "/path",
shouldFail: true,
},
{
path: "/my///path",
prefix: "/my",
expected: "///path",
shouldFail: false,
},
{
path: "/my/path///slash",
prefix: "/my",
expected: "/path///slash",
shouldFail: false,
},
{
path: "/my/%2f/path/%2f",
prefix: "/my",
expected: "/%2f/path/%2f",
shouldFail: false,
}, {
path: "/my/%20/path",
prefix: "/my",
expected: "/%20/path",
shouldFail: false,
}, {
path: "/path",
prefix: "",
expected: "/path",
shouldFail: false,
}, {
path: "/path/my/",
prefix: "/my",
expected: "/path/my/",
shouldFail: false,
}, {
path: "",
prefix: "/my",
expected: "/",
shouldFail: false,
}, {
path: "/apath",
prefix: "",
expected: "/apath",
shouldFail: false,
},
} {
u, _ := url.Parse(pt.path)
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.EscapedPath() != want {
if !pt.shouldFail {
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.EscapedPath())
}
} else if pt.shouldFail {
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.EscapedPath())
}
}
}
func TestMakeHTTPServerWithHeaderLimit(t *testing.T) { func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
for name, c := range map[string]struct { for name, c := range map[string]struct {
group []*SiteConfig group []*SiteConfig
......
...@@ -16,6 +16,7 @@ package requestid ...@@ -16,6 +16,7 @@ package requestid
import ( import (
"context" "context"
"log"
"net/http" "net/http"
"github.com/google/uuid" "github.com/google/uuid"
...@@ -24,12 +25,29 @@ import ( ...@@ -24,12 +25,29 @@ import (
// Handler is a middleware handler // Handler is a middleware handler
type Handler struct { type Handler struct {
Next httpserver.Handler Next httpserver.Handler
HeaderName string // (optional) header from which to read an existing ID
} }
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
reqid := uuid.New().String() var reqid uuid.UUID
c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid)
uuidFromHeader := r.Header.Get(h.HeaderName)
if h.HeaderName != "" && uuidFromHeader != "" {
// use the ID in the header field if it exists
var err error
reqid, err = uuid.Parse(uuidFromHeader)
if err != nil {
log.Printf("[NOTICE] Parsing request ID from %s header: %v", h.HeaderName, err)
reqid = uuid.New()
}
} else {
// otherwise, create a new one
reqid = uuid.New()
}
// set the request ID on the context
c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid.String())
r = r.WithContext(c) r = r.WithContext(c)
return h.Next.ServeHTTP(w, r) return h.Next.ServeHTTP(w, r)
......
...@@ -15,34 +15,53 @@ ...@@ -15,34 +15,53 @@
package requestid package requestid
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/google/uuid"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
func TestRequestID(t *testing.T) { func TestRequestIDHandler(t *testing.T) {
request, err := http.NewRequest("GET", "http://localhost/", nil) handler := Handler{
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string)
if value == "" {
t.Error("Request ID should not be empty")
}
return 0, nil
}),
}
req, err := http.NewRequest("GET", "http://localhost/", nil)
if err != nil { if err != nil {
t.Fatal("Could not create HTTP request:", err) t.Fatal("Could not create HTTP request:", err)
} }
rec := httptest.NewRecorder()
reqid := uuid.New().String() handler.ServeHTTP(rec, req)
}
c := context.WithValue(request.Context(), httpserver.RequestIDCtxKey, reqid)
request = request.WithContext(c)
// See caddyhttp/replacer.go
value, _ := request.Context().Value(httpserver.RequestIDCtxKey).(string)
if value == "" { func TestRequestIDFromHeader(t *testing.T) {
t.Fatal("Request ID should not be empty") headerName := "X-Request-ID"
headerValue := "71a75329-d9f9-4d25-957e-e689a7b68d78"
handler := Handler{
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string)
if value != headerValue {
t.Errorf("Request ID should be '%s' but got '%s'", headerValue, value)
}
return 0, nil
}),
HeaderName: headerName,
} }
if value != reqid { req, err := http.NewRequest("GET", "http://localhost/", nil)
t.Fatal("Request ID does not match") if err != nil {
t.Fatal("Could not create HTTP request:", err)
} }
req.Header.Set(headerName, headerValue)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
} }
...@@ -27,14 +27,19 @@ func init() { ...@@ -27,14 +27,19 @@ func init() {
} }
func setup(c *caddy.Controller) error { func setup(c *caddy.Controller) error {
var headerName string
for c.Next() { for c.Next() {
if c.NextArg() { if c.NextArg() {
return c.ArgErr() //no arg expected. headerName = c.Val()
}
if c.NextArg() {
return c.ArgErr()
} }
} }
httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Handler{Next: next} return Handler{Next: next, HeaderName: headerName}
}) })
return nil return nil
......
...@@ -45,7 +45,15 @@ func TestSetup(t *testing.T) { ...@@ -45,7 +45,15 @@ func TestSetup(t *testing.T) {
} }
func TestSetupWithArg(t *testing.T) { func TestSetupWithArg(t *testing.T) {
c := caddy.NewTestController("http", `requestid abc`) c := caddy.NewTestController("http", `requestid X-Request-ID`)
err := setup(c)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
}
func TestSetupWithTooManyArgs(t *testing.T) {
c := caddy.NewTestController("http", `requestid foo bar`)
err := setup(c) err := setup(c)
if err == nil { if err == nil {
t.Errorf("Expected an error, got: %v", err) t.Errorf("Expected an error, got: %v", err)
......
...@@ -107,6 +107,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err ...@@ -107,6 +107,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err
if d.IsDir() { if d.IsDir() {
// ensure there is a trailing slash // ensure there is a trailing slash
if urlCopy.Path[len(urlCopy.Path)-1] != '/' { if urlCopy.Path[len(urlCopy.Path)-1] != '/' {
for strings.HasPrefix(urlCopy.Path, "//") {
// prevent path-based open redirects
urlCopy.Path = strings.TrimPrefix(urlCopy.Path, "/")
}
urlCopy.Path += "/" urlCopy.Path += "/"
http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently) http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently)
return http.StatusMovedPermanently, nil return http.StatusMovedPermanently, nil
...@@ -131,6 +135,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err ...@@ -131,6 +135,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err
} }
if redir { if redir {
for strings.HasPrefix(urlCopy.Path, "//") {
// prevent path-based open redirects
urlCopy.Path = strings.TrimPrefix(urlCopy.Path, "/")
}
http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently) http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently)
return http.StatusMovedPermanently, nil return http.StatusMovedPermanently, nil
} }
......
...@@ -77,9 +77,9 @@ func TestServeHTTP(t *testing.T) { ...@@ -77,9 +77,9 @@ func TestServeHTTP(t *testing.T) {
{ {
url: "https://foo/dirwithindex/", url: "https://foo/dirwithindex/",
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBodyContent: testFiles[webrootDirwithindexIndeHTML], expectedBodyContent: testFiles[webrootDirwithindexIndexHTML],
expectedEtag: `"2n9cw"`, expectedEtag: `"2n9cw"`,
expectedContentLength: strconv.Itoa(len(testFiles[webrootDirwithindexIndeHTML])), expectedContentLength: strconv.Itoa(len(testFiles[webrootDirwithindexIndexHTML])),
}, },
// Test 4 - access folder with index file without trailing slash // Test 4 - access folder with index file without trailing slash
{ {
...@@ -235,16 +235,38 @@ func TestServeHTTP(t *testing.T) { ...@@ -235,16 +235,38 @@ func TestServeHTTP(t *testing.T) {
expectedBodyContent: movedPermanently, expectedBodyContent: movedPermanently,
}, },
{ {
// Test 27 - Check etag
url: "https://foo/notindex.html", url: "https://foo/notindex.html",
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBodyContent: testFiles[webrootNotIndexHTML], expectedBodyContent: testFiles[webrootNotIndexHTML],
expectedEtag: `"2n9cm"`, expectedEtag: `"2n9cm"`,
expectedContentLength: strconv.Itoa(len(testFiles[webrootNotIndexHTML])), expectedContentLength: strconv.Itoa(len(testFiles[webrootNotIndexHTML])),
}, },
{
// Test 28 - Prevent path-based open redirects (directory)
url: "https://foo//example.com%2f..",
expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/example.com/../",
expectedBodyContent: movedPermanently,
},
{
// Test 29 - Prevent path-based open redirects (file)
url: "https://foo//example.com%2f../dirwithindex/index.html",
expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/example.com/../dirwithindex/",
expectedBodyContent: movedPermanently,
},
{
// Test 29 - Prevent path-based open redirects (extra leading slashes)
url: "https://foo///example.com%2f..",
expectedStatus: http.StatusMovedPermanently,
expectedLocation: "https://foo/example.com/../",
expectedBodyContent: movedPermanently,
},
} }
for i, test := range tests { for i, test := range tests {
// set up response writer and rewuest // set up response writer and request
responseRecorder := httptest.NewRecorder() responseRecorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.url, nil) request, err := http.NewRequest("GET", test.url, nil)
if err != nil { if err != nil {
...@@ -518,7 +540,7 @@ var ( ...@@ -518,7 +540,7 @@ var (
webrootNotIndexHTML = filepath.Join(webrootName, "notindex.html") webrootNotIndexHTML = filepath.Join(webrootName, "notindex.html")
webrootDirFile2HTML = filepath.Join(webrootName, "dir", "file2.html") webrootDirFile2HTML = filepath.Join(webrootName, "dir", "file2.html")
webrootDirHiddenHTML = filepath.Join(webrootName, "dir", "hidden.html") webrootDirHiddenHTML = filepath.Join(webrootName, "dir", "hidden.html")
webrootDirwithindexIndeHTML = filepath.Join(webrootName, "dirwithindex", "index.html") webrootDirwithindexIndexHTML = filepath.Join(webrootName, "dirwithindex", "index.html")
webrootSubGzippedHTML = filepath.Join(webrootName, "sub", "gzipped.html") webrootSubGzippedHTML = filepath.Join(webrootName, "sub", "gzipped.html")
webrootSubGzippedHTMLGz = filepath.Join(webrootName, "sub", "gzipped.html.gz") webrootSubGzippedHTMLGz = filepath.Join(webrootName, "sub", "gzipped.html.gz")
webrootSubGzippedHTMLBr = filepath.Join(webrootName, "sub", "gzipped.html.br") webrootSubGzippedHTMLBr = filepath.Join(webrootName, "sub", "gzipped.html.br")
...@@ -544,7 +566,7 @@ var testFiles = map[string]string{ ...@@ -544,7 +566,7 @@ var testFiles = map[string]string{
webrootFile1HTML: "<h1>file1.html</h1>", webrootFile1HTML: "<h1>file1.html</h1>",
webrootNotIndexHTML: "<h1>notindex.html</h1>", webrootNotIndexHTML: "<h1>notindex.html</h1>",
webrootDirFile2HTML: "<h1>dir/file2.html</h1>", webrootDirFile2HTML: "<h1>dir/file2.html</h1>",
webrootDirwithindexIndeHTML: "<h1>dirwithindex/index.html</h1>", webrootDirwithindexIndexHTML: "<h1>dirwithindex/index.html</h1>",
webrootDirHiddenHTML: "<h1>dir/hidden.html</h1>", webrootDirHiddenHTML: "<h1>dir/hidden.html</h1>",
webrootSubGzippedHTML: "<h1>gzipped.html</h1>", webrootSubGzippedHTML: "<h1>gzipped.html</h1>",
webrootSubGzippedHTMLGz: "1.gzipped.html.gz", webrootSubGzippedHTMLGz: "1.gzipped.html.gz",
......
This diff is collapsed.
...@@ -17,57 +17,71 @@ package caddytls ...@@ -17,57 +17,71 @@ package caddytls
import "testing" import "testing"
func TestUnexportedGetCertificate(t *testing.T) { func TestUnexportedGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }() certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
// When cache is empty // When cache is empty
if _, matched, defaulted := getCertificate("example.com"); matched || defaulted { if _, matched, defaulted := cfg.getCertificate("example.com"); matched || defaulted {
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted) t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
} }
// When cache has one certificate in it (also is default) // When cache has one certificate in it
defaultCert := Certificate{Names: []string{"example.com", ""}} firstCert := Certificate{Names: []string{"example.com"}}
certCache[""] = defaultCert certCache.cache["0xdeadbeef"] = firstCert
certCache["example.com"] = defaultCert cfg.Certificates["example.com"] = "0xdeadbeef"
if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" { if cert, matched, defaulted := cfg.getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
} }
if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" { if cert, matched, defaulted := cfg.getCertificate("example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
} }
// When retrieving wildcard certificate // When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}} certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}}
if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" { cfg.Certificates["*.example.com"] = "0xb01dface"
if cert, matched, defaulted := cfg.getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
} }
// When no certificate matches, the default is returned // When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted { if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
}
// When no certificate matches and SNI is NOT provided, a random is returned
if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted {
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert) t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
} else if cert.Names[0] != "example.com" {
t.Errorf("Expected default cert, got: %v", cert)
} }
} }
func TestCacheCertificate(t *testing.T) { func TestCacheCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }() certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}}) cfg.cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}, Hash: "foobar"})
if _, ok := certCache["example.com"]; !ok { if len(certCache.cache) != 1 {
t.Error("Expected first cert to be cached by key 'example.com', but it wasn't") t.Errorf("Expected length of certificate cache to be 1")
}
if _, ok := certCache.cache["foobar"]; !ok {
t.Error("Expected first cert to be cached by key 'foobar', but it wasn't")
} }
if _, ok := certCache["sub.example.com"]; !ok { if _, ok := cfg.Certificates["example.com"]; !ok {
t.Error("Expected first cert to be cached by key 'sub.example.com', but it wasn't") t.Error("Expected first cert to be keyed by 'example.com', but it wasn't")
} }
if cert, ok := certCache[""]; !ok || cert.Names[2] != "" { if _, ok := cfg.Certificates["sub.example.com"]; !ok {
t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't") t.Error("Expected first cert to be keyed by 'sub.example.com', but it wasn't")
} }
cacheCertificate(Certificate{Names: []string{"example2.com"}}) // different config, but using same cache; and has cert with overlapping name,
if _, ok := certCache["example2.com"]; !ok { // but different hash
t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't") cfg2 := &Config{Certificates: make(map[string]string), certCache: certCache}
cfg2.cacheCertificate(Certificate{Names: []string{"example.com"}, Hash: "barbaz"})
if _, ok := certCache.cache["barbaz"]; !ok {
t.Error("Expected second cert to be cached by key 'barbaz.com', but it wasn't")
} }
if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" { if hash, ok := cfg2.Certificates["example.com"]; !ok {
t.Error("Expected second cert to NOT be cached as default, but it was") t.Error("Expected second cert to be keyed by 'example.com', but it wasn't")
} else if hash != "barbaz" {
t.Errorf("Expected second cert to map to 'barbaz' but it was %s instead", hash)
} }
} }
...@@ -40,7 +40,7 @@ type ACMEClient struct { ...@@ -40,7 +40,7 @@ type ACMEClient struct {
AllowPrompts bool AllowPrompts bool
config *Config config *Config
acmeClient *acme.Client acmeClient *acme.Client
locker Locker storage Storage
} }
// newACMEClient creates a new ACMEClient given an email and whether // newACMEClient creates a new ACMEClient given an email and whether
...@@ -122,10 +122,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -122,10 +122,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
AllowPrompts: allowPrompts, AllowPrompts: allowPrompts,
config: config, config: config,
acmeClient: client, acmeClient: client,
locker: &syncLock{ storage: storage,
nameLocks: make(map[string]*sync.WaitGroup),
nameLocksMu: sync.Mutex{},
},
} }
if config.DNSProvider == "" { if config.DNSProvider == "" {
...@@ -161,7 +158,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -161,7 +158,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
// See if TLS challenge needs to be handled by our own facilities // See if TLS challenge needs to be handled by our own facilities
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) { if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSniSolver{}) c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
} }
// Disable any challenges that should not be used // Disable any challenges that should not be used
...@@ -210,13 +207,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -210,13 +207,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
// Callers who have access to a Config value should use the ObtainCert // Callers who have access to a Config value should use the ObtainCert
// method on that instead of this lower-level method. // method on that instead of this lower-level method.
func (c *ACMEClient) Obtain(name string) error { func (c *ACMEClient) Obtain(name string) error {
// Get access to ACME storage waiter, err := c.storage.TryLock(name)
storage, err := c.config.StorageFor(c.config.CAUrl)
if err != nil {
return err
}
waiter, err := c.locker.TryLock(name)
if err != nil { if err != nil {
return err return err
} }
...@@ -226,7 +217,7 @@ func (c *ACMEClient) Obtain(name string) error { ...@@ -226,7 +217,7 @@ func (c *ACMEClient) Obtain(name string) error {
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
} }
defer func() { defer func() {
if err := c.locker.Unlock(name); err != nil { if err := c.storage.Unlock(name); err != nil {
log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err) log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err)
} }
}() }()
...@@ -269,7 +260,7 @@ Attempts: ...@@ -269,7 +260,7 @@ Attempts:
} }
// Success - immediately save the certificate resource // Success - immediately save the certificate resource
err = saveCertResource(storage, certificate) err = saveCertResource(c.storage, certificate)
if err != nil { if err != nil {
return fmt.Errorf("error saving assets for %v: %v", name, err) return fmt.Errorf("error saving assets for %v: %v", name, err)
} }
...@@ -282,35 +273,30 @@ Attempts: ...@@ -282,35 +273,30 @@ Attempts:
return nil return nil
} }
// Renew renews the managed certificate for name. This function is // Renew renews the managed certificate for name. It puts the renewed
// safe for concurrent use. // certificate into storage (not the cache). This function is safe for
// concurrent use.
// //
// Callers who have access to a Config value should use the RenewCert // Callers who have access to a Config value should use the RenewCert
// method on that instead of this lower-level method. // method on that instead of this lower-level method.
func (c *ACMEClient) Renew(name string) error { func (c *ACMEClient) Renew(name string) error {
// Get access to ACME storage waiter, err := c.storage.TryLock(name)
storage, err := c.config.StorageFor(c.config.CAUrl)
if err != nil {
return err
}
waiter, err := c.locker.TryLock(name)
if err != nil { if err != nil {
return err return err
} }
if waiter != nil { if waiter != nil {
log.Printf("[INFO] Certificate for %s is already being renewed elsewhere and stored; waiting", name) log.Printf("[INFO] Certificate for %s is already being renewed elsewhere and stored; waiting", name)
waiter.Wait() waiter.Wait()
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again return nil // assume that the worker that renewed the cert succeeded; avoid hammering this path over and over
} }
defer func() { defer func() {
if err := c.locker.Unlock(name); err != nil { if err := c.storage.Unlock(name); err != nil {
log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err) log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err)
} }
}() }()
// Prepare for renewal (load PEM cert, key, and meta) // Prepare for renewal (load PEM cert, key, and meta)
siteData, err := storage.LoadSite(name) siteData, err := c.storage.LoadSite(name)
if err != nil { if err != nil {
return err return err
} }
...@@ -357,18 +343,13 @@ func (c *ACMEClient) Renew(name string) error { ...@@ -357,18 +343,13 @@ func (c *ACMEClient) Renew(name string) error {
go diagnostics.Increment("acme_certificates_obtained") go diagnostics.Increment("acme_certificates_obtained")
go diagnostics.Increment("acme_certificates_renewed") go diagnostics.Increment("acme_certificates_renewed")
return saveCertResource(storage, newCertMeta) return saveCertResource(c.storage, newCertMeta)
} }
// Revoke revokes the certificate for name and deltes // Revoke revokes the certificate for name and deletes
// it from storage. // it from storage.
func (c *ACMEClient) Revoke(name string) error { func (c *ACMEClient) Revoke(name string) error {
storage, err := c.config.StorageFor(c.config.CAUrl) siteExists, err := c.storage.SiteExists(name)
if err != nil {
return err
}
siteExists, err := storage.SiteExists(name)
if err != nil { if err != nil {
return err return err
} }
...@@ -377,7 +358,7 @@ func (c *ACMEClient) Revoke(name string) error { ...@@ -377,7 +358,7 @@ func (c *ACMEClient) Revoke(name string) error {
return errors.New("no certificate and key for " + name) return errors.New("no certificate and key for " + name)
} }
siteData, err := storage.LoadSite(name) siteData, err := c.storage.LoadSite(name)
if err != nil { if err != nil {
return err return err
} }
...@@ -387,7 +368,7 @@ func (c *ACMEClient) Revoke(name string) error { ...@@ -387,7 +368,7 @@ func (c *ACMEClient) Revoke(name string) error {
return err return err
} }
err = storage.DeleteSite(name) err = c.storage.DeleteSite(name)
if err != nil { if err != nil {
return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error()) return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error())
} }
......
...@@ -93,16 +93,17 @@ type Config struct { ...@@ -93,16 +93,17 @@ type Config struct {
// an ACME challenge // an ACME challenge
ListenHost string ListenHost string
// The alternate port (ONLY port, not host) // The alternate port (ONLY port, not host) to
// to use for the ACME HTTP challenge; this // use for the ACME HTTP challenge; if non-empty,
// port will be used if we proxy challenges // this port will be used instead of
// coming in on port 80 to this alternate port // HTTPChallengePort to spin up a listener for
// the HTTP challenge
AltHTTPPort string AltHTTPPort string
// The alternate port (ONLY port, not host) // The alternate port (ONLY port, not host)
// to use for the ACME TLS-SNI challenge. // to use for the ACME TLS-SNI challenge.
// The system must forward the standard port // The system must forward TLSSNIChallengePort
// for the TLS-SNI challenge to this port. // to this port for challenge to succeed
AltTLSSNIPort string AltTLSSNIPort string
// The string identifier of the DNS provider // The string identifier of the DNS provider
...@@ -134,7 +135,12 @@ type Config struct { ...@@ -134,7 +135,12 @@ type Config struct {
// Protocol Negotiation (ALPN). // Protocol Negotiation (ALPN).
ALPN []string ALPN []string
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig() // The map of hostname to certificate hash. This is used to complete
// handshakes and serve the right certificate given the SNI.
Certificates map[string]string
certCache *certificateCache // pointer to the Instance's certificate store
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
} }
// OnDemandState contains some state relevant for providing // OnDemandState contains some state relevant for providing
...@@ -155,6 +161,25 @@ type OnDemandState struct { ...@@ -155,6 +161,25 @@ type OnDemandState struct {
AskURL *url.URL AskURL *url.URL
} }
// NewConfig returns a new Config with a pointer to the instance's
// certificate cache. You will usually need to set Other fields on
// the returned Config for successful practical use.
func NewConfig(inst *caddy.Instance) *Config {
inst.StorageMu.RLock()
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
inst.StorageMu.RUnlock()
if !ok || certCache == nil {
certCache = &certificateCache{cache: make(map[string]Certificate)}
inst.StorageMu.Lock()
inst.Storage[CertCacheInstStorageKey] = certCache
inst.StorageMu.Unlock()
}
cfg := new(Config)
cfg.Certificates = make(map[string]string)
cfg.certCache = certCache
return cfg
}
// ObtainCert obtains a certificate for name using c, as long // ObtainCert obtains a certificate for name using c, as long
// as a certificate does not already exist in storage for that // as a certificate does not already exist in storage for that
// name. The name must qualify and c must be flagged as Managed. // name. The name must qualify and c must be flagged as Managed.
...@@ -330,7 +355,9 @@ func (c *Config) buildStandardTLSConfig() error { ...@@ -330,7 +355,9 @@ func (c *Config) buildStandardTLSConfig() error {
// MakeTLSConfig makes a tls.Config from configs. The returned // MakeTLSConfig makes a tls.Config from configs. The returned
// tls.Config is programmed to load the matching caddytls.Config // tls.Config is programmed to load the matching caddytls.Config
// based on the hostname in SNI, but that's all. // based on the hostname in SNI, but that's all. This is used
// to create a single TLS configuration for a listener (a group
// of sites).
func MakeTLSConfig(configs []*Config) (*tls.Config, error) { func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
if len(configs) == 0 { if len(configs) == 0 {
return nil, nil return nil, nil
...@@ -358,15 +385,28 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) { ...@@ -358,15 +385,28 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto) configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
} }
// convert each caddytls.Config into a tls.Config // convert this caddytls.Config into a tls.Config
if err := cfg.buildStandardTLSConfig(); err != nil { if err := cfg.buildStandardTLSConfig(); err != nil {
return nil, err return nil, err
} }
// Key this config by its hostname (overwriting // if an existing config with this hostname was already
// configs with the same hostname pattern); during // configured, then they must be identical (or at least
// TLS handshakes, configs are loaded based on // compatible), otherwise that is a configuration error
// the hostname pattern, according to client's SNI. if otherConfig, ok := configMap[cfg.Hostname]; ok {
if err := assertConfigsCompatible(cfg, otherConfig); err != nil {
return nil, fmt.Errorf("incompabile TLS configurations for the same SNI "+
"name (%s) on the same listener: %v",
cfg.Hostname, err)
}
}
// key this config by its hostname (overwrites
// configs with the same hostname pattern; should
// be OK since we already asserted they are roughly
// the same); during TLS handshakes, configs are
// loaded based on the hostname pattern, according
// to client's SNI
configMap[cfg.Hostname] = cfg configMap[cfg.Hostname] = cfg
} }
...@@ -383,6 +423,63 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) { ...@@ -383,6 +423,63 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
}, nil }, nil
} }
// assertConfigsCompatible returns an error if the two Configs
// do not have the same (or roughly compatible) configurations.
// If one of the tlsConfig pointers on either Config is nil,
// an error will be returned. If both are nil, no error.
func assertConfigsCompatible(cfg1, cfg2 *Config) error {
c1, c2 := cfg1.tlsConfig, cfg2.tlsConfig
if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) {
return fmt.Errorf("one config is not made")
}
if c1 == nil && c2 == nil {
return nil
}
if len(c1.CipherSuites) != len(c2.CipherSuites) {
return fmt.Errorf("different number of allowed cipher suites")
}
for i, ciph := range c1.CipherSuites {
if c2.CipherSuites[i] != ciph {
return fmt.Errorf("different cipher suites or different order")
}
}
if len(c1.CurvePreferences) != len(c2.CurvePreferences) {
return fmt.Errorf("different number of allowed cipher suites")
}
for i, curve := range c1.CurvePreferences {
if c2.CurvePreferences[i] != curve {
return fmt.Errorf("different curve preferences or different order")
}
}
if len(c1.NextProtos) != len(c2.NextProtos) {
return fmt.Errorf("different number of ALPN (NextProtos) values")
}
for i, proto := range c1.NextProtos {
if c2.NextProtos[i] != proto {
return fmt.Errorf("different ALPN (NextProtos) values or different order")
}
}
if c1.PreferServerCipherSuites != c2.PreferServerCipherSuites {
return fmt.Errorf("one prefers server cipher suites, the other does not")
}
if c1.MinVersion != c2.MinVersion {
return fmt.Errorf("minimum TLS version mismatch")
}
if c1.MaxVersion != c2.MaxVersion {
return fmt.Errorf("maximum TLS version mismatch")
}
if c1.ClientAuth != c2.ClientAuth {
return fmt.Errorf("client authentication policy mismatch")
}
return nil
}
// ConfigGetter gets a Config keyed by key. // ConfigGetter gets a Config keyed by key.
type ConfigGetter func(c *caddy.Controller) *Config type ConfigGetter func(c *caddy.Controller) *Config
...@@ -522,7 +619,7 @@ var supportedCurvesMap = map[string]tls.CurveID{ ...@@ -522,7 +619,7 @@ var supportedCurvesMap = map[string]tls.CurveID{
"P521": tls.CurveP521, "P521": tls.CurveP521,
} }
// List of all the curves we want to use by default // List of all the curves we want to use by default.
// //
// This list should only include curves which are fast by design (e.g. X25519) // This list should only include curves which are fast by design (e.g. X25519)
// and those for which an optimized assembly implementation exists (e.g. P256). // and those for which an optimized assembly implementation exists (e.g. P256).
...@@ -548,4 +645,8 @@ const ( ...@@ -548,4 +645,8 @@ const (
// be capable of proxying or forwarding the request to this // be capable of proxying or forwarding the request to this
// alternate port. // alternate port.
DefaultHTTPAlternatePort = "5033" DefaultHTTPAlternatePort = "5033"
// CertCacheInstStorageKey is the name of the key for
// accessing the certificate storage on the *caddy.Instance.
CertCacheInstStorageKey = "tls_cert_cache"
) )
...@@ -237,15 +237,17 @@ func makeSelfSignedCert(config *Config) error { ...@@ -237,15 +237,17 @@ func makeSelfSignedCert(config *Config) error {
return fmt.Errorf("could not create certificate: %v", err) return fmt.Errorf("could not create certificate: %v", err)
} }
cacheCertificate(Certificate{ chain := [][]byte{derBytes}
config.cacheCertificate(Certificate{
Certificate: tls.Certificate{ Certificate: tls.Certificate{
Certificate: [][]byte{derBytes}, Certificate: chain,
PrivateKey: privKey, PrivateKey: privKey,
Leaf: cert, Leaf: cert,
}, },
Names: cert.DNSNames, Names: cert.DNSNames,
NotAfter: cert.NotAfter, NotAfter: cert.NotAfter,
Config: config, Hash: hashCertificateChain(chain),
}) })
return nil return nil
......
...@@ -38,9 +38,9 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme") ...@@ -38,9 +38,9 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
// Storage instance backed by the local disk. The resulting Storage // Storage instance backed by the local disk. The resulting Storage
// instance is guaranteed to be non-nil if there is no error. // instance is guaranteed to be non-nil if there is no error.
func NewFileStorage(caURL *url.URL) (Storage, error) { func NewFileStorage(caURL *url.URL) (Storage, error) {
return &FileStorage{ storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
Path: filepath.Join(storageBasePath, caURL.Host), storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
}, nil return storage, nil
} }
// FileStorage facilitates forming file paths derived from a root // FileStorage facilitates forming file paths derived from a root
...@@ -48,6 +48,7 @@ func NewFileStorage(caURL *url.URL) (Storage, error) { ...@@ -48,6 +48,7 @@ func NewFileStorage(caURL *url.URL) (Storage, error) {
// cross-platform way or persisting ACME assets on the file system. // cross-platform way or persisting ACME assets on the file system.
type FileStorage struct { type FileStorage struct {
Path string Path string
Locker
} }
// sites gets the directory that stores site certificate and keys. // sites gets the directory that stores site certificate and keys.
......
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"fmt"
"os"
"sync"
"time"
"github.com/mholt/caddy"
)
func init() {
// be sure to remove lock files when exiting the process!
caddy.OnProcessExit = append(caddy.OnProcessExit, func() {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
for key, fw := range fileStorageNameLocks {
os.Remove(fw.filename)
delete(fileStorageNameLocks, key)
}
})
}
// fileStorageLock facilitates ACME-related locking by using
// the associated FileStorage, so multiple processes can coordinate
// renewals on the certificates on a shared file system.
type fileStorageLock struct {
caURL string
storage *FileStorage
}
// TryLock attempts to get a lock for name, otherwise it returns
// a Waiter value to wait until the other process is finished.
func (s *fileStorageLock) TryLock(name string) (Waiter, error) {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
// see if lock already exists within this process
fw, ok := fileStorageNameLocks[s.caURL+name]
if ok {
// lock already created within process, let caller wait on it
return fw, nil
}
// attempt to persist lock to disk by creating lock file
fw = &fileWaiter{
filename: s.storage.siteCertFile(name) + ".lock",
wg: new(sync.WaitGroup),
}
// parent dir must exist
if err := os.MkdirAll(s.storage.site(name), 0700); err != nil {
return nil, err
}
lf, err := os.OpenFile(fw.filename, os.O_CREATE|os.O_EXCL, 0644)
if err != nil {
if os.IsExist(err) {
// another process has the lock; use it to wait
return fw, nil
}
// otherwise, this was some unexpected error
return nil, err
}
lf.Close()
// looks like we get the lock
fw.wg.Add(1)
fileStorageNameLocks[s.caURL+name] = fw
return nil, nil
}
// Unlock unlocks name.
func (s *fileStorageLock) Unlock(name string) error {
fileStorageNameLocksMu.Lock()
defer fileStorageNameLocksMu.Unlock()
fw, ok := fileStorageNameLocks[s.caURL+name]
if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name)
}
os.Remove(fw.filename)
fw.wg.Done()
delete(fileStorageNameLocks, s.caURL+name)
return nil
}
// fileWaiter waits for a file to disappear; it polls
// the file system to check for the existence of a file.
// It also has a WaitGroup which will be faster than
// polling, for when locking need only happen within this
// process.
type fileWaiter struct {
filename string
wg *sync.WaitGroup
}
// Wait waits until the lock is released.
func (fw *fileWaiter) Wait() {
start := time.Now()
fw.wg.Wait()
for time.Since(start) < 1*time.Hour {
_, err := os.Stat(fw.filename)
if os.IsNotExist(err) {
return
}
time.Sleep(1 * time.Second)
}
}
var fileStorageNameLocks = make(map[string]*fileWaiter) // keyed by CA + name
var fileStorageNameLocksMu sync.Mutex
var _ Locker = &fileStorageLock{}
var _ Waiter = &fileWaiter{}
...@@ -61,15 +61,15 @@ func (cg configGroup) getConfig(name string) *Config { ...@@ -61,15 +61,15 @@ func (cg configGroup) getConfig(name string) *Config {
} }
} }
// as a fallback, try a config that serves all names // try a config that serves all names (this
// is basically the same as a config defined
// for "*" -- I think -- but the above loop
// doesn't try an empty string)
if config, ok := cg[""]; ok { if config, ok := cg[""]; ok {
return config return config
} }
// as a last resort, use a random config // no matches, so just serve up a random config
// (even if the config isn't for that hostname,
// it should help us serve clients without SNI
// or at least defer TLS alerts to the cert)
for _, config := range cg { for _, config := range cg {
return config return config
} }
...@@ -121,6 +121,86 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif ...@@ -121,6 +121,86 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
return &cert.Certificate, err return &cert.Certificate, err
} }
// getCertificate gets a certificate that matches name (a server name)
// from the in-memory cache, according to the lookup table associated with
// cfg. The lookup then points to a certificate in the Instance certificate
// cache.
//
// If there is no exact match for name, it will be checked against names of
// the form '*.example.com' (wildcard certificates) according to RFC 6125.
// If a match is found, matched will be true. If no matches are found, matched
// will be false and a "default" certificate will be returned with defaulted
// set to true. If defaulted is false, then no certificates were available.
//
// The logic in this function is adapted from the Go standard library,
// which is by the Go Authors.
//
// This function is safe for concurrent use.
func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defaulted bool) {
var certKey string
var ok bool
// Not going to trim trailing dots here since RFC 3546 says,
// "The hostname is represented ... without a trailing dot."
// Just normalize to lowercase.
name = strings.ToLower(name)
cfg.certCache.RLock()
defer cfg.certCache.RUnlock()
// exact match? great, let's use it
if certKey, ok = cfg.Certificates[name]; ok {
cert = cfg.certCache.cache[certKey]
matched = true
return
}
// try replacing labels in the name with wildcards until we get a match
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if certKey, ok = cfg.Certificates[candidate]; ok {
cert = cfg.certCache.cache[certKey]
matched = true
return
}
}
// check the certCache directly to see if the SNI name is
// already the key of the certificate it wants! this is vital
// for supporting the TLS-SNI challenge, since the tlsSNISolver
// just puts the temporary certificate in the instance cache,
// with no regard for configs; this also means that the SNI
// can contain the hash of a specific cert (chain) it wants
// and we will still be able to serve it up
// (this behavior, by the way, could be controversial as to
// whether it complies with RFC 6066 about SNI, but I think
// it does soooo...)
// NOTE/TODO: TLS-SNI challenge is changing, as of Jan. 2018
// but what will be different, if it ever returns, is unclear
if directCert, ok := cfg.certCache.cache[name]; ok {
cert = directCert
matched = true
return
}
// if nothing matches and SNI was not provided, use a random
// certificate; at least there's a chance this older client
// can connect, and in the future we won't need this provision
// (if SNI is present, it's probably best to just raise a TLS
// alert by not serving a certificate)
if name == "" {
for _, certKey := range cfg.Certificates {
defaulted = true
cert = cfg.certCache.cache[certKey]
return
}
}
return
}
// getCertDuringHandshake will get a certificate for name. It first tries // getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache. If no certificate for name is in the cache, the // the in-memory cache. If no certificate for name is in the cache, the
// config most closely corresponding to name will be loaded. If that config // config most closely corresponding to name will be loaded. If that config
...@@ -134,7 +214,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif ...@@ -134,7 +214,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
// This function is safe for concurrent use. // This function is safe for concurrent use.
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it // First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := getCertificate(name) cert, matched, defaulted := cfg.getCertificate(name)
if matched { if matched {
return cert, nil return cert, nil
} }
...@@ -277,7 +357,7 @@ func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) { ...@@ -277,7 +357,7 @@ func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
obtainCertWaitChans[name] = wait obtainCertWaitChans[name] = wait
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
// do the obtain // obtain the certificate
log.Printf("[INFO] Obtaining new certificate for %s", name) log.Printf("[INFO] Obtaining new certificate for %s", name)
err := cfg.ObtainCert(name, false) err := cfg.ObtainCert(name, false)
...@@ -336,9 +416,9 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific ...@@ -336,9 +416,9 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific
// quite common considering not all certs have issuer URLs that support it. // quite common considering not all certs have issuer URLs that support it.
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err) log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
} }
certCacheMu.Lock() cfg.certCache.Lock()
certCache[name] = cert cfg.certCache.cache[cert.Hash] = cert
certCacheMu.Unlock() cfg.certCache.Unlock()
} }
} }
...@@ -367,29 +447,22 @@ func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate) ...@@ -367,29 +447,22 @@ func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate)
obtainCertWaitChans[name] = wait obtainCertWaitChans[name] = wait
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
// do the renew and reload the certificate // renew and reload the certificate
log.Printf("[INFO] Renewing certificate for %s", name) log.Printf("[INFO] Renewing certificate for %s", name)
err := cfg.RenewCert(name, false) err := cfg.RenewCert(name, false)
if err == nil { if err == nil {
// immediately flush this certificate from the cache so
// the name doesn't overlap when we try to replace it,
// which would fail, because overlapping existing cert
// names isn't allowed
certCacheMu.Lock()
for _, certName := range currentCert.Names {
delete(certCache, certName)
}
certCacheMu.Unlock()
// even though the recursive nature of the dynamic cert loading // even though the recursive nature of the dynamic cert loading
// would just call this function anyway, we do it here to // would just call this function anyway, we do it here to
// make the replacement as atomic as possible. (TODO: similar // make the replacement as atomic as possible.
// to the note in maintain.go, it'd be nice if the clearing of newCert, err := currentCert.configs[0].CacheManagedCertificate(name)
// the cache entries above and this load function were truly
// atomic...)
_, err := currentCert.Config.CacheManagedCertificate(name)
if err != nil { if err != nil {
log.Printf("[ERROR] loading renewed certificate: %v", err) log.Printf("[ERROR] loading renewed certificate for %s: %v", name, err)
} else {
// replace the old certificate with the new one
err = cfg.certCache.replaceCertificate(currentCert, newCert)
if err != nil {
log.Printf("[ERROR] Replacing certificate for %s: %v", name, err)
}
} }
} }
......
...@@ -21,9 +21,8 @@ import ( ...@@ -21,9 +21,8 @@ import (
) )
func TestGetCertificate(t *testing.T) { func TestGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }() certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
cfg := new(Config)
hello := &tls.ClientHelloInfo{ServerName: "example.com"} hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
...@@ -38,33 +37,40 @@ func TestGetCertificate(t *testing.T) { ...@@ -38,33 +37,40 @@ func TestGetCertificate(t *testing.T) {
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert) t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
} }
// When cache has one certificate in it (also is default) // When cache has one certificate in it
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}} firstCert := Certificate{Names: []string{"example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
certCache[""] = defaultCert cfg.cacheCertificate(firstCert)
certCache["example.com"] = defaultCert
if cert, err := cfg.GetCertificate(hello); err != nil { if cert, err := cfg.GetCertificate(hello); err != nil {
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err) t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" { } else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert) t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
} }
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil { if _, err := cfg.GetCertificate(helloNoSNI); err != nil {
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err) t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
} }
// When retrieving wildcard certificate // When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}} wildcardCert := Certificate{
Names: []string{"*.example.com"},
Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}},
Hash: "(don't overwrite the first one)",
}
cfg.cacheCertificate(wildcardCert)
if cert, err := cfg.GetCertificate(helloSub); err != nil { if cert, err := cfg.GetCertificate(helloSub); err != nil {
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err) t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
} else if cert.Leaf.DNSNames[0] != "*.example.com" { } else if cert.Leaf.DNSNames[0] != "*.example.com" {
t.Errorf("Got wrong certificate, expected wildcard: %v", cert) t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
} }
// When no certificate matches, the default is returned // When cache is NOT empty but there's no SNI
if cert, err := cfg.GetCertificate(helloNoMatch); err != nil { if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err) t.Errorf("Expected random certificate with no error when no SNI, got err: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" { } else if cert == nil || len(cert.Leaf.DNSNames) == 0 {
t.Errorf("Expected default cert with no matches, got: %v", cert) t.Errorf("Expected random cert with no matches, got: %v", cert)
}
// When no certificate matches, raise an alert
if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
} }
} }
...@@ -27,10 +27,11 @@ import ( ...@@ -27,10 +27,11 @@ import (
const challengeBasePath = "/.well-known/acme-challenge" const challengeBasePath = "/.well-known/acme-challenge"
// HTTPChallengeHandler proxies challenge requests to ACME client if the // HTTPChallengeHandler proxies challenge requests to ACME client if the
// request path starts with challengeBasePath. It returns true if it // request path starts with challengeBasePath, if the HTTP challenge is not
// handled the request and no more needs to be done; it returns false // disabled, and if we are known to be obtaining a certificate for the name.
// if this call was a no-op and the request still needs handling. // It returns true if it handled the request and no more needs to be done;
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, altPort string) bool { // it returns false if this call was a no-op and the request still needs handling.
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost string) bool {
if !strings.HasPrefix(r.URL.Path, challengeBasePath) { if !strings.HasPrefix(r.URL.Path, challengeBasePath) {
return false return false
} }
...@@ -50,7 +51,11 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, al ...@@ -50,7 +51,11 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, al
listenHost = "localhost" listenHost = "localhost"
} }
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, altPort)) // always proxy to the DefaultHTTPAlternatePort because obviously the
// ACME challenge request already got into one of our HTTP handlers, so
// it means we must have started a HTTP listener on the alternate
// port instead; which is only accessible via listenHost
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, DefaultHTTPAlternatePort))
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
log.Printf("[ERROR] ACME proxy handler: %v", err) log.Printf("[ERROR] ACME proxy handler: %v", err)
......
...@@ -39,7 +39,7 @@ func TestHTTPChallengeHandlerNoOp(t *testing.T) { ...@@ -39,7 +39,7 @@ func TestHTTPChallengeHandlerNoOp(t *testing.T) {
t.Fatalf("Could not craft request, got error: %v", err) t.Fatalf("Could not craft request, got error: %v", err)
} }
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
if HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort) { if HTTPChallengeHandler(rw, req, "") {
t.Errorf("Got true with this URL, but shouldn't have: %s", url) t.Errorf("Got true with this URL, but shouldn't have: %s", url)
} }
} }
...@@ -76,7 +76,7 @@ func TestHTTPChallengeHandlerSuccess(t *testing.T) { ...@@ -76,7 +76,7 @@ func TestHTTPChallengeHandlerSuccess(t *testing.T) {
} }
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort) HTTPChallengeHandler(rw, req, "")
if !proxySuccess { if !proxySuccess {
t.Fatal("Expected request to be proxied, but it wasn't") t.Fatal("Expected request to be proxied, but it wasn't")
......
This diff is collapsed.
...@@ -38,6 +38,7 @@ func init() { ...@@ -38,6 +38,7 @@ func init() {
// are specified by the user in the config file. All the automatic HTTPS // are specified by the user in the config file. All the automatic HTTPS
// stuff comes later outside of this function. // stuff comes later outside of this function.
func setupTLS(c *caddy.Controller) error { func setupTLS(c *caddy.Controller) error {
// obtain the configGetter, which loads the config we're, uh, configuring
configGetter, ok := configGetters[c.ServerType()] configGetter, ok := configGetters[c.ServerType()]
if !ok { if !ok {
return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType()) return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType())
...@@ -47,6 +48,14 @@ func setupTLS(c *caddy.Controller) error { ...@@ -47,6 +48,14 @@ func setupTLS(c *caddy.Controller) error {
return fmt.Errorf("no caddytls.Config to set up for %s", c.Key) return fmt.Errorf("no caddytls.Config to set up for %s", c.Key)
} }
// the certificate cache is tied to the current caddy.Instance; get a pointer to it
certCache, ok := c.Get(CertCacheInstStorageKey).(*certificateCache)
if !ok || certCache == nil {
certCache = &certificateCache{cache: make(map[string]Certificate)}
c.Set(CertCacheInstStorageKey, certCache)
}
config.certCache = certCache
config.Enabled = true config.Enabled = true
for c.Next() { for c.Next() {
...@@ -237,7 +246,7 @@ func setupTLS(c *caddy.Controller) error { ...@@ -237,7 +246,7 @@ func setupTLS(c *caddy.Controller) error {
// load a single certificate and key, if specified // load a single certificate and key, if specified
if certificateFile != "" && keyFile != "" { if certificateFile != "" && keyFile != "" {
err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile) err := config.cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
if err != nil { if err != nil {
return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err) return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err)
} }
...@@ -246,7 +255,7 @@ func setupTLS(c *caddy.Controller) error { ...@@ -246,7 +255,7 @@ func setupTLS(c *caddy.Controller) error {
// load a directory of certificates, if specified // load a directory of certificates, if specified
if loadDir != "" { if loadDir != "" {
err := loadCertsInDir(c, loadDir) err := loadCertsInDir(config, c, loadDir)
if err != nil { if err != nil {
return err return err
} }
...@@ -273,7 +282,7 @@ func setupTLS(c *caddy.Controller) error { ...@@ -273,7 +282,7 @@ func setupTLS(c *caddy.Controller) error {
// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt // https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
// //
// This function may write to the log as it walks the directory tree. // This function may write to the log as it walks the directory tree.
func loadCertsInDir(c *caddy.Controller, dir string) error { func loadCertsInDir(cfg *Config, c *caddy.Controller, dir string) error {
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil { if err != nil {
log.Printf("[WARNING] Unable to traverse into %s; skipping", path) log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
...@@ -336,7 +345,7 @@ func loadCertsInDir(c *caddy.Controller, dir string) error { ...@@ -336,7 +345,7 @@ func loadCertsInDir(c *caddy.Controller, dir string) error {
return c.Errf("%s: no private key block found", path) return c.Errf("%s: no private key block found", path)
} }
err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes) err = cfg.cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
if err != nil { if err != nil {
return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err) return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err)
} }
......
...@@ -46,9 +46,12 @@ func TestMain(m *testing.M) { ...@@ -46,9 +46,12 @@ func TestMain(m *testing.M) {
} }
func TestSetupParseBasic(t *testing.T) { func TestSetupParseBasic(t *testing.T) {
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``) c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
...@@ -124,9 +127,12 @@ func TestSetupParseWithOptionalParams(t *testing.T) { ...@@ -124,9 +127,12 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
must_staple must_staple
alpn http/1.1 alpn http/1.1
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
...@@ -158,9 +164,11 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) { ...@@ -158,9 +164,11 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) {
params := `tls { params := `tls {
ciphers RSA-3DES-EDE-CBC-SHA ciphers RSA-3DES-EDE-CBC-SHA
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
...@@ -176,9 +184,12 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -176,9 +184,12 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` { params := `tls ` + certFile + ` ` + keyFile + ` {
protocols ssl tls protocols ssl tls
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err == nil { if err == nil {
t.Errorf("Expected errors, but no error returned") t.Errorf("Expected errors, but no error returned")
...@@ -191,6 +202,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -191,6 +202,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
cfg = new(Config) cfg = new(Config)
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c = caddy.NewTestController("", params) c = caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err = setupTLS(c) err = setupTLS(c)
if err == nil { if err == nil {
t.Error("Expected errors, but no error returned") t.Error("Expected errors, but no error returned")
...@@ -215,6 +227,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -215,6 +227,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
cfg = new(Config) cfg = new(Config)
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c = caddy.NewTestController("", params) c = caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err = setupTLS(c) err = setupTLS(c)
if err == nil { if err == nil {
t.Error("Expected errors, but no error returned") t.Error("Expected errors, but no error returned")
...@@ -226,7 +239,8 @@ func TestSetupParseWithClientAuth(t *testing.T) { ...@@ -226,7 +239,8 @@ func TestSetupParseWithClientAuth(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` { params := `tls ` + certFile + ` ` + keyFile + ` {
clients clients
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
err := setupTLS(c) err := setupTLS(c)
...@@ -259,9 +273,11 @@ func TestSetupParseWithClientAuth(t *testing.T) { ...@@ -259,9 +273,11 @@ func TestSetupParseWithClientAuth(t *testing.T) {
clients verify_if_given clients verify_if_given
}`, tls.VerifyClientCertIfGiven, true, noCAs}, }`, tls.VerifyClientCertIfGiven, true, noCAs},
} { } {
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", caseData.params) c := caddy.NewTestController("", caseData.params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if caseData.expectedErr { if caseData.expectedErr {
if err == nil { if err == nil {
...@@ -311,9 +327,11 @@ func TestSetupParseWithCAUrl(t *testing.T) { ...@@ -311,9 +327,11 @@ func TestSetupParseWithCAUrl(t *testing.T) {
ca 1 2 ca 1 2
}`, true, ""}, }`, true, ""},
} { } {
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", caseData.params) c := caddy.NewTestController("", caseData.params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if caseData.expectedErr { if caseData.expectedErr {
if err == nil { if err == nil {
...@@ -335,9 +353,11 @@ func TestSetupParseWithKeyType(t *testing.T) { ...@@ -335,9 +353,11 @@ func TestSetupParseWithKeyType(t *testing.T) {
params := `tls { params := `tls {
key_type p384 key_type p384
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
...@@ -353,9 +373,11 @@ func TestSetupParseWithCurves(t *testing.T) { ...@@ -353,9 +373,11 @@ func TestSetupParseWithCurves(t *testing.T) {
params := `tls { params := `tls {
curves x25519 p256 p384 p521 curves x25519 p256 p384 p521
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
...@@ -380,9 +402,11 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) { ...@@ -380,9 +402,11 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) {
params := `tls { params := `tls {
protocols tls1.2 protocols tls1.2
}` }`
cfg := new(Config) certCache := &certificateCache{cache: make(map[string]Certificate)}
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
c.Set(CertCacheInstStorageKey, certCache)
err := setupTLS(c) err := setupTLS(c)
if err != nil { if err != nil {
......
...@@ -107,6 +107,10 @@ type Storage interface { ...@@ -107,6 +107,10 @@ type Storage interface {
// in StoreUser. The result is an empty string if there are no // in StoreUser. The result is an empty string if there are no
// persisted users in storage. // persisted users in storage.
MostRecentUserEmail() string MostRecentUserEmail() string
// Locker is necessary because synchronizing certificate maintenance
// depends on how storage is implemented.
Locker
} }
// ErrNotExist is returned by Storage implementations when // ErrNotExist is returned by Storage implementations when
......
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"fmt"
"sync"
)
var _ Locker = &syncLock{}
type syncLock struct {
nameLocks map[string]*sync.WaitGroup
nameLocksMu sync.Mutex
}
// TryLock attempts to get a lock for name, otherwise it returns
// a Waiter value to wait until the other process is finished.
func (s *syncLock) TryLock(name string) (Waiter, error) {
s.nameLocksMu.Lock()
defer s.nameLocksMu.Unlock()
wg, ok := s.nameLocks[name]
if ok {
// lock already obtained, let caller wait on it
return wg, nil
}
// caller gets lock
wg = new(sync.WaitGroup)
wg.Add(1)
s.nameLocks[name] = wg
return nil, nil
}
// Unlock unlocks name.
func (s *syncLock) Unlock(name string) error {
s.nameLocksMu.Lock()
defer s.nameLocksMu.Unlock()
wg, ok := s.nameLocks[name]
if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name)
}
wg.Done()
delete(s.nameLocks, name)
return nil
}
...@@ -88,30 +88,38 @@ func Revoke(host string) error { ...@@ -88,30 +88,38 @@ func Revoke(host string) error {
return client.Revoke(host) return client.Revoke(host)
} }
// tlsSniSolver is a type that can solve tls-sni challenges using // tlsSNISolver is a type that can solve TLS-SNI challenges using
// an existing listener and our custom, in-memory certificate cache. // an existing listener and our custom, in-memory certificate cache.
type tlsSniSolver struct{} type tlsSNISolver struct {
certCache *certificateCache
}
// Present adds the challenge certificate to the cache. // Present adds the challenge certificate to the cache.
func (s tlsSniSolver) Present(domain, token, keyAuth string) error { func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth) cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
if err != nil { if err != nil {
return err return err
} }
cacheCertificate(Certificate{ certHash := hashCertificateChain(cert.Certificate)
s.certCache.Lock()
s.certCache.cache[acmeDomain] = Certificate{
Certificate: cert, Certificate: cert,
Names: []string{acmeDomain}, Names: []string{acmeDomain},
}) Hash: certHash, // perhaps not necesssary
}
s.certCache.Unlock()
return nil return nil
} }
// CleanUp removes the challenge certificate from the cache. // CleanUp removes the challenge certificate from the cache.
func (s tlsSniSolver) CleanUp(domain, token, keyAuth string) error { func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth) _, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
if err != nil { if err != nil {
return err return err
} }
uncacheCertificate(acmeDomain) s.certCache.Lock()
delete(s.certCache.cache, acmeDomain)
s.certCache.Unlock()
return nil return nil
} }
......
...@@ -103,6 +103,20 @@ func (c *Controller) Context() Context { ...@@ -103,6 +103,20 @@ func (c *Controller) Context() Context {
return c.instance.context return c.instance.context
} }
// Get safely gets a value from the Instance's storage.
func (c *Controller) Get(key interface{}) interface{} {
c.instance.StorageMu.RLock()
defer c.instance.StorageMu.RUnlock()
return c.instance.Storage[key]
}
// Set safely sets a value on the Instance's storage.
func (c *Controller) Set(key, val interface{}) {
c.instance.StorageMu.Lock()
c.instance.Storage[key] = val
c.instance.StorageMu.Unlock()
}
// NewTestController creates a new Controller for // NewTestController creates a new Controller for
// the server type and input specified. The filename // the server type and input specified. The filename
// is "Testfile". If the server type is not empty and // is "Testfile". If the server type is not empty and
...@@ -113,12 +127,12 @@ func (c *Controller) Context() Context { ...@@ -113,12 +127,12 @@ func (c *Controller) Context() Context {
// Used only for testing, but exported so plugins can // Used only for testing, but exported so plugins can
// use this for convenience. // use this for convenience.
func NewTestController(serverType, input string) *Controller { func NewTestController(serverType, input string) *Controller {
var ctx Context testInst := &Instance{serverType: serverType, Storage: make(map[interface{}]interface{})}
if stype, err := getServerType(serverType); err == nil { if stype, err := getServerType(serverType); err == nil {
ctx = stype.NewContext() testInst.context = stype.NewContext(testInst)
} }
return &Controller{ return &Controller{
instance: &Instance{serverType: serverType, context: ctx}, instance: testInst,
Dispenser: caddyfile.NewDispenser("Testfile", strings.NewReader(input)), Dispenser: caddyfile.NewDispenser("Testfile", strings.NewReader(input)),
OncePerServerBlock: func(f func() error) error { return f() }, OncePerServerBlock: func(f func() error) error { return f() },
} }
......
...@@ -91,6 +91,7 @@ Install the systemd service unit configuration file, reload the systemd daemon, ...@@ -91,6 +91,7 @@ Install the systemd service unit configuration file, reload the systemd daemon,
and start caddy: and start caddy:
```bash ```bash
wget https://raw.githubusercontent.com/mholt/caddy/master/dist/init/linux-systemd/caddy.service
sudo cp caddy.service /etc/systemd/system/ sudo cp caddy.service /etc/systemd/system/
sudo chown root:root /etc/systemd/system/caddy.service sudo chown root:root /etc/systemd/system/caddy.service
sudo chmod 644 /etc/systemd/system/caddy.service sudo chmod 644 /etc/systemd/system/caddy.service
......
...@@ -30,8 +30,8 @@ LimitNPROC=512 ...@@ -30,8 +30,8 @@ LimitNPROC=512
; Use private /tmp and /var/tmp, which are discarded after caddy stops. ; Use private /tmp and /var/tmp, which are discarded after caddy stops.
PrivateTmp=true PrivateTmp=true
; Use a minimal /dev ; Use a minimal /dev (May bring additional security if switched to 'true', but it may not work on Raspberry Pi's or other devices, so it has been disabled in this dist.)
PrivateDevices=true PrivateDevices=false
; Hide /home, /root, and /run/user. Nobody will steal your SSH-keys. ; Hide /home, /root, and /run/user. Nobody will steal your SSH-keys.
ProtectHome=true ProtectHome=true
; Make /usr, /boot, /etc and possibly some more folders read-only. ; Make /usr, /boot, /etc and possibly some more folders read-only.
......
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
"log" "log"
"net" "net"
"sort" "sort"
"sync"
"github.com/mholt/caddy/caddyfile" "github.com/mholt/caddy/caddyfile"
) )
...@@ -38,7 +39,7 @@ var ( ...@@ -38,7 +39,7 @@ var (
// eventHooks is a map of hook name to Hook. All hooks plugins // eventHooks is a map of hook name to Hook. All hooks plugins
// must have a name. // must have a name.
eventHooks = make(map[string]EventHook) eventHooks = sync.Map{}
// parsingCallbacks maps server type to map of directive // parsingCallbacks maps server type to map of directive
// to list of callback functions. These aren't really // to list of callback functions. These aren't really
...@@ -98,11 +99,15 @@ func ListPlugins() map[string][]string { ...@@ -98,11 +99,15 @@ func ListPlugins() map[string][]string {
p["caddyfile_loaders"] = append(p["caddyfile_loaders"], defaultCaddyfileLoader.name) p["caddyfile_loaders"] = append(p["caddyfile_loaders"], defaultCaddyfileLoader.name)
} }
// event hook plugins // List the event hook plugins
if len(eventHooks) > 0 { hooks := ""
for name := range eventHooks { eventHooks.Range(func(k, _ interface{}) bool {
p["event_hooks"] = append(p["event_hooks"], name) hooks += " hook." + k.(string) + "\n"
} return true
})
if hooks != "" {
str += "\nEvent hook plugins:\n"
str += hooks
} }
// alphabetize the rest of the plugins // alphabetize the rest of the plugins
...@@ -220,7 +225,7 @@ type ServerType struct { ...@@ -220,7 +225,7 @@ type ServerType struct {
// startup phases before this one. It's a way to keep // startup phases before this one. It's a way to keep
// each set of server instances separate and to reduce // each set of server instances separate and to reduce
// the amount of global state you need. // the amount of global state you need.
NewContext func() Context NewContext func(inst *Instance) Context
} }
// Plugin is a type which holds information about a plugin. // Plugin is a type which holds information about a plugin.
...@@ -277,23 +282,23 @@ func RegisterEventHook(name string, hook EventHook) { ...@@ -277,23 +282,23 @@ func RegisterEventHook(name string, hook EventHook) {
if name == "" { if name == "" {
panic("event hook must have a name") panic("event hook must have a name")
} }
if _, dup := eventHooks[name]; dup { _, dup := eventHooks.LoadOrStore(name, hook)
if dup {
panic("hook named " + name + " already registered") panic("hook named " + name + " already registered")
} }
eventHooks[name] = hook
} }
// EmitEvent executes the different hooks passing the EventType as an // EmitEvent executes the different hooks passing the EventType as an
// argument. This is a blocking function. Hook developers should // argument. This is a blocking function. Hook developers should
// use 'go' keyword if they don't want to block Caddy. // use 'go' keyword if they don't want to block Caddy.
func EmitEvent(event EventName, info interface{}) { func EmitEvent(event EventName, info interface{}) {
for name, hook := range eventHooks { eventHooks.Range(func(k, v interface{}) bool {
err := hook(event, info) err := v.(EventHook)(event, info)
if err != nil { if err != nil {
log.Printf("error on '%s' hook: %v", name, err) log.Printf("error on '%s' hook: %v", k.(string), err)
} }
} return true
})
} }
// ParsingCallback is a function that is called after // ParsingCallback is a function that is called after
...@@ -412,6 +417,14 @@ func loadCaddyfileInput(serverType string) (Input, error) { ...@@ -412,6 +417,14 @@ func loadCaddyfileInput(serverType string) (Input, error) {
return caddyfileToUse, nil return caddyfileToUse, nil
} }
// OnProcessExit is a list of functions to run when the process
// exits -- they are ONLY for cleanup and should not block,
// return errors, or do anything fancy. They will be run with
// every signal, even if "shutdown callbacks" are not executed.
// This variable must only be modified in the main goroutine
// from init() functions.
var OnProcessExit []func()
// caddyfileLoader pairs the name of a loader to the loader. // caddyfileLoader pairs the name of a loader to the loader.
type caddyfileLoader struct { type caddyfileLoader struct {
name string name string
......
...@@ -44,16 +44,17 @@ func trapSignalsCrossPlatform() { ...@@ -44,16 +44,17 @@ func trapSignalsCrossPlatform() {
if i > 0 { if i > 0 {
log.Println("[INFO] SIGINT: Force quit") log.Println("[INFO] SIGINT: Force quit")
if PidFile != "" { for _, f := range OnProcessExit {
os.Remove(PidFile) f() // important cleanup actions only
} }
os.Exit(2) os.Exit(2)
} }
log.Println("[INFO] SIGINT: Shutting down") log.Println("[INFO] SIGINT: Shutting down")
if PidFile != "" { // important cleanup actions before shutdown callbacks
os.Remove(PidFile) for _, f := range OnProcessExit {
f()
} }
go func() { go func() {
......
...@@ -33,22 +33,22 @@ func trapSignalsPosix() { ...@@ -33,22 +33,22 @@ func trapSignalsPosix() {
switch sig { switch sig {
case syscall.SIGQUIT: case syscall.SIGQUIT:
log.Println("[INFO] SIGQUIT: Quitting process immediately") log.Println("[INFO] SIGQUIT: Quitting process immediately")
if PidFile != "" { for _, f := range OnProcessExit {
os.Remove(PidFile) f() // only perform important cleanup actions
} }
os.Exit(0) os.Exit(0)
case syscall.SIGTERM: case syscall.SIGTERM:
log.Println("[INFO] SIGTERM: Shutting down servers then terminating") log.Println("[INFO] SIGTERM: Shutting down servers then terminating")
exitCode := executeShutdownCallbacks("SIGTERM") exitCode := executeShutdownCallbacks("SIGTERM")
for _, f := range OnProcessExit {
f() // only perform important cleanup actions
}
err := Stop() err := Stop()
if err != nil { if err != nil {
log.Printf("[ERROR] SIGTERM stop: %v", err) log.Printf("[ERROR] SIGTERM stop: %v", err)
exitCode = 3 exitCode = 3
} }
if PidFile != "" {
os.Remove(PidFile)
}
os.Exit(exitCode) os.Exit(exitCode)
case syscall.SIGUSR1: case syscall.SIGUSR1:
......
The MIT License (MIT)
Copyright (c) 2016 Richard Barnes
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mint
import "strconv"
type Alert uint8
const (
// alert level
AlertLevelWarning = 1
AlertLevelError = 2
)
const (
AlertCloseNotify Alert = 0
AlertUnexpectedMessage Alert = 10
AlertBadRecordMAC Alert = 20
AlertDecryptionFailed Alert = 21
AlertRecordOverflow Alert = 22
AlertDecompressionFailure Alert = 30
AlertHandshakeFailure Alert = 40
AlertBadCertificate Alert = 42
AlertUnsupportedCertificate Alert = 43
AlertCertificateRevoked Alert = 44
AlertCertificateExpired Alert = 45
AlertCertificateUnknown Alert = 46
AlertIllegalParameter Alert = 47
AlertUnknownCA Alert = 48
AlertAccessDenied Alert = 49
AlertDecodeError Alert = 50
AlertDecryptError Alert = 51
AlertProtocolVersion Alert = 70
AlertInsufficientSecurity Alert = 71
AlertInternalError Alert = 80
AlertInappropriateFallback Alert = 86
AlertUserCanceled Alert = 90
AlertNoRenegotiation Alert = 100
AlertMissingExtension Alert = 109
AlertUnsupportedExtension Alert = 110
AlertCertificateUnobtainable Alert = 111
AlertUnrecognizedName Alert = 112
AlertBadCertificateStatsResponse Alert = 113
AlertBadCertificateHashValue Alert = 114
AlertUnknownPSKIdentity Alert = 115
AlertNoApplicationProtocol Alert = 120
AlertWouldBlock Alert = 254
AlertNoAlert Alert = 255
)
var alertText = map[Alert]string{
AlertCloseNotify: "close notify",
AlertUnexpectedMessage: "unexpected message",
AlertBadRecordMAC: "bad record MAC",
AlertDecryptionFailed: "decryption failed",
AlertRecordOverflow: "record overflow",
AlertDecompressionFailure: "decompression failure",
AlertHandshakeFailure: "handshake failure",
AlertBadCertificate: "bad certificate",
AlertUnsupportedCertificate: "unsupported certificate",
AlertCertificateRevoked: "revoked certificate",
AlertCertificateExpired: "expired certificate",
AlertCertificateUnknown: "unknown certificate",
AlertIllegalParameter: "illegal parameter",
AlertUnknownCA: "unknown certificate authority",
AlertAccessDenied: "access denied",
AlertDecodeError: "error decoding message",
AlertDecryptError: "error decrypting message",
AlertProtocolVersion: "protocol version not supported",
AlertInsufficientSecurity: "insufficient security level",
AlertInternalError: "internal error",
AlertInappropriateFallback: "inappropriate fallback",
AlertUserCanceled: "user canceled",
AlertMissingExtension: "missing extension",
AlertUnsupportedExtension: "unsupported extension",
AlertCertificateUnobtainable: "certificate unobtainable",
AlertUnrecognizedName: "unrecognized name",
AlertBadCertificateStatsResponse: "bad certificate status response",
AlertBadCertificateHashValue: "bad certificate hash value",
AlertUnknownPSKIdentity: "unknown PSK identity",
AlertNoApplicationProtocol: "no application protocol",
AlertNoRenegotiation: "no renegotiation",
AlertWouldBlock: "would have blocked",
AlertNoAlert: "no alert",
}
func (e Alert) String() string {
s, ok := alertText[e]
if ok {
return s
}
return "alert(" + strconv.Itoa(int(e)) + ")"
}
func (e Alert) Error() string {
return e.String()
}
package main
import (
"flag"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"github.com/bifurcation/mint"
)
var url string
func main() {
url := flag.String("url", "https://localhost:4430", "URL to send request")
flag.Parse()
mintdial := func(network, addr string) (net.Conn, error) {
return mint.Dial(network, addr, nil)
}
tr := &http.Transport{
DialTLS: mintdial,
DisableCompression: true,
}
client := &http.Client{Transport: tr}
response, err := client.Get(*url)
if err != nil {
fmt.Println("err:", err)
return
}
defer response.Body.Close()
contents, err := ioutil.ReadAll(response.Body)
if err != nil {
fmt.Printf("%s", err)
os.Exit(1)
}
fmt.Printf("%s\n", string(contents))
}
package main
import (
"flag"
"fmt"
"github.com/bifurcation/mint"
)
var addr string
func main() {
flag.StringVar(&addr, "addr", "localhost:4430", "port")
flag.Parse()
conn, err := mint.Dial("tcp", addr, nil)
if err != nil {
fmt.Println("TLS handshake failed:", err)
return
}
request := "GET / HTTP/1.0\r\n\r\n"
conn.Write([]byte(request))
response := ""
buffer := make([]byte, 1024)
var read int
for err == nil {
read, err = conn.Read(buffer)
fmt.Println(" ~~ read: ", read)
response += string(buffer)
}
fmt.Println("err:", err)
fmt.Println("Received from server:")
fmt.Println(response)
}
package main
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"log"
"net/http"
"github.com/bifurcation/mint"
"golang.org/x/net/http2"
)
var (
port string
serverName string
certFile string
keyFile string
responseFile string
h2 bool
sendTickets bool
)
type responder []byte
func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Write(rsp)
}
// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve
// PEM-encoded private key.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) {
keyDER, _ := pem.Decode(keyPEM)
if keyDER == nil {
return nil, err
}
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes)
if err != nil {
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes)
if err != nil {
generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes)
if err != nil {
// We don't include the actual error into
// the final error. The reason might be
// we don't want to leak any info about
// the private key.
return nil, fmt.Errorf("No successful private key decoder")
}
}
}
switch generalKey.(type) {
case *rsa.PrivateKey:
return generalKey.(*rsa.PrivateKey), nil
case *ecdsa.PrivateKey:
return generalKey.(*ecdsa.PrivateKey), nil
}
// should never reach here
return nil, fmt.Errorf("Should be unreachable")
}
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
// either a raw x509 certificate or a PKCS #7 structure possibly containing
// multiple certificates, from the top of certsPEM, which itself may
// contain multiple PEM encoded certificate objects.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
block, rest := pem.Decode(certsPEM)
if block == nil {
return nil, rest, nil
}
cert, err := x509.ParseCertificate(block.Bytes)
var certs = []*x509.Certificate{cert}
return certs, rest, err
}
// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them,
// can handle PEM encoded PKCS #7 structures.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
var certs []*x509.Certificate
var err error
certsPEM = bytes.TrimSpace(certsPEM)
for len(certsPEM) > 0 {
var cert []*x509.Certificate
cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM)
if err != nil {
return nil, err
} else if cert == nil {
break
}
certs = append(certs, cert...)
}
if len(certsPEM) > 0 {
return nil, fmt.Errorf("Trailing PEM data")
}
return certs, nil
}
func main() {
flag.StringVar(&port, "port", "4430", "port")
flag.StringVar(&serverName, "host", "example.com", "hostname")
flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER")
flag.StringVar(&keyFile, "key", "", "private key in PEM format")
flag.StringVar(&responseFile, "response", "", "file to serve")
flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)")
flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets")
flag.Parse()
var certChain []*x509.Certificate
var priv crypto.Signer
var response []byte
var err error
// Load the key and certificate chain
if certFile != "" {
certs, err := ioutil.ReadFile(certFile)
if err != nil {
log.Fatalf("Error: %v", err)
} else {
certChain, err = ParseCertificatesPEM(certs)
if err != nil {
certChain, err = x509.ParseCertificates(certs)
if err != nil {
log.Fatalf("Error parsing certificates: %v", err)
}
}
}
}
if keyFile != "" {
keyPEM, err := ioutil.ReadFile(keyFile)
if err != nil {
log.Fatalf("Error: %v", err)
} else {
priv, err = ParsePrivateKeyPEM(keyPEM)
if priv == nil || err != nil {
log.Fatalf("Error parsing private key: %v", err)
}
}
}
if err != nil {
log.Fatalf("Error: %v", err)
}
// Load response file
if responseFile != "" {
log.Printf("Loading response file: %v", responseFile)
response, err = ioutil.ReadFile(responseFile)
if err != nil {
log.Fatalf("Error: %v", err)
}
} else {
response = []byte("Welcome to the TLS 1.3 zone!")
}
handler := responder(response)
config := mint.Config{
SendSessionTickets: true,
ServerName: serverName,
NextProtos: []string{"http/1.1"},
}
if h2 {
config.NextProtos = []string{"h2"}
}
config.SendSessionTickets = sendTickets
if certChain != nil && priv != nil {
log.Printf("Loading cert: %v key: %v", certFile, keyFile)
config.Certificates = []*mint.Certificate{
{
Chain: certChain,
PrivateKey: priv,
},
}
}
config.Init(false)
service := "0.0.0.0:" + port
srv := &http.Server{Handler: handler}
log.Printf("Listening on port %v", port)
// Need the inner loop here because the h1 server errors on a dropped connection
// Need the outer loop here because the h2 server is per-connection
for {
listener, err := mint.Listen("tcp", service, &config)
if err != nil {
log.Printf("Listen Error: %v", err)
continue
}
if !h2 {
alert := srv.Serve(listener)
if alert != mint.AlertNoAlert {
log.Printf("Serve Error: %v", err)
}
} else {
srv2 := new(http2.Server)
opts := &http2.ServeConnOpts{
Handler: handler,
BaseConfig: srv,
}
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go srv2.ServeConn(conn, opts)
}
}
}
}
package main
import (
"flag"
"log"
"net"
"github.com/bifurcation/mint"
)
var port string
func main() {
var config mint.Config
config.SendSessionTickets = true
config.ServerName = "localhost"
config.Init(false)
flag.StringVar(&port, "port", "4430", "port")
flag.Parse()
service := "0.0.0.0:" + port
listener, err := mint.Listen("tcp", service, &config)
if err != nil {
log.Fatalf("server: listen: %s", err)
}
log.Print("server: listening")
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("server: accept: %s", err)
break
}
defer conn.Close()
log.Printf("server: accepted from %s", conn.RemoteAddr())
go handleClient(conn)
}
}
func handleClient(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 10)
for {
log.Print("server: conn: waiting")
n, err := conn.Read(buf)
if err != nil {
if err != nil {
log.Printf("server: conn: read: %s", err)
}
break
}
n, err = conn.Write([]byte("hello world"))
log.Printf("server: conn: wrote %d bytes", n)
if err != nil {
log.Printf("server: write: %s", err)
break
}
break
}
log.Println("server: conn: closed")
}
This diff is collapsed.
package mint
import (
"fmt"
"strconv"
)
var (
supportedVersion uint16 = 0x7f15 // draft-21
// Flags for some minor compat issues
allowWrongVersionNumber = true
allowPKCS1 = true
)
// enum {...} ContentType;
type RecordType byte
const (
RecordTypeAlert RecordType = 21
RecordTypeHandshake RecordType = 22
RecordTypeApplicationData RecordType = 23
)
// enum {...} HandshakeType;
type HandshakeType byte
const (
// Omitted: *_RESERVED
HandshakeTypeClientHello HandshakeType = 1
HandshakeTypeServerHello HandshakeType = 2
HandshakeTypeNewSessionTicket HandshakeType = 4
HandshakeTypeEndOfEarlyData HandshakeType = 5
HandshakeTypeHelloRetryRequest HandshakeType = 6
HandshakeTypeEncryptedExtensions HandshakeType = 8
HandshakeTypeCertificate HandshakeType = 11
HandshakeTypeCertificateRequest HandshakeType = 13
HandshakeTypeCertificateVerify HandshakeType = 15
HandshakeTypeServerConfiguration HandshakeType = 17
HandshakeTypeFinished HandshakeType = 20
HandshakeTypeKeyUpdate HandshakeType = 24
HandshakeTypeMessageHash HandshakeType = 254
)
// uint8 CipherSuite[2];
type CipherSuite uint16
const (
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
// value for this type so that we can detect when a field is set.
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305
)
func (c CipherSuite) String() string {
switch c {
case CIPHER_SUITE_UNKNOWN:
return "unknown"
case TLS_AES_128_GCM_SHA256:
return "TLS_AES_128_GCM_SHA256"
case TLS_AES_256_GCM_SHA384:
return "TLS_AES_256_GCM_SHA384"
case TLS_CHACHA20_POLY1305_SHA256:
return "TLS_CHACHA20_POLY1305_SHA256"
case TLS_AES_128_CCM_SHA256:
return "TLS_AES_128_CCM_SHA256"
case TLS_AES_256_CCM_8_SHA256:
return "TLS_AES_256_CCM_8_SHA256"
}
// cannot use %x here, since it calls String(), leading to infinite recursion
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16))
}
// enum {...} SignatureScheme
type SignatureScheme uint16
const (
// RSASSA-PKCS1-v1_5 algorithms
RSA_PKCS1_SHA1 SignatureScheme = 0x0201
RSA_PKCS1_SHA256 SignatureScheme = 0x0401
RSA_PKCS1_SHA384 SignatureScheme = 0x0501
RSA_PKCS1_SHA512 SignatureScheme = 0x0601
// ECDSA algorithms
ECDSA_P256_SHA256 SignatureScheme = 0x0403
ECDSA_P384_SHA384 SignatureScheme = 0x0503
ECDSA_P521_SHA512 SignatureScheme = 0x0603
// RSASSA-PSS algorithms
RSA_PSS_SHA256 SignatureScheme = 0x0804
RSA_PSS_SHA384 SignatureScheme = 0x0805
RSA_PSS_SHA512 SignatureScheme = 0x0806
// EdDSA algorithms
Ed25519 SignatureScheme = 0x0807
Ed448 SignatureScheme = 0x0808
)
// enum {...} ExtensionType
type ExtensionType uint16
const (
ExtensionTypeServerName ExtensionType = 0
ExtensionTypeSupportedGroups ExtensionType = 10
ExtensionTypeSignatureAlgorithms ExtensionType = 13
ExtensionTypeALPN ExtensionType = 16
ExtensionTypeKeyShare ExtensionType = 40
ExtensionTypePreSharedKey ExtensionType = 41
ExtensionTypeEarlyData ExtensionType = 42
ExtensionTypeSupportedVersions ExtensionType = 43
ExtensionTypeCookie ExtensionType = 44
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
)
// enum {...} NamedGroup
type NamedGroup uint16
const (
// Elliptic Curve Groups.
P256 NamedGroup = 23
P384 NamedGroup = 24
P521 NamedGroup = 25
// ECDH functions.
X25519 NamedGroup = 29
X448 NamedGroup = 30
// Finite field groups.
FFDHE2048 NamedGroup = 256
FFDHE3072 NamedGroup = 257
FFDHE4096 NamedGroup = 258
FFDHE6144 NamedGroup = 259
FFDHE8192 NamedGroup = 260
)
// enum {...} PskKeyExchangeMode;
type PSKKeyExchangeMode uint8
const (
PSKModeKE PSKKeyExchangeMode = 0
PSKModeDHEKE PSKKeyExchangeMode = 1
)
// enum {
// update_not_requested(0), update_requested(1), (255)
// } KeyUpdateRequest;
type KeyUpdateRequest uint8
const (
KeyUpdateNotRequested KeyUpdateRequest = 0
KeyUpdateRequested KeyUpdateRequest = 1
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
package mint
import (
"encoding/hex"
"math/big"
)
var (
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B423861285C97FFFFFFFFFFFFFFFF"
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex)
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes)
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF"
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex)
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes)
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" +
"FFFFFFFFFFFFFFFF"
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex)
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes)
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF"
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex)
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes)
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" +
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" +
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" +
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" +
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" +
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" +
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" +
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" +
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" +
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" +
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" +
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF"
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex)
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes)
)
// Read a generic "framed" packet consisting of a header and a
// This is used for both TLS Records and TLS Handshake Messages
package mint
type framing interface {
headerLen() int
defaultReadLen() int
frameLen(hdr []byte) (int, error)
}
const (
kFrameReaderHdr = 0
kFrameReaderBody = 1
)
type frameNextAction func(f *frameReader) error
type frameReader struct {
details framing
state uint8
header []byte
body []byte
working []byte
writeOffset int
remainder []byte
}
func newFrameReader(d framing) *frameReader {
hdr := make([]byte, d.headerLen())
return &frameReader{
d,
kFrameReaderHdr,
hdr,
nil,
hdr,
0,
nil,
}
}
func dup(a []byte) []byte {
r := make([]byte, len(a))
copy(r, a)
return r
}
func (f *frameReader) needed() int {
tmp := (len(f.working) - f.writeOffset) - len(f.remainder)
if tmp < 0 {
return 0
}
return tmp
}
func (f *frameReader) addChunk(in []byte) {
// Append to the buffer.
logf(logTypeFrameReader, "Appending %v", len(in))
f.remainder = append(f.remainder, in...)
}
func (f *frameReader) process() (hdr []byte, body []byte, err error) {
for f.needed() == 0 {
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset)
// Fill out our working block
copied := copy(f.working[f.writeOffset:], f.remainder)
f.remainder = f.remainder[copied:]
f.writeOffset += copied
if f.writeOffset < len(f.working) {
logf(logTypeFrameReader, "Read would have blocked 1")
return nil, nil, WouldBlock
}
// Reset the write offset, because we are now full.
f.writeOffset = 0
// We have read a full frame
if f.state == kFrameReaderBody {
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder))
f.state = kFrameReaderHdr
f.working = f.header
return dup(f.header), dup(f.body), nil
}
// We have read the header
bodyLen, err := f.details.frameLen(f.header)
if err != nil {
return nil, nil, err
}
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen)
f.body = make([]byte, bodyLen)
f.working = f.body
f.writeOffset = 0
f.state = kFrameReaderBody
}
logf(logTypeFrameReader, "Read would have blocked 2")
return nil, nil, WouldBlock
}
This diff is collapsed.
This diff is collapsed.
package mint
import (
"fmt"
"log"
"os"
"strings"
)
// We use this environment variable to control logging. It should be a
// comma-separated list of log tags (see below) or "*" to enable all logging.
const logConfigVar = "MINT_LOG"
// Pre-defined log types
const (
logTypeCrypto = "crypto"
logTypeHandshake = "handshake"
logTypeNegotiation = "negotiation"
logTypeIO = "io"
logTypeFrameReader = "frame"
logTypeVerbose = "verbose"
)
var (
logFunction = log.Printf
logAll = false
logSettings = map[string]bool{}
)
func init() {
parseLogEnv(os.Environ())
}
func parseLogEnv(env []string) {
for _, stmt := range env {
if strings.HasPrefix(stmt, logConfigVar+"=") {
val := stmt[len(logConfigVar)+1:]
if val == "*" {
logAll = true
} else {
for _, t := range strings.Split(val, ",") {
logSettings[t] = true
}
}
}
}
}
func logf(tag string, format string, args ...interface{}) {
if logAll || logSettings[tag] {
fullFormat := fmt.Sprintf("[%s] %s", tag, format)
logFunction(fullFormat, args...)
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
package syntax
import (
"strconv"
"strings"
)
// `tls:"head=2,min=2,max=255"`
type tagOptions map[string]uint
// parseTag parses a struct field's "tls" tag as a comma-separated list of
// name=value pairs, where the values MUST be unsigned integers
func parseTag(tag string) tagOptions {
opts := tagOptions{}
for _, token := range strings.Split(tag, ",") {
if strings.Index(token, "=") == -1 {
continue
}
parts := strings.Split(token, "=")
if len(parts[0]) == 0 {
continue
}
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
opts[parts[0]] = uint(val)
}
}
return opts
}
This diff is collapsed.
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
SendingAllowed() bool
GetStopWaitingFrame(force bool) *frames.StopWaitingFrame
DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber
GetAlarmTimeout() time.Time
OnAlarm()
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
SetLowerLimit(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *frames.AckFrame
}
...@@ -3,7 +3,7 @@ package quic ...@@ -3,7 +3,7 @@ package quic
import ( import (
"sync" "sync"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
var bufferPool sync.Pool var bufferPool sync.Pool
......
package crypto
import (
"encoding/binary"
"github.com/lucas-clemente/quic-go/protocol"
)
func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}
package quic
import (
"io"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStreamI interface {
StreamID() protocol.StreamID
io.Reader
io.Writer
handleStreamFrame(*wire.StreamFrame) error
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
setReadOffset(protocol.ByteCount)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type cryptoStream struct {
*stream
}
var _ cryptoStreamI = &cryptoStream{}
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
str := newStream(version.CryptoStreamID(), sender, flowController, version)
return &cryptoStream{str}
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPosition = offset
}
package frames
import "github.com/lucas-clemente/quic-go/protocol"
// AckRange is an ACK range
type AckRange struct {
FirstPacketNumber protocol.PacketNumber
LastPacketNumber protocol.PacketNumber
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment