Commit 216a6172 authored by Matthew Holt's avatar Matthew Holt

tls: Some bug fixes, basic rate limiting, max_certs setting

parent d25a3e95
......@@ -26,6 +26,7 @@ import (
"path"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mholt/caddy/caddy/https"
......@@ -317,6 +318,7 @@ func LoadCaddyfile(loader func() (Input, error)) (cdyfile Input, err error) {
return nil, err
}
cdyfile = loadedGob.Caddyfile
atomic.StoreInt32(https.OnDemandIssuedCount, loadedGob.OnDemandTLSCertsIssued)
}
// Try user's loader
......
......@@ -63,10 +63,12 @@ var signalParentOnce sync.Once
// caddyfileGob maps bind address to index of the file descriptor
// in the Files array passed to the child process. It also contains
// the caddyfile contents. Used only during graceful restarts.
// the caddyfile contents and other state needed by the new process.
// Used only during graceful restarts where a new process is spawned.
type caddyfileGob struct {
ListenerFds map[string]uintptr
Caddyfile Input
ListenerFds map[string]uintptr
Caddyfile Input
OnDemandTLSCertsIssued int32
}
// IsRestart returns whether this process is, according
......
......@@ -3,7 +3,6 @@ package https
import (
"crypto/tls"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
......@@ -23,21 +22,16 @@ func RequestCallback(w http.ResponseWriter, r *http.Request) bool {
scheme = "https"
}
hostname, _, err := net.SplitHostPort(r.Host)
if err != nil {
hostname = r.Host
}
upstream, err := url.Parse(scheme + "://" + hostname + ":" + AlternatePort)
upstream, err := url.Parse(scheme + "://localhost:" + AlternatePort)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Printf("[ERROR] letsencrypt handler: %v", err)
log.Printf("[ERROR] ACME proxy handler: %v", err)
return true
}
proxy := httputil.NewSingleHostReverseProxy(upstream)
proxy.Transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // client would use self-signed cert
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs
}
proxy.ServeHTTP(w, r)
......
......@@ -7,7 +7,9 @@ import (
"errors"
"fmt"
"log"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mholt/caddy/server"
......@@ -15,11 +17,12 @@ import (
)
// GetCertificate gets a certificate to satisfy clientHello as long as
// the certificate is already cached in memory.
// the certificate is already cached in memory. It will not be loaded
// from disk or obtained from the CA during the handshake.
//
// This function is safe for use as a tls.Config.GetCertificate callback.
func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertDuringHandshake(clientHello.ServerName, false)
cert, err := getCertDuringHandshake(clientHello.ServerName, false, false)
return cert.Certificate, err
}
......@@ -31,45 +34,60 @@ func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
//
// This function is safe for use as a tls.Config.GetCertificate callback.
func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertDuringHandshake(clientHello.ServerName, true)
cert, err := getCertDuringHandshake(clientHello.ServerName, true, true)
return cert.Certificate, err
}
// getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache, then, if obtainIfNecessary is true, it goes to disk,
// then asks the CA for a certificate if necessary.
// the in-memory cache. If no certificate for name is in the cach and if
// loadIfNecessary == true, it goes to disk to load it into the cache and
// serve it. If it's not on disk and if obtainIfNecessary == true, the
// certificate will be obtained from the CA, cached, and served. If
// obtainIfNecessary is true, then loadIfNecessary must also be set to true.
//
// This function is safe for concurrent use.
func getCertDuringHandshake(name string, obtainIfNecessary bool) (Certificate, error) {
func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it
cert, ok := getCertificate(name)
if ok {
return cert, nil
}
if obtainIfNecessary {
// TODO: Mitigate abuse!
if loadIfNecessary {
var err error
// Then check to see if we have one on disk
cert, err := cacheManagedCertificate(name, true)
if err != nil {
return cert, err
} else if cert.Certificate != nil {
cert, err := handshakeMaintenance(name, cert)
cert, err = cacheManagedCertificate(name, true)
if err == nil {
cert, err = handshakeMaintenance(name, cert)
if err != nil {
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
}
return cert, err
return cert, nil
}
// Only option left is to get one from LE, but the name has to qualify first
if !HostQualifies(name) {
return cert, errors.New("hostname '" + name + "' does not qualify for certificate")
}
if obtainIfNecessary {
name = strings.ToLower(name)
// Make sure aren't over any applicable limits
if onDemandMaxIssue > 0 && atomic.LoadInt32(OnDemandIssuedCount) >= onDemandMaxIssue {
return Certificate{}, fmt.Errorf("%s: maximum certificates issued (%d)", name, onDemandMaxIssue)
}
failedIssuanceMu.RLock()
when, ok := failedIssuance[name]
failedIssuanceMu.RUnlock()
if ok {
return Certificate{}, fmt.Errorf("%s: throttled; refusing to issue cert since last attempt on %s failed", name, when.String())
}
// Only option left is to get one from LE, but the name has to qualify first
if !HostQualifies(name) {
return cert, errors.New("hostname '" + name + "' does not qualify for certificate")
}
// By this point, we need to obtain one from the CA.
return obtainOnDemandCertificate(name)
// By this point, we need to obtain one from the CA.
return obtainOnDemandCertificate(name)
}
}
return Certificate{}, nil
......@@ -89,7 +107,7 @@ func obtainOnDemandCertificate(name string) (Certificate, error) {
// wait for it to finish obtaining the cert and then we'll use it.
obtainCertWaitChansMu.Unlock()
<-wait
return getCertDuringHandshake(name, false) // passing in true might result in infinite loop if obtain failed
return getCertDuringHandshake(name, true, false)
}
// looks like it's up to us to do all the work and obtain the cert
......@@ -115,11 +133,24 @@ func obtainOnDemandCertificate(name string) (Certificate, error) {
client.Configure("") // TODO: which BindHost?
err = client.Obtain([]string{name})
if err != nil {
// Failed to solve challenge, so don't allow another on-demand
// issue for this name to be attempted for a little while.
failedIssuanceMu.Lock()
failedIssuance[name] = time.Now()
go func(name string) {
time.Sleep(5 * time.Minute)
failedIssuanceMu.Lock()
delete(failedIssuance, name)
failedIssuanceMu.Unlock()
}(name)
failedIssuanceMu.Unlock()
return Certificate{}, err
}
atomic.AddInt32(OnDemandIssuedCount, 1)
// The certificate is on disk; now just start over to load it and serve it
return getCertDuringHandshake(name, false) // pass in false as a fail-safe from infinite-looping
return getCertDuringHandshake(name, true, false)
}
// handshakeMaintenance performs a check on cert for expiration and OCSP
......@@ -127,12 +158,6 @@ func obtainOnDemandCertificate(name string) (Certificate, error) {
//
// This function is safe for use by multiple concurrent goroutines.
func handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
// fmt.Println("ON-DEMAND CERT?", cert.OnDemand)
// if !cert.OnDemand {
// return cert, nil
// }
fmt.Println("Checking expiration of cert; on-demand:", cert.OnDemand)
// Check cert expiration
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < renewDurationBefore {
......@@ -173,7 +198,7 @@ func renewDynamicCertificate(name string) (Certificate, error) {
// wait for it to finish, then we'll use the new one.
obtainCertWaitChansMu.Unlock()
<-wait
return getCertDuringHandshake(name, false)
return getCertDuringHandshake(name, true, false)
}
// looks like it's up to us to do all the work and renew the cert
......@@ -201,7 +226,7 @@ func renewDynamicCertificate(name string) (Certificate, error) {
return Certificate{}, err
}
return getCertDuringHandshake(name, false)
return getCertDuringHandshake(name, true, false)
}
// stapleOCSP staples OCSP information to cert for hostname name.
......@@ -235,3 +260,20 @@ func stapleOCSP(cert *Certificate, pemBundle []byte) error {
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
var obtainCertWaitChans = make(map[string]chan struct{})
var obtainCertWaitChansMu sync.Mutex
// OnDemandIssuedCount is the number of certificates that have been issued
// on-demand by this process. It is only safe to modify this count atomically.
// If it reaches max_certs, on-demand issuances will fail.
var OnDemandIssuedCount = new(int32)
// onDemandMaxIssue is set based on max_certs in tls config. It specifies the
// maximum number of certificates that can be issued.
// TODO: This applies globally, but we should probably make a server-specific
// way to keep track of these limits and counts...
var onDemandMaxIssue int32
// failedIssuance is a set of names that we recently failed to get a
// certificate for from the ACME CA. They are removed after some time.
// When a name is in this map, do not issue a certificate for it.
var failedIssuance = make(map[string]time.Time)
var failedIssuanceMu sync.RWMutex
......@@ -8,6 +8,7 @@ import (
"log"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/mholt/caddy/caddy/setup"
......@@ -27,7 +28,7 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
}
for c.Next() {
var certificateFile, keyFile, loadDir string
var certificateFile, keyFile, loadDir, maxCerts string
args := c.RemainingArgs()
switch len(args) {
......@@ -80,6 +81,8 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
case "load":
c.Args(&loadDir)
c.TLS.Manual = true
case "max_certs":
c.Args(&maxCerts)
default:
return nil, c.Errf("Unknown keyword '%s'", c.Val())
}
......@@ -90,6 +93,20 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
return nil, c.ArgErr()
}
if c.TLS.Manual && maxCerts != "" {
return nil, c.Err("Cannot limit certificate count (max_certs) for manual TLS configurations")
}
if maxCerts != "" {
maxCertsNum, err := strconv.Atoi(maxCerts)
if err != nil || maxCertsNum < 0 {
return nil, c.Err("max_certs must be a positive integer")
}
if onDemandMaxIssue == 0 || int32(maxCertsNum) < onDemandMaxIssue { // keep the minimum; TODO: This is global; should be per-server or per-vhost...
onDemandMaxIssue = int32(maxCertsNum)
}
}
// don't load certificates unless we're supposed to
if !c.TLS.Enabled || !c.TLS.Manual {
continue
......
......@@ -11,6 +11,7 @@ import (
"os"
"os/exec"
"path"
"sync/atomic"
"github.com/mholt/caddy/caddy/https"
)
......@@ -55,8 +56,9 @@ func Restart(newCaddyfile Input) error {
// Prepare our payload to the child process
cdyfileGob := caddyfileGob{
ListenerFds: make(map[string]uintptr),
Caddyfile: newCaddyfile,
ListenerFds: make(map[string]uintptr),
Caddyfile: newCaddyfile,
OnDemandTLSCertsIssued: atomic.LoadInt32(https.OnDemandIssuedCount),
}
// Prepare a pipe to the fork's stdin so it can get the Caddyfile
......
......@@ -118,7 +118,7 @@ md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61`
}
if !actualRule.Password(pwd) || actualRule.Password(test.password+"!") {
t.Errorf("Test %d, rule %d: Expected password '%v', got '%v'",
i, j, test.password, actualRule.Password)
i, j, test.password, actualRule.Password(""))
}
expectedRes := fmt.Sprintf("%v", expectedRule.Resources)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment