Commit e6e8b723 authored by Adam Langley's avatar Adam Langley

crypto/tls: don't always use the default private key.

When SNI based certificate selection is enabled, we previously used
the default private key even if we selected a non-default certificate.

Fixes #3367.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/5987058
parent 55af51d5
...@@ -23,8 +23,8 @@ type keyAgreement interface { ...@@ -23,8 +23,8 @@ type keyAgreement interface {
// In the case that the key agreement protocol doesn't use a // In the case that the key agreement protocol doesn't use a
// ServerKeyExchange message, generateServerKeyExchange can return nil, // ServerKeyExchange message, generateServerKeyExchange can return nil,
// nil. // nil.
generateServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
processClientKeyExchange(*Config, *clientKeyExchangeMsg, uint16) ([]byte, error) processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error)
// On the client side, the next two methods are called in order. // On the client side, the next two methods are called in order.
......
...@@ -112,37 +112,38 @@ FindCipherSuite: ...@@ -112,37 +112,38 @@ FindCipherSuite:
hello.nextProtoNeg = true hello.nextProtoNeg = true
hello.nextProtos = config.NextProtos hello.nextProtos = config.NextProtos
} }
if clientHello.ocspStapling && len(config.Certificates[0].OCSPStaple) > 0 {
hello.ocspStapling = true
}
finishedHash.Write(hello.marshal())
c.writeRecord(recordTypeHandshake, hello.marshal())
if len(config.Certificates) == 0 { if len(config.Certificates) == 0 {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
cert := &config.Certificates[0]
certMsg := new(certificateMsg)
if len(clientHello.serverName) > 0 { if len(clientHello.serverName) > 0 {
c.serverName = clientHello.serverName c.serverName = clientHello.serverName
certMsg.certificates = config.getCertificateForName(clientHello.serverName).Certificate cert = config.getCertificateForName(clientHello.serverName)
} else {
certMsg.certificates = config.Certificates[0].Certificate
} }
if clientHello.ocspStapling && len(cert.OCSPStaple) > 0 {
hello.ocspStapling = true
}
finishedHash.Write(hello.marshal())
c.writeRecord(recordTypeHandshake, hello.marshal())
certMsg := new(certificateMsg)
certMsg.certificates = cert.Certificate
finishedHash.Write(certMsg.marshal()) finishedHash.Write(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal()) c.writeRecord(recordTypeHandshake, certMsg.marshal())
if hello.ocspStapling { if hello.ocspStapling {
certStatus := new(certificateStatusMsg) certStatus := new(certificateStatusMsg)
certStatus.statusType = statusTypeOCSP certStatus.statusType = statusTypeOCSP
certStatus.response = config.Certificates[0].OCSPStaple certStatus.response = cert.OCSPStaple
finishedHash.Write(certStatus.marshal()) finishedHash.Write(certStatus.marshal())
c.writeRecord(recordTypeHandshake, certStatus.marshal()) c.writeRecord(recordTypeHandshake, certStatus.marshal())
} }
keyAgreement := suite.ka() keyAgreement := suite.ka()
skx, err := keyAgreement.generateServerKeyExchange(config, clientHello, hello) skx, err := keyAgreement.generateServerKeyExchange(config, cert, clientHello, hello)
if err != nil { if err != nil {
c.sendAlert(alertHandshakeFailure) c.sendAlert(alertHandshakeFailure)
return err return err
...@@ -288,7 +289,7 @@ FindCipherSuite: ...@@ -288,7 +289,7 @@ FindCipherSuite:
finishedHash.Write(certVerify.marshal()) finishedHash.Write(certVerify.marshal())
} }
preMasterSecret, err := keyAgreement.processClientKeyExchange(config, ckx, c.vers) preMasterSecret, err := keyAgreement.processClientKeyExchange(config, cert, ckx, c.vers)
if err != nil { if err != nil {
c.sendAlert(alertHandshakeFailure) c.sendAlert(alertHandshakeFailure)
return err return err
......
This diff is collapsed.
...@@ -20,11 +20,11 @@ import ( ...@@ -20,11 +20,11 @@ import (
// encrypts the pre-master secret to the server's public key. // encrypts the pre-master secret to the server's public key.
type rsaKeyAgreement struct{} type rsaKeyAgreement struct{}
func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
return nil, nil return nil, nil
} }
func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
preMasterSecret := make([]byte, 48) preMasterSecret := make([]byte, 48)
_, err := io.ReadFull(config.rand(), preMasterSecret[2:]) _, err := io.ReadFull(config.rand(), preMasterSecret[2:])
if err != nil { if err != nil {
...@@ -44,7 +44,7 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, ckx *clientKe ...@@ -44,7 +44,7 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, ckx *clientKe
ciphertext = ckx.ciphertext[2:] ciphertext = ckx.ciphertext[2:]
} }
err = rsa.DecryptPKCS1v15SessionKey(config.rand(), config.Certificates[0].PrivateKey.(*rsa.PrivateKey), ciphertext, preMasterSecret) err = rsa.DecryptPKCS1v15SessionKey(config.rand(), cert.PrivateKey.(*rsa.PrivateKey), ciphertext, preMasterSecret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -109,7 +109,7 @@ type ecdheRSAKeyAgreement struct { ...@@ -109,7 +109,7 @@ type ecdheRSAKeyAgreement struct {
x, y *big.Int x, y *big.Int
} }
func (ka *ecdheRSAKeyAgreement) generateServerKeyExchange(config *Config, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { func (ka *ecdheRSAKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
var curveid uint16 var curveid uint16
Curve: Curve:
...@@ -151,7 +151,7 @@ Curve: ...@@ -151,7 +151,7 @@ Curve:
copy(serverECDHParams[4:], ecdhePublic) copy(serverECDHParams[4:], ecdhePublic)
md5sha1 := md5SHA1Hash(clientHello.random, hello.random, serverECDHParams) md5sha1 := md5SHA1Hash(clientHello.random, hello.random, serverECDHParams)
sig, err := rsa.SignPKCS1v15(config.rand(), config.Certificates[0].PrivateKey.(*rsa.PrivateKey), crypto.MD5SHA1, md5sha1) sig, err := rsa.SignPKCS1v15(config.rand(), cert.PrivateKey.(*rsa.PrivateKey), crypto.MD5SHA1, md5sha1)
if err != nil { if err != nil {
return nil, errors.New("failed to sign ECDHE parameters: " + err.Error()) return nil, errors.New("failed to sign ECDHE parameters: " + err.Error())
} }
...@@ -167,7 +167,7 @@ Curve: ...@@ -167,7 +167,7 @@ Curve:
return skx, nil return skx, nil
} }
func (ka *ecdheRSAKeyAgreement) processClientKeyExchange(config *Config, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { func (ka *ecdheRSAKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
return nil, errors.New("bad ClientKeyExchange") return nil, errors.New("bad ClientKeyExchange")
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment