Commit c581ec49 authored by Jeff R. Allen's avatar Jeff R. Allen Committed by Adam Langley

crypto/tls: Improve TLS Client Authentication

Fix incorrect marshal/unmarshal of certificateRequest.
Add support for configuring client-auth on the server side.
Fix the certificate selection in the client side.
Update generate_cert.go to new time package

Fixes #2521.

R=krautz, agl, bradfitz
CC=golang-dev, mikkel
https://golang.org/cl/5448093
parent 8f1cb093
...@@ -111,6 +111,18 @@ type ConnectionState struct { ...@@ -111,6 +111,18 @@ type ConnectionState struct {
VerifiedChains [][]*x509.Certificate VerifiedChains [][]*x509.Certificate
} }
// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
type ClientAuthType int
const (
NoClientCert ClientAuthType = iota
RequestClientCert
RequireAnyClientCert
VerifyClientCertIfGiven
RequireAndVerifyClientCert
)
// A Config structure is used to configure a TLS client or server. After one // A Config structure is used to configure a TLS client or server. After one
// has been passed to a TLS function it must not be modified. // has been passed to a TLS function it must not be modified.
type Config struct { type Config struct {
...@@ -120,7 +132,7 @@ type Config struct { ...@@ -120,7 +132,7 @@ type Config struct {
Rand io.Reader Rand io.Reader
// Time returns the current time as the number of seconds since the epoch. // Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses the system time.Seconds. // If Time is nil, TLS uses time.Now.
Time func() time.Time Time func() time.Time
// Certificates contains one or more certificate chains // Certificates contains one or more certificate chains
...@@ -148,11 +160,14 @@ type Config struct { ...@@ -148,11 +160,14 @@ type Config struct {
// hosting. // hosting.
ServerName string ServerName string
// AuthenticateClient controls whether a server will request a certificate // ClientAuth determines the server's policy for
// from the client. It does not require that the client send a // TLS Client Authentication. The default is NoClientCert.
// certificate nor does it require that the certificate sent be ClientAuth ClientAuthType
// anything more than self-signed.
AuthenticateClient bool // ClientCAs defines the set of root certificate authorities
// that servers use if required to verify a client certificate
// by the policy in ClientAuth.
ClientCAs *x509.CertPool
// InsecureSkipVerify controls whether a client verifies the // InsecureSkipVerify controls whether a client verifies the
// server's certificate chain and host name. // server's certificate chain and host name.
...@@ -259,6 +274,11 @@ type Certificate struct { ...@@ -259,6 +274,11 @@ type Certificate struct {
// OCSPStaple contains an optional OCSP response which will be served // OCSPStaple contains an optional OCSP response which will be served
// to clients that request it. // to clients that request it.
OCSPStaple []byte OCSPStaple []byte
// Leaf is the parsed form of the leaf certificate, which may be
// initialized using x509.ParseCertificate to reduce per-handshake
// processing for TLS clients doing client authentication. If nil, the
// leaf certificate will be parsed as needed.
Leaf *x509.Certificate
} }
// A TLS record. // A TLS record.
......
...@@ -5,12 +5,14 @@ ...@@ -5,12 +5,14 @@
package tls package tls
import ( import (
"bytes"
"crypto" "crypto"
"crypto/rsa" "crypto/rsa"
"crypto/subtle" "crypto/subtle"
"crypto/x509" "crypto/x509"
"errors" "errors"
"io" "io"
"strconv"
) )
func (c *Conn) clientHandshake() error { func (c *Conn) clientHandshake() error {
...@@ -162,10 +164,23 @@ func (c *Conn) clientHandshake() error { ...@@ -162,10 +164,23 @@ func (c *Conn) clientHandshake() error {
} }
} }
transmitCert := false var certToSend *Certificate
certReq, ok := msg.(*certificateRequestMsg) certReq, ok := msg.(*certificateRequestMsg)
if ok { if ok {
// We only accept certificates with RSA keys. // RFC 4346 on the certificateAuthorities field:
// A list of the distinguished names of acceptable certificate
// authorities. These distinguished names may specify a desired
// distinguished name for a root CA or for a subordinate CA;
// thus, this message can be used to describe both known roots
// and a desired authorization space. If the
// certificate_authorities list is empty then the client MAY
// send any certificate of the appropriate
// ClientCertificateType, unless there is some external
// arrangement to the contrary.
finishedHash.Write(certReq.marshal())
// For now, we only know how to sign challenges with RSA
rsaAvail := false rsaAvail := false
for _, certType := range certReq.certificateTypes { for _, certType := range certReq.certificateTypes {
if certType == certTypeRSASign { if certType == certTypeRSASign {
...@@ -174,23 +189,41 @@ func (c *Conn) clientHandshake() error { ...@@ -174,23 +189,41 @@ func (c *Conn) clientHandshake() error {
} }
} }
// For now, only send a certificate back if the server gives us an // We need to search our list of client certs for one
// empty list of certificateAuthorities. // where SignatureAlgorithm is RSA and the Issuer is in
// // certReq.certificateAuthorities
// RFC 4346 on the certificateAuthorities field: findCert:
// A list of the distinguished names of acceptable certificate for i, cert := range c.config.Certificates {
// authorities. These distinguished names may specify a desired if !rsaAvail {
// distinguished name for a root CA or for a subordinate CA; thus, continue
// this message can be used to describe both known roots and a
// desired authorization space. If the certificate_authorities
// list is empty then the client MAY send any certificate of the
// appropriate ClientCertificateType, unless there is some
// external arrangement to the contrary.
if rsaAvail && len(certReq.certificateAuthorities) == 0 {
transmitCert = true
} }
finishedHash.Write(certReq.marshal()) leaf := cert.Leaf
if leaf == nil {
if leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
}
}
if leaf.PublicKeyAlgorithm != x509.RSA {
continue
}
if len(certReq.certificateAuthorities) == 0 {
// they gave us an empty list, so just take the
// first RSA cert from c.config.Certificates
certToSend = &cert
break
}
for _, ca := range certReq.certificateAuthorities {
if bytes.Equal(leaf.RawIssuer, ca) {
certToSend = &cert
break findCert
}
}
}
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {
...@@ -204,17 +237,9 @@ func (c *Conn) clientHandshake() error { ...@@ -204,17 +237,9 @@ func (c *Conn) clientHandshake() error {
} }
finishedHash.Write(shd.marshal()) finishedHash.Write(shd.marshal())
var cert *x509.Certificate if certToSend != nil {
if transmitCert {
certMsg = new(certificateMsg) certMsg = new(certificateMsg)
if len(c.config.Certificates) > 0 { certMsg.certificates = certToSend.Certificate
cert, err = x509.ParseCertificate(c.config.Certificates[0].Certificate[0])
if err == nil && cert.PublicKeyAlgorithm == x509.RSA {
certMsg.certificates = c.config.Certificates[0].Certificate
} else {
cert = nil
}
}
finishedHash.Write(certMsg.marshal()) finishedHash.Write(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal()) c.writeRecord(recordTypeHandshake, certMsg.marshal())
} }
...@@ -229,7 +254,7 @@ func (c *Conn) clientHandshake() error { ...@@ -229,7 +254,7 @@ func (c *Conn) clientHandshake() error {
c.writeRecord(recordTypeHandshake, ckx.marshal()) c.writeRecord(recordTypeHandshake, ckx.marshal())
} }
if cert != nil { if certToSend != nil {
certVerify := new(certificateVerifyMsg) certVerify := new(certificateVerifyMsg)
digest := make([]byte, 0, 36) digest := make([]byte, 0, 36)
digest = finishedHash.serverMD5.Sum(digest) digest = finishedHash.serverMD5.Sum(digest)
......
...@@ -881,9 +881,11 @@ func (m *certificateRequestMsg) marshal() (x []byte) { ...@@ -881,9 +881,11 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
// See http://tools.ietf.org/html/rfc4346#section-7.4.4 // See http://tools.ietf.org/html/rfc4346#section-7.4.4
length := 1 + len(m.certificateTypes) + 2 length := 1 + len(m.certificateTypes) + 2
casLength := 0
for _, ca := range m.certificateAuthorities { for _, ca := range m.certificateAuthorities {
length += 2 + len(ca) casLength += 2 + len(ca)
} }
length += casLength
x = make([]byte, 4+length) x = make([]byte, 4+length)
x[0] = typeCertificateRequest x[0] = typeCertificateRequest
...@@ -895,10 +897,8 @@ func (m *certificateRequestMsg) marshal() (x []byte) { ...@@ -895,10 +897,8 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
copy(x[5:], m.certificateTypes) copy(x[5:], m.certificateTypes)
y := x[5+len(m.certificateTypes):] y := x[5+len(m.certificateTypes):]
y[0] = uint8(casLength >> 8)
numCA := len(m.certificateAuthorities) y[1] = uint8(casLength)
y[0] = uint8(numCA >> 8)
y[1] = uint8(numCA)
y = y[2:] y = y[2:]
for _, ca := range m.certificateAuthorities { for _, ca := range m.certificateAuthorities {
y[0] = uint8(len(ca) >> 8) y[0] = uint8(len(ca) >> 8)
...@@ -909,7 +909,6 @@ func (m *certificateRequestMsg) marshal() (x []byte) { ...@@ -909,7 +909,6 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
} }
m.raw = x m.raw = x
return return
} }
...@@ -937,31 +936,34 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { ...@@ -937,31 +936,34 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
} }
data = data[numCertTypes:] data = data[numCertTypes:]
if len(data) < 2 { if len(data) < 2 {
return false return false
} }
casLength := uint16(data[0])<<8 | uint16(data[1])
numCAs := uint16(data[0])<<16 | uint16(data[1])
data = data[2:] data = data[2:]
if len(data) < int(casLength) {
m.certificateAuthorities = make([][]byte, numCAs)
for i := uint16(0); i < numCAs; i++ {
if len(data) < 2 {
return false return false
} }
caLen := uint16(data[0])<<16 | uint16(data[1]) cas := make([]byte, casLength)
copy(cas, data)
data = data[casLength:]
data = data[2:] m.certificateAuthorities = nil
if len(data) < int(caLen) { for len(cas) > 0 {
if len(cas) < 2 {
return false return false
} }
caLen := uint16(cas[0])<<8 | uint16(cas[1])
cas = cas[2:]
ca := make([]byte, caLen) if len(cas) < int(caLen) {
copy(ca, data) return false
m.certificateAuthorities[i] = ca
data = data[caLen:]
} }
m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
cas = cas[caLen:]
}
if len(data) > 0 { if len(data) > 0 {
return false return false
} }
......
...@@ -150,14 +150,19 @@ FindCipherSuite: ...@@ -150,14 +150,19 @@ FindCipherSuite:
c.writeRecord(recordTypeHandshake, skx.marshal()) c.writeRecord(recordTypeHandshake, skx.marshal())
} }
if config.AuthenticateClient { if config.ClientAuth >= RequestClientCert {
// Request a client certificate // Request a client certificate
certReq := new(certificateRequestMsg) certReq := new(certificateRequestMsg)
certReq.certificateTypes = []byte{certTypeRSASign} certReq.certificateTypes = []byte{certTypeRSASign}
// An empty list of certificateAuthorities signals to // An empty list of certificateAuthorities signals to
// the client that it may send any certificate in response // the client that it may send any certificate in response
// to our request. // to our request. When we know the CAs we trust, then
// we can send them down, so that the client can choose
// an appropriate certificate to give to us.
if config.ClientCAs != nil {
certReq.certificateAuthorities = config.ClientCAs.Subjects()
}
finishedHash.Write(certReq.marshal()) finishedHash.Write(certReq.marshal())
c.writeRecord(recordTypeHandshake, certReq.marshal()) c.writeRecord(recordTypeHandshake, certReq.marshal())
} }
...@@ -166,52 +171,87 @@ FindCipherSuite: ...@@ -166,52 +171,87 @@ FindCipherSuite:
finishedHash.Write(helloDone.marshal()) finishedHash.Write(helloDone.marshal())
c.writeRecord(recordTypeHandshake, helloDone.marshal()) c.writeRecord(recordTypeHandshake, helloDone.marshal())
var pub *rsa.PublicKey var pub *rsa.PublicKey // public key for client auth, if any
if config.AuthenticateClient {
// Get client certificate
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {
return err return err
} }
certMsg, ok = msg.(*certificateMsg)
if !ok { // If we requested a client certificate, then the client must send a
return c.sendAlert(alertUnexpectedMessage) // certificate message, even if it's empty.
if config.ClientAuth >= RequestClientCert {
if certMsg, ok = msg.(*certificateMsg); !ok {
return c.sendAlert(alertHandshakeFailure)
} }
finishedHash.Write(certMsg.marshal()) finishedHash.Write(certMsg.marshal())
if len(certMsg.certificates) == 0 {
// The client didn't actually send a certificate
switch config.ClientAuth {
case RequireAnyClientCert, RequireAndVerifyClientCert:
c.sendAlert(alertBadCertificate)
return errors.New("tls: client didn't provide a certificate")
}
}
certs := make([]*x509.Certificate, len(certMsg.certificates)) certs := make([]*x509.Certificate, len(certMsg.certificates))
for i, asn1Data := range certMsg.certificates { for i, asn1Data := range certMsg.certificates {
cert, err := x509.ParseCertificate(asn1Data) if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
if err != nil {
c.sendAlert(alertBadCertificate) c.sendAlert(alertBadCertificate)
return errors.New("could not parse client's certificate: " + err.Error()) return errors.New("tls: failed to parse client certificate: " + err.Error())
}
}
if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
opts := x509.VerifyOptions{
Roots: c.config.ClientCAs,
CurrentTime: c.config.time(),
Intermediates: x509.NewCertPool(),
}
for i, cert := range certs {
if i == 0 {
continue
} }
certs[i] = cert opts.Intermediates.AddCert(cert)
} }
// TODO(agl): do better validation of certs: max path length, name restrictions etc. chains, err := certs[0].Verify(opts)
for i := 1; i < len(certs); i++ { if err != nil {
if err := certs[i-1].CheckSignatureFrom(certs[i]); err != nil {
c.sendAlert(alertBadCertificate) c.sendAlert(alertBadCertificate)
return errors.New("could not validate certificate signature: " + err.Error()) return errors.New("tls: failed to verify client's certificate: " + err.Error())
} }
ok := false
for _, ku := range certs[0].ExtKeyUsage {
if ku == x509.ExtKeyUsageClientAuth {
ok = true
break
}
}
if !ok {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client's certificate's extended key usage doesn't permit it to be used for client authentication")
}
c.verifiedChains = chains
} }
if len(certs) > 0 { if len(certs) > 0 {
key, ok := certs[0].PublicKey.(*rsa.PublicKey) if pub, ok = certs[0].PublicKey.(*rsa.PublicKey); !ok {
if !ok {
return c.sendAlert(alertUnsupportedCertificate) return c.sendAlert(alertUnsupportedCertificate)
} }
pub = key
c.peerCertificates = certs c.peerCertificates = certs
} }
}
// Get client key exchange
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {
return err return err
} }
}
// Get client key exchange
ckx, ok := msg.(*clientKeyExchangeMsg) ckx, ok := msg.(*clientKeyExchangeMsg)
if !ok { if !ok {
return c.sendAlert(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
......
This diff is collapsed.
...@@ -120,7 +120,7 @@ func Dial(network, addr string, config *Config) (*Conn, error) { ...@@ -120,7 +120,7 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
// LoadX509KeyPair reads and parses a public/private key pair from a pair of // LoadX509KeyPair reads and parses a public/private key pair from a pair of
// files. The files must contain PEM encoded data. // files. The files must contain PEM encoded data.
func LoadX509KeyPair(certFile string, keyFile string) (cert Certificate, err error) { func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
certPEMBlock, err := ioutil.ReadFile(certFile) certPEMBlock, err := ioutil.ReadFile(certFile)
if err != nil { if err != nil {
return return
......
...@@ -101,3 +101,13 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { ...@@ -101,3 +101,13 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
return return
} }
// Subjects returns a list of the DER-encoded subjects of
// all of the certificates in the pool.
func (s *CertPool) Subjects() (res [][]byte) {
res = make([][]byte, len(s.certs))
for i, c := range s.certs {
res[i] = c.RawSubject
}
return
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment