Commit 5f5402b7 authored by Ian Gudger's avatar Ian Gudger Committed by Brad Fitzpatrick

net: fix handling of Conns created by Resolver.Dial

The DNS client in net is documented to treat Conns returned by
Resolver.Dial which implement PacketConn as UDP and those which don't as
TCP regardless of what was requested. golang.org/cl/37879 changed the
DNS client to assume that the Conn returned by Resolver.Dial was the
requested type which broke compatibility.

Fixes #26573
Updates #16218

Change-Id: Idf4f073a4cc3b1db36a3804898df206907f9c43c
Reviewed-on: https://go-review.googlesource.com/125735
Run-TryBot: Ian Gudger <igudger@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent c1e1e882
...@@ -137,10 +137,10 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que ...@@ -137,10 +137,10 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que
} }
var p dnsmessage.Parser var p dnsmessage.Parser
var h dnsmessage.Header var h dnsmessage.Header
if network == "tcp" { if _, ok := c.(PacketConn); ok {
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
} else {
p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
} else {
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
} }
c.Close() c.Close()
if err != nil { if err != nil {
......
...@@ -113,7 +113,7 @@ var specialDomainNameTests = []struct { ...@@ -113,7 +113,7 @@ var specialDomainNameTests = []struct {
} }
func TestSpecialDomainName(t *testing.T) { func TestSpecialDomainName(t *testing.T) {
fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{ r := dnsmessage.Message{
Header: dnsmessage.Header{ Header: dnsmessage.Header{
ID: q.ID, ID: q.ID,
...@@ -189,7 +189,7 @@ func TestAvoidDNSName(t *testing.T) { ...@@ -189,7 +189,7 @@ func TestAvoidDNSName(t *testing.T) {
} }
} }
var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{ r := dnsmessage.Message{
Header: dnsmessage.Header{ Header: dnsmessage.Header{
ID: q.ID, ID: q.ID,
...@@ -473,7 +473,7 @@ var goLookupIPWithResolverConfigTests = []struct { ...@@ -473,7 +473,7 @@ var goLookupIPWithResolverConfigTests = []struct {
func TestGoLookupIPWithResolverConfig(t *testing.T) { func TestGoLookupIPWithResolverConfig(t *testing.T) {
defer dnsWaitGroup.Wait() defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
switch s { switch s {
case "[2001:4860:4860::8888]:53", "8.8.8.8:53": case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
break break
...@@ -571,7 +571,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { ...@@ -571,7 +571,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
func TestGoLookupIPOrderFallbackToFile(t *testing.T) { func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
defer dnsWaitGroup.Wait() defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) { fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{ r := dnsmessage.Message{
Header: dnsmessage.Header{ Header: dnsmessage.Header{
ID: q.ID, ID: q.ID,
...@@ -641,7 +641,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { ...@@ -641,7 +641,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{ r := dnsmessage.Message{
Header: dnsmessage.Header{ Header: dnsmessage.Header{
ID: q.ID, ID: q.ID,
...@@ -696,7 +696,7 @@ func TestIgnoreLameReferrals(t *testing.T) { ...@@ -696,7 +696,7 @@ func TestIgnoreLameReferrals(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
t.Log(s, q) t.Log(s, q)
r := dnsmessage.Message{ r := dnsmessage.Message{
Header: dnsmessage.Header{ Header: dnsmessage.Header{
...@@ -788,12 +788,15 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { ...@@ -788,12 +788,15 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
} }
type fakeDNSServer struct { type fakeDNSServer struct {
rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
alwaysTCP bool
} }
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
tcp := n == "tcp" || n == "tcp4" || n == "tcp6" if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
}
return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
} }
type fakeDNSConn struct { type fakeDNSConn struct {
...@@ -846,10 +849,6 @@ func (f *fakeDNSConn) Read(b []byte) (int, error) { ...@@ -846,10 +849,6 @@ func (f *fakeDNSConn) Read(b []byte) (int, error) {
return len(bb), nil return len(bb), nil
} }
func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
return 0, nil, nil
}
func (f *fakeDNSConn) Write(b []byte) (int, error) { func (f *fakeDNSConn) Write(b []byte) (int, error) {
if f.tcp && len(b) >= 2 { if f.tcp && len(b) >= 2 {
b = b[2:] b = b[2:]
...@@ -860,15 +859,24 @@ func (f *fakeDNSConn) Write(b []byte) (int, error) { ...@@ -860,15 +859,24 @@ func (f *fakeDNSConn) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
return 0, nil
}
func (f *fakeDNSConn) SetDeadline(t time.Time) error { func (f *fakeDNSConn) SetDeadline(t time.Time) error {
f.t = t f.t = t
return nil return nil
} }
type fakeDNSPacketConn struct {
PacketConn
fakeDNSConn
}
func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
return f.fakeDNSConn.SetDeadline(t)
}
func (f *fakeDNSPacketConn) Close() error {
return f.fakeDNSConn.Close()
}
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
func TestIgnoreDNSForgeries(t *testing.T) { func TestIgnoreDNSForgeries(t *testing.T) {
c, s := Pipe() c, s := Pipe()
...@@ -973,7 +981,7 @@ func TestRetryTimeout(t *testing.T) { ...@@ -973,7 +981,7 @@ func TestRetryTimeout(t *testing.T) {
var deadline0 time.Time var deadline0 time.Time
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q, deadline) t.Log(s, q, deadline)
if deadline.IsZero() { if deadline.IsZero() {
...@@ -1034,7 +1042,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { ...@@ -1034,7 +1042,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
} }
var usedServers []string var usedServers []string
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
usedServers = append(usedServers, s) usedServers = append(usedServers, s)
return mockTXTResponse(q), nil return mockTXTResponse(q), nil
}} }}
...@@ -1218,7 +1226,7 @@ func TestStrictErrorsLookupIP(t *testing.T) { ...@@ -1218,7 +1226,7 @@ func TestStrictErrorsLookupIP(t *testing.T) {
} }
for i, tt := range cases { for i, tt := range cases {
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q) t.Log(s, q)
switch tt.resolveWhich(q.Questions[0]) { switch tt.resolveWhich(q.Questions[0]) {
...@@ -1356,7 +1364,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) { ...@@ -1356,7 +1364,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
const searchY = "test.y.golang.org." const searchY = "test.y.golang.org."
const txt = "Hello World" const txt = "Hello World"
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q) t.Log(s, q)
switch q.Questions[0].Name.String() { switch q.Questions[0].Name.String() {
...@@ -1402,7 +1410,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) { ...@@ -1402,7 +1410,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
func TestDNSGoroutineRace(t *testing.T) { func TestDNSGoroutineRace(t *testing.T) {
defer dnsWaitGroup.Wait() defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) { fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
time.Sleep(10 * time.Microsecond) time.Sleep(10 * time.Microsecond)
return dnsmessage.Message{}, poll.ErrTimeout return dnsmessage.Message{}, poll.ErrTimeout
}} }}
...@@ -1502,3 +1510,28 @@ func TestIssue12778(t *testing.T) { ...@@ -1502,3 +1510,28 @@ func TestIssue12778(t *testing.T) {
t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error()) t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
} }
} }
// Issue 26573: verify that Conns that don't implement PacketConn are treated
// as streams even when udp was requested.
func TestDNSDialTCP(t *testing.T) {
fake := fakeDNSServer{
rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.Header.ID,
Response: true,
RCode: dnsmessage.RCodeSuccess,
},
Questions: q.Questions,
}
return r, nil
},
alwaysTCP: true,
}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
ctx := context.Background()
_, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second)
if err != nil {
t.Fatal("exhange failed:", err)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment