Commit 73794f2a authored by Matt Holt's avatar Matt Holt Committed by GitHub

tls: Refactor internals related to TLS configurations (#1466)

* tls: Refactor TLS config innards with a few minor syntax changes

muststaple -> must_staple
"http2 off" -> "alpn" with list of ALPN values

* Fix typo

* Fix QUIC handler

* Inline struct field assignments
parent 4b877eeb
...@@ -79,7 +79,7 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error { ...@@ -79,7 +79,7 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
cfg.TLS.Enabled = true cfg.TLS.Enabled = true
cfg.Addr.Scheme = "https" cfg.Addr.Scheme = "https"
if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) { if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) {
_, err := caddytls.CacheManagedCertificate(cfg.Addr.Host, cfg.TLS) _, err := cfg.TLS.CacheManagedCertificate(cfg.Addr.Host)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -35,6 +35,11 @@ type tlsHandler struct { ...@@ -35,6 +35,11 @@ type tlsHandler struct {
// Halderman, et. al. in "The Security Impact of HTTPS Interception" (NDSS '17): // Halderman, et. al. in "The Security Impact of HTTPS Interception" (NDSS '17):
// https://jhalderm.com/pub/papers/interception-ndss17.pdf // https://jhalderm.com/pub/papers/interception-ndss17.pdf
func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h.listener == nil {
h.next.ServeHTTP(w, r)
return
}
h.listener.helloInfosMu.RLock() h.listener.helloInfosMu.RLock()
info := h.listener.helloInfos[r.RemoteAddr] info := h.listener.helloInfos[r.RemoteAddr]
h.listener.helloInfosMu.RUnlock() h.listener.helloInfosMu.RUnlock()
...@@ -78,63 +83,62 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -78,63 +83,62 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.next.ServeHTTP(w, r) h.next.ServeHTTP(w, r)
} }
// clientHelloConn reads the ClientHello
// and stores it in the attached listener.
type clientHelloConn struct { type clientHelloConn struct {
net.Conn net.Conn
readHello bool
listener *tlsHelloListener listener *tlsHelloListener
readHello bool // whether ClientHello has been read
buf *bytes.Buffer
} }
// Read reads from c.Conn (by letting the standard library
// do the reading off the wire), with the exception of
// getting a copy of the ClientHello so it can parse it.
func (c *clientHelloConn) Read(b []byte) (n int, err error) { func (c *clientHelloConn) Read(b []byte) (n int, err error) {
if !c.readHello { // if we've already read the ClientHello, pass thru
// Read the header bytes. if c.readHello {
hdr := make([]byte, 5) return c.Conn.Read(b)
n, err := io.ReadFull(c.Conn, hdr) }
if err != nil {
return n, err
}
// Get the length of the ClientHello message and read it as well.
length := uint16(hdr[3])<<8 | uint16(hdr[4])
hello := make([]byte, int(length))
n, err = io.ReadFull(c.Conn, hello)
if err != nil {
return n, err
}
// Parse the ClientHello and store it in the map.
rawParsed := parseRawClientHello(hello)
c.listener.helloInfosMu.Lock()
c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed
c.listener.helloInfosMu.Unlock()
// Since we buffered the header and ClientHello, pretend we were // we let the standard lib read off the wire for us, and
// never here by lining up the buffered values to be read with a // tee that into our buffer so we can read the ClientHello
// custom connection type, followed by the rest of the actual tee := io.TeeReader(c.Conn, c.buf)
// underlying connection. n, err = tee.Read(b)
mr := io.MultiReader(bytes.NewReader(hdr), bytes.NewReader(hello), c.Conn) if err != nil {
mc := multiConn{Conn: c.Conn, reader: mr} return
}
if c.buf.Len() < 5 {
return // need to read more bytes for header
}
c.Conn = mc // read the header bytes
hdr := make([]byte, 5)
_, err = io.ReadFull(c.buf, hdr)
if err != nil {
return // this would be highly unusual and sad
}
c.readHello = true // get length of the ClientHello message and read it
length := int(uint16(hdr[3])<<8 | uint16(hdr[4]))
if c.buf.Len() < length {
return // need to read more bytes
} }
return c.Conn.Read(b) hello := make([]byte, length)
} _, err = io.ReadFull(c.buf, hello)
if err != nil {
return
}
c.buf = nil // buffer no longer needed
// multiConn is a net.Conn that reads from the // parse the ClientHello and store it in the map
// given reader instead of the wire directly. This rawParsed := parseRawClientHello(hello)
// is useful when some of the connection has already c.listener.helloInfosMu.Lock()
// been read (like the TLS Client Hello) and the c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed
// reader is a io.MultiReader that starts with c.listener.helloInfosMu.Unlock()
// the contents of the buffer.
type multiConn struct {
net.Conn
reader io.Reader
}
// Read reads from mc.reader. c.readHello = true
func (mc multiConn) Read(b []byte) (n int, err error) { return
return mc.reader.Read(b)
} }
// parseRawClientHello parses data which contains the raw // parseRawClientHello parses data which contains the raw
...@@ -279,7 +283,7 @@ func (l *tlsHelloListener) Accept() (net.Conn, error) { ...@@ -279,7 +283,7 @@ func (l *tlsHelloListener) Accept() (net.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
helloConn := &clientHelloConn{Conn: conn, listener: l} helloConn := &clientHelloConn{Conn: conn, listener: l, buf: new(bytes.Buffer)}
return tls.Server(helloConn, l.config), nil return tls.Server(helloConn, l.config), nil
} }
......
...@@ -84,7 +84,7 @@ func TestHeuristicFunctions(t *testing.T) { ...@@ -84,7 +84,7 @@ func TestHeuristicFunctions(t *testing.T) {
// clientHello pairs a User-Agent string to its ClientHello message. // clientHello pairs a User-Agent string to its ClientHello message.
type clientHello struct { type clientHello struct {
userAgent string userAgent string
helloHex string helloHex string // do NOT include the header, just the ClientHello message
} }
// clientHellos groups samples of true (real) ClientHellos by the // clientHellos groups samples of true (real) ClientHellos by the
...@@ -158,7 +158,12 @@ func TestHeuristicFunctions(t *testing.T) { ...@@ -158,7 +158,12 @@ func TestHeuristicFunctions(t *testing.T) {
}, },
{ {
// IE 11 on Windows 7, this connection was intercepted by Blue Coat // IE 11 on Windows 7, this connection was intercepted by Blue Coat
helloHex: "010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100", helloHex: `010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100`,
},
{
// Firefox 51.0.1 being intercepted by burp 1.7.17
userAgent: "(TODO)",
helloHex: `010000d8030358a92f4daca95acc2f6a10a9c50d736135eae39406d3090238464540d482677600003ac023c027003cc025c02900670040c009c013002fc004c00e00330032c02bc02f009cc02dc031009e00a2c008c012000ac003c00d0016001300ff01000075000a0034003200170001000300130015000600070009000a0018000b000c0019000d000e000f001000110002001200040005001400080016000b00020100000d00180016060306010503050104030401040202030201020201010000001700150000126a61677561722e6b796877616e612e6f7267`,
}, },
}, },
} }
......
...@@ -31,40 +31,47 @@ type Server struct { ...@@ -31,40 +31,47 @@ type Server struct {
connTimeout time.Duration // max time to wait for a connection before force stop connTimeout time.Duration // max time to wait for a connection before force stop
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
vhosts *vhostTrie vhosts *vhostTrie
tlsConfig caddytls.ConfigGroup
} }
// ensure it satisfies the interface // ensure it satisfies the interface
var _ caddy.GracefulServer = new(Server) var _ caddy.GracefulServer = new(Server)
var defaultALPN = []string{"h2", "http/1.1"}
// makeTLSConfig extracts TLS settings from each site config to
// build a tls.Config usable in Caddy HTTP servers. The returned
// config will be nil if TLS is disabled for these sites.
func makeTLSConfig(group []*SiteConfig) (*tls.Config, error) {
var tlsConfigs []*caddytls.Config
for i := range group {
if HTTP2 && len(group[i].TLS.ALPN) == 0 {
// if no application-level protocol was configured up to now,
// default to HTTP/2, then HTTP/1.1 if necessary
group[i].TLS.ALPN = defaultALPN
}
tlsConfigs = append(tlsConfigs, group[i].TLS)
}
return caddytls.MakeTLSConfig(tlsConfigs)
}
// NewServer creates a new Server instance that will listen on addr // NewServer creates a new Server instance that will listen on addr
// and will serve the sites configured in group. // and will serve the sites configured in group.
func NewServer(addr string, group []*SiteConfig) (*Server, error) { func NewServer(addr string, group []*SiteConfig) (*Server, error) {
s := &Server{ s := &Server{
Server: makeHTTPServer(addr, group), Server: makeHTTPServerWithTimeouts(addr, group),
vhosts: newVHostTrie(), vhosts: newVHostTrie(),
sites: group, sites: group,
connTimeout: GracefulTimeout, connTimeout: GracefulTimeout,
} }
s.Server.Handler = s // this is weird, but whatever s.Server.Handler = s // this is weird, but whatever
tlsh := &tlsHandler{next: s.Server.Handler}
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
// when a connection closes or is hijacked, delete its entry
// in the map, because we are done with it.
if tlsh.listener != nil {
if cs == http.StateHijacked || cs == http.StateClosed {
tlsh.listener.helloInfosMu.Lock()
delete(tlsh.listener.helloInfos, c.RemoteAddr().String())
tlsh.listener.helloInfosMu.Unlock()
}
}
}
// Disable HTTP/2 if desired // extract TLS settings from each site config to build
if !HTTP2 { // a tls.Config, which will not be nil if TLS is enabled
s.Server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) tlsConfig, err := makeTLSConfig(group)
if err != nil {
return nil, err
} }
s.Server.TLSConfig = tlsConfig
// Enable QUIC if desired // Enable QUIC if desired
if QUIC { if QUIC {
...@@ -72,43 +79,38 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { ...@@ -72,43 +79,38 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler) s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler)
} }
// Set up TLS configuration // if TLS is enabled, make sure we prepare the Server accordingly
tlsConfigs := make(caddytls.ConfigGroup) if s.Server.TLSConfig != nil {
var allConfigs []*caddytls.Config // wrap the HTTP handler with a handler that does MITM detection
tlsh := &tlsHandler{next: s.Server.Handler}
for _, site := range group { s.Server.Handler = tlsh // this needs to be the "outer" handler when Serve() is called, for type assertion
if err := site.TLS.Build(tlsConfigs); err != nil { // when Serve() creates the TLS listener later, that listener should
return nil, err // be adding a reference the ClientHello info to a map; this callback
// will be sure to clear out that entry when the connection closes.
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
// when a connection closes or is hijacked, delete its entry
// in the map, because we are done with it.
if tlsh.listener != nil {
if cs == http.StateHijacked || cs == http.StateClosed {
tlsh.listener.helloInfosMu.Lock()
delete(tlsh.listener.helloInfos, c.RemoteAddr().String())
tlsh.listener.helloInfosMu.Unlock()
}
}
} }
tlsConfigs[site.TLS.Hostname] = site.TLS // As of Go 1.7, if the Server's TLSConfig is not nil, HTTP/2 is enabled only
allConfigs = append(allConfigs, site.TLS) // if TLSConfig.NextProtos includes the string "h2"
} if HTTP2 && len(s.Server.TLSConfig.NextProtos) == 0 {
// some experimenting shows that this NextProtos must have at least
// Check if configs are valid // one value that overlaps with the NextProtos of any other tls.Config
if err := caddytls.CheckConfigs(allConfigs); err != nil { // that is returned from GetConfigForClient; if there is no overlap,
return nil, err // the connection will fail (as of Go 1.8, Feb. 2017).
} s.Server.TLSConfig.NextProtos = defaultALPN
s.tlsConfig = tlsConfigs
if caddytls.HasTLSEnabled(allConfigs) {
s.Server.TLSConfig = &tls.Config{
GetConfigForClient: s.tlsConfig.GetConfigForClient,
GetCertificate: s.tlsConfig.GetCertificate,
} }
} }
// As of Go 1.7, HTTP/2 is enabled only if NextProtos includes the string "h2"
if HTTP2 && s.Server.TLSConfig != nil && len(s.Server.TLSConfig.NextProtos) == 0 {
s.Server.TLSConfig.NextProtos = []string{"h2"}
}
if s.Server.TLSConfig != nil {
s.Server.Handler = tlsh
}
// Compile custom middleware for every site (enables virtual hosting) // Compile custom middleware for every site (enables virtual hosting)
for _, site := range group { for _, site := range group {
stack := Handler(staticfiles.FileServer{Root: http.Dir(site.Root), Hide: site.HiddenFiles}) stack := Handler(staticfiles.FileServer{Root: http.Dir(site.Root), Hide: site.HiddenFiles})
...@@ -122,6 +124,61 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { ...@@ -122,6 +124,61 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
return s, nil return s, nil
} }
// makeHTTPServerWithTimeouts makes an http.Server from the group of
// configs in a way that configures timeouts (or, if not set, it uses
// the default timeouts) by combining the configuration of each
// SiteConfig in the group. (Timeouts are important for mitigating
// slowloris attacks.)
func makeHTTPServerWithTimeouts(addr string, group []*SiteConfig) *http.Server {
// find the minimum duration configured for each timeout
var min Timeouts
for _, cfg := range group {
if cfg.Timeouts.ReadTimeoutSet &&
(!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) {
min.ReadTimeoutSet = true
min.ReadTimeout = cfg.Timeouts.ReadTimeout
}
if cfg.Timeouts.ReadHeaderTimeoutSet &&
(!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) {
min.ReadHeaderTimeoutSet = true
min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout
}
if cfg.Timeouts.WriteTimeoutSet &&
(!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) {
min.WriteTimeoutSet = true
min.WriteTimeout = cfg.Timeouts.WriteTimeout
}
if cfg.Timeouts.IdleTimeoutSet &&
(!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) {
min.IdleTimeoutSet = true
min.IdleTimeout = cfg.Timeouts.IdleTimeout
}
}
// for the values that were not set, use defaults
if !min.ReadTimeoutSet {
min.ReadTimeout = defaultTimeouts.ReadTimeout
}
if !min.ReadHeaderTimeoutSet {
min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout
}
if !min.WriteTimeoutSet {
min.WriteTimeout = defaultTimeouts.WriteTimeout
}
if !min.IdleTimeoutSet {
min.IdleTimeout = defaultTimeouts.IdleTimeout
}
// set the final values on the server and return it
return &http.Server{
Addr: addr,
ReadTimeout: min.ReadTimeout,
ReadHeaderTimeout: min.ReadHeaderTimeout,
WriteTimeout: min.WriteTimeout,
IdleTimeout: min.IdleTimeout,
}
}
func (s *Server) wrapWithSvcHeaders(previousHandler http.Handler) http.HandlerFunc { func (s *Server) wrapWithSvcHeaders(previousHandler http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
s.quicServer.SetQuicHeaders(w.Header()) s.quicServer.SetQuicHeaders(w.Header())
...@@ -390,62 +447,6 @@ var defaultTimeouts = Timeouts{ ...@@ -390,62 +447,6 @@ var defaultTimeouts = Timeouts{
IdleTimeout: 2 * time.Minute, IdleTimeout: 2 * time.Minute,
} }
// makeHTTPServer makes an http.Server from the group of configs
// in a way that configures timeouts (or, if not set, it uses the
// default timeouts) and other http.Server properties by combining
// the configuration of each SiteConfig in the group. (Timeouts
// are important for mitigating slowloris attacks.)
func makeHTTPServer(addr string, group []*SiteConfig) *http.Server {
s := &http.Server{Addr: addr}
// find the minimum duration configured for each timeout
var min Timeouts
for _, cfg := range group {
if cfg.Timeouts.ReadTimeoutSet &&
(!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) {
min.ReadTimeoutSet = true
min.ReadTimeout = cfg.Timeouts.ReadTimeout
}
if cfg.Timeouts.ReadHeaderTimeoutSet &&
(!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) {
min.ReadHeaderTimeoutSet = true
min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout
}
if cfg.Timeouts.WriteTimeoutSet &&
(!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) {
min.WriteTimeoutSet = true
min.WriteTimeout = cfg.Timeouts.WriteTimeout
}
if cfg.Timeouts.IdleTimeoutSet &&
(!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) {
min.IdleTimeoutSet = true
min.IdleTimeout = cfg.Timeouts.IdleTimeout
}
}
// for the values that were not set, use defaults
if !min.ReadTimeoutSet {
min.ReadTimeout = defaultTimeouts.ReadTimeout
}
if !min.ReadHeaderTimeoutSet {
min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout
}
if !min.WriteTimeoutSet {
min.WriteTimeout = defaultTimeouts.WriteTimeout
}
if !min.IdleTimeoutSet {
min.IdleTimeout = defaultTimeouts.IdleTimeout
}
// set the final values on the server
s.ReadTimeout = min.ReadTimeout
s.ReadHeaderTimeout = min.ReadHeaderTimeout
s.WriteTimeout = min.WriteTimeout
s.IdleTimeout = min.IdleTimeout
return s
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so // connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually // dead TCP connections (e.g. closing laptop mid-download) eventually
......
...@@ -92,7 +92,7 @@ func TestMakeHTTPServer(t *testing.T) { ...@@ -92,7 +92,7 @@ func TestMakeHTTPServer(t *testing.T) {
}, },
}, },
} { } {
actual := makeHTTPServer("127.0.0.1:9005", tc.group) actual := makeHTTPServerWithTimeouts("127.0.0.1:9005", tc.group)
if got, want := actual.Addr, "127.0.0.1:9005"; got != want { if got, want := actual.Addr, "127.0.0.1:9005"; got != want {
t.Errorf("Test %d: Expected Addr=%s, but was %s", i, want, got) t.Errorf("Test %d: Expected Addr=%s, but was %s", i, want, got)
......
...@@ -89,8 +89,8 @@ func getCertificate(name string) (cert Certificate, matched, defaulted bool) { ...@@ -89,8 +89,8 @@ func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
// cache, flagging it as Managed and, if onDemand is true, as "OnDemand" // cache, flagging it as Managed and, if onDemand is true, as "OnDemand"
// (meaning that it was obtained or loaded during a TLS handshake). // (meaning that it was obtained or loaded during a TLS handshake).
// //
// This function is safe for concurrent use. // This method is safe for concurrent use.
func CacheManagedCertificate(domain string, cfg *Config) (Certificate, error) { func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
storage, err := cfg.StorageFor(cfg.CAUrl) storage, err := cfg.StorageFor(cfg.CAUrl)
if err != nil { if err != nil {
return Certificate{}, err return Certificate{}, err
......
...@@ -109,11 +109,11 @@ type Config struct { ...@@ -109,11 +109,11 @@ type Config struct {
// Add the must staple TLS extension to the CSR generated by lego/acme // Add the must staple TLS extension to the CSR generated by lego/acme
MustStaple bool MustStaple bool
// Disables HTTP2 completely // The list of protocols to choose from for Application Layer
DisableHTTP2 bool // Protocol Negotiation (ALPN).
ALPN []string
// Holds final tls.Config tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
tlsConfig *tls.Config
} }
// OnDemandState contains some state relevant for providing // OnDemandState contains some state relevant for providing
...@@ -223,33 +223,20 @@ func (c *Config) StorageFor(caURL string) (Storage, error) { ...@@ -223,33 +223,20 @@ func (c *Config) StorageFor(caURL string) (Storage, error) {
return s, nil return s, nil
} }
func (cfg *Config) Build(group ConfigGroup) error { // buildStandardTLSConfig converts cfg (*caddytls.Config) to a *tls.Config
config, err := cfg.build() // and stores it in cfg so it can be used in servers. If TLS is disabled,
// no tls.Config is created.
if err != nil { func (cfg *Config) buildStandardTLSConfig() error {
return err if !cfg.Enabled {
} return nil
if config != nil {
cfg.tlsConfig = config
cfg.tlsConfig.GetCertificate = group.GetCertificate
} }
return nil
}
func (cfg *Config) build() (*tls.Config, error) {
config := new(tls.Config) config := new(tls.Config)
if !cfg.Enabled {
return nil, nil
}
ciphersAdded := make(map[uint16]struct{}) ciphersAdded := make(map[uint16]struct{})
curvesAdded := make(map[tls.CurveID]struct{}) curvesAdded := make(map[tls.CurveID]struct{})
// Add cipher suites // add cipher suites
for _, ciph := range cfg.Ciphers { for _, ciph := range cfg.Ciphers {
if _, ok := ciphersAdded[ciph]; !ok { if _, ok := ciphersAdded[ciph]; !ok {
ciphersAdded[ciph] = struct{}{} ciphersAdded[ciph] = struct{}{}
...@@ -259,7 +246,7 @@ func (cfg *Config) build() (*tls.Config, error) { ...@@ -259,7 +246,7 @@ func (cfg *Config) build() (*tls.Config, error) {
config.PreferServerCipherSuites = cfg.PreferServerCipherSuites config.PreferServerCipherSuites = cfg.PreferServerCipherSuites
// Union curves // add curve preferences
for _, curv := range cfg.CurvePreferences { for _, curv := range cfg.CurvePreferences {
if _, ok := curvesAdded[curv]; !ok { if _, ok := curvesAdded[curv]; !ok {
curvesAdded[curv] = struct{}{} curvesAdded[curv] = struct{}{}
...@@ -270,8 +257,10 @@ func (cfg *Config) build() (*tls.Config, error) { ...@@ -270,8 +257,10 @@ func (cfg *Config) build() (*tls.Config, error) {
config.MinVersion = cfg.ProtocolMinVersion config.MinVersion = cfg.ProtocolMinVersion
config.MaxVersion = cfg.ProtocolMaxVersion config.MaxVersion = cfg.ProtocolMaxVersion
config.ClientAuth = cfg.ClientAuth config.ClientAuth = cfg.ClientAuth
config.NextProtos = cfg.ALPN
config.GetCertificate = cfg.GetCertificate
// Set up client authentication if enabled // set up client authentication if enabled
if config.ClientAuth != tls.NoClientCert { if config.ClientAuth != tls.NoClientCert {
pool := x509.NewCertPool() pool := x509.NewCertPool()
clientCertsAdded := make(map[string]struct{}) clientCertsAdded := make(map[string]struct{})
...@@ -286,45 +275,51 @@ func (cfg *Config) build() (*tls.Config, error) { ...@@ -286,45 +275,51 @@ func (cfg *Config) build() (*tls.Config, error) {
// Any client with a certificate from this CA will be allowed to connect // Any client with a certificate from this CA will be allowed to connect
caCrt, err := ioutil.ReadFile(caFile) caCrt, err := ioutil.ReadFile(caFile)
if err != nil { if err != nil {
return nil, err return err
} }
if !pool.AppendCertsFromPEM(caCrt) { if !pool.AppendCertsFromPEM(caCrt) {
return nil, fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile) return fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
} }
} }
config.ClientCAs = pool config.ClientCAs = pool
} }
// Default cipher suites // default cipher suites
if len(config.CipherSuites) == 0 { if len(config.CipherSuites) == 0 {
config.CipherSuites = defaultCiphers config.CipherSuites = defaultCiphers
} }
// For security, ensure TLS_FALLBACK_SCSV is always included first // for security, ensure TLS_FALLBACK_SCSV is always included first
if len(config.CipherSuites) == 0 || config.CipherSuites[0] != tls.TLS_FALLBACK_SCSV { if len(config.CipherSuites) == 0 || config.CipherSuites[0] != tls.TLS_FALLBACK_SCSV {
config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...) config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...)
} }
if cfg.DisableHTTP2 { // store the resulting new tls.Config
config.NextProtos = []string{} cfg.tlsConfig = config
} else {
config.NextProtos = []string{"h2"}
}
return config, nil return nil
} }
// CheckConfigs checks if multiple TLS configs does not collide with each other // MakeTLSConfig makes a tls.Config from configs. The returned
func CheckConfigs(configs []*Config) error { // tls.Config is programmed to load the matching caddytls.Config
// based on the hostname in SNI, but that's all.
func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
if len(configs) == 0 { if len(configs) == 0 {
return nil return nil, nil
} }
configMap := make(configGroup)
for i, cfg := range configs { for i, cfg := range configs {
if cfg == nil {
// avoid nil pointer dereference below this loop
configs[i] = new(Config)
continue
}
// Can't serve TLS and not-TLS on same port // can't serve TLS and non-TLS on same port
if i > 0 && cfg.Enabled != configs[i-1].Enabled { if i > 0 && cfg.Enabled != configs[i-1].Enabled {
thisConfProto, lastConfProto := "not TLS", "not TLS" thisConfProto, lastConfProto := "not TLS", "not TLS"
if cfg.Enabled { if cfg.Enabled {
...@@ -333,26 +328,33 @@ func CheckConfigs(configs []*Config) error { ...@@ -333,26 +328,33 @@ func CheckConfigs(configs []*Config) error {
if configs[i-1].Enabled { if configs[i-1].Enabled {
lastConfProto = "TLS" lastConfProto = "TLS"
} }
return fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener", return nil, fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto) configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
} }
if !cfg.Enabled { // convert each caddytls.Config into a tls.Config
continue if err := cfg.buildStandardTLSConfig(); err != nil {
return nil, err
} }
}
return nil // Key this config by its hostname (overwriting
} // configs with the same hostname pattern); during
// TLS handshakes, configs are loaded based on
// the hostname pattern, according to client's SNI.
configMap[cfg.Hostname] = cfg
}
func HasTLSEnabled(configs []*Config) bool { // Is TLS disabled? By now, we know that all
for _, config := range configs { // configs agree whether it is or not, so we
if config.Enabled { // can just look at the first one. If so,
return true // we're done here.
} if len(configs) == 0 || !configs[0].Enabled {
return nil, nil
} }
return false return &tls.Config{
GetConfigForClient: configMap.GetConfigForClient,
}, nil
} }
// ConfigGetter gets a Config keyed by key. // ConfigGetter gets a Config keyed by key.
......
...@@ -8,50 +8,50 @@ import ( ...@@ -8,50 +8,50 @@ import (
"testing" "testing"
) )
func TestMakeTLSConfigProtocolVersions(t *testing.T) { func TestConvertTLSConfigProtocolVersions(t *testing.T) {
// same min and max protocol versions // same min and max protocol versions
config := Config{ config := &Config{
Enabled: true, Enabled: true,
ProtocolMinVersion: tls.VersionTLS12, ProtocolMinVersion: tls.VersionTLS12,
ProtocolMaxVersion: tls.VersionTLS12, ProtocolMaxVersion: tls.VersionTLS12,
} }
result, err := config.build() err := config.buildStandardTLSConfig()
if err != nil { if err != nil {
t.Fatalf("Did not expect an error, but got %v", err) t.Fatalf("Did not expect an error, but got %v", err)
} }
if got, want := result.MinVersion, uint16(tls.VersionTLS12); got != want { if got, want := config.tlsConfig.MinVersion, uint16(tls.VersionTLS12); got != want {
t.Errorf("Expected min version to be %x, got %x", want, got) t.Errorf("Expected min version to be %x, got %x", want, got)
} }
if got, want := result.MaxVersion, uint16(tls.VersionTLS12); got != want { if got, want := config.tlsConfig.MaxVersion, uint16(tls.VersionTLS12); got != want {
t.Errorf("Expected max version to be %x, got %x", want, got) t.Errorf("Expected max version to be %x, got %x", want, got)
} }
} }
func TestMakeTLSConfigPreferServerCipherSuites(t *testing.T) { func TestConvertTLSConfigPreferServerCipherSuites(t *testing.T) {
// prefer server cipher suites // prefer server cipher suites
config := Config{Enabled: true, PreferServerCipherSuites: true} config := Config{Enabled: true, PreferServerCipherSuites: true}
result, err := config.build() err := config.buildStandardTLSConfig()
if err != nil { if err != nil {
t.Fatalf("Did not expect an error, but got %v", err) t.Fatalf("Did not expect an error, but got %v", err)
} }
if got, want := result.PreferServerCipherSuites, true; got != want { if got, want := config.tlsConfig.PreferServerCipherSuites, true; got != want {
t.Errorf("Expected PreferServerCipherSuites==%v but got %v", want, got) t.Errorf("Expected PreferServerCipherSuites==%v but got %v", want, got)
} }
} }
func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) { func TestMakeTLSConfigTLSEnabledDisabledError(t *testing.T) {
// verify handling when Enabled is true and false // verify handling when Enabled is true and false
configs := []*Config{ configs := []*Config{
{Enabled: true}, {Enabled: true},
{Enabled: false}, {Enabled: false},
} }
err := CheckConfigs(configs) _, err := MakeTLSConfig(configs)
if err == nil { if err == nil {
t.Fatalf("Expected an error, but got %v", err) t.Fatalf("Expected an error, but got %v", err)
} }
} }
func TestMakeTLSConfigCipherSuites(t *testing.T) { func TestConvertTLSConfigCipherSuites(t *testing.T) {
// ensure cipher suites are unioned and // ensure cipher suites are unioned and
// that TLS_FALLBACK_SCSV is prepended // that TLS_FALLBACK_SCSV is prepended
configs := []*Config{ configs := []*Config{
...@@ -67,10 +67,13 @@ func TestMakeTLSConfigCipherSuites(t *testing.T) { ...@@ -67,10 +67,13 @@ func TestMakeTLSConfigCipherSuites(t *testing.T) {
} }
for i, config := range configs { for i, config := range configs {
cfg, _ := config.build() err := config.buildStandardTLSConfig()
if err != nil {
if !reflect.DeepEqual(cfg.CipherSuites, expectedCiphers[i]) { t.Errorf("Test %d: Expected no error, got: %v", i, err)
t.Errorf("Expected ciphers %v but got %v", expectedCiphers[i], cfg.CipherSuites) }
if !reflect.DeepEqual(config.tlsConfig.CipherSuites, expectedCiphers[i]) {
t.Errorf("Test %d: Expected ciphers %v but got %v",
i, expectedCiphers[i], config.tlsConfig.CipherSuites)
} }
} }
......
...@@ -13,18 +13,19 @@ import ( ...@@ -13,18 +13,19 @@ import (
// configGroup is a type that keys configs by their hostname // configGroup is a type that keys configs by their hostname
// (hostnames can have wildcard characters; use the getConfig // (hostnames can have wildcard characters; use the getConfig
// method to get a config by matching its hostname). Its // method to get a config by matching its hostname).
// GetCertificate function can be used with tls.Config. type configGroup map[string]*Config
type ConfigGroup map[string]*Config
// getConfig gets the config by the first key match for name. // getConfig gets the config by the first key match for name.
// In other words, "sub.foo.bar" will get the config for "*.foo.bar" // In other words, "sub.foo.bar" will get the config for "*.foo.bar"
// if that is the closest match. This function MAY return nil // if that is the closest match. If no match is found, the first
// if no match is found. // (random) config will be loaded, which will defer any TLS alerts
// to the certificate validation (this may or may not be ideal;
// let's talk about it if this becomes problematic).
// //
// This function follows nearly the same logic to lookup // This function follows nearly the same logic to lookup
// a hostname as the getCertificate function uses. // a hostname as the getCertificate function uses.
func (cg ConfigGroup) getConfig(name string) *Config { func (cg configGroup) getConfig(name string) *Config {
name = strings.ToLower(name) name = strings.ToLower(name)
// exact match? great, let's use it // exact match? great, let's use it
...@@ -42,14 +43,36 @@ func (cg ConfigGroup) getConfig(name string) *Config { ...@@ -42,14 +43,36 @@ func (cg ConfigGroup) getConfig(name string) *Config {
} }
} }
// as last resort, try a config that serves all names // as a fallback, try a config that serves all names
if config, ok := cg[""]; ok { if config, ok := cg[""]; ok {
return config return config
} }
// as a last resort, use a random config
// (even if the config isn't for that hostname,
// it should help us serve clients without SNI
// or at least defer TLS alerts to the cert)
for _, config := range cg {
return config
}
return nil return nil
} }
// GetConfigForClient gets a TLS configuration satisfying clientHello.
// In getting the configuration, it abides the rules and settings
// defined in the Config that matches clientHello.ServerName. If no
// tls.Config is set on the matching Config, a nil value is returned.
//
// This method is safe for use as a tls.Config.GetConfigForClient callback.
func (cg configGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
config := cg.getConfig(clientHello.ServerName)
if config != nil {
return config.tlsConfig, nil
}
return nil, nil
}
// GetCertificate gets a certificate to satisfy clientHello. In getting // GetCertificate gets a certificate to satisfy clientHello. In getting
// the certificate, it abides the rules and settings defined in the // the certificate, it abides the rules and settings defined in the
// Config that matches clientHello.ServerName. It first checks the in- // Config that matches clientHello.ServerName. It first checks the in-
...@@ -58,27 +81,11 @@ func (cg ConfigGroup) getConfig(name string) *Config { ...@@ -58,27 +81,11 @@ func (cg ConfigGroup) getConfig(name string) *Config {
// via ACME. // via ACME.
// //
// This method is safe for use as a tls.Config.GetCertificate callback. // This method is safe for use as a tls.Config.GetCertificate callback.
func (cg ConfigGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := cg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true) cert, err := cfg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true)
return &cert.Certificate, err return &cert.Certificate, err
} }
// GetConfigForClient gets a TLS configuration satisfying clientHello. In getting
// the configuration, it abides the rules and settings defined in the
// Config that matches clientHello.ServerName.
//
// This method is safe for use as a tls.Config.GetConfigForClient callback.
func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
config := cg.getConfig(clientHello.ServerName)
if config != nil {
return config.tlsConfig, nil
}
return nil, nil
}
// getCertDuringHandshake will get a certificate for name. It first tries // getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache. If no certificate for name is in the cache, the // the in-memory cache. If no certificate for name is in the cache, the
// config most closely corresponding to name will be loaded. If that config // config most closely corresponding to name will be loaded. If that config
...@@ -90,21 +97,20 @@ func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls ...@@ -90,21 +97,20 @@ func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls
// certificate is available. // certificate is available.
// //
// This function is safe for concurrent use. // This function is safe for concurrent use.
func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it // First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := getCertificate(name) cert, matched, defaulted := getCertificate(name)
if matched { if matched {
return cert, nil return cert, nil
} }
// Get the relevant TLS config for this name. If OnDemand is enabled, // If OnDemand is enabled, then we might be able to load or
// then we might be able to load or obtain a needed certificate. // obtain a needed certificate
cfg := cg.getConfig(name) if cfg.OnDemand && loadIfNecessary {
if cfg != nil && cfg.OnDemand && loadIfNecessary {
// Then check to see if we have one on disk // Then check to see if we have one on disk
loadedCert, err := CacheManagedCertificate(name, cfg) loadedCert, err := cfg.CacheManagedCertificate(name)
if err == nil { if err == nil {
loadedCert, err = cg.handshakeMaintenance(name, loadedCert) loadedCert, err = cfg.handshakeMaintenance(name, loadedCert)
if err != nil { if err != nil {
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err) log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
} }
...@@ -116,7 +122,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai ...@@ -116,7 +122,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
name = strings.ToLower(name) name = strings.ToLower(name)
// Make sure aren't over any applicable limits // Make sure aren't over any applicable limits
err := cg.checkLimitsForObtainingNewCerts(name, cfg) err := cfg.checkLimitsForObtainingNewCerts(name)
if err != nil { if err != nil {
return Certificate{}, err return Certificate{}, err
} }
...@@ -127,7 +133,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai ...@@ -127,7 +133,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
} }
// Obtain certificate from the CA // Obtain certificate from the CA
return cg.obtainOnDemandCertificate(name, cfg) return cfg.obtainOnDemandCertificate(name)
} }
} }
...@@ -143,7 +149,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai ...@@ -143,7 +149,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
// now according to mitigating factors we keep track of and preferences the // now according to mitigating factors we keep track of and preferences the
// user has set. If a non-nil error is returned, do not issue a new certificate // user has set. If a non-nil error is returned, do not issue a new certificate
// for name. // for name.
func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) error { func (cfg *Config) checkLimitsForObtainingNewCerts(name string) error {
// User can set hard limit for number of certs for the process to issue // User can set hard limit for number of certs for the process to issue
if cfg.OnDemandState.MaxObtain > 0 && if cfg.OnDemandState.MaxObtain > 0 &&
atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain { atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain {
...@@ -167,7 +173,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) ...@@ -167,7 +173,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config)
return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since) return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since)
} }
// 👍Good to go // Good to go 👍
return nil return nil
} }
...@@ -176,7 +182,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) ...@@ -176,7 +182,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config)
// name, it will wait and use what the other goroutine obtained. // name, it will wait and use what the other goroutine obtained.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certificate, error) { func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
// We must protect this process from happening concurrently, so synchronize. // We must protect this process from happening concurrently, so synchronize.
obtainCertWaitChansMu.Lock() obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name] wait, ok := obtainCertWaitChans[name]
...@@ -185,7 +191,7 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi ...@@ -185,7 +191,7 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi
// wait for it to finish obtaining the cert and then we'll use it. // wait for it to finish obtaining the cert and then we'll use it.
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
<-wait <-wait
return cg.getCertDuringHandshake(name, true, false) return cfg.getCertDuringHandshake(name, true, false)
} }
// looks like it's up to us to do all the work and obtain the cert. // looks like it's up to us to do all the work and obtain the cert.
...@@ -228,19 +234,19 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi ...@@ -228,19 +234,19 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi
lastIssueTimeMu.Unlock() lastIssueTimeMu.Unlock()
// certificate is already on disk; now just start over to load it and serve it // certificate is already on disk; now just start over to load it and serve it
return cg.getCertDuringHandshake(name, true, false) return cfg.getCertDuringHandshake(name, true, false)
} }
// handshakeMaintenance performs a check on cert for expiration and OCSP // handshakeMaintenance performs a check on cert for expiration and OCSP
// validity. // validity.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certificate, error) { func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
// Check cert expiration // Check cert expiration
timeLeft := cert.NotAfter.Sub(time.Now().UTC()) timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < RenewDurationBefore { if timeLeft < RenewDurationBefore {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
return cg.renewDynamicCertificate(name, cert.Config) return cfg.renewDynamicCertificate(name)
} }
// Check OCSP staple validity // Check OCSP staple validity
...@@ -268,7 +274,7 @@ func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certi ...@@ -268,7 +274,7 @@ func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certi
// usable. name should already be lower-cased before calling this function. // usable. name should already be lower-cased before calling this function.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certificate, error) { func (cfg *Config) renewDynamicCertificate(name string) (Certificate, error) {
obtainCertWaitChansMu.Lock() obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name] wait, ok := obtainCertWaitChans[name]
if ok { if ok {
...@@ -276,7 +282,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi ...@@ -276,7 +282,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi
// wait for it to finish, then we'll use the new one. // wait for it to finish, then we'll use the new one.
obtainCertWaitChansMu.Unlock() obtainCertWaitChansMu.Unlock()
<-wait <-wait
return cg.getCertDuringHandshake(name, true, false) return cfg.getCertDuringHandshake(name, true, false)
} }
// looks like it's up to us to do all the work and renew the cert // looks like it's up to us to do all the work and renew the cert
...@@ -300,7 +306,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi ...@@ -300,7 +306,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi
return Certificate{}, err return Certificate{}, err
} }
return cg.getCertDuringHandshake(name, true, false) return cfg.getCertDuringHandshake(name, true, false)
} }
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname. // obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
......
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
func TestGetCertificate(t *testing.T) { func TestGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }() defer func() { certCache = make(map[string]Certificate) }()
cg := make(ConfigGroup) cfg := new(Config)
hello := &tls.ClientHelloInfo{ServerName: "example.com"} hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
...@@ -17,10 +17,10 @@ func TestGetCertificate(t *testing.T) { ...@@ -17,10 +17,10 @@ func TestGetCertificate(t *testing.T) {
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
// When cache is empty // When cache is empty
if cert, err := cg.GetCertificate(hello); err == nil { if cert, err := cfg.GetCertificate(hello); err == nil {
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert) t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
} }
if cert, err := cg.GetCertificate(helloNoSNI); err == nil { if cert, err := cfg.GetCertificate(helloNoSNI); err == nil {
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert) t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
} }
...@@ -28,12 +28,12 @@ func TestGetCertificate(t *testing.T) { ...@@ -28,12 +28,12 @@ func TestGetCertificate(t *testing.T) {
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}} defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
certCache[""] = defaultCert certCache[""] = defaultCert
certCache["example.com"] = defaultCert certCache["example.com"] = defaultCert
if cert, err := cg.GetCertificate(hello); err != nil { if cert, err := cfg.GetCertificate(hello); err != nil {
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err) t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" { } else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert) t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
} }
if cert, err := cg.GetCertificate(helloNoSNI); err != nil { if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err) t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" { } else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert) t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
...@@ -41,14 +41,14 @@ func TestGetCertificate(t *testing.T) { ...@@ -41,14 +41,14 @@ func TestGetCertificate(t *testing.T) {
// When retrieving wildcard certificate // When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}} certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
if cert, err := cg.GetCertificate(helloSub); err != nil { if cert, err := cfg.GetCertificate(helloSub); err != nil {
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err) t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
} else if cert.Leaf.DNSNames[0] != "*.example.com" { } else if cert.Leaf.DNSNames[0] != "*.example.com" {
t.Errorf("Got wrong certificate, expected wildcard: %v", cert) t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
} }
// When no certificate matches, the default is returned // When no certificate matches, the default is returned
if cert, err := cg.GetCertificate(helloNoMatch); err != nil { if cert, err := cfg.GetCertificate(helloNoMatch); err != nil {
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err) t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" { } else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Expected default cert with no matches, got: %v", cert) t.Errorf("Expected default cert with no matches, got: %v", cert)
......
...@@ -152,7 +152,7 @@ func RenewManagedCertificates(allowPrompts bool) (err error) { ...@@ -152,7 +152,7 @@ func RenewManagedCertificates(allowPrompts bool) (err error) {
delete(certCache, "") delete(certCache, "")
certCacheMu.Unlock() certCacheMu.Unlock()
} }
_, err := CacheManagedCertificate(cert.Names[0], cert.Config) _, err := cert.Config.CacheManagedCertificate(cert.Names[0])
if err != nil { if err != nil {
if allowPrompts { if allowPrompts {
return err // operator is present, so report error immediately return err // operator is present, so report error immediately
......
...@@ -164,21 +164,15 @@ func setupTLS(c *caddy.Controller) error { ...@@ -164,21 +164,15 @@ func setupTLS(c *caddy.Controller) error {
return c.Errf("Unsupported Storage provider '%s'", args[0]) return c.Errf("Unsupported Storage provider '%s'", args[0])
} }
config.StorageProvider = args[0] config.StorageProvider = args[0]
case "alpn":
case "http2":
args := c.RemainingArgs() args := c.RemainingArgs()
if len(args) != 1 { if len(args) == 0 {
return c.ArgErr() return c.ArgErr()
} }
for _, arg := range args {
switch args[0] { config.ALPN = append(config.ALPN, arg)
case "off":
config.DisableHTTP2 = true
default:
c.ArgErr()
} }
case "must_staple":
case "muststaple":
config.MustStaple = true config.MustStaple = true
default: default:
return c.Errf("Unknown keyword '%s'", c.Val()) return c.Errf("Unknown keyword '%s'", c.Val())
......
...@@ -91,8 +91,8 @@ func TestSetupParseBasic(t *testing.T) { ...@@ -91,8 +91,8 @@ func TestSetupParseBasic(t *testing.T) {
t.Error("Expected PreferServerCipherSuites = true, but was false") t.Error("Expected PreferServerCipherSuites = true, but was false")
} }
if cfg.DisableHTTP2 { if len(cfg.ALPN) != 0 {
t.Error("Expected HTTP2 to be enabled by default") t.Error("Expected ALPN empty by default")
} }
// Ensure curve count is correct // Ensure curve count is correct
...@@ -121,8 +121,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) { ...@@ -121,8 +121,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` { params := `tls ` + certFile + ` ` + keyFile + ` {
protocols tls1.0 tls1.2 protocols tls1.0 tls1.2
ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384 ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
muststaple must_staple
http2 off alpn http/1.1
}` }`
cfg := new(Config) cfg := new(Config)
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
...@@ -149,8 +149,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) { ...@@ -149,8 +149,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
t.Error("Expected must staple to be true") t.Error("Expected must staple to be true")
} }
if !cfg.DisableHTTP2 { if len(cfg.ALPN) != 1 || cfg.ALPN[0] != "http/1.1" {
t.Error("Expected HTTP2 to be disabled") t.Errorf("Expected ALPN to contain only 'http/1.1' but got: %v", cfg.ALPN)
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment