Commit cff3e758 authored by Adam Langley's avatar Adam Langley Committed by Brad Fitzpatrick

crypto/tls: add Config.GetConfigForClient

GetConfigForClient allows the tls.Config to be updated on a per-client
basis.

Fixes #16066.
Fixes #15707.
Fixes #15699.

Change-Id: I2c675a443d557f969441226729f98502b38901ea
Reviewed-on: https://go-review.googlesource.com/30790
Run-TryBot: Adam Langley <agl@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 7e2bf952
...@@ -303,7 +303,27 @@ type Config struct { ...@@ -303,7 +303,27 @@ type Config struct {
// If GetCertificate is nil or returns nil, then the certificate is // If GetCertificate is nil or returns nil, then the certificate is
// retrieved from NameToCertificate. If NameToCertificate is nil, the // retrieved from NameToCertificate. If NameToCertificate is nil, the
// first element of Certificates will be used. // first element of Certificates will be used.
GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error) GetCertificate func(*ClientHelloInfo) (*Certificate, error)
// GetConfigForClient, if not nil, is called after a ClientHello is
// received from a client. It may return a non-nil Config in order to
// change the Config that will be used to handle this connection. If
// the returned Config is nil, the original Config will be used. The
// Config returned by this callback may not be subsequently modified.
//
// If GetConfigForClient is nil, the Config passed to Server() will be
// used for all connections.
//
// Uniquely for the fields in the returned Config, session ticket keys
// will be duplicated from the original Config if not set.
// Specifically, if SetSessionTicketKeys was called on the original
// config but not on the returned config then the ticket keys from the
// original config will be copied into the new config before use.
// Otherwise, if SessionTicketKey was set in the original config but
// not in the returned config then it will be copied into the returned
// config before use. If neither of those cases applies then the key
// material from the returned config will be used for session tickets.
GetConfigForClient func(*ClientHelloInfo) (*Config, error)
// RootCAs defines the set of root certificate authorities // RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates. // that clients use when verifying server certificates.
...@@ -398,13 +418,17 @@ type Config struct { ...@@ -398,13 +418,17 @@ type Config struct {
serverInitOnce sync.Once // guards calling (*Config).serverInit serverInitOnce sync.Once // guards calling (*Config).serverInit
// mutex protects sessionTicketKeys // mutex protects sessionTicketKeys and originalConfig.
mutex sync.RWMutex mutex sync.RWMutex
// sessionTicketKeys contains zero or more ticket keys. If the length // sessionTicketKeys contains zero or more ticket keys. If the length
// is zero, SessionTicketsDisabled must be true. The first key is used // is zero, SessionTicketsDisabled must be true. The first key is used
// for new tickets and any subsequent keys can be used to decrypt old // for new tickets and any subsequent keys can be used to decrypt old
// tickets. // tickets.
sessionTicketKeys []ticketKey sessionTicketKeys []ticketKey
// originalConfig is set to the Config that was passed to Server if
// this Config is returned by a GetConfigForClient callback. It's used
// by serverInit in order to copy session ticket keys if needed.
originalConfig *Config
} }
// ticketKeyNameLen is the number of bytes of identifier that is prepended to // ticketKeyNameLen is the number of bytes of identifier that is prepended to
...@@ -434,12 +458,18 @@ func ticketKeyFromBytes(b [32]byte) (key ticketKey) { ...@@ -434,12 +458,18 @@ func ticketKeyFromBytes(b [32]byte) (key ticketKey) {
// Clone returns a shallow clone of c. // Clone returns a shallow clone of c.
// Only the exported fields are copied. // Only the exported fields are copied.
func (c *Config) Clone() *Config { func (c *Config) Clone() *Config {
var sessionTicketKeys []ticketKey
c.mutex.RLock()
sessionTicketKeys = c.sessionTicketKeys
c.mutex.RUnlock()
return &Config{ return &Config{
Rand: c.Rand, Rand: c.Rand,
Time: c.Time, Time: c.Time,
Certificates: c.Certificates, Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate, NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate, GetCertificate: c.GetCertificate,
GetConfigForClient: c.GetConfigForClient,
RootCAs: c.RootCAs, RootCAs: c.RootCAs,
NextProtos: c.NextProtos, NextProtos: c.NextProtos,
ServerName: c.ServerName, ServerName: c.ServerName,
...@@ -457,6 +487,8 @@ func (c *Config) Clone() *Config { ...@@ -457,6 +487,8 @@ func (c *Config) Clone() *Config {
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation, Renegotiation: c.Renegotiation,
KeyLogWriter: c.KeyLogWriter, KeyLogWriter: c.KeyLogWriter,
sessionTicketKeys: sessionTicketKeys,
// originalConfig is deliberately not duplicated.
} }
} }
...@@ -465,6 +497,11 @@ func (c *Config) serverInit() { ...@@ -465,6 +497,11 @@ func (c *Config) serverInit() {
return return
} }
var originalConfig *Config
c.mutex.Lock()
originalConfig, c.originalConfig = c.originalConfig, nil
c.mutex.Unlock()
alreadySet := false alreadySet := false
for _, b := range c.SessionTicketKey { for _, b := range c.SessionTicketKey {
if b != 0 { if b != 0 {
...@@ -474,13 +511,21 @@ func (c *Config) serverInit() { ...@@ -474,13 +511,21 @@ func (c *Config) serverInit() {
} }
if !alreadySet { if !alreadySet {
if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { if originalConfig != nil {
copy(c.SessionTicketKey[:], originalConfig.SessionTicketKey[:])
} else if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
c.SessionTicketsDisabled = true c.SessionTicketsDisabled = true
return return
} }
} }
c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)} if originalConfig != nil {
originalConfig.mutex.RLock()
c.sessionTicketKeys = originalConfig.sessionTicketKeys
originalConfig.mutex.RUnlock()
} else {
c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)}
}
} }
func (c *Config) ticketKeys() []ticketKey { func (c *Config) ticketKeys() []ticketKey {
......
...@@ -37,11 +37,9 @@ type serverHandshakeState struct { ...@@ -37,11 +37,9 @@ type serverHandshakeState struct {
// serverHandshake performs a TLS handshake as a server. // serverHandshake performs a TLS handshake as a server.
// c.out.Mutex <= L; c.handshakeMutex <= L. // c.out.Mutex <= L; c.handshakeMutex <= L.
func (c *Conn) serverHandshake() error { func (c *Conn) serverHandshake() error {
config := c.config
// If this is the first server handshake, we generate a random key to // If this is the first server handshake, we generate a random key to
// encrypt the tickets with. // encrypt the tickets with.
config.serverInitOnce.Do(config.serverInit) c.config.serverInitOnce.Do(c.config.serverInit)
hs := serverHandshakeState{ hs := serverHandshakeState{
c: c, c: c,
...@@ -112,7 +110,6 @@ func (c *Conn) serverHandshake() error { ...@@ -112,7 +110,6 @@ func (c *Conn) serverHandshake() error {
// readClientHello reads a ClientHello message from the client and decides // readClientHello reads a ClientHello message from the client and decides
// whether we will perform session resumption. // whether we will perform session resumption.
func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
config := hs.c.config
c := hs.c c := hs.c
msg, err := c.readHandshake() msg, err := c.readHandshake()
...@@ -125,7 +122,29 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { ...@@ -125,7 +122,29 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
c.sendAlert(alertUnexpectedMessage) c.sendAlert(alertUnexpectedMessage)
return false, unexpectedMessageError(hs.clientHello, msg) return false, unexpectedMessageError(hs.clientHello, msg)
} }
c.vers, ok = config.mutualVersion(hs.clientHello.vers)
clientHelloInfo := &ClientHelloInfo{
CipherSuites: hs.clientHello.cipherSuites,
ServerName: hs.clientHello.serverName,
SupportedCurves: hs.clientHello.supportedCurves,
SupportedPoints: hs.clientHello.supportedPoints,
}
if c.config.GetConfigForClient != nil {
if newConfig, err := c.config.GetConfigForClient(clientHelloInfo); err != nil {
c.sendAlert(alertInternalError)
return false, err
} else if newConfig != nil {
newConfig.mutex.Lock()
newConfig.originalConfig = c.config
newConfig.mutex.Unlock()
newConfig.serverInitOnce.Do(newConfig.serverInit)
c.config = newConfig
}
}
c.vers, ok = c.config.mutualVersion(hs.clientHello.vers)
if !ok { if !ok {
c.sendAlert(alertProtocolVersion) c.sendAlert(alertProtocolVersion)
return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
...@@ -135,7 +154,7 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { ...@@ -135,7 +154,7 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
hs.hello = new(serverHelloMsg) hs.hello = new(serverHelloMsg)
supportedCurve := false supportedCurve := false
preferredCurves := config.curvePreferences() preferredCurves := c.config.curvePreferences()
Curves: Curves:
for _, curve := range hs.clientHello.supportedCurves { for _, curve := range hs.clientHello.supportedCurves {
for _, supported := range preferredCurves { for _, supported := range preferredCurves {
...@@ -171,7 +190,7 @@ Curves: ...@@ -171,7 +190,7 @@ Curves:
hs.hello.vers = c.vers hs.hello.vers = c.vers
hs.hello.random = make([]byte, 32) hs.hello.random = make([]byte, 32)
_, err = io.ReadFull(config.rand(), hs.hello.random) _, err = io.ReadFull(c.config.rand(), hs.hello.random)
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return false, err return false, err
...@@ -196,20 +215,15 @@ Curves: ...@@ -196,20 +215,15 @@ Curves:
} else { } else {
// Although sending an empty NPN extension is reasonable, Firefox has // Although sending an empty NPN extension is reasonable, Firefox has
// had a bug around this. Best to send nothing at all if // had a bug around this. Best to send nothing at all if
// config.NextProtos is empty. See // c.config.NextProtos is empty. See
// https://golang.org/issue/5445. // https://golang.org/issue/5445.
if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 { if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 {
hs.hello.nextProtoNeg = true hs.hello.nextProtoNeg = true
hs.hello.nextProtos = config.NextProtos hs.hello.nextProtos = c.config.NextProtos
} }
} }
hs.cert, err = config.getCertificate(&ClientHelloInfo{ hs.cert, err = c.config.getCertificate(clientHelloInfo)
CipherSuites: hs.clientHello.cipherSuites,
ServerName: hs.clientHello.serverName,
SupportedCurves: hs.clientHello.supportedCurves,
SupportedPoints: hs.clientHello.supportedPoints,
})
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return false, err return false, err
...@@ -354,18 +368,17 @@ func (hs *serverHandshakeState) doResumeHandshake() error { ...@@ -354,18 +368,17 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
} }
func (hs *serverHandshakeState) doFullHandshake() error { func (hs *serverHandshakeState) doFullHandshake() error {
config := hs.c.config
c := hs.c c := hs.c
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
hs.hello.ocspStapling = true hs.hello.ocspStapling = true
} }
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !config.SessionTicketsDisabled hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
hs.hello.cipherSuite = hs.suite.id hs.hello.cipherSuite = hs.suite.id
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
if config.ClientAuth == NoClientCert { if c.config.ClientAuth == NoClientCert {
// No need to keep a full record of the handshake if client // No need to keep a full record of the handshake if client
// certificates won't be used. // certificates won't be used.
hs.finishedHash.discardHandshakeBuffer() hs.finishedHash.discardHandshakeBuffer()
...@@ -394,7 +407,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -394,7 +407,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
} }
keyAgreement := hs.suite.ka(c.vers) keyAgreement := hs.suite.ka(c.vers)
skx, err := keyAgreement.generateServerKeyExchange(config, hs.cert, hs.clientHello, hs.hello) skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
if err != nil { if err != nil {
c.sendAlert(alertHandshakeFailure) c.sendAlert(alertHandshakeFailure)
return err return err
...@@ -406,7 +419,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -406,7 +419,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
} }
} }
if config.ClientAuth >= RequestClientCert { if c.config.ClientAuth >= RequestClientCert {
// Request a client certificate // Request a client certificate
certReq := new(certificateRequestMsg) certReq := new(certificateRequestMsg)
certReq.certificateTypes = []byte{ certReq.certificateTypes = []byte{
...@@ -423,8 +436,8 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -423,8 +436,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
// to our request. When we know the CAs we trust, then // to our request. When we know the CAs we trust, then
// we can send them down, so that the client can choose // we can send them down, so that the client can choose
// an appropriate certificate to give to us. // an appropriate certificate to give to us.
if config.ClientCAs != nil { if c.config.ClientCAs != nil {
certReq.certificateAuthorities = config.ClientCAs.Subjects() certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
} }
hs.finishedHash.Write(certReq.marshal()) hs.finishedHash.Write(certReq.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
...@@ -452,7 +465,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -452,7 +465,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
var ok bool var ok bool
// If we requested a client certificate, then the client must send a // If we requested a client certificate, then the client must send a
// certificate message, even if it's empty. // certificate message, even if it's empty.
if config.ClientAuth >= RequestClientCert { if c.config.ClientAuth >= RequestClientCert {
if certMsg, ok = msg.(*certificateMsg); !ok { if certMsg, ok = msg.(*certificateMsg); !ok {
c.sendAlert(alertUnexpectedMessage) c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg) return unexpectedMessageError(certMsg, msg)
...@@ -461,7 +474,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -461,7 +474,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
if len(certMsg.certificates) == 0 { if len(certMsg.certificates) == 0 {
// The client didn't actually send a certificate // The client didn't actually send a certificate
switch config.ClientAuth { switch c.config.ClientAuth {
case RequireAnyClientCert, RequireAndVerifyClientCert: case RequireAnyClientCert, RequireAndVerifyClientCert:
c.sendAlert(alertBadCertificate) c.sendAlert(alertBadCertificate)
return errors.New("tls: client didn't provide a certificate") return errors.New("tls: client didn't provide a certificate")
...@@ -487,13 +500,13 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -487,13 +500,13 @@ func (hs *serverHandshakeState) doFullHandshake() error {
} }
hs.finishedHash.Write(ckx.marshal()) hs.finishedHash.Write(ckx.marshal())
preMasterSecret, err := keyAgreement.processClientKeyExchange(config, hs.cert, ckx, c.vers) preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
if err != nil { if err != nil {
c.sendAlert(alertHandshakeFailure) c.sendAlert(alertHandshakeFailure)
return err return err
} }
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
if err := config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil { if err := c.config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return err return err
} }
......
...@@ -1141,6 +1141,141 @@ func TestSNIGivenOnFailure(t *testing.T) { ...@@ -1141,6 +1141,141 @@ func TestSNIGivenOnFailure(t *testing.T) {
} }
} }
var getConfigForClientTests = []struct {
setup func(config *Config)
callback func(clientHello *ClientHelloInfo) (*Config, error)
errorSubstring string
verify func(config *Config) error
}{
{
nil,
func(clientHello *ClientHelloInfo) (*Config, error) {
return nil, nil
},
"",
nil,
},
{
nil,
func(clientHello *ClientHelloInfo) (*Config, error) {
return nil, errors.New("should bubble up")
},
"should bubble up",
nil,
},
{
nil,
func(clientHello *ClientHelloInfo) (*Config, error) {
config := testConfig.Clone()
// Setting a maximum version of TLS 1.1 should cause
// the handshake to fail.
config.MaxVersion = VersionTLS11
return config, nil
},
"version 301 when expecting version 302",
nil,
},
{
func(config *Config) {
for i := range config.SessionTicketKey {
config.SessionTicketKey[i] = byte(i)
}
config.sessionTicketKeys = nil
},
func(clientHello *ClientHelloInfo) (*Config, error) {
config := testConfig.Clone()
for i := range config.SessionTicketKey {
config.SessionTicketKey[i] = 0
}
config.sessionTicketKeys = nil
return config, nil
},
"",
func(config *Config) error {
// The value of SessionTicketKey should have been
// duplicated into the per-connection Config.
for i := range config.SessionTicketKey {
if b := config.SessionTicketKey[i]; b != byte(i) {
return fmt.Errorf("SessionTicketKey was not duplicated from original Config: byte %d has value %d", i, b)
}
}
return nil
},
},
{
func(config *Config) {
var dummyKey [32]byte
for i := range dummyKey {
dummyKey[i] = byte(i)
}
config.SetSessionTicketKeys([][32]byte{dummyKey})
},
func(clientHello *ClientHelloInfo) (*Config, error) {
config := testConfig.Clone()
config.sessionTicketKeys = nil
return config, nil
},
"",
func(config *Config) error {
// The session ticket keys should have been duplicated
// into the per-connection Config.
if l := len(config.sessionTicketKeys); l != 1 {
return fmt.Errorf("got len(sessionTicketKeys) == %d, wanted 1", l)
}
return nil
},
},
}
func TestGetConfigForClient(t *testing.T) {
serverConfig := testConfig.Clone()
clientConfig := testConfig.Clone()
clientConfig.MinVersion = VersionTLS12
for i, test := range getConfigForClientTests {
if test.setup != nil {
test.setup(serverConfig)
}
var configReturned *Config
serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) {
config, err := test.callback(clientHello)
configReturned = config
return config, err
}
c, s := net.Pipe()
done := make(chan error)
go func() {
defer s.Close()
done <- Server(s, serverConfig).Handshake()
}()
clientErr := Client(c, clientConfig).Handshake()
c.Close()
serverErr := <-done
if len(test.errorSubstring) == 0 {
if serverErr != nil || clientErr != nil {
t.Errorf("%#d: expected no error but got serverErr: %q, clientErr: %q", i, serverErr, clientErr)
}
if test.verify != nil {
if err := test.verify(configReturned); err != nil {
t.Errorf("#%d: verify returned error: %v", i, err)
}
}
} else {
if serverErr == nil {
t.Errorf("%#d: expected error containing %q but got no error", i, test.errorSubstring)
} else if !strings.Contains(serverErr.Error(), test.errorSubstring) {
t.Errorf("%#d: expected error to contain %q but it was %q", i, test.errorSubstring, serverErr)
}
}
}
}
func bigFromString(s string) *big.Int { func bigFromString(s string) *big.Int {
ret := new(big.Int) ret := new(big.Int)
ret.SetString(s, 10) ret.SetString(s, 10)
......
...@@ -477,7 +477,7 @@ func TestClone(t *testing.T) { ...@@ -477,7 +477,7 @@ func TestClone(t *testing.T) {
case "Rand": case "Rand":
f.Set(reflect.ValueOf(io.Reader(os.Stdin))) f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
continue continue
case "Time", "GetCertificate": case "Time", "GetCertificate", "GetConfigForClient":
// DeepEqual can't compare functions. // DeepEqual can't compare functions.
continue continue
case "Certificates": case "Certificates":
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment