Commit 1df62ca6 authored by Russ Cox's avatar Russ Cox

crypto/tls: fix handshake message test

This test breaks when I make reflect.DeepEqual
distinguish empty slices from nil slices.

R=agl
CC=golang-dev
https://golang.org/cl/5369110
parent ba98a7ee
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
package tls package tls
import "bytes"
type clientHelloMsg struct { type clientHelloMsg struct {
raw []byte raw []byte
vers uint16 vers uint16
...@@ -18,6 +20,25 @@ type clientHelloMsg struct { ...@@ -18,6 +20,25 @@ type clientHelloMsg struct {
supportedPoints []uint8 supportedPoints []uint8
} }
func (m *clientHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*clientHelloMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
m.nextProtoNeg == m1.nextProtoNeg &&
m.serverName == m1.serverName &&
m.ocspStapling == m1.ocspStapling &&
eqUint16s(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints)
}
func (m *clientHelloMsg) marshal() []byte { func (m *clientHelloMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -309,6 +330,23 @@ type serverHelloMsg struct { ...@@ -309,6 +330,23 @@ type serverHelloMsg struct {
ocspStapling bool ocspStapling bool
} }
func (m *serverHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*serverHelloMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
m.cipherSuite == m1.cipherSuite &&
m.compressionMethod == m1.compressionMethod &&
m.nextProtoNeg == m1.nextProtoNeg &&
eqStrings(m.nextProtos, m1.nextProtos) &&
m.ocspStapling == m1.ocspStapling
}
func (m *serverHelloMsg) marshal() []byte { func (m *serverHelloMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -463,6 +501,16 @@ type certificateMsg struct { ...@@ -463,6 +501,16 @@ type certificateMsg struct {
certificates [][]byte certificates [][]byte
} }
func (m *certificateMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
eqByteSlices(m.certificates, m1.certificates)
}
func (m *certificateMsg) marshal() (x []byte) { func (m *certificateMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct { ...@@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct {
key []byte key []byte
} }
func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
m1, ok := i.(*serverKeyExchangeMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.key, m1.key)
}
func (m *serverKeyExchangeMsg) marshal() []byte { func (m *serverKeyExchangeMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -571,6 +629,17 @@ type certificateStatusMsg struct { ...@@ -571,6 +629,17 @@ type certificateStatusMsg struct {
response []byte response []byte
} }
func (m *certificateStatusMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateStatusMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.statusType == m1.statusType &&
bytes.Equal(m.response, m1.response)
}
func (m *certificateStatusMsg) marshal() []byte { func (m *certificateStatusMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { ...@@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
type serverHelloDoneMsg struct{} type serverHelloDoneMsg struct{}
func (m *serverHelloDoneMsg) equal(i interface{}) bool {
_, ok := i.(*serverHelloDoneMsg)
return ok
}
func (m *serverHelloDoneMsg) marshal() []byte { func (m *serverHelloDoneMsg) marshal() []byte {
x := make([]byte, 4) x := make([]byte, 4)
x[0] = typeServerHelloDone x[0] = typeServerHelloDone
...@@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct { ...@@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct {
ciphertext []byte ciphertext []byte
} }
func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
m1, ok := i.(*clientKeyExchangeMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.ciphertext, m1.ciphertext)
}
func (m *clientKeyExchangeMsg) marshal() []byte { func (m *clientKeyExchangeMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -671,6 +755,16 @@ type finishedMsg struct { ...@@ -671,6 +755,16 @@ type finishedMsg struct {
verifyData []byte verifyData []byte
} }
func (m *finishedMsg) equal(i interface{}) bool {
m1, ok := i.(*finishedMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.verifyData, m1.verifyData)
}
func (m *finishedMsg) marshal() (x []byte) { func (m *finishedMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -698,6 +792,16 @@ type nextProtoMsg struct { ...@@ -698,6 +792,16 @@ type nextProtoMsg struct {
proto string proto string
} }
func (m *nextProtoMsg) equal(i interface{}) bool {
m1, ok := i.(*nextProtoMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.proto == m1.proto
}
func (m *nextProtoMsg) marshal() []byte { func (m *nextProtoMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -759,6 +863,17 @@ type certificateRequestMsg struct { ...@@ -759,6 +863,17 @@ type certificateRequestMsg struct {
certificateAuthorities [][]byte certificateAuthorities [][]byte
} }
func (m *certificateRequestMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateRequestMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities)
}
func (m *certificateRequestMsg) marshal() (x []byte) { func (m *certificateRequestMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -859,6 +974,16 @@ type certificateVerifyMsg struct { ...@@ -859,6 +974,16 @@ type certificateVerifyMsg struct {
signature []byte signature []byte
} }
func (m *certificateVerifyMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateVerifyMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.signature, m1.signature)
}
func (m *certificateVerifyMsg) marshal() (x []byte) { func (m *certificateVerifyMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -902,3 +1027,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool { ...@@ -902,3 +1027,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
return true return true
} }
func eqUint16s(x, y []uint16) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if y[i] != v {
return false
}
}
return true
}
func eqStrings(x, y []string) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if y[i] != v {
return false
}
}
return true
}
func eqByteSlices(x, y [][]byte) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if !bytes.Equal(v, y[i]) {
return false
}
}
return true
}
...@@ -27,10 +27,12 @@ var tests = []interface{}{ ...@@ -27,10 +27,12 @@ var tests = []interface{}{
type testMessage interface { type testMessage interface {
marshal() []byte marshal() []byte
unmarshal([]byte) bool unmarshal([]byte) bool
equal(interface{}) bool
} }
func TestMarshalUnmarshal(t *testing.T) { func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0)) rand := rand.New(rand.NewSource(0))
for i, iface := range tests { for i, iface := range tests {
ty := reflect.ValueOf(iface).Type() ty := reflect.ValueOf(iface).Type()
...@@ -54,7 +56,7 @@ func TestMarshalUnmarshal(t *testing.T) { ...@@ -54,7 +56,7 @@ func TestMarshalUnmarshal(t *testing.T) {
} }
m2.marshal() // to fill any marshal cache in the message m2.marshal() // to fill any marshal cache in the message
if !reflect.DeepEqual(m1, m2) { if !m1.equal(m2) {
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
break break
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment