Commit 2a8c81ff authored by Adam Langley's avatar Adam Langley Committed by Andrew Gerrand

crypto/tls: buffer handshake messages.

This change causes TLS handshake messages to be buffered and written in
a single Write to the underlying net.Conn.

There are two reasons to want to do this:

Firstly, it's slightly preferable to do this in order to save sending
several, small packets over the network where a single one will do.

Secondly, since 37c28759 errors from
Write have been returned from a handshake. This means that, if a peer
closes the connection during a handshake, a “broken pipe” error may
result from tls.Conn.Handshake(). This can mask any, more detailed,
fatal alerts that the peer may have sent because a read will never
happen.

Buffering handshake messages means that the peer will not receive, and
possibly reject, any of a flow while it's still being written.

Fixes #15709

Change-Id: I38dcff1abecc06e52b2de647ea98713ce0fb9a21
Reviewed-on: https://go-review.googlesource.com/23609Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
Run-TryBot: Andrew Gerrand <adg@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent f6c02419
...@@ -71,10 +71,12 @@ type Conn struct { ...@@ -71,10 +71,12 @@ type Conn struct {
clientProtocolFallback bool clientProtocolFallback bool
// input/output // input/output
in, out halfConn // in.Mutex < out.Mutex in, out halfConn // in.Mutex < out.Mutex
rawInput *block // raw input, right off the wire rawInput *block // raw input, right off the wire
input *block // application data waiting to be read input *block // application data waiting to be read
hand bytes.Buffer // handshake data waiting to be read hand bytes.Buffer // handshake data waiting to be read
buffering bool // whether records are buffered in sendBuf
sendBuf []byte // a buffer of records waiting to be sent
// bytesSent counts the bytes of application data sent. // bytesSent counts the bytes of application data sent.
// packetsSent counts packets. // packetsSent counts packets.
...@@ -803,6 +805,30 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int { ...@@ -803,6 +805,30 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
return n return n
} }
// c.out.Mutex <= L.
func (c *Conn) write(data []byte) (int, error) {
if c.buffering {
c.sendBuf = append(c.sendBuf, data...)
return len(data), nil
}
n, err := c.conn.Write(data)
c.bytesSent += int64(n)
return n, err
}
func (c *Conn) flush() (int, error) {
if len(c.sendBuf) == 0 {
return 0, nil
}
n, err := c.conn.Write(c.sendBuf)
c.bytesSent += int64(n)
c.sendBuf = nil
c.buffering = false
return n, err
}
// writeRecordLocked writes a TLS record with the given type and payload to the // writeRecordLocked writes a TLS record with the given type and payload to the
// connection and updates the record layer state. // connection and updates the record layer state.
// c.out.Mutex <= L. // c.out.Mutex <= L.
...@@ -862,10 +888,9 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { ...@@ -862,10 +888,9 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
} }
copy(b.data[recordHeaderLen+explicitIVLen:], data) copy(b.data[recordHeaderLen+explicitIVLen:], data)
c.out.encrypt(b, explicitIVLen) c.out.encrypt(b, explicitIVLen)
if _, err := c.conn.Write(b.data); err != nil { if _, err := c.write(b.data); err != nil {
return n, err return n, err
} }
c.bytesSent += int64(m)
n += m n += m
data = data[m:] data = data[m:]
} }
......
...@@ -206,6 +206,7 @@ NextCipherSuite: ...@@ -206,6 +206,7 @@ NextCipherSuite:
hs.finishedHash.Write(hs.hello.marshal()) hs.finishedHash.Write(hs.hello.marshal())
hs.finishedHash.Write(hs.serverHello.marshal()) hs.finishedHash.Write(hs.serverHello.marshal())
c.buffering = true
if isResume { if isResume {
if err := hs.establishKeys(); err != nil { if err := hs.establishKeys(); err != nil {
return err return err
...@@ -220,6 +221,9 @@ NextCipherSuite: ...@@ -220,6 +221,9 @@ NextCipherSuite:
if err := hs.sendFinished(c.clientFinished[:]); err != nil { if err := hs.sendFinished(c.clientFinished[:]); err != nil {
return err return err
} }
if _, err := c.flush(); err != nil {
return err
}
} else { } else {
if err := hs.doFullHandshake(); err != nil { if err := hs.doFullHandshake(); err != nil {
return err return err
...@@ -230,6 +234,9 @@ NextCipherSuite: ...@@ -230,6 +234,9 @@ NextCipherSuite:
if err := hs.sendFinished(c.clientFinished[:]); err != nil { if err := hs.sendFinished(c.clientFinished[:]); err != nil {
return err return err
} }
if _, err := c.flush(); err != nil {
return err
}
c.clientFinishedIsFirst = true c.clientFinishedIsFirst = true
if err := hs.readSessionTicket(); err != nil { if err := hs.readSessionTicket(); err != nil {
return err return err
......
...@@ -983,7 +983,7 @@ func (b *brokenConn) Write(data []byte) (int, error) { ...@@ -983,7 +983,7 @@ func (b *brokenConn) Write(data []byte) (int, error) {
func TestFailedWrite(t *testing.T) { func TestFailedWrite(t *testing.T) {
// Test that a write error during the handshake is returned. // Test that a write error during the handshake is returned.
for _, breakAfter := range []int{0, 1, 2, 3} { for _, breakAfter := range []int{0, 1} {
c, s := net.Pipe() c, s := net.Pipe()
done := make(chan bool) done := make(chan bool)
...@@ -1003,3 +1003,45 @@ func TestFailedWrite(t *testing.T) { ...@@ -1003,3 +1003,45 @@ func TestFailedWrite(t *testing.T) {
<-done <-done
} }
} }
// writeCountingConn wraps a net.Conn and counts the number of Write calls.
type writeCountingConn struct {
net.Conn
// numWrites is the number of writes that have been done.
numWrites int
}
func (wcc *writeCountingConn) Write(data []byte) (int, error) {
wcc.numWrites++
return wcc.Conn.Write(data)
}
func TestBuffering(t *testing.T) {
c, s := net.Pipe()
done := make(chan bool)
clientWCC := &writeCountingConn{Conn: c}
serverWCC := &writeCountingConn{Conn: s}
go func() {
Server(serverWCC, testConfig).Handshake()
serverWCC.Close()
done <- true
}()
err := Client(clientWCC, testConfig).Handshake()
if err != nil {
t.Fatal(err)
}
clientWCC.Close()
<-done
if n := clientWCC.numWrites; n != 2 {
t.Errorf("expected client handshake to complete with only two writes, but saw %d", n)
}
if n := serverWCC.numWrites; n != 2 {
t.Errorf("expected server handshake to complete with only two writes, but saw %d", n)
}
}
...@@ -52,6 +52,7 @@ func (c *Conn) serverHandshake() error { ...@@ -52,6 +52,7 @@ func (c *Conn) serverHandshake() error {
} }
// For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3 // For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3
c.buffering = true
if isResume { if isResume {
// The client has included a session ticket and so we do an abbreviated handshake. // The client has included a session ticket and so we do an abbreviated handshake.
if err := hs.doResumeHandshake(); err != nil { if err := hs.doResumeHandshake(); err != nil {
...@@ -71,6 +72,9 @@ func (c *Conn) serverHandshake() error { ...@@ -71,6 +72,9 @@ func (c *Conn) serverHandshake() error {
if err := hs.sendFinished(c.serverFinished[:]); err != nil { if err := hs.sendFinished(c.serverFinished[:]); err != nil {
return err return err
} }
if _, err := c.flush(); err != nil {
return err
}
c.clientFinishedIsFirst = false c.clientFinishedIsFirst = false
if err := hs.readFinished(nil); err != nil { if err := hs.readFinished(nil); err != nil {
return err return err
...@@ -89,12 +93,16 @@ func (c *Conn) serverHandshake() error { ...@@ -89,12 +93,16 @@ func (c *Conn) serverHandshake() error {
return err return err
} }
c.clientFinishedIsFirst = true c.clientFinishedIsFirst = true
c.buffering = true
if err := hs.sendSessionTicket(); err != nil { if err := hs.sendSessionTicket(); err != nil {
return err return err
} }
if err := hs.sendFinished(nil); err != nil { if err := hs.sendFinished(nil); err != nil {
return err return err
} }
if _, err := c.flush(); err != nil {
return err
}
} }
c.handshakeComplete = true c.handshakeComplete = true
...@@ -430,6 +438,10 @@ func (hs *serverHandshakeState) doFullHandshake() error { ...@@ -430,6 +438,10 @@ func (hs *serverHandshakeState) doFullHandshake() error {
return err return err
} }
if _, err := c.flush(); err != nil {
return err
}
var pub crypto.PublicKey // public key for client auth, if any var pub crypto.PublicKey // public key for client auth, if any
msg, err := c.readHandshake() msg, err := c.readHandshake()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment