Commit 790053b2 authored by Mikio Hara's avatar Mikio Hara

net: filter destination addresses when source address is specified

This change filters out destination addresses by address family when
source address is specified to avoid running Dial operation with wrong
addressing scopes.

Fixes #11837.

Change-Id: I10b7a1fa325add2cd8ed58f105d527700a10d342
Reviewed-on: https://go-review.googlesource.com/20586Reviewed-by: default avatarPaul Marks <pmarks@google.com>
parent 76b724cc
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
package net package net
import ( import (
"errors"
"runtime" "runtime"
"time" "time"
) )
...@@ -140,8 +139,11 @@ func parseNetwork(net string) (afnet string, proto int, err error) { ...@@ -140,8 +139,11 @@ func parseNetwork(net string) (afnet string, proto int, err error) {
return "", 0, UnknownNetworkError(net) return "", 0, UnknownNetworkError(net)
} }
func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) { // resolverAddrList resolves addr using hint and returns a list of
afnet, _, err := parseNetwork(net) // addresses. The result contains at least one address when error is
// nil.
func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (addrList, error) {
afnet, _, err := parseNetwork(network)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -154,9 +156,59 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) ...@@ -154,9 +156,59 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if op == "dial" && hint != nil && addr.Network() != hint.Network() {
return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
}
return addrList{addr}, nil return addrList{addr}, nil
} }
return internetAddrList(afnet, addr, deadline) addrs, err := internetAddrList(afnet, addr, deadline)
if err != nil || op != "dial" || hint == nil {
return addrs, err
}
var (
tcp *TCPAddr
udp *UDPAddr
ip *IPAddr
wildcard bool
)
switch hint := hint.(type) {
case *TCPAddr:
tcp = hint
wildcard = tcp.isWildcard()
case *UDPAddr:
udp = hint
wildcard = udp.isWildcard()
case *IPAddr:
ip = hint
wildcard = ip.isWildcard()
}
naddrs := addrs[:0]
for _, addr := range addrs {
if addr.Network() != hint.Network() {
return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
}
switch addr := addr.(type) {
case *TCPAddr:
if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
continue
}
naddrs = append(naddrs, addr)
case *UDPAddr:
if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
continue
}
naddrs = append(naddrs, addr)
case *IPAddr:
if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
continue
}
naddrs = append(naddrs, addr)
}
}
if len(naddrs) == 0 {
return nil, errNoSuitableAddress
}
return naddrs, nil
} }
// Dial connects to the address on the named network. // Dial connects to the address on the named network.
...@@ -214,7 +266,7 @@ type dialContext struct { ...@@ -214,7 +266,7 @@ type dialContext struct {
// parameters. // parameters.
func (d *Dialer) Dial(network, address string) (Conn, error) { func (d *Dialer) Dial(network, address string) (Conn, error) {
finalDeadline := d.deadline(time.Now()) finalDeadline := d.deadline(time.Now())
addrs, err := resolveAddrList("dial", network, address, finalDeadline) addrs, err := resolveAddrList("dial", network, address, d.LocalAddr, finalDeadline)
if err != nil { if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err} return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
} }
...@@ -387,9 +439,6 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e ...@@ -387,9 +439,6 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e
// dial function, because some OSes don't implement the deadline feature. // dial function, because some OSes don't implement the deadline feature.
func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) { func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) {
la := ctx.LocalAddr la := ctx.LocalAddr
if la != nil && la.Network() != ra.Network() {
return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())}
}
switch ra := ra.(type) { switch ra := ra.(type) {
case *TCPAddr: case *TCPAddr:
la, _ := la.(*TCPAddr) la, _ := la.(*TCPAddr)
...@@ -420,7 +469,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan str ...@@ -420,7 +469,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan str
// instead of just the interface with the given host address. // instead of just the interface with the given host address.
// See Dial for more details about address syntax. // See Dial for more details about address syntax.
func Listen(net, laddr string) (Listener, error) { func Listen(net, laddr string) (Listener, error) {
addrs, err := resolveAddrList("listen", net, laddr, noDeadline) addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
if err != nil { if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
} }
...@@ -447,7 +496,7 @@ func Listen(net, laddr string) (Listener, error) { ...@@ -447,7 +496,7 @@ func Listen(net, laddr string) (Listener, error) {
// instead of just the interface with the given host address. // instead of just the interface with the given host address.
// See Dial for the syntax of laddr. // See Dial for the syntax of laddr.
func ListenPacket(net, laddr string) (PacketConn, error) { func ListenPacket(net, laddr string) (PacketConn, error) {
addrs, err := resolveAddrList("listen", net, laddr, noDeadline) addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
if err != nil { if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
} }
......
...@@ -646,41 +646,118 @@ func TestDialerPartialDeadline(t *testing.T) { ...@@ -646,41 +646,118 @@ func TestDialerPartialDeadline(t *testing.T) {
} }
} }
type dialerLocalAddrTest struct {
network, raddr string
laddr Addr
error
}
var dialerLocalAddrTests = []dialerLocalAddrTest{
{"tcp4", "127.0.0.1", nil, nil},
{"tcp4", "127.0.0.1", &TCPAddr{}, nil},
{"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
{"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
{"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, &AddrError{Err: "some error"}},
{"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, nil},
{"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, nil},
{"tcp4", "127.0.0.1", &TCPAddr{IP: IPv6loopback}, errNoSuitableAddress},
{"tcp4", "127.0.0.1", &UDPAddr{}, &AddrError{Err: "some error"}},
{"tcp4", "127.0.0.1", &UnixAddr{}, &AddrError{Err: "some error"}},
{"tcp6", "::1", nil, nil},
{"tcp6", "::1", &TCPAddr{}, nil},
{"tcp6", "::1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
{"tcp6", "::1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
{"tcp6", "::1", &TCPAddr{IP: ParseIP("::")}, nil},
{"tcp6", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, errNoSuitableAddress},
{"tcp6", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, errNoSuitableAddress},
{"tcp6", "::1", &TCPAddr{IP: IPv6loopback}, nil},
{"tcp6", "::1", &UDPAddr{}, &AddrError{Err: "some error"}},
{"tcp6", "::1", &UnixAddr{}, &AddrError{Err: "some error"}},
{"tcp", "127.0.0.1", nil, nil},
{"tcp", "127.0.0.1", &TCPAddr{}, nil},
{"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
{"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
{"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, nil},
{"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, nil},
{"tcp", "127.0.0.1", &TCPAddr{IP: IPv6loopback}, errNoSuitableAddress},
{"tcp", "127.0.0.1", &UDPAddr{}, &AddrError{Err: "some error"}},
{"tcp", "127.0.0.1", &UnixAddr{}, &AddrError{Err: "some error"}},
{"tcp", "::1", nil, nil},
{"tcp", "::1", &TCPAddr{}, nil},
{"tcp", "::1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
{"tcp", "::1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
{"tcp", "::1", &TCPAddr{IP: ParseIP("::")}, nil},
{"tcp", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, errNoSuitableAddress},
{"tcp", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, errNoSuitableAddress},
{"tcp", "::1", &TCPAddr{IP: IPv6loopback}, nil},
{"tcp", "::1", &UDPAddr{}, &AddrError{Err: "some error"}},
{"tcp", "::1", &UnixAddr{}, &AddrError{Err: "some error"}},
}
func TestDialerLocalAddr(t *testing.T) { func TestDialerLocalAddr(t *testing.T) {
ch := make(chan error, 1) if !supportsIPv4 || !supportsIPv6 {
t.Skip("both IPv4 and IPv6 are required")
}
if supportsIPv4map {
dialerLocalAddrTests = append(dialerLocalAddrTests, dialerLocalAddrTest{
"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, nil,
})
} else {
dialerLocalAddrTests = append(dialerLocalAddrTests, dialerLocalAddrTest{
"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, &AddrError{Err: "some error"},
})
}
origTestHookLookupIP := testHookLookupIP
defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = lookupLocalhost
handler := func(ls *localServer, ln Listener) { handler := func(ls *localServer, ln Listener) {
for {
c, err := ln.Accept() c, err := ln.Accept()
if err != nil { if err != nil {
ch <- err
return return
} }
defer c.Close() c.Close()
ch <- nil
} }
ls, err := newLocalServer("tcp") }
var err error
var lss [2]*localServer
for i, network := range []string{"tcp4", "tcp6"} {
lss[i], err = newLocalServer(network)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer ls.teardown() defer lss[i].teardown()
if err := ls.buildup(handler); err != nil { if err := lss[i].buildup(handler); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}
laddr, err := ResolveTCPAddr(ls.Listener.Addr().Network(), ls.Listener.Addr().String()) for _, tt := range dialerLocalAddrTests {
if err != nil { d := &Dialer{LocalAddr: tt.laddr}
t.Fatal(err) var addr string
ip := ParseIP(tt.raddr)
if ip.To4() != nil {
addr = lss[0].Listener.Addr().String()
} }
laddr.Port = 0 if ip.To16() != nil && ip.To4() == nil {
d := &Dialer{LocalAddr: laddr} addr = lss[1].Listener.Addr().String()
c, err := d.Dial(ls.Listener.Addr().Network(), ls.Addr().String()) }
if err != nil { c, err := d.Dial(tt.network, addr)
t.Fatal(err) if err == nil && tt.error != nil || err != nil && tt.error == nil {
t.Errorf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error)
} }
defer c.Close()
c.Read(make([]byte, 1))
err = <-ch
if err != nil { if err != nil {
t.Error(err) if perr := parseDialError(err); perr != nil {
t.Error(perr)
}
continue
}
c.Close()
} }
} }
......
...@@ -96,7 +96,7 @@ second: ...@@ -96,7 +96,7 @@ second:
goto third goto third
} }
switch nestedErr { switch nestedErr {
case errCanceled, errClosing, errMissingAddress: case errCanceled, errClosing, errMissingAddress, errNoSuitableAddress:
return nil return nil
} }
return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
...@@ -416,7 +416,7 @@ second: ...@@ -416,7 +416,7 @@ second:
goto third goto third
} }
switch nestedErr { switch nestedErr {
case errCanceled, errClosing, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF: case errCanceled, errClosing, errMissingAddress, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF:
return nil return nil
} }
return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
......
...@@ -377,6 +377,10 @@ func bytesEqual(x, y []byte) bool { ...@@ -377,6 +377,10 @@ func bytesEqual(x, y []byte) bool {
return true return true
} }
func (ip IP) matchAddrFamily(x IP) bool {
return ip.To4() != nil && x.To4() != nil || ip.To16() != nil && ip.To4() == nil && x.To16() != nil && x.To4() == nil
}
// If mask is a sequence of 1 bits followed by 0 bits, // If mask is a sequence of 1 bits followed by 0 bits,
// return the number of 1 bits. // return the number of 1 bits.
func simpleMaskLength(mask IPMask) int { func simpleMaskLength(mask IPMask) int {
......
...@@ -6,10 +6,7 @@ ...@@ -6,10 +6,7 @@
package net package net
import ( import "time"
"errors"
"time"
)
var ( var (
// supportsIPv4 reports whether the platform supports IPv4 // supportsIPv4 reports whether the platform supports IPv4
...@@ -73,8 +70,6 @@ func (addrs addrList) partition(strategy func(Addr) bool) (primaries, fallbacks ...@@ -73,8 +70,6 @@ func (addrs addrList) partition(strategy func(Addr) bool) (primaries, fallbacks
return return
} }
var errNoSuitableAddress = errors.New("no suitable address found")
// filterAddrList applies a filter to a list of IP addresses, // filterAddrList applies a filter to a list of IP addresses,
// yielding a list of Addr objects. Known filters are nil, ipv4only, // yielding a list of Addr objects. Known filters are nil, ipv4only,
// and ipv6only. It returns every address when the filter is nil. // and ipv6only. It returns every address when the filter is nil.
......
...@@ -364,6 +364,9 @@ type Error interface { ...@@ -364,6 +364,9 @@ type Error interface {
// Various errors contained in OpError. // Various errors contained in OpError.
var ( var (
// For connection setup operations.
errNoSuitableAddress = errors.New("no suitable address found")
// For connection setup and write operations. // For connection setup and write operations.
errMissingAddress = errors.New("missing address") errMissingAddress = errors.New("missing address")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment