diff --git a/src/pkg/crypto/tls/conn.go b/src/pkg/crypto/tls/conn.go index fac65afd9cfb84807d79abdfcb99f19bd76898a3..3d018c0c7a6d903c5d9178eb1c010db7a34e055a 100644 --- a/src/pkg/crypto/tls/conn.go +++ b/src/pkg/crypto/tls/conn.go @@ -470,6 +470,19 @@ Again: if n > maxCiphertext { return c.sendAlert(alertRecordOverflow) } + if !c.haveVers { + // First message, be extra suspicious: + // this might not be a TLS client. + // Bail out before reading a full 'body', if possible. + // The current max version is 3.1. + // If the version is >= 16.0, it's probably not real. + // Similarly, a clientHello message encodes in + // well under a kilobyte. If the length is >= 12 kB, + // it's probably not real. + if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 { + return c.sendAlert(alertUnexpectedMessage) + } + } if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil { if err == os.EOF { err = io.ErrUnexpectedEOF diff --git a/src/pkg/crypto/tls/handshake_server_test.go b/src/pkg/crypto/tls/handshake_server_test.go index b77646e4383cb0da92785824033b1281ade80120..c1b37be275db2d3ca974ca1e28e0d1db8a42dace 100644 --- a/src/pkg/crypto/tls/handshake_server_test.go +++ b/src/pkg/crypto/tls/handshake_server_test.go @@ -142,6 +142,52 @@ func TestHandshakeServerAES(t *testing.T) { testServerScript(t, "AES", aesServerScript, aesConfig) } +func TestUnexpectedTLS(t *testing.T) { + l, err := Listen("tcp", "127.0.0.1:0", testConfig) + if err != nil { + t.Fatal(err) + } + ch := make(chan os.Error, 1) + done := make(chan bool) + go func() { + // Simulate HTTP client trying to do unencrypted HTTP on TLS port. + c, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + ch <- err + <-done + return + } + defer func() { + <-done + c.Close() + }() + _, err = c.Write([]byte("GET / HTTP/1.0\r\nHost: www.google.com\r\n\r\n")) + if err != nil { + ch <- err + return + } + ch <- nil + }() + + c, err := l.Accept() + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 100) + n, err := c.Read(buf) + if n > 0 || err == nil { + t.Errorf("TLS Read = %d, %v, want error", n, err) + } + t.Logf("%d, %v", n, err) + + err = <-ch + done <- true + if err != nil { + t.Errorf("TLS Write: %v", err) + } + +} + var serve = flag.Bool("serve", false, "run a TLS server on :10443") func TestRunServer(t *testing.T) {