Commit 73794f2a authored by Matt Holt's avatar Matt Holt Committed by GitHub

tls: Refactor internals related to TLS configurations (#1466)

* tls: Refactor TLS config innards with a few minor syntax changes

muststaple -> must_staple
"http2 off" -> "alpn" with list of ALPN values

* Fix typo

* Fix QUIC handler

* Inline struct field assignments
parent 4b877eeb
......@@ -79,7 +79,7 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
cfg.TLS.Enabled = true
cfg.Addr.Scheme = "https"
if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) {
_, err := caddytls.CacheManagedCertificate(cfg.Addr.Host, cfg.TLS)
_, err := cfg.TLS.CacheManagedCertificate(cfg.Addr.Host)
if err != nil {
return err
}
......
......@@ -35,6 +35,11 @@ type tlsHandler struct {
// Halderman, et. al. in "The Security Impact of HTTPS Interception" (NDSS '17):
// https://jhalderm.com/pub/papers/interception-ndss17.pdf
func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h.listener == nil {
h.next.ServeHTTP(w, r)
return
}
h.listener.helloInfosMu.RLock()
info := h.listener.helloInfos[r.RemoteAddr]
h.listener.helloInfosMu.RUnlock()
......@@ -78,63 +83,62 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.next.ServeHTTP(w, r)
}
// clientHelloConn reads the ClientHello
// and stores it in the attached listener.
type clientHelloConn struct {
net.Conn
readHello bool
listener *tlsHelloListener
readHello bool // whether ClientHello has been read
buf *bytes.Buffer
}
// Read reads from c.Conn (by letting the standard library
// do the reading off the wire), with the exception of
// getting a copy of the ClientHello so it can parse it.
func (c *clientHelloConn) Read(b []byte) (n int, err error) {
if !c.readHello {
// Read the header bytes.
hdr := make([]byte, 5)
n, err := io.ReadFull(c.Conn, hdr)
if err != nil {
return n, err
}
// Get the length of the ClientHello message and read it as well.
length := uint16(hdr[3])<<8 | uint16(hdr[4])
hello := make([]byte, int(length))
n, err = io.ReadFull(c.Conn, hello)
if err != nil {
return n, err
}
// Parse the ClientHello and store it in the map.
rawParsed := parseRawClientHello(hello)
c.listener.helloInfosMu.Lock()
c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed
c.listener.helloInfosMu.Unlock()
// if we've already read the ClientHello, pass thru
if c.readHello {
return c.Conn.Read(b)
}
// Since we buffered the header and ClientHello, pretend we were
// never here by lining up the buffered values to be read with a
// custom connection type, followed by the rest of the actual
// underlying connection.
mr := io.MultiReader(bytes.NewReader(hdr), bytes.NewReader(hello), c.Conn)
mc := multiConn{Conn: c.Conn, reader: mr}
// we let the standard lib read off the wire for us, and
// tee that into our buffer so we can read the ClientHello
tee := io.TeeReader(c.Conn, c.buf)
n, err = tee.Read(b)
if err != nil {
return
}
if c.buf.Len() < 5 {
return // need to read more bytes for header
}
c.Conn = mc
// read the header bytes
hdr := make([]byte, 5)
_, err = io.ReadFull(c.buf, hdr)
if err != nil {
return // this would be highly unusual and sad
}
c.readHello = true
// get length of the ClientHello message and read it
length := int(uint16(hdr[3])<<8 | uint16(hdr[4]))
if c.buf.Len() < length {
return // need to read more bytes
}
return c.Conn.Read(b)
}
hello := make([]byte, length)
_, err = io.ReadFull(c.buf, hello)
if err != nil {
return
}
c.buf = nil // buffer no longer needed
// multiConn is a net.Conn that reads from the
// given reader instead of the wire directly. This
// is useful when some of the connection has already
// been read (like the TLS Client Hello) and the
// reader is a io.MultiReader that starts with
// the contents of the buffer.
type multiConn struct {
net.Conn
reader io.Reader
}
// parse the ClientHello and store it in the map
rawParsed := parseRawClientHello(hello)
c.listener.helloInfosMu.Lock()
c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed
c.listener.helloInfosMu.Unlock()
// Read reads from mc.reader.
func (mc multiConn) Read(b []byte) (n int, err error) {
return mc.reader.Read(b)
c.readHello = true
return
}
// parseRawClientHello parses data which contains the raw
......@@ -279,7 +283,7 @@ func (l *tlsHelloListener) Accept() (net.Conn, error) {
if err != nil {
return nil, err
}
helloConn := &clientHelloConn{Conn: conn, listener: l}
helloConn := &clientHelloConn{Conn: conn, listener: l, buf: new(bytes.Buffer)}
return tls.Server(helloConn, l.config), nil
}
......
......@@ -84,7 +84,7 @@ func TestHeuristicFunctions(t *testing.T) {
// clientHello pairs a User-Agent string to its ClientHello message.
type clientHello struct {
userAgent string
helloHex string
helloHex string // do NOT include the header, just the ClientHello message
}
// clientHellos groups samples of true (real) ClientHellos by the
......@@ -158,7 +158,12 @@ func TestHeuristicFunctions(t *testing.T) {
},
{
// IE 11 on Windows 7, this connection was intercepted by Blue Coat
helloHex: "010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100",
helloHex: `010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100`,
},
{
// Firefox 51.0.1 being intercepted by burp 1.7.17
userAgent: "(TODO)",
helloHex: `010000d8030358a92f4daca95acc2f6a10a9c50d736135eae39406d3090238464540d482677600003ac023c027003cc025c02900670040c009c013002fc004c00e00330032c02bc02f009cc02dc031009e00a2c008c012000ac003c00d0016001300ff01000075000a0034003200170001000300130015000600070009000a0018000b000c0019000d000e000f001000110002001200040005001400080016000b00020100000d00180016060306010503050104030401040202030201020201010000001700150000126a61677561722e6b796877616e612e6f7267`,
},
},
}
......
......@@ -31,40 +31,47 @@ type Server struct {
connTimeout time.Duration // max time to wait for a connection before force stop
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
vhosts *vhostTrie
tlsConfig caddytls.ConfigGroup
}
// ensure it satisfies the interface
var _ caddy.GracefulServer = new(Server)
var defaultALPN = []string{"h2", "http/1.1"}
// makeTLSConfig extracts TLS settings from each site config to
// build a tls.Config usable in Caddy HTTP servers. The returned
// config will be nil if TLS is disabled for these sites.
func makeTLSConfig(group []*SiteConfig) (*tls.Config, error) {
var tlsConfigs []*caddytls.Config
for i := range group {
if HTTP2 && len(group[i].TLS.ALPN) == 0 {
// if no application-level protocol was configured up to now,
// default to HTTP/2, then HTTP/1.1 if necessary
group[i].TLS.ALPN = defaultALPN
}
tlsConfigs = append(tlsConfigs, group[i].TLS)
}
return caddytls.MakeTLSConfig(tlsConfigs)
}
// NewServer creates a new Server instance that will listen on addr
// and will serve the sites configured in group.
func NewServer(addr string, group []*SiteConfig) (*Server, error) {
s := &Server{
Server: makeHTTPServer(addr, group),
Server: makeHTTPServerWithTimeouts(addr, group),
vhosts: newVHostTrie(),
sites: group,
connTimeout: GracefulTimeout,
}
s.Server.Handler = s // this is weird, but whatever
tlsh := &tlsHandler{next: s.Server.Handler}
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
// when a connection closes or is hijacked, delete its entry
// in the map, because we are done with it.
if tlsh.listener != nil {
if cs == http.StateHijacked || cs == http.StateClosed {
tlsh.listener.helloInfosMu.Lock()
delete(tlsh.listener.helloInfos, c.RemoteAddr().String())
tlsh.listener.helloInfosMu.Unlock()
}
}
}
// Disable HTTP/2 if desired
if !HTTP2 {
s.Server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
// extract TLS settings from each site config to build
// a tls.Config, which will not be nil if TLS is enabled
tlsConfig, err := makeTLSConfig(group)
if err != nil {
return nil, err
}
s.Server.TLSConfig = tlsConfig
// Enable QUIC if desired
if QUIC {
......@@ -72,43 +79,38 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler)
}
// Set up TLS configuration
tlsConfigs := make(caddytls.ConfigGroup)
var allConfigs []*caddytls.Config
for _, site := range group {
if err := site.TLS.Build(tlsConfigs); err != nil {
return nil, err
// if TLS is enabled, make sure we prepare the Server accordingly
if s.Server.TLSConfig != nil {
// wrap the HTTP handler with a handler that does MITM detection
tlsh := &tlsHandler{next: s.Server.Handler}
s.Server.Handler = tlsh // this needs to be the "outer" handler when Serve() is called, for type assertion
// when Serve() creates the TLS listener later, that listener should
// be adding a reference the ClientHello info to a map; this callback
// will be sure to clear out that entry when the connection closes.
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
// when a connection closes or is hijacked, delete its entry
// in the map, because we are done with it.
if tlsh.listener != nil {
if cs == http.StateHijacked || cs == http.StateClosed {
tlsh.listener.helloInfosMu.Lock()
delete(tlsh.listener.helloInfos, c.RemoteAddr().String())
tlsh.listener.helloInfosMu.Unlock()
}
}
}
tlsConfigs[site.TLS.Hostname] = site.TLS
allConfigs = append(allConfigs, site.TLS)
}
// Check if configs are valid
if err := caddytls.CheckConfigs(allConfigs); err != nil {
return nil, err
}
s.tlsConfig = tlsConfigs
if caddytls.HasTLSEnabled(allConfigs) {
s.Server.TLSConfig = &tls.Config{
GetConfigForClient: s.tlsConfig.GetConfigForClient,
GetCertificate: s.tlsConfig.GetCertificate,
// As of Go 1.7, if the Server's TLSConfig is not nil, HTTP/2 is enabled only
// if TLSConfig.NextProtos includes the string "h2"
if HTTP2 && len(s.Server.TLSConfig.NextProtos) == 0 {
// some experimenting shows that this NextProtos must have at least
// one value that overlaps with the NextProtos of any other tls.Config
// that is returned from GetConfigForClient; if there is no overlap,
// the connection will fail (as of Go 1.8, Feb. 2017).
s.Server.TLSConfig.NextProtos = defaultALPN
}
}
// As of Go 1.7, HTTP/2 is enabled only if NextProtos includes the string "h2"
if HTTP2 && s.Server.TLSConfig != nil && len(s.Server.TLSConfig.NextProtos) == 0 {
s.Server.TLSConfig.NextProtos = []string{"h2"}
}
if s.Server.TLSConfig != nil {
s.Server.Handler = tlsh
}
// Compile custom middleware for every site (enables virtual hosting)
for _, site := range group {
stack := Handler(staticfiles.FileServer{Root: http.Dir(site.Root), Hide: site.HiddenFiles})
......@@ -122,6 +124,61 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
return s, nil
}
// makeHTTPServerWithTimeouts makes an http.Server from the group of
// configs in a way that configures timeouts (or, if not set, it uses
// the default timeouts) by combining the configuration of each
// SiteConfig in the group. (Timeouts are important for mitigating
// slowloris attacks.)
func makeHTTPServerWithTimeouts(addr string, group []*SiteConfig) *http.Server {
// find the minimum duration configured for each timeout
var min Timeouts
for _, cfg := range group {
if cfg.Timeouts.ReadTimeoutSet &&
(!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) {
min.ReadTimeoutSet = true
min.ReadTimeout = cfg.Timeouts.ReadTimeout
}
if cfg.Timeouts.ReadHeaderTimeoutSet &&
(!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) {
min.ReadHeaderTimeoutSet = true
min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout
}
if cfg.Timeouts.WriteTimeoutSet &&
(!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) {
min.WriteTimeoutSet = true
min.WriteTimeout = cfg.Timeouts.WriteTimeout
}
if cfg.Timeouts.IdleTimeoutSet &&
(!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) {
min.IdleTimeoutSet = true
min.IdleTimeout = cfg.Timeouts.IdleTimeout
}
}
// for the values that were not set, use defaults
if !min.ReadTimeoutSet {
min.ReadTimeout = defaultTimeouts.ReadTimeout
}
if !min.ReadHeaderTimeoutSet {
min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout
}
if !min.WriteTimeoutSet {
min.WriteTimeout = defaultTimeouts.WriteTimeout
}
if !min.IdleTimeoutSet {
min.IdleTimeout = defaultTimeouts.IdleTimeout
}
// set the final values on the server and return it
return &http.Server{
Addr: addr,
ReadTimeout: min.ReadTimeout,
ReadHeaderTimeout: min.ReadHeaderTimeout,
WriteTimeout: min.WriteTimeout,
IdleTimeout: min.IdleTimeout,
}
}
func (s *Server) wrapWithSvcHeaders(previousHandler http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s.quicServer.SetQuicHeaders(w.Header())
......@@ -390,62 +447,6 @@ var defaultTimeouts = Timeouts{
IdleTimeout: 2 * time.Minute,
}
// makeHTTPServer makes an http.Server from the group of configs
// in a way that configures timeouts (or, if not set, it uses the
// default timeouts) and other http.Server properties by combining
// the configuration of each SiteConfig in the group. (Timeouts
// are important for mitigating slowloris attacks.)
func makeHTTPServer(addr string, group []*SiteConfig) *http.Server {
s := &http.Server{Addr: addr}
// find the minimum duration configured for each timeout
var min Timeouts
for _, cfg := range group {
if cfg.Timeouts.ReadTimeoutSet &&
(!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) {
min.ReadTimeoutSet = true
min.ReadTimeout = cfg.Timeouts.ReadTimeout
}
if cfg.Timeouts.ReadHeaderTimeoutSet &&
(!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) {
min.ReadHeaderTimeoutSet = true
min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout
}
if cfg.Timeouts.WriteTimeoutSet &&
(!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) {
min.WriteTimeoutSet = true
min.WriteTimeout = cfg.Timeouts.WriteTimeout
}
if cfg.Timeouts.IdleTimeoutSet &&
(!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) {
min.IdleTimeoutSet = true
min.IdleTimeout = cfg.Timeouts.IdleTimeout
}
}
// for the values that were not set, use defaults
if !min.ReadTimeoutSet {
min.ReadTimeout = defaultTimeouts.ReadTimeout
}
if !min.ReadHeaderTimeoutSet {
min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout
}
if !min.WriteTimeoutSet {
min.WriteTimeout = defaultTimeouts.WriteTimeout
}
if !min.IdleTimeoutSet {
min.IdleTimeout = defaultTimeouts.IdleTimeout
}
// set the final values on the server
s.ReadTimeout = min.ReadTimeout
s.ReadHeaderTimeout = min.ReadHeaderTimeout
s.WriteTimeout = min.WriteTimeout
s.IdleTimeout = min.IdleTimeout
return s
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
......
......@@ -92,7 +92,7 @@ func TestMakeHTTPServer(t *testing.T) {
},
},
} {
actual := makeHTTPServer("127.0.0.1:9005", tc.group)
actual := makeHTTPServerWithTimeouts("127.0.0.1:9005", tc.group)
if got, want := actual.Addr, "127.0.0.1:9005"; got != want {
t.Errorf("Test %d: Expected Addr=%s, but was %s", i, want, got)
......
......@@ -89,8 +89,8 @@ func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
// cache, flagging it as Managed and, if onDemand is true, as "OnDemand"
// (meaning that it was obtained or loaded during a TLS handshake).
//
// This function is safe for concurrent use.
func CacheManagedCertificate(domain string, cfg *Config) (Certificate, error) {
// This method is safe for concurrent use.
func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
storage, err := cfg.StorageFor(cfg.CAUrl)
if err != nil {
return Certificate{}, err
......
......@@ -109,11 +109,11 @@ type Config struct {
// Add the must staple TLS extension to the CSR generated by lego/acme
MustStaple bool
// Disables HTTP2 completely
DisableHTTP2 bool
// The list of protocols to choose from for Application Layer
// Protocol Negotiation (ALPN).
ALPN []string
// Holds final tls.Config
tlsConfig *tls.Config
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
}
// OnDemandState contains some state relevant for providing
......@@ -223,33 +223,20 @@ func (c *Config) StorageFor(caURL string) (Storage, error) {
return s, nil
}
func (cfg *Config) Build(group ConfigGroup) error {
config, err := cfg.build()
if err != nil {
return err
}
if config != nil {
cfg.tlsConfig = config
cfg.tlsConfig.GetCertificate = group.GetCertificate
// buildStandardTLSConfig converts cfg (*caddytls.Config) to a *tls.Config
// and stores it in cfg so it can be used in servers. If TLS is disabled,
// no tls.Config is created.
func (cfg *Config) buildStandardTLSConfig() error {
if !cfg.Enabled {
return nil
}
return nil
}
func (cfg *Config) build() (*tls.Config, error) {
config := new(tls.Config)
if !cfg.Enabled {
return nil, nil
}
ciphersAdded := make(map[uint16]struct{})
curvesAdded := make(map[tls.CurveID]struct{})
// Add cipher suites
// add cipher suites
for _, ciph := range cfg.Ciphers {
if _, ok := ciphersAdded[ciph]; !ok {
ciphersAdded[ciph] = struct{}{}
......@@ -259,7 +246,7 @@ func (cfg *Config) build() (*tls.Config, error) {
config.PreferServerCipherSuites = cfg.PreferServerCipherSuites
// Union curves
// add curve preferences
for _, curv := range cfg.CurvePreferences {
if _, ok := curvesAdded[curv]; !ok {
curvesAdded[curv] = struct{}{}
......@@ -270,8 +257,10 @@ func (cfg *Config) build() (*tls.Config, error) {
config.MinVersion = cfg.ProtocolMinVersion
config.MaxVersion = cfg.ProtocolMaxVersion
config.ClientAuth = cfg.ClientAuth
config.NextProtos = cfg.ALPN
config.GetCertificate = cfg.GetCertificate
// Set up client authentication if enabled
// set up client authentication if enabled
if config.ClientAuth != tls.NoClientCert {
pool := x509.NewCertPool()
clientCertsAdded := make(map[string]struct{})
......@@ -286,45 +275,51 @@ func (cfg *Config) build() (*tls.Config, error) {
// Any client with a certificate from this CA will be allowed to connect
caCrt, err := ioutil.ReadFile(caFile)
if err != nil {
return nil, err
return err
}
if !pool.AppendCertsFromPEM(caCrt) {
return nil, fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
return fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
}
}
config.ClientCAs = pool
}
// Default cipher suites
// default cipher suites
if len(config.CipherSuites) == 0 {
config.CipherSuites = defaultCiphers
}
// For security, ensure TLS_FALLBACK_SCSV is always included first
// for security, ensure TLS_FALLBACK_SCSV is always included first
if len(config.CipherSuites) == 0 || config.CipherSuites[0] != tls.TLS_FALLBACK_SCSV {
config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...)
}
if cfg.DisableHTTP2 {
config.NextProtos = []string{}
} else {
config.NextProtos = []string{"h2"}
}
// store the resulting new tls.Config
cfg.tlsConfig = config
return config, nil
return nil
}
// CheckConfigs checks if multiple TLS configs does not collide with each other
func CheckConfigs(configs []*Config) error {
// MakeTLSConfig makes a tls.Config from configs. The returned
// tls.Config is programmed to load the matching caddytls.Config
// based on the hostname in SNI, but that's all.
func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
if len(configs) == 0 {
return nil
return nil, nil
}
configMap := make(configGroup)
for i, cfg := range configs {
if cfg == nil {
// avoid nil pointer dereference below this loop
configs[i] = new(Config)
continue
}
// Can't serve TLS and not-TLS on same port
// can't serve TLS and non-TLS on same port
if i > 0 && cfg.Enabled != configs[i-1].Enabled {
thisConfProto, lastConfProto := "not TLS", "not TLS"
if cfg.Enabled {
......@@ -333,26 +328,33 @@ func CheckConfigs(configs []*Config) error {
if configs[i-1].Enabled {
lastConfProto = "TLS"
}
return fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
return nil, fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
}
if !cfg.Enabled {
continue
// convert each caddytls.Config into a tls.Config
if err := cfg.buildStandardTLSConfig(); err != nil {
return nil, err
}
}
return nil
}
// Key this config by its hostname (overwriting
// configs with the same hostname pattern); during
// TLS handshakes, configs are loaded based on
// the hostname pattern, according to client's SNI.
configMap[cfg.Hostname] = cfg
}
func HasTLSEnabled(configs []*Config) bool {
for _, config := range configs {
if config.Enabled {
return true
}
// Is TLS disabled? By now, we know that all
// configs agree whether it is or not, so we
// can just look at the first one. If so,
// we're done here.
if len(configs) == 0 || !configs[0].Enabled {
return nil, nil
}
return false
return &tls.Config{
GetConfigForClient: configMap.GetConfigForClient,
}, nil
}
// ConfigGetter gets a Config keyed by key.
......
......@@ -8,50 +8,50 @@ import (
"testing"
)
func TestMakeTLSConfigProtocolVersions(t *testing.T) {
func TestConvertTLSConfigProtocolVersions(t *testing.T) {
// same min and max protocol versions
config := Config{
config := &Config{
Enabled: true,
ProtocolMinVersion: tls.VersionTLS12,
ProtocolMaxVersion: tls.VersionTLS12,
}
result, err := config.build()
err := config.buildStandardTLSConfig()
if err != nil {
t.Fatalf("Did not expect an error, but got %v", err)
}
if got, want := result.MinVersion, uint16(tls.VersionTLS12); got != want {
if got, want := config.tlsConfig.MinVersion, uint16(tls.VersionTLS12); got != want {
t.Errorf("Expected min version to be %x, got %x", want, got)
}
if got, want := result.MaxVersion, uint16(tls.VersionTLS12); got != want {
if got, want := config.tlsConfig.MaxVersion, uint16(tls.VersionTLS12); got != want {
t.Errorf("Expected max version to be %x, got %x", want, got)
}
}
func TestMakeTLSConfigPreferServerCipherSuites(t *testing.T) {
func TestConvertTLSConfigPreferServerCipherSuites(t *testing.T) {
// prefer server cipher suites
config := Config{Enabled: true, PreferServerCipherSuites: true}
result, err := config.build()
err := config.buildStandardTLSConfig()
if err != nil {
t.Fatalf("Did not expect an error, but got %v", err)
}
if got, want := result.PreferServerCipherSuites, true; got != want {
if got, want := config.tlsConfig.PreferServerCipherSuites, true; got != want {
t.Errorf("Expected PreferServerCipherSuites==%v but got %v", want, got)
}
}
func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) {
func TestMakeTLSConfigTLSEnabledDisabledError(t *testing.T) {
// verify handling when Enabled is true and false
configs := []*Config{
{Enabled: true},
{Enabled: false},
}
err := CheckConfigs(configs)
_, err := MakeTLSConfig(configs)
if err == nil {
t.Fatalf("Expected an error, but got %v", err)
}
}
func TestMakeTLSConfigCipherSuites(t *testing.T) {
func TestConvertTLSConfigCipherSuites(t *testing.T) {
// ensure cipher suites are unioned and
// that TLS_FALLBACK_SCSV is prepended
configs := []*Config{
......@@ -67,10 +67,13 @@ func TestMakeTLSConfigCipherSuites(t *testing.T) {
}
for i, config := range configs {
cfg, _ := config.build()
if !reflect.DeepEqual(cfg.CipherSuites, expectedCiphers[i]) {
t.Errorf("Expected ciphers %v but got %v", expectedCiphers[i], cfg.CipherSuites)
err := config.buildStandardTLSConfig()
if err != nil {
t.Errorf("Test %d: Expected no error, got: %v", i, err)
}
if !reflect.DeepEqual(config.tlsConfig.CipherSuites, expectedCiphers[i]) {
t.Errorf("Test %d: Expected ciphers %v but got %v",
i, expectedCiphers[i], config.tlsConfig.CipherSuites)
}
}
......
......@@ -13,18 +13,19 @@ import (
// configGroup is a type that keys configs by their hostname
// (hostnames can have wildcard characters; use the getConfig
// method to get a config by matching its hostname). Its
// GetCertificate function can be used with tls.Config.
type ConfigGroup map[string]*Config
// method to get a config by matching its hostname).
type configGroup map[string]*Config
// getConfig gets the config by the first key match for name.
// In other words, "sub.foo.bar" will get the config for "*.foo.bar"
// if that is the closest match. This function MAY return nil
// if no match is found.
// if that is the closest match. If no match is found, the first
// (random) config will be loaded, which will defer any TLS alerts
// to the certificate validation (this may or may not be ideal;
// let's talk about it if this becomes problematic).
//
// This function follows nearly the same logic to lookup
// a hostname as the getCertificate function uses.
func (cg ConfigGroup) getConfig(name string) *Config {
func (cg configGroup) getConfig(name string) *Config {
name = strings.ToLower(name)
// exact match? great, let's use it
......@@ -42,14 +43,36 @@ func (cg ConfigGroup) getConfig(name string) *Config {
}
}
// as last resort, try a config that serves all names
// as a fallback, try a config that serves all names
if config, ok := cg[""]; ok {
return config
}
// as a last resort, use a random config
// (even if the config isn't for that hostname,
// it should help us serve clients without SNI
// or at least defer TLS alerts to the cert)
for _, config := range cg {
return config
}
return nil
}
// GetConfigForClient gets a TLS configuration satisfying clientHello.
// In getting the configuration, it abides the rules and settings
// defined in the Config that matches clientHello.ServerName. If no
// tls.Config is set on the matching Config, a nil value is returned.
//
// This method is safe for use as a tls.Config.GetConfigForClient callback.
func (cg configGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
config := cg.getConfig(clientHello.ServerName)
if config != nil {
return config.tlsConfig, nil
}
return nil, nil
}
// GetCertificate gets a certificate to satisfy clientHello. In getting
// the certificate, it abides the rules and settings defined in the
// Config that matches clientHello.ServerName. It first checks the in-
......@@ -58,27 +81,11 @@ func (cg ConfigGroup) getConfig(name string) *Config {
// via ACME.
//
// This method is safe for use as a tls.Config.GetCertificate callback.
func (cg ConfigGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := cg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true)
func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := cfg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true)
return &cert.Certificate, err
}
// GetConfigForClient gets a TLS configuration satisfying clientHello. In getting
// the configuration, it abides the rules and settings defined in the
// Config that matches clientHello.ServerName.
//
// This method is safe for use as a tls.Config.GetConfigForClient callback.
func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
config := cg.getConfig(clientHello.ServerName)
if config != nil {
return config.tlsConfig, nil
}
return nil, nil
}
// getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache. If no certificate for name is in the cache, the
// config most closely corresponding to name will be loaded. If that config
......@@ -90,21 +97,20 @@ func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls
// certificate is available.
//
// This function is safe for concurrent use.
func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := getCertificate(name)
if matched {
return cert, nil
}
// Get the relevant TLS config for this name. If OnDemand is enabled,
// then we might be able to load or obtain a needed certificate.
cfg := cg.getConfig(name)
if cfg != nil && cfg.OnDemand && loadIfNecessary {
// If OnDemand is enabled, then we might be able to load or
// obtain a needed certificate
if cfg.OnDemand && loadIfNecessary {
// Then check to see if we have one on disk
loadedCert, err := CacheManagedCertificate(name, cfg)
loadedCert, err := cfg.CacheManagedCertificate(name)
if err == nil {
loadedCert, err = cg.handshakeMaintenance(name, loadedCert)
loadedCert, err = cfg.handshakeMaintenance(name, loadedCert)
if err != nil {
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
}
......@@ -116,7 +122,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
name = strings.ToLower(name)
// Make sure aren't over any applicable limits
err := cg.checkLimitsForObtainingNewCerts(name, cfg)
err := cfg.checkLimitsForObtainingNewCerts(name)
if err != nil {
return Certificate{}, err
}
......@@ -127,7 +133,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
}
// Obtain certificate from the CA
return cg.obtainOnDemandCertificate(name, cfg)
return cfg.obtainOnDemandCertificate(name)
}
}
......@@ -143,7 +149,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
// now according to mitigating factors we keep track of and preferences the
// user has set. If a non-nil error is returned, do not issue a new certificate
// for name.
func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) error {
func (cfg *Config) checkLimitsForObtainingNewCerts(name string) error {
// User can set hard limit for number of certs for the process to issue
if cfg.OnDemandState.MaxObtain > 0 &&
atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain {
......@@ -167,7 +173,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config)
return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since)
}
// 👍Good to go
// Good to go 👍
return nil
}
......@@ -176,7 +182,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config)
// name, it will wait and use what the other goroutine obtained.
//
// This function is safe for use by multiple concurrent goroutines.
func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certificate, error) {
func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
// We must protect this process from happening concurrently, so synchronize.
obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name]
......@@ -185,7 +191,7 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi
// wait for it to finish obtaining the cert and then we'll use it.
obtainCertWaitChansMu.Unlock()
<-wait
return cg.getCertDuringHandshake(name, true, false)
return cfg.getCertDuringHandshake(name, true, false)
}
// looks like it's up to us to do all the work and obtain the cert.
......@@ -228,19 +234,19 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi
lastIssueTimeMu.Unlock()
// certificate is already on disk; now just start over to load it and serve it
return cg.getCertDuringHandshake(name, true, false)
return cfg.getCertDuringHandshake(name, true, false)
}
// handshakeMaintenance performs a check on cert for expiration and OCSP
// validity.
//
// This function is safe for use by multiple concurrent goroutines.
func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
// Check cert expiration
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < RenewDurationBefore {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
return cg.renewDynamicCertificate(name, cert.Config)
return cfg.renewDynamicCertificate(name)
}
// Check OCSP staple validity
......@@ -268,7 +274,7 @@ func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certi
// usable. name should already be lower-cased before calling this function.
//
// This function is safe for use by multiple concurrent goroutines.
func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certificate, error) {
func (cfg *Config) renewDynamicCertificate(name string) (Certificate, error) {
obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name]
if ok {
......@@ -276,7 +282,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi
// wait for it to finish, then we'll use the new one.
obtainCertWaitChansMu.Unlock()
<-wait
return cg.getCertDuringHandshake(name, true, false)
return cfg.getCertDuringHandshake(name, true, false)
}
// looks like it's up to us to do all the work and renew the cert
......@@ -300,7 +306,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi
return Certificate{}, err
}
return cg.getCertDuringHandshake(name, true, false)
return cfg.getCertDuringHandshake(name, true, false)
}
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
......
......@@ -9,7 +9,7 @@ import (
func TestGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
cg := make(ConfigGroup)
cfg := new(Config)
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
......@@ -17,10 +17,10 @@ func TestGetCertificate(t *testing.T) {
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
// When cache is empty
if cert, err := cg.GetCertificate(hello); err == nil {
if cert, err := cfg.GetCertificate(hello); err == nil {
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
}
if cert, err := cg.GetCertificate(helloNoSNI); err == nil {
if cert, err := cfg.GetCertificate(helloNoSNI); err == nil {
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
}
......@@ -28,12 +28,12 @@ func TestGetCertificate(t *testing.T) {
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
certCache[""] = defaultCert
certCache["example.com"] = defaultCert
if cert, err := cg.GetCertificate(hello); err != nil {
if cert, err := cfg.GetCertificate(hello); err != nil {
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
}
if cert, err := cg.GetCertificate(helloNoSNI); err != nil {
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
......@@ -41,14 +41,14 @@ func TestGetCertificate(t *testing.T) {
// When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
if cert, err := cg.GetCertificate(helloSub); err != nil {
if cert, err := cfg.GetCertificate(helloSub); err != nil {
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
}
// When no certificate matches, the default is returned
if cert, err := cg.GetCertificate(helloNoMatch); err != nil {
if cert, err := cfg.GetCertificate(helloNoMatch); err != nil {
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Expected default cert with no matches, got: %v", cert)
......
......@@ -152,7 +152,7 @@ func RenewManagedCertificates(allowPrompts bool) (err error) {
delete(certCache, "")
certCacheMu.Unlock()
}
_, err := CacheManagedCertificate(cert.Names[0], cert.Config)
_, err := cert.Config.CacheManagedCertificate(cert.Names[0])
if err != nil {
if allowPrompts {
return err // operator is present, so report error immediately
......
......@@ -164,21 +164,15 @@ func setupTLS(c *caddy.Controller) error {
return c.Errf("Unsupported Storage provider '%s'", args[0])
}
config.StorageProvider = args[0]
case "http2":
case "alpn":
args := c.RemainingArgs()
if len(args) != 1 {
if len(args) == 0 {
return c.ArgErr()
}
switch args[0] {
case "off":
config.DisableHTTP2 = true
default:
c.ArgErr()
for _, arg := range args {
config.ALPN = append(config.ALPN, arg)
}
case "muststaple":
case "must_staple":
config.MustStaple = true
default:
return c.Errf("Unknown keyword '%s'", c.Val())
......
......@@ -91,8 +91,8 @@ func TestSetupParseBasic(t *testing.T) {
t.Error("Expected PreferServerCipherSuites = true, but was false")
}
if cfg.DisableHTTP2 {
t.Error("Expected HTTP2 to be enabled by default")
if len(cfg.ALPN) != 0 {
t.Error("Expected ALPN empty by default")
}
// Ensure curve count is correct
......@@ -121,8 +121,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` {
protocols tls1.0 tls1.2
ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
muststaple
http2 off
must_staple
alpn http/1.1
}`
cfg := new(Config)
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
......@@ -149,8 +149,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
t.Error("Expected must staple to be true")
}
if !cfg.DisableHTTP2 {
t.Error("Expected HTTP2 to be disabled")
if len(cfg.ALPN) != 1 || cfg.ALPN[0] != "http/1.1" {
t.Errorf("Expected ALPN to contain only 'http/1.1' but got: %v", cfg.ALPN)
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment