Commit ddf4b1fd authored by W. Mark Kubacki's avatar W. Mark Kubacki

Merge pull request #757 from mholt/extend-tls-client-auth

Extend tls client auth
parents 4e98cc30 69c2d78f
...@@ -83,10 +83,30 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { ...@@ -83,10 +83,30 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
c.TLS.Ciphers = append(c.TLS.Ciphers, value) c.TLS.Ciphers = append(c.TLS.Ciphers, value)
} }
case "clients": case "clients":
c.TLS.ClientCerts = c.RemainingArgs() clientCertList := c.RemainingArgs()
if len(c.TLS.ClientCerts) == 0 { if len(clientCertList) == 0 {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
listStart, mustProvideCA := 1, true
switch clientCertList[0] {
case "request":
c.TLS.ClientAuth = tls.RequestClientCert
mustProvideCA = false
case "require":
c.TLS.ClientAuth = tls.RequireAnyClientCert
mustProvideCA = false
case "verify_if_given":
c.TLS.ClientAuth = tls.VerifyClientCertIfGiven
default:
c.TLS.ClientAuth = tls.RequireAndVerifyClientCert
listStart = 0
}
if mustProvideCA && len(clientCertList) <= listStart {
return nil, c.ArgErr()
}
c.TLS.ClientCerts = clientCertList[listStart:]
case "load": case "load":
c.Args(&loadDir) c.Args(&loadDir)
c.TLS.Manual = true c.TLS.Manual = true
......
...@@ -189,33 +189,68 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { ...@@ -189,33 +189,68 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
} }
func TestSetupParseWithClientAuth(t *testing.T) { func TestSetupParseWithClientAuth(t *testing.T) {
// Test missing client cert file
params := `tls ` + certFile + ` ` + keyFile + ` { params := `tls ` + certFile + ` ` + keyFile + ` {
clients client_ca.crt client2_ca.crt clients
}` }`
c := setup.NewTestController(params) c := setup.NewTestController(params)
_, err := Setup(c) _, err := Setup(c)
if err != nil { if err == nil {
t.Errorf("Expected no errors, got: %v", err) t.Errorf("Expected an error, but no error returned")
} }
if count := len(c.TLS.ClientCerts); count != 2 { noCAs, twoCAs := []string{}, []string{"client_ca.crt", "client2_ca.crt"}
t.Fatalf("Expected two client certs, had %d", count) for caseNumber, caseData := range []struct {
} params string
if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" { clientAuthType tls.ClientAuthType
t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual) expectedErr bool
} expectedCAs []string
if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" { }{
t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual) {"", tls.NoClientCert, false, noCAs},
} {`tls ` + certFile + ` ` + keyFile + ` {
clients client_ca.crt client2_ca.crt
}`, tls.RequireAndVerifyClientCert, false, twoCAs},
// now come modifier
{`tls ` + certFile + ` ` + keyFile + ` {
clients request
}`, tls.RequestClientCert, false, noCAs},
{`tls ` + certFile + ` ` + keyFile + ` {
clients require
}`, tls.RequireAnyClientCert, false, noCAs},
{`tls ` + certFile + ` ` + keyFile + ` {
clients verify_if_given client_ca.crt client2_ca.crt
}`, tls.VerifyClientCertIfGiven, false, twoCAs},
{`tls ` + certFile + ` ` + keyFile + ` {
clients verify_if_given
}`, tls.VerifyClientCertIfGiven, true, noCAs},
} {
c := setup.NewTestController(caseData.params)
_, err := Setup(c)
if caseData.expectedErr {
if err == nil {
t.Errorf("In case %d: Expected an error, got: %v", caseNumber, err)
}
continue
}
if err != nil {
t.Errorf("In case %d: Expected no errors, got: %v", caseNumber, err)
}
// Test missing client cert file if caseData.clientAuthType != c.TLS.ClientAuth {
params = `tls ` + certFile + ` ` + keyFile + ` { t.Errorf("In case %d: Expected TLS client auth type %v, got: %v",
clients caseNumber, caseData.clientAuthType, c.TLS.ClientAuth)
}` }
c = setup.NewTestController(params)
_, err = Setup(c) if count := len(c.TLS.ClientCerts); count < len(caseData.expectedCAs) {
if err == nil { t.Fatalf("In case %d: Expected %d client certs, had %d", caseNumber, len(caseData.expectedCAs), count)
t.Errorf("Expected an error, but no error returned") }
for idx, expected := range caseData.expectedCAs {
if actual := c.TLS.ClientCerts[idx]; actual != expected {
t.Errorf("In case %d: Expected %dth client cert file to be '%s', but was '%s'",
caseNumber, idx, expected, actual)
}
}
} }
} }
......
package server package server
import ( import (
"crypto/tls"
"net" "net"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
...@@ -75,4 +76,5 @@ type TLSConfig struct { ...@@ -75,4 +76,5 @@ type TLSConfig struct {
ProtocolMaxVersion uint16 ProtocolMaxVersion uint16
PreferServerCipherSuites bool PreferServerCipherSuites bool
ClientCerts []string ClientCerts []string
ClientAuth tls.ClientAuthType
} }
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"path/filepath"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
...@@ -332,6 +333,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -332,6 +333,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
// Use URL.RawPath If you need the original, "raw" URL.Path in your middleware.
// Collapse any ./ ../ /// madness here instead of doing that in every plugin.
if r.URL.Path != "/" {
path := filepath.Clean(r.URL.Path)
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
r.URL.Path = path
}
// Execute the optional request callback if it exists and it's not disabled // Execute the optional request callback if it exists and it's not disabled
if s.ReqCallback != nil && !s.vhosts[host].config.TLS.Manual && s.ReqCallback(w, r) { if s.ReqCallback != nil && !s.vhosts[host].config.TLS.Manual && s.ReqCallback(w, r) {
return return
...@@ -368,17 +379,19 @@ func DefaultErrorFunc(w http.ResponseWriter, r *http.Request, status int) { ...@@ -368,17 +379,19 @@ func DefaultErrorFunc(w http.ResponseWriter, r *http.Request, status int) {
// setupClientAuth sets up TLS client authentication only if // setupClientAuth sets up TLS client authentication only if
// any of the TLS configs specified at least one cert file. // any of the TLS configs specified at least one cert file.
func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error { func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error {
var clientAuth bool whatClientAuth := tls.NoClientCert
for _, cfg := range tlsConfigs { for _, cfg := range tlsConfigs {
if len(cfg.ClientCerts) > 0 { if whatClientAuth < cfg.ClientAuth { // Use the most restrictive.
clientAuth = true whatClientAuth = cfg.ClientAuth
break
} }
} }
if clientAuth { if whatClientAuth != tls.NoClientCert {
pool := x509.NewCertPool() pool := x509.NewCertPool()
for _, cfg := range tlsConfigs { for _, cfg := range tlsConfigs {
if len(cfg.ClientCerts) == 0 {
continue
}
for _, caFile := range cfg.ClientCerts { for _, caFile := range cfg.ClientCerts {
caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect
if err != nil { if err != nil {
...@@ -390,7 +403,7 @@ func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error { ...@@ -390,7 +403,7 @@ func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error {
} }
} }
config.ClientCAs = pool config.ClientCAs = pool
config.ClientAuth = tls.RequireAndVerifyClientCert config.ClientAuth = whatClientAuth
} }
return nil return nil
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment