Commit 286d8d1e authored by Mateusz Gajewski's avatar Mateusz Gajewski Committed by Matt Holt

tls: Per-site TLS configs using GetClientConfig, including http2 switch (#1389)

* Remove manual TLS clone method

* WiP tls

* Use GetClientConfig for tls.Config

* gofmt -s -w

* GetConfig

* Handshake

* Removed comment

* Disable HTTP2 on demand

* Remove junk

* Remove http2 enable (no-op)
parent 977a3c32
...@@ -31,6 +31,7 @@ type Server struct { ...@@ -31,6 +31,7 @@ type Server struct {
connTimeout time.Duration // max time to wait for a connection before force stop connTimeout time.Duration // max time to wait for a connection before force stop
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
vhosts *vhostTrie vhosts *vhostTrie
tlsConfig caddytls.ConfigGroup
} }
// ensure it satisfies the interface // ensure it satisfies the interface
...@@ -72,16 +73,31 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { ...@@ -72,16 +73,31 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
} }
// Set up TLS configuration // Set up TLS configuration
var tlsConfigs []*caddytls.Config tlsConfigs := make(caddytls.ConfigGroup)
var allConfigs []*caddytls.Config
for _, site := range group { for _, site := range group {
tlsConfigs = append(tlsConfigs, site.TLS)
if err := site.TLS.Build(tlsConfigs); err != nil {
return nil, err
}
tlsConfigs[site.TLS.Hostname] = site.TLS
allConfigs = append(allConfigs, site.TLS)
} }
var err error
s.Server.TLSConfig, err = caddytls.MakeTLSConfig(tlsConfigs) // Check if configs are valid
if err != nil { if err := caddytls.CheckConfigs(allConfigs); err != nil {
return nil, err return nil, err
} }
s.tlsConfig = tlsConfigs
s.Server.TLSConfig = &tls.Config{
GetConfigForClient: s.tlsConfig.GetConfigForClient,
GetCertificate: s.tlsConfig.GetCertificate,
}
// As of Go 1.7, HTTP/2 is enabled only if NextProtos includes the string "h2" // As of Go 1.7, HTTP/2 is enabled only if NextProtos includes the string "h2"
if HTTP2 && s.Server.TLSConfig != nil && len(s.Server.TLSConfig.NextProtos) == 0 { if HTTP2 && s.Server.TLSConfig != nil && len(s.Server.TLSConfig.NextProtos) == 0 {
s.Server.TLSConfig.NextProtos = []string{"h2"} s.Server.TLSConfig.NextProtos = []string{"h2"}
......
...@@ -442,7 +442,7 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { ...@@ -442,7 +442,7 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
if b, _ := base.(*http.Transport); b != nil { if b, _ := base.(*http.Transport); b != nil {
tlsClientConfig := b.TLSClientConfig tlsClientConfig := b.TLSClientConfig
if tlsClientConfig.NextProtos != nil { if tlsClientConfig.NextProtos != nil {
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig) tlsClientConfig = tlsClientConfig.Clone()
tlsClientConfig.NextProtos = nil tlsClientConfig.NextProtos = nil
} }
...@@ -566,37 +566,6 @@ func (tlsHandshakeTimeoutError) Timeout() bool { return true } ...@@ -566,37 +566,6 @@ func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true } func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
// cloneTLSClientConfig is like cloneTLSConfig but omits
// the fields SessionTicketsDisabled and SessionTicketKey.
// This makes it safe to call cloneTLSClientConfig on a config
// in active use by a server.
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
Renegotiation: cfg.Renegotiation,
}
}
func requestIsWebsocket(req *http.Request) bool { func requestIsWebsocket(req *http.Request) bool {
return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
} }
......
...@@ -108,6 +108,12 @@ type Config struct { ...@@ -108,6 +108,12 @@ type Config struct {
// Add the must staple TLS extension to the CSR generated by lego/acme // Add the must staple TLS extension to the CSR generated by lego/acme
MustStaple bool MustStaple bool
// Disables HTTP2 completely
DisableHTTP2 bool
// Holds final tls.Config
tlsConfig *tls.Config
} }
// OnDemandState contains some state relevant for providing // OnDemandState contains some state relevant for providing
...@@ -217,88 +223,70 @@ func (c *Config) StorageFor(caURL string) (Storage, error) { ...@@ -217,88 +223,70 @@ func (c *Config) StorageFor(caURL string) (Storage, error) {
return s, nil return s, nil
} }
// MakeTLSConfig reduces configs into a single tls.Config. func (cfg *Config) Build(group ConfigGroup) error {
// If TLS is to be disabled, a nil tls.Config will be returned. config, err := cfg.build()
func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
if len(configs) == 0 { if err != nil {
return nil, nil return err
} }
cfg.tlsConfig = config
cfg.tlsConfig.GetCertificate = group.GetCertificate
return nil
}
func (cfg *Config) build() (*tls.Config, error) {
config := new(tls.Config) config := new(tls.Config)
ciphersAdded := make(map[uint16]struct{}) ciphersAdded := make(map[uint16]struct{})
curvesAdded := make(map[tls.CurveID]struct{}) curvesAdded := make(map[tls.CurveID]struct{})
configMap := make(configGroup)
for i, cfg := range configs { // Add cipher suites
if cfg == nil { for _, ciph := range cfg.Ciphers {
// avoid nil pointer dereference below if _, ok := ciphersAdded[ciph]; !ok {
configs[i] = new(Config) ciphersAdded[ciph] = struct{}{}
continue config.CipherSuites = append(config.CipherSuites, ciph)
} }
}
// Key this config by its hostname; this config.PreferServerCipherSuites = cfg.PreferServerCipherSuites
// overwrites configs with the same hostname
configMap[cfg.Hostname] = cfg
// Can't serve TLS and not-TLS on same port
if i > 0 && cfg.Enabled != configs[i-1].Enabled {
thisConfProto, lastConfProto := "not TLS", "not TLS"
if cfg.Enabled {
thisConfProto = "TLS"
}
if configs[i-1].Enabled {
lastConfProto = "TLS"
}
return nil, fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
}
if !cfg.Enabled { // Union curves
continue for _, curv := range cfg.CurvePreferences {
if _, ok := curvesAdded[curv]; !ok {
curvesAdded[curv] = struct{}{}
config.CurvePreferences = append(config.CurvePreferences, curv)
} }
}
// Union cipher suites config.MinVersion = cfg.ProtocolMinVersion
for _, ciph := range cfg.Ciphers { config.MaxVersion = cfg.ProtocolMaxVersion
if _, ok := ciphersAdded[ciph]; !ok { config.ClientAuth = cfg.ClientAuth
ciphersAdded[ciph] = struct{}{}
config.CipherSuites = append(config.CipherSuites, ciph)
}
}
// Can't resolve conflicting PreferServerCipherSuites settings // Set up client authentication if enabled
if i > 0 && cfg.PreferServerCipherSuites != configs[i-1].PreferServerCipherSuites { if config.ClientAuth != tls.NoClientCert {
return nil, fmt.Errorf("cannot both PreferServerCipherSuites and not prefer them") pool := x509.NewCertPool()
} clientCertsAdded := make(map[string]struct{})
config.PreferServerCipherSuites = cfg.PreferServerCipherSuites
// Union curves for _, caFile := range cfg.ClientCerts {
for _, curv := range cfg.CurvePreferences { // don't add cert to pool more than once
if _, ok := curvesAdded[curv]; !ok { if _, ok := clientCertsAdded[caFile]; ok {
curvesAdded[curv] = struct{}{} continue
config.CurvePreferences = append(config.CurvePreferences, curv)
} }
} clientCertsAdded[caFile] = struct{}{}
// Go with the widest range of protocol versions // Any client with a certificate from this CA will be allowed to connect
if config.MinVersion == 0 || cfg.ProtocolMinVersion < config.MinVersion { caCrt, err := ioutil.ReadFile(caFile)
config.MinVersion = cfg.ProtocolMinVersion if err != nil {
} return nil, err
if cfg.ProtocolMaxVersion > config.MaxVersion { }
config.MaxVersion = cfg.ProtocolMaxVersion
}
// Go with the strictest ClientAuth type if !pool.AppendCertsFromPEM(caCrt) {
if cfg.ClientAuth > config.ClientAuth { return nil, fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
config.ClientAuth = cfg.ClientAuth }
} }
}
// Is TLS disabled? If so, we're done here. config.ClientCAs = pool
// By now, we know that all configs agree
// whether it is or not, so we can just look
// at the first one.
if len(configs) == 0 || !configs[0].Enabled {
return nil, nil
} }
// Default cipher suites // Default cipher suites
...@@ -311,41 +299,42 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) { ...@@ -311,41 +299,42 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...) config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...)
} }
// Default curves if cfg.DisableHTTP2 {
if len(config.CurvePreferences) == 0 { config.NextProtos = []string{}
config.CurvePreferences = defaultCurves } else {
config.NextProtos = []string{"h2"}
} }
// Set up client authentication if enabled return config, nil
if config.ClientAuth != tls.NoClientCert { }
pool := x509.NewCertPool()
clientCertsAdded := make(map[string]struct{}) // CheckConfigs checks if multiple TLS configs does not collide with each other
for _, cfg := range configs { func CheckConfigs(configs []*Config) error {
for _, caFile := range cfg.ClientCerts { if len(configs) == 0 {
// don't add cert to pool more than once return nil
if _, ok := clientCertsAdded[caFile]; ok { }
continue
} for i, cfg := range configs {
clientCertsAdded[caFile] = struct{}{}
// Can't serve TLS and not-TLS on same port
// Any client with a certificate from this CA will be allowed to connect if i > 0 && cfg.Enabled != configs[i-1].Enabled {
caCrt, err := ioutil.ReadFile(caFile) thisConfProto, lastConfProto := "not TLS", "not TLS"
if err != nil { if cfg.Enabled {
return nil, err thisConfProto = "TLS"
} }
if configs[i-1].Enabled {
if !pool.AppendCertsFromPEM(caCrt) { lastConfProto = "TLS"
return nil, fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
}
} }
return fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
} }
config.ClientCAs = pool
}
// Associate the GetCertificate callback, or almost nothing we just did will work if !cfg.Enabled {
config.GetCertificate = configMap.GetCertificate continue
}
}
return config, nil return nil
} }
// ConfigGetter gets a Config keyed by key. // ConfigGetter gets a Config keyed by key.
......
...@@ -10,14 +10,12 @@ import ( ...@@ -10,14 +10,12 @@ import (
func TestMakeTLSConfigProtocolVersions(t *testing.T) { func TestMakeTLSConfigProtocolVersions(t *testing.T) {
// same min and max protocol versions // same min and max protocol versions
configs := []*Config{ config := Config{
{ Enabled: true,
Enabled: true, ProtocolMinVersion: tls.VersionTLS12,
ProtocolMinVersion: tls.VersionTLS12, ProtocolMaxVersion: tls.VersionTLS12,
ProtocolMaxVersion: tls.VersionTLS12,
},
} }
result, err := MakeTLSConfig(configs) result, err := config.build()
if err != nil { if err != nil {
t.Fatalf("Did not expect an error, but got %v", err) t.Fatalf("Did not expect an error, but got %v", err)
} }
...@@ -31,28 +29,14 @@ func TestMakeTLSConfigProtocolVersions(t *testing.T) { ...@@ -31,28 +29,14 @@ func TestMakeTLSConfigProtocolVersions(t *testing.T) {
func TestMakeTLSConfigPreferServerCipherSuites(t *testing.T) { func TestMakeTLSConfigPreferServerCipherSuites(t *testing.T) {
// prefer server cipher suites // prefer server cipher suites
configs := []*Config{{Enabled: true, PreferServerCipherSuites: true}} config := Config{Enabled: true, PreferServerCipherSuites: true}
result, err := MakeTLSConfig(configs) result, err := config.build()
if err != nil { if err != nil {
t.Fatalf("Did not expect an error, but got %v", err) t.Fatalf("Did not expect an error, but got %v", err)
} }
if got, want := result.PreferServerCipherSuites, true; got != want { if got, want := result.PreferServerCipherSuites, true; got != want {
t.Errorf("Expected PreferServerCipherSuites==%v but got %v", want, got) t.Errorf("Expected PreferServerCipherSuites==%v but got %v", want, got)
} }
// make sure we don't get an error if there's a conflict
// when both of the configs have TLS disabled
configs = []*Config{
{Enabled: false, PreferServerCipherSuites: false},
{Enabled: false, PreferServerCipherSuites: true},
}
result, err = MakeTLSConfig(configs)
if err != nil {
t.Fatalf("Did not expect an error when TLS is disabled, but got '%v'", err)
}
if result != nil {
t.Errorf("Expected nil result because TLS disabled, got: %+v", err)
}
} }
func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) { func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) {
...@@ -61,20 +45,10 @@ func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) { ...@@ -61,20 +45,10 @@ func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) {
{Enabled: true}, {Enabled: true},
{Enabled: false}, {Enabled: false},
} }
_, err := MakeTLSConfig(configs) err := CheckConfigs(configs)
if err == nil { if err == nil {
t.Fatalf("Expected an error, but got %v", err) t.Fatalf("Expected an error, but got %v", err)
} }
// verify that when disabled, a nil pair is returned
configs = []*Config{{}, {}}
result, err := MakeTLSConfig(configs)
if err != nil {
t.Errorf("Did not expect an error, but got %v", err)
}
if result != nil {
t.Errorf("Expected a nil *tls.Config result, got %+v", result)
}
} }
func TestMakeTLSConfigCipherSuites(t *testing.T) { func TestMakeTLSConfigCipherSuites(t *testing.T) {
...@@ -83,25 +57,22 @@ func TestMakeTLSConfigCipherSuites(t *testing.T) { ...@@ -83,25 +57,22 @@ func TestMakeTLSConfigCipherSuites(t *testing.T) {
configs := []*Config{ configs := []*Config{
{Enabled: true, Ciphers: []uint16{0xc02c, 0xc030}}, {Enabled: true, Ciphers: []uint16{0xc02c, 0xc030}},
{Enabled: true, Ciphers: []uint16{0xc012, 0xc030, 0xc00a}}, {Enabled: true, Ciphers: []uint16{0xc012, 0xc030, 0xc00a}},
} {Enabled: true, Ciphers: nil},
result, err := MakeTLSConfig(configs)
if err != nil {
t.Fatalf("Did not expect an error, but got %v", err)
}
expected := []uint16{tls.TLS_FALLBACK_SCSV, 0xc02c, 0xc030, 0xc012, 0xc00a}
if !reflect.DeepEqual(result.CipherSuites, expected) {
t.Errorf("Expected ciphers %v but got %v", expected, result.CipherSuites)
} }
// use default suites if none specified expectedCiphers := [][]uint16{
configs = []*Config{{Enabled: true}} {tls.TLS_FALLBACK_SCSV, 0xc02c, 0xc030},
result, err = MakeTLSConfig(configs) {tls.TLS_FALLBACK_SCSV, 0xc012, 0xc030, 0xc00a},
if err != nil { append([]uint16{tls.TLS_FALLBACK_SCSV}, defaultCiphers...),
t.Fatalf("Did not expect an error, but got %v", err)
} }
expected = append([]uint16{tls.TLS_FALLBACK_SCSV}, defaultCiphers...)
if !reflect.DeepEqual(result.CipherSuites, expected) { for i, config := range configs {
t.Errorf("Expected default ciphers %v but got %v", expected, result.CipherSuites) cfg, _ := config.build()
if !reflect.DeepEqual(cfg.CipherSuites, expectedCiphers[i]) {
t.Errorf("Expected ciphers %v but got %v", expectedCiphers[i], cfg.CipherSuites)
}
} }
} }
......
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
// (hostnames can have wildcard characters; use the getConfig // (hostnames can have wildcard characters; use the getConfig
// method to get a config by matching its hostname). Its // method to get a config by matching its hostname). Its
// GetCertificate function can be used with tls.Config. // GetCertificate function can be used with tls.Config.
type configGroup map[string]*Config type ConfigGroup map[string]*Config
// getConfig gets the config by the first key match for name. // getConfig gets the config by the first key match for name.
// In other words, "sub.foo.bar" will get the config for "*.foo.bar" // In other words, "sub.foo.bar" will get the config for "*.foo.bar"
...@@ -24,7 +24,7 @@ type configGroup map[string]*Config ...@@ -24,7 +24,7 @@ type configGroup map[string]*Config
// //
// This function follows nearly the same logic to lookup // This function follows nearly the same logic to lookup
// a hostname as the getCertificate function uses. // a hostname as the getCertificate function uses.
func (cg configGroup) getConfig(name string) *Config { func (cg ConfigGroup) getConfig(name string) *Config {
name = strings.ToLower(name) name = strings.ToLower(name)
// exact match? great, let's use it // exact match? great, let's use it
...@@ -58,11 +58,27 @@ func (cg configGroup) getConfig(name string) *Config { ...@@ -58,11 +58,27 @@ func (cg configGroup) getConfig(name string) *Config {
// via ACME. // via ACME.
// //
// This method is safe for use as a tls.Config.GetCertificate callback. // This method is safe for use as a tls.Config.GetCertificate callback.
func (cg configGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { func (cg ConfigGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := cg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true) cert, err := cg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true)
return &cert.Certificate, err return &cert.Certificate, err
} }
// GetConfigForClient gets a TLS configuration satisfying clientHello. In getting
// the configuration, it abides the rules and settings defined in the
// Config that matches clientHello.ServerName.
//
// This method is safe for use as a tls.Config.GetConfigForClient callback.
func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
config := cg.getConfig(clientHello.ServerName)
if config != nil {
return config.tlsConfig, nil
}
return nil, nil
}
// getCertDuringHandshake will get a certificate for name. It first tries // getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache. If no certificate for name is in the cache, the // the in-memory cache. If no certificate for name is in the cache, the
// config most closely corresponding to name will be loaded. If that config // config most closely corresponding to name will be loaded. If that config
...@@ -74,7 +90,7 @@ func (cg configGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Cer ...@@ -74,7 +90,7 @@ func (cg configGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Cer
// certificate is available. // certificate is available.
// //
// This function is safe for concurrent use. // This function is safe for concurrent use.
func (cg configGroup) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it // First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := getCertificate(name) cert, matched, defaulted := getCertificate(name)
if matched { if matched {
...@@ -127,7 +143,7 @@ func (cg configGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai ...@@ -127,7 +143,7 @@ func (cg configGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
// now according to mitigating factors we keep track of and preferences the // now according to mitigating factors we keep track of and preferences the
// user has set. If a non-nil error is returned, do not issue a new certificate // user has set. If a non-nil error is returned, do not issue a new certificate
// for name. // for name.
func (cg configGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) error { func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) error {
// User can set hard limit for number of certs for the process to issue // User can set hard limit for number of certs for the process to issue
if cfg.OnDemandState.MaxObtain > 0 && if cfg.OnDemandState.MaxObtain > 0 &&
atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain { atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain {
...@@ -160,7 +176,7 @@ func (cg configGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) ...@@ -160,7 +176,7 @@ func (cg configGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config)
// name, it will wait and use what the other goroutine obtained. // name, it will wait and use what the other goroutine obtained.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func (cg configGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certificate, error) { func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certificate, error) {
// We must protect this process from happening concurrently, so synchronize. // We must protect this process from happening concurrently, so synchronize.
obtainCertWaitChansMu.Lock() obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name] wait, ok := obtainCertWaitChans[name]
...@@ -219,7 +235,7 @@ func (cg configGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi ...@@ -219,7 +235,7 @@ func (cg configGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi
// validity. // validity.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func (cg configGroup) handshakeMaintenance(name string, cert Certificate) (Certificate, error) { func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
// Check cert expiration // Check cert expiration
timeLeft := cert.NotAfter.Sub(time.Now().UTC()) timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < RenewDurationBefore { if timeLeft < RenewDurationBefore {
...@@ -252,7 +268,7 @@ func (cg configGroup) handshakeMaintenance(name string, cert Certificate) (Certi ...@@ -252,7 +268,7 @@ func (cg configGroup) handshakeMaintenance(name string, cert Certificate) (Certi
// usable. name should already be lower-cased before calling this function. // usable. name should already be lower-cased before calling this function.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func (cg configGroup) renewDynamicCertificate(name string, cfg *Config) (Certificate, error) { func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certificate, error) {
obtainCertWaitChansMu.Lock() obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name] wait, ok := obtainCertWaitChans[name]
if ok { if ok {
......
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
func TestGetCertificate(t *testing.T) { func TestGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }() defer func() { certCache = make(map[string]Certificate) }()
cg := make(configGroup) cg := make(ConfigGroup)
hello := &tls.ClientHelloInfo{ServerName: "example.com"} hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
......
...@@ -164,6 +164,20 @@ func setupTLS(c *caddy.Controller) error { ...@@ -164,6 +164,20 @@ func setupTLS(c *caddy.Controller) error {
return c.Errf("Unsupported Storage provider '%s'", args[0]) return c.Errf("Unsupported Storage provider '%s'", args[0])
} }
config.StorageProvider = args[0] config.StorageProvider = args[0]
case "http2":
args := c.RemainingArgs()
if len(args) != 1 {
return c.ArgErr()
}
switch args[0] {
case "off":
config.DisableHTTP2 = true
default:
c.ArgErr()
}
case "muststaple": case "muststaple":
config.MustStaple = true config.MustStaple = true
default: default:
......
...@@ -91,6 +91,10 @@ func TestSetupParseBasic(t *testing.T) { ...@@ -91,6 +91,10 @@ func TestSetupParseBasic(t *testing.T) {
t.Error("Expected PreferServerCipherSuites = true, but was false") t.Error("Expected PreferServerCipherSuites = true, but was false")
} }
if cfg.DisableHTTP2 {
t.Error("Expected HTTP2 to be enabled by default")
}
// Ensure curve count is correct // Ensure curve count is correct
if len(cfg.CurvePreferences) != len(defaultCurves) { if len(cfg.CurvePreferences) != len(defaultCurves) {
t.Errorf("Expected %v Curves, got %v", len(defaultCurves), len(cfg.CurvePreferences)) t.Errorf("Expected %v Curves, got %v", len(defaultCurves), len(cfg.CurvePreferences))
...@@ -118,6 +122,7 @@ func TestSetupParseWithOptionalParams(t *testing.T) { ...@@ -118,6 +122,7 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
protocols tls1.0 tls1.2 protocols tls1.0 tls1.2
ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384 ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
muststaple muststaple
http2 off
}` }`
cfg := new(Config) cfg := new(Config)
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
...@@ -141,7 +146,11 @@ func TestSetupParseWithOptionalParams(t *testing.T) { ...@@ -141,7 +146,11 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
} }
if !cfg.MustStaple { if !cfg.MustStaple {
t.Errorf("Expected must staple to be true") t.Error("Expected must staple to be true")
}
if !cfg.DisableHTTP2 {
t.Error("Expected HTTP2 to be disabled")
} }
} }
...@@ -184,7 +193,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -184,7 +193,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
c = caddy.NewTestController("", params) c = caddy.NewTestController("", params)
err = setupTLS(c) err = setupTLS(c)
if err == nil { if err == nil {
t.Errorf("Expected errors, but no error returned") t.Error("Expected errors, but no error returned")
} }
// Test key_type wrong params // Test key_type wrong params
...@@ -196,7 +205,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -196,7 +205,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
c = caddy.NewTestController("", params) c = caddy.NewTestController("", params)
err = setupTLS(c) err = setupTLS(c)
if err == nil { if err == nil {
t.Errorf("Expected errors, but no error returned") t.Error("Expected errors, but no error returned")
} }
// Test curves wrong params // Test curves wrong params
...@@ -208,7 +217,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -208,7 +217,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
c = caddy.NewTestController("", params) c = caddy.NewTestController("", params)
err = setupTLS(c) err = setupTLS(c)
if err == nil { if err == nil {
t.Errorf("Expected errors, but no error returned") t.Error("Expected errors, but no error returned")
} }
} }
...@@ -222,7 +231,7 @@ func TestSetupParseWithClientAuth(t *testing.T) { ...@@ -222,7 +231,7 @@ func TestSetupParseWithClientAuth(t *testing.T) {
c := caddy.NewTestController("", params) c := caddy.NewTestController("", params)
err := setupTLS(c) err := setupTLS(c)
if err == nil { if err == nil {
t.Errorf("Expected an error, but no error returned") t.Error("Expected an error, but no error returned")
} }
noCAs, twoCAs := []string{}, []string{"client_ca.crt", "client2_ca.crt"} noCAs, twoCAs := []string{}, []string{"client_ca.crt", "client2_ca.crt"}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment