Commit 9239f3cb authored by Jiri Tyr's avatar Jiri Tyr Committed by Toby Allen

Adding TLS client cert placeholders (#2217)

* Adding TLS client cert placeholders

* Use function to get the peer certificate

* Changing SHA1 to SHA256

* Use UTC instead of GMT

* Adding tests

* Adding getters for Protocol and Cipher
parent b7a7fd46
...@@ -16,6 +16,10 @@ package httpserver ...@@ -16,6 +16,10 @@ package httpserver
import ( import (
"bytes" "bytes"
"crypto/sha256"
"crypto/x509"
"encoding/pem"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
...@@ -243,6 +247,15 @@ func round(d, r time.Duration) time.Duration { ...@@ -243,6 +247,15 @@ func round(d, r time.Duration) time.Duration {
return d return d
} }
// getPeerCert returns peer certificate
func (r *replacer) getPeerCert() *x509.Certificate {
if r.request.TLS != nil && len(r.request.TLS.PeerCertificates) > 0 {
return r.request.TLS.PeerCertificates[0]
}
return nil
}
// getSubstitution retrieves value from corresponding key // getSubstitution retrieves value from corresponding key
func (r *replacer) getSubstitution(key string) string { func (r *replacer) getSubstitution(key string) string {
// search custom replacements first // search custom replacements first
...@@ -413,22 +426,80 @@ func (r *replacer) getSubstitution(key string) string { ...@@ -413,22 +426,80 @@ func (r *replacer) getSubstitution(key string) string {
return strconv.FormatInt(convertToMilliseconds(elapsedDuration), 10) return strconv.FormatInt(convertToMilliseconds(elapsedDuration), 10)
case "{tls_protocol}": case "{tls_protocol}":
if r.request.TLS != nil { if r.request.TLS != nil {
for k, v := range caddytls.SupportedProtocols { if name, err := caddytls.GetSupportedProtocolName(r.request.TLS.Version); err == nil {
if v == r.request.TLS.Version { return name
return k } else {
}
}
return "tls" // this should never happen, but guard in case return "tls" // this should never happen, but guard in case
} }
}
return r.emptyValue // because not using a secure channel return r.emptyValue // because not using a secure channel
case "{tls_cipher}": case "{tls_cipher}":
if r.request.TLS != nil { if r.request.TLS != nil {
for k, v := range caddytls.SupportedCiphersMap { if name, err := caddytls.GetSupportedCipherName(r.request.TLS.CipherSuite); err == nil {
if v == r.request.TLS.CipherSuite { return name
return k } else {
return "UNKNOWN" // this should never happen, but guard in case
} }
} }
return "UNKNOWN" // this should never happen, but guard in case return r.emptyValue
case "{tls_client_escaped_cert}":
cert := r.getPeerCert()
if cert != nil {
pemBlock := pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
return url.QueryEscape(string(pem.EncodeToMemory(&pemBlock)))
}
return r.emptyValue
case "{tls_client_fingerprint}":
cert := r.getPeerCert()
if cert != nil {
return fmt.Sprintf("%x", sha256.Sum256(cert.Raw))
}
return r.emptyValue
case "{tls_client_i_dn}":
cert := r.getPeerCert()
if cert != nil {
return cert.Issuer.String()
}
return r.emptyValue
case "{tls_client_raw_cert}":
cert := r.getPeerCert()
if cert != nil {
return string(cert.Raw)
}
return r.emptyValue
case "{tls_client_s_dn}":
cert := r.getPeerCert()
if cert != nil {
return cert.Subject.String()
}
return r.emptyValue
case "{tls_client_serial}":
cert := r.getPeerCert()
if cert != nil {
return fmt.Sprintf("%x", cert.SerialNumber)
}
return r.emptyValue
case "{tls_client_v_end}":
cert := r.getPeerCert()
if cert != nil {
return cert.NotAfter.In(time.UTC).Format("Jan 02 15:04:05 2006 MST")
}
return r.emptyValue
case "{tls_client_v_remain}":
cert := r.getPeerCert()
if cert != nil {
now := time.Now().In(time.UTC)
days := int64(cert.NotAfter.Sub(now).Seconds() / 86400)
return strconv.FormatInt(days, 10)
}
return r.emptyValue
case "{tls_client_v_start}":
cert := r.getPeerCert()
if cert != nil {
return cert.NotBefore.Format("Jan 02 15:04:05 2006 MST")
} }
return r.emptyValue return r.emptyValue
default: default:
......
...@@ -16,12 +16,21 @@ package httpserver ...@@ -16,12 +16,21 @@ package httpserver
import ( import (
"context" "context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"os" "os"
"strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/mholt/caddy/caddytls"
) )
func TestNewReplacer(t *testing.T) { func TestNewReplacer(t *testing.T) {
...@@ -147,6 +156,102 @@ func TestReplace(t *testing.T) { ...@@ -147,6 +156,102 @@ func TestReplace(t *testing.T) {
} }
} }
func TestTlsReplace(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
clientCertText := []byte(`-----BEGIN CERTIFICATE-----
MIIB9jCCAV+gAwIBAgIBAjANBgkqhkiG9w0BAQsFADAYMRYwFAYDVQQDDA1DYWRk
eSBUZXN0IENBMB4XDTE4MDcyNDIxMzUwNVoXDTI4MDcyMTIxMzUwNVowHTEbMBkG
A1UEAwwSY2xpZW50LmxvY2FsZG9tYWluMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
iQKBgQDFDEpzF0ew68teT3xDzcUxVFaTII+jXH1ftHXxxP4BEYBU4q90qzeKFneF
z83I0nC0WAQ45ZwHfhLMYHFzHPdxr6+jkvKPASf0J2v2HDJuTM1bHBbik5Ls5eq+
fVZDP8o/VHKSBKxNs8Goc2NTsr5b07QTIpkRStQK+RJALk4x9QIDAQABo0swSTAJ
BgNVHRMEAjAAMAsGA1UdDwQEAwIHgDAaBgNVHREEEzARgglsb2NhbGhvc3SHBH8A
AAEwEwYDVR0lBAwwCgYIKwYBBQUHAwIwDQYJKoZIhvcNAQELBQADgYEANSjz2Sk+
eqp31wM9il1n+guTNyxJd+FzVAH+hCZE5K+tCgVDdVFUlDEHHbS/wqb2PSIoouLV
3Q9fgDkiUod+uIK0IynzIKvw+Cjg+3nx6NQ0IM0zo8c7v398RzB4apbXKZyeeqUH
9fNwfEi+OoXR6s+upSKobCmLGLGi9Na5s5g=
-----END CERTIFICATE-----`)
block, _ := pem.Decode(clientCertText)
if block == nil {
t.Fatalf("failed to decode PEM certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatalf("failed to decode PEM certificate: %v", err)
}
request := &http.Request{
Method: "GET",
Host: "foo.com",
URL: &url.URL{
Scheme: "https",
Path: "/path/",
Host: "foo.com",
},
Header: http.Header{},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
RemoteAddr: "192.0.2.1:1234",
RequestURI: "https://foo.com/path/",
TLS: &tls.ConnectionState{
Version: tls.VersionTLS12,
HandshakeComplete: true,
ServerName: "foo.com",
CipherSuite: tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
PeerCertificates: []*x509.Certificate{cert},
},
}
repl := NewReplacer(request, recordRequest, "-")
now := time.Now().In(time.UTC)
days := int64(cert.NotAfter.Sub(now).Seconds() / 86400)
pemBlock := pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
protocol, _ := caddytls.GetSupportedProtocolName(request.TLS.Version)
cipher, _ := caddytls.GetSupportedCipherName(request.TLS.CipherSuite)
cEscapedCert := url.QueryEscape(string(pem.EncodeToMemory(&pemBlock)))
cFingerprint := fmt.Sprintf("%x", sha256.Sum256(cert.Raw))
cIDn := cert.Issuer.String()
cRawCert := string(cert.Raw)
cSDn := cert.Subject.String()
cSerial := fmt.Sprintf("%x", cert.SerialNumber)
cVEnd := cert.NotAfter.In(time.UTC).Format("Jan 02 15:04:05 2006 MST")
cVRemain := strconv.FormatInt(days, 10)
cVStart := cert.NotBefore.Format("Jan 02 15:04:05 2006 MST")
testCases := []struct {
template string
expect string
}{
{"{tls_protocol}", protocol},
{"{tls_cipher}", cipher},
{"{tls_client_escaped_cert}", cEscapedCert},
{"{tls_client_fingerprint}", cFingerprint},
{"{tls_client_i_dn}", cIDn},
{"{tls_client_raw_cert}", cRawCert},
{"{tls_client_s_dn}", cSDn},
{"{tls_client_serial}", cSerial},
{"{tls_client_v_end}", cVEnd},
{"{tls_client_v_remain}", cVRemain},
{"{tls_client_v_start}", cVStart},
}
for _, c := range testCases {
if expected, actual := c.expect, repl.Replace(c.template); expected != actual {
t.Errorf("for template '%s', expected '%s', got '%s'", c.template, expected, actual)
}
}
}
func BenchmarkReplace(b *testing.B) { func BenchmarkReplace(b *testing.B) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w) recordRequest := NewResponseRecorder(w)
......
...@@ -17,6 +17,7 @@ package caddytls ...@@ -17,6 +17,7 @@ package caddytls
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
...@@ -584,6 +585,17 @@ var SupportedProtocols = map[string]uint16{ ...@@ -584,6 +585,17 @@ var SupportedProtocols = map[string]uint16{
"tls1.2": tls.VersionTLS12, "tls1.2": tls.VersionTLS12,
} }
// GetSupportedProtocolName returns the protocol name
func GetSupportedProtocolName(protocol uint16) (string, error) {
for k, v := range SupportedProtocols {
if v == protocol {
return k, nil
}
}
return "", errors.New("name: unsuported protocol")
}
// Map of supported ciphers, used only for parsing config. // Map of supported ciphers, used only for parsing config.
// //
// Note that, at time of writing, HTTP/2 blacklists 276 cipher suites, // Note that, at time of writing, HTTP/2 blacklists 276 cipher suites,
...@@ -611,6 +623,17 @@ var SupportedCiphersMap = map[string]uint16{ ...@@ -611,6 +623,17 @@ var SupportedCiphersMap = map[string]uint16{
"RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, "RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
} }
// GetSupportedCipherName returns the cipher name
func GetSupportedCipherName(cipher uint16) (string, error) {
for k, v := range SupportedCiphersMap {
if v == cipher {
return k, nil
}
}
return "", errors.New("name: unsuported cipher")
}
// List of all the ciphers we want to use by default // List of all the ciphers we want to use by default
var defaultCiphers = []uint16{ var defaultCiphers = []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment