Commit 9bcb5a57 authored by David S. Miller's avatar David S. Miller

Merge branch 'net-l3mdev-Support-for-sockets-bound-to-enslaved-device'

David Ahern says:

====================
net: l3mdev: Support for sockets bound to enslaved device

A missing piece to the VRF puzzle is the ability to bind sockets to
devices enslaved to a VRF. This patch set adds the enslaved device
index, sdif, to IPv4 and IPv6 socket lookups. The end result for users
is the following scope options for services:

1. "global" services - sockets not bound to any device

   Allows 1 service to work across all network interfaces with
   connected sockets bound to the VRF the connection originates
   (Requires net.ipv4.tcp_l3mdev_accept=1 for TCP and
    net.ipv4.udp_l3mdev_accept=1 for UDP)

2. "VRF" local services - sockets bound to a VRF

   Sockets work across all network interfaces enslaved to a VRF but
   are limited to just the one VRF.

3. "device" services - sockets bound to a specific network interface

   Service works only through the one specific interface.

v3
- convert __inet_lookup_established in dccp_v4_err; missed in v2

v2
- remove sk_lookup struct and add sdif as an argument to existing
  functions

Changes since RFC:
- no significant logic changes; mainly whitespace cleanups
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 46d4b68f 5108ab4b
...@@ -118,7 +118,8 @@ extern int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf, ...@@ -118,7 +118,8 @@ extern int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf,
struct ip_msfilter __user *optval, int __user *optlen); struct ip_msfilter __user *optval, int __user *optlen);
extern int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, extern int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf,
struct group_filter __user *optval, int __user *optlen); struct group_filter __user *optval, int __user *optlen);
extern int ip_mc_sf_allow(struct sock *sk, __be32 local, __be32 rmt, int dif); extern int ip_mc_sf_allow(struct sock *sk, __be32 local, __be32 rmt,
int dif, int sdif);
extern void ip_mc_init_dev(struct in_device *); extern void ip_mc_init_dev(struct in_device *);
extern void ip_mc_destroy_dev(struct in_device *); extern void ip_mc_destroy_dev(struct in_device *);
extern void ip_mc_up(struct in_device *); extern void ip_mc_up(struct in_device *);
......
...@@ -158,6 +158,16 @@ static inline bool inet6_is_jumbogram(const struct sk_buff *skb) ...@@ -158,6 +158,16 @@ static inline bool inet6_is_jumbogram(const struct sk_buff *skb)
return !!(IP6CB(skb)->flags & IP6SKB_JUMBOGRAM); return !!(IP6CB(skb)->flags & IP6SKB_JUMBOGRAM);
} }
/* can not be used in TCP layer after tcp_v6_fill_cb */
static inline int inet6_sdif(const struct sk_buff *skb)
{
#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
if (skb && ipv6_l3mdev_skb(IP6CB(skb)->flags))
return IP6CB(skb)->iif;
#endif
return 0;
}
/* can not be used in TCP layer after tcp_v6_fill_cb */ /* can not be used in TCP layer after tcp_v6_fill_cb */
static inline bool inet6_exact_dif_match(struct net *net, struct sk_buff *skb) static inline bool inet6_exact_dif_match(struct net *net, struct sk_buff *skb)
{ {
......
...@@ -49,7 +49,8 @@ struct sock *__inet6_lookup_established(struct net *net, ...@@ -49,7 +49,8 @@ struct sock *__inet6_lookup_established(struct net *net,
const struct in6_addr *saddr, const struct in6_addr *saddr,
const __be16 sport, const __be16 sport,
const struct in6_addr *daddr, const struct in6_addr *daddr,
const u16 hnum, const int dif); const u16 hnum, const int dif,
const int sdif);
struct sock *inet6_lookup_listener(struct net *net, struct sock *inet6_lookup_listener(struct net *net,
struct inet_hashinfo *hashinfo, struct inet_hashinfo *hashinfo,
...@@ -57,7 +58,8 @@ struct sock *inet6_lookup_listener(struct net *net, ...@@ -57,7 +58,8 @@ struct sock *inet6_lookup_listener(struct net *net,
const struct in6_addr *saddr, const struct in6_addr *saddr,
const __be16 sport, const __be16 sport,
const struct in6_addr *daddr, const struct in6_addr *daddr,
const unsigned short hnum, const int dif); const unsigned short hnum,
const int dif, const int sdif);
static inline struct sock *__inet6_lookup(struct net *net, static inline struct sock *__inet6_lookup(struct net *net,
struct inet_hashinfo *hashinfo, struct inet_hashinfo *hashinfo,
...@@ -66,24 +68,25 @@ static inline struct sock *__inet6_lookup(struct net *net, ...@@ -66,24 +68,25 @@ static inline struct sock *__inet6_lookup(struct net *net,
const __be16 sport, const __be16 sport,
const struct in6_addr *daddr, const struct in6_addr *daddr,
const u16 hnum, const u16 hnum,
const int dif, const int dif, const int sdif,
bool *refcounted) bool *refcounted)
{ {
struct sock *sk = __inet6_lookup_established(net, hashinfo, saddr, struct sock *sk = __inet6_lookup_established(net, hashinfo, saddr,
sport, daddr, hnum, dif); sport, daddr, hnum,
dif, sdif);
*refcounted = true; *refcounted = true;
if (sk) if (sk)
return sk; return sk;
*refcounted = false; *refcounted = false;
return inet6_lookup_listener(net, hashinfo, skb, doff, saddr, sport, return inet6_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
daddr, hnum, dif); daddr, hnum, dif, sdif);
} }
static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo, static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff, struct sk_buff *skb, int doff,
const __be16 sport, const __be16 sport,
const __be16 dport, const __be16 dport,
int iif, int iif, int sdif,
bool *refcounted) bool *refcounted)
{ {
struct sock *sk = skb_steal_sock(skb); struct sock *sk = skb_steal_sock(skb);
...@@ -95,7 +98,7 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo, ...@@ -95,7 +98,7 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb, return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
doff, &ipv6_hdr(skb)->saddr, sport, doff, &ipv6_hdr(skb)->saddr, sport,
&ipv6_hdr(skb)->daddr, ntohs(dport), &ipv6_hdr(skb)->daddr, ntohs(dport),
iif, refcounted); iif, sdif, refcounted);
} }
struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo, struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
...@@ -107,13 +110,14 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo, ...@@ -107,13 +110,14 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
int inet6_hash(struct sock *sk); int inet6_hash(struct sock *sk);
#endif /* IS_ENABLED(CONFIG_IPV6) */ #endif /* IS_ENABLED(CONFIG_IPV6) */
#define INET6_MATCH(__sk, __net, __saddr, __daddr, __ports, __dif) \ #define INET6_MATCH(__sk, __net, __saddr, __daddr, __ports, __dif, __sdif) \
(((__sk)->sk_portpair == (__ports)) && \ (((__sk)->sk_portpair == (__ports)) && \
((__sk)->sk_family == AF_INET6) && \ ((__sk)->sk_family == AF_INET6) && \
ipv6_addr_equal(&(__sk)->sk_v6_daddr, (__saddr)) && \ ipv6_addr_equal(&(__sk)->sk_v6_daddr, (__saddr)) && \
ipv6_addr_equal(&(__sk)->sk_v6_rcv_saddr, (__daddr)) && \ ipv6_addr_equal(&(__sk)->sk_v6_rcv_saddr, (__daddr)) && \
(!(__sk)->sk_bound_dev_if || \ (!(__sk)->sk_bound_dev_if || \
((__sk)->sk_bound_dev_if == (__dif))) && \ ((__sk)->sk_bound_dev_if == (__dif)) || \
((__sk)->sk_bound_dev_if == (__sdif))) && \
net_eq(sock_net(__sk), (__net))) net_eq(sock_net(__sk), (__net)))
#endif /* _INET6_HASHTABLES_H */ #endif /* _INET6_HASHTABLES_H */
...@@ -221,16 +221,16 @@ struct sock *__inet_lookup_listener(struct net *net, ...@@ -221,16 +221,16 @@ struct sock *__inet_lookup_listener(struct net *net,
const __be32 saddr, const __be16 sport, const __be32 saddr, const __be16 sport,
const __be32 daddr, const __be32 daddr,
const unsigned short hnum, const unsigned short hnum,
const int dif); const int dif, const int sdif);
static inline struct sock *inet_lookup_listener(struct net *net, static inline struct sock *inet_lookup_listener(struct net *net,
struct inet_hashinfo *hashinfo, struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff, struct sk_buff *skb, int doff,
__be32 saddr, __be16 sport, __be32 saddr, __be16 sport,
__be32 daddr, __be16 dport, int dif) __be32 daddr, __be16 dport, int dif, int sdif)
{ {
return __inet_lookup_listener(net, hashinfo, skb, doff, saddr, sport, return __inet_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
daddr, ntohs(dport), dif); daddr, ntohs(dport), dif, sdif);
} }
/* Socket demux engine toys. */ /* Socket demux engine toys. */
...@@ -262,22 +262,24 @@ static inline struct sock *inet_lookup_listener(struct net *net, ...@@ -262,22 +262,24 @@ static inline struct sock *inet_lookup_listener(struct net *net,
(((__force __u64)(__be32)(__daddr)) << 32) | \ (((__force __u64)(__be32)(__daddr)) << 32) | \
((__force __u64)(__be32)(__saddr))) ((__force __u64)(__be32)(__saddr)))
#endif /* __BIG_ENDIAN */ #endif /* __BIG_ENDIAN */
#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif) \ #define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif, __sdif) \
(((__sk)->sk_portpair == (__ports)) && \ (((__sk)->sk_portpair == (__ports)) && \
((__sk)->sk_addrpair == (__cookie)) && \ ((__sk)->sk_addrpair == (__cookie)) && \
(!(__sk)->sk_bound_dev_if || \ (!(__sk)->sk_bound_dev_if || \
((__sk)->sk_bound_dev_if == (__dif))) && \ ((__sk)->sk_bound_dev_if == (__dif)) || \
((__sk)->sk_bound_dev_if == (__sdif))) && \
net_eq(sock_net(__sk), (__net))) net_eq(sock_net(__sk), (__net)))
#else /* 32-bit arch */ #else /* 32-bit arch */
#define INET_ADDR_COOKIE(__name, __saddr, __daddr) \ #define INET_ADDR_COOKIE(__name, __saddr, __daddr) \
const int __name __deprecated __attribute__((unused)) const int __name __deprecated __attribute__((unused))
#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif) \ #define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif, __sdif) \
(((__sk)->sk_portpair == (__ports)) && \ (((__sk)->sk_portpair == (__ports)) && \
((__sk)->sk_daddr == (__saddr)) && \ ((__sk)->sk_daddr == (__saddr)) && \
((__sk)->sk_rcv_saddr == (__daddr)) && \ ((__sk)->sk_rcv_saddr == (__daddr)) && \
(!(__sk)->sk_bound_dev_if || \ (!(__sk)->sk_bound_dev_if || \
((__sk)->sk_bound_dev_if == (__dif))) && \ ((__sk)->sk_bound_dev_if == (__dif)) || \
((__sk)->sk_bound_dev_if == (__sdif))) && \
net_eq(sock_net(__sk), (__net))) net_eq(sock_net(__sk), (__net)))
#endif /* 64-bit arch */ #endif /* 64-bit arch */
...@@ -288,7 +290,7 @@ struct sock *__inet_lookup_established(struct net *net, ...@@ -288,7 +290,7 @@ struct sock *__inet_lookup_established(struct net *net,
struct inet_hashinfo *hashinfo, struct inet_hashinfo *hashinfo,
const __be32 saddr, const __be16 sport, const __be32 saddr, const __be16 sport,
const __be32 daddr, const u16 hnum, const __be32 daddr, const u16 hnum,
const int dif); const int dif, const int sdif);
static inline struct sock * static inline struct sock *
inet_lookup_established(struct net *net, struct inet_hashinfo *hashinfo, inet_lookup_established(struct net *net, struct inet_hashinfo *hashinfo,
...@@ -297,7 +299,7 @@ static inline struct sock * ...@@ -297,7 +299,7 @@ static inline struct sock *
const int dif) const int dif)
{ {
return __inet_lookup_established(net, hashinfo, saddr, sport, daddr, return __inet_lookup_established(net, hashinfo, saddr, sport, daddr,
ntohs(dport), dif); ntohs(dport), dif, 0);
} }
static inline struct sock *__inet_lookup(struct net *net, static inline struct sock *__inet_lookup(struct net *net,
...@@ -305,20 +307,20 @@ static inline struct sock *__inet_lookup(struct net *net, ...@@ -305,20 +307,20 @@ static inline struct sock *__inet_lookup(struct net *net,
struct sk_buff *skb, int doff, struct sk_buff *skb, int doff,
const __be32 saddr, const __be16 sport, const __be32 saddr, const __be16 sport,
const __be32 daddr, const __be16 dport, const __be32 daddr, const __be16 dport,
const int dif, const int dif, const int sdif,
bool *refcounted) bool *refcounted)
{ {
u16 hnum = ntohs(dport); u16 hnum = ntohs(dport);
struct sock *sk; struct sock *sk;
sk = __inet_lookup_established(net, hashinfo, saddr, sport, sk = __inet_lookup_established(net, hashinfo, saddr, sport,
daddr, hnum, dif); daddr, hnum, dif, sdif);
*refcounted = true; *refcounted = true;
if (sk) if (sk)
return sk; return sk;
*refcounted = false; *refcounted = false;
return __inet_lookup_listener(net, hashinfo, skb, doff, saddr, return __inet_lookup_listener(net, hashinfo, skb, doff, saddr,
sport, daddr, hnum, dif); sport, daddr, hnum, dif, sdif);
} }
static inline struct sock *inet_lookup(struct net *net, static inline struct sock *inet_lookup(struct net *net,
...@@ -332,7 +334,7 @@ static inline struct sock *inet_lookup(struct net *net, ...@@ -332,7 +334,7 @@ static inline struct sock *inet_lookup(struct net *net,
bool refcounted; bool refcounted;
sk = __inet_lookup(net, hashinfo, skb, doff, saddr, sport, daddr, sk = __inet_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
dport, dif, &refcounted); dport, dif, 0, &refcounted);
if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt)) if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL; sk = NULL;
...@@ -344,6 +346,7 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo, ...@@ -344,6 +346,7 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
int doff, int doff,
const __be16 sport, const __be16 sport,
const __be16 dport, const __be16 dport,
const int sdif,
bool *refcounted) bool *refcounted)
{ {
struct sock *sk = skb_steal_sock(skb); struct sock *sk = skb_steal_sock(skb);
...@@ -355,7 +358,7 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo, ...@@ -355,7 +358,7 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb, return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
doff, iph->saddr, sport, doff, iph->saddr, sport,
iph->daddr, dport, inet_iif(skb), iph->daddr, dport, inet_iif(skb), sdif,
refcounted); refcounted);
} }
......
...@@ -78,6 +78,16 @@ struct ipcm_cookie { ...@@ -78,6 +78,16 @@ struct ipcm_cookie {
#define IPCB(skb) ((struct inet_skb_parm*)((skb)->cb)) #define IPCB(skb) ((struct inet_skb_parm*)((skb)->cb))
#define PKTINFO_SKB_CB(skb) ((struct in_pktinfo *)((skb)->cb)) #define PKTINFO_SKB_CB(skb) ((struct in_pktinfo *)((skb)->cb))
/* return enslaved device index if relevant */
static inline int inet_sdif(struct sk_buff *skb)
{
#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
if (skb && ipv4_l3mdev_skb(IPCB(skb)->flags))
return IPCB(skb)->iif;
#endif
return 0;
}
struct ip_ra_chain { struct ip_ra_chain {
struct ip_ra_chain __rcu *next; struct ip_ra_chain __rcu *next;
struct sock *sk; struct sock *sk;
......
...@@ -26,7 +26,7 @@ extern struct proto raw_prot; ...@@ -26,7 +26,7 @@ extern struct proto raw_prot;
extern struct raw_hashinfo raw_v4_hashinfo; extern struct raw_hashinfo raw_v4_hashinfo;
struct sock *__raw_v4_lookup(struct net *net, struct sock *sk, struct sock *__raw_v4_lookup(struct net *net, struct sock *sk,
unsigned short num, __be32 raddr, unsigned short num, __be32 raddr,
__be32 laddr, int dif); __be32 laddr, int dif, int sdif);
int raw_abort(struct sock *sk, int err); int raw_abort(struct sock *sk, int err);
void raw_icmp_error(struct sk_buff *, int, u32); void raw_icmp_error(struct sk_buff *, int, u32);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
extern struct raw_hashinfo raw_v6_hashinfo; extern struct raw_hashinfo raw_v6_hashinfo;
struct sock *__raw_v6_lookup(struct net *net, struct sock *sk, struct sock *__raw_v6_lookup(struct net *net, struct sock *sk,
unsigned short num, const struct in6_addr *loc_addr, unsigned short num, const struct in6_addr *loc_addr,
const struct in6_addr *rmt_addr, int dif); const struct in6_addr *rmt_addr, int dif, int sdif);
int raw_abort(struct sock *sk, int err); int raw_abort(struct sock *sk, int err);
......
...@@ -827,6 +827,16 @@ static inline int tcp_v6_iif(const struct sk_buff *skb) ...@@ -827,6 +827,16 @@ static inline int tcp_v6_iif(const struct sk_buff *skb)
return l3_slave ? skb->skb_iif : TCP_SKB_CB(skb)->header.h6.iif; return l3_slave ? skb->skb_iif : TCP_SKB_CB(skb)->header.h6.iif;
} }
/* TCP_SKB_CB reference means this can not be used from early demux */
static inline int tcp_v6_sdif(const struct sk_buff *skb)
{
#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
if (skb && ipv6_l3mdev_skb(TCP_SKB_CB(skb)->header.h6.flags))
return TCP_SKB_CB(skb)->header.h6.iif;
#endif
return 0;
}
#endif #endif
/* TCP_SKB_CB reference means this can not be used from early demux */ /* TCP_SKB_CB reference means this can not be used from early demux */
...@@ -840,6 +850,16 @@ static inline bool inet_exact_dif_match(struct net *net, struct sk_buff *skb) ...@@ -840,6 +850,16 @@ static inline bool inet_exact_dif_match(struct net *net, struct sk_buff *skb)
return false; return false;
} }
/* TCP_SKB_CB reference means this can not be used from early demux */
static inline int tcp_v4_sdif(struct sk_buff *skb)
{
#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
if (skb && ipv4_l3mdev_skb(TCP_SKB_CB(skb)->header.h4.flags))
return TCP_SKB_CB(skb)->header.h4.iif;
#endif
return 0;
}
/* Due to TSO, an SKB can be composed of multiple actual /* Due to TSO, an SKB can be composed of multiple actual
* packets. To keep these tracked properly, we use this. * packets. To keep these tracked properly, we use this.
*/ */
......
...@@ -287,7 +287,7 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname, ...@@ -287,7 +287,7 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
__be32 daddr, __be16 dport, int dif); __be32 daddr, __be16 dport, int dif);
struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
__be32 daddr, __be16 dport, int dif, __be32 daddr, __be16 dport, int dif, int sdif,
struct udp_table *tbl, struct sk_buff *skb); struct udp_table *tbl, struct sk_buff *skb);
struct sock *udp4_lib_lookup_skb(struct sk_buff *skb, struct sock *udp4_lib_lookup_skb(struct sk_buff *skb,
__be16 sport, __be16 dport); __be16 sport, __be16 dport);
...@@ -298,7 +298,7 @@ struct sock *udp6_lib_lookup(struct net *net, ...@@ -298,7 +298,7 @@ struct sock *udp6_lib_lookup(struct net *net,
struct sock *__udp6_lib_lookup(struct net *net, struct sock *__udp6_lib_lookup(struct net *net,
const struct in6_addr *saddr, __be16 sport, const struct in6_addr *saddr, __be16 sport,
const struct in6_addr *daddr, __be16 dport, const struct in6_addr *daddr, __be16 dport,
int dif, struct udp_table *tbl, int dif, int sdif, struct udp_table *tbl,
struct sk_buff *skb); struct sk_buff *skb);
struct sock *udp6_lib_lookup_skb(struct sk_buff *skb, struct sock *udp6_lib_lookup_skb(struct sk_buff *skb,
__be16 sport, __be16 dport); __be16 sport, __be16 dport);
......
...@@ -256,7 +256,7 @@ static void dccp_v4_err(struct sk_buff *skb, u32 info) ...@@ -256,7 +256,7 @@ static void dccp_v4_err(struct sk_buff *skb, u32 info)
sk = __inet_lookup_established(net, &dccp_hashinfo, sk = __inet_lookup_established(net, &dccp_hashinfo,
iph->daddr, dh->dccph_dport, iph->daddr, dh->dccph_dport,
iph->saddr, ntohs(dh->dccph_sport), iph->saddr, ntohs(dh->dccph_sport),
inet_iif(skb)); inet_iif(skb), 0);
if (!sk) { if (!sk) {
__ICMP_INC_STATS(net, ICMP_MIB_INERRORS); __ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
return; return;
...@@ -804,7 +804,7 @@ static int dccp_v4_rcv(struct sk_buff *skb) ...@@ -804,7 +804,7 @@ static int dccp_v4_rcv(struct sk_buff *skb)
lookup: lookup:
sk = __inet_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh), sk = __inet_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh),
dh->dccph_sport, dh->dccph_dport, &refcounted); dh->dccph_sport, dh->dccph_dport, 0, &refcounted);
if (!sk) { if (!sk) {
dccp_pr_debug("failed to look up flow ID in table and " dccp_pr_debug("failed to look up flow ID in table and "
"get corresponding socket\n"); "get corresponding socket\n");
......
...@@ -89,7 +89,7 @@ static void dccp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, ...@@ -89,7 +89,7 @@ static void dccp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
sk = __inet6_lookup_established(net, &dccp_hashinfo, sk = __inet6_lookup_established(net, &dccp_hashinfo,
&hdr->daddr, dh->dccph_dport, &hdr->daddr, dh->dccph_dport,
&hdr->saddr, ntohs(dh->dccph_sport), &hdr->saddr, ntohs(dh->dccph_sport),
inet6_iif(skb)); inet6_iif(skb), 0);
if (!sk) { if (!sk) {
__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev), __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
...@@ -687,7 +687,7 @@ static int dccp_v6_rcv(struct sk_buff *skb) ...@@ -687,7 +687,7 @@ static int dccp_v6_rcv(struct sk_buff *skb)
lookup: lookup:
sk = __inet6_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh), sk = __inet6_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh),
dh->dccph_sport, dh->dccph_dport, dh->dccph_sport, dh->dccph_dport,
inet6_iif(skb), &refcounted); inet6_iif(skb), 0, &refcounted);
if (!sk) { if (!sk) {
dccp_pr_debug("failed to look up flow ID in table and " dccp_pr_debug("failed to look up flow ID in table and "
"get corresponding socket\n"); "get corresponding socket\n");
......
...@@ -2549,7 +2549,8 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, ...@@ -2549,7 +2549,8 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf,
/* /*
* check if a multicast source filter allows delivery for a given <src,dst,intf> * check if a multicast source filter allows delivery for a given <src,dst,intf>
*/ */
int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr, int dif) int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr,
int dif, int sdif)
{ {
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
struct ip_mc_socklist *pmc; struct ip_mc_socklist *pmc;
...@@ -2564,7 +2565,8 @@ int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr, int dif) ...@@ -2564,7 +2565,8 @@ int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr, int dif)
rcu_read_lock(); rcu_read_lock();
for_each_pmc_rcu(inet, pmc) { for_each_pmc_rcu(inet, pmc) {
if (pmc->multi.imr_multiaddr.s_addr == loc_addr && if (pmc->multi.imr_multiaddr.s_addr == loc_addr &&
pmc->multi.imr_ifindex == dif) (pmc->multi.imr_ifindex == dif ||
(sdif && pmc->multi.imr_ifindex == sdif)))
break; break;
} }
ret = inet->mc_all; ret = inet->mc_all;
......
...@@ -170,7 +170,7 @@ EXPORT_SYMBOL_GPL(__inet_inherit_port); ...@@ -170,7 +170,7 @@ EXPORT_SYMBOL_GPL(__inet_inherit_port);
static inline int compute_score(struct sock *sk, struct net *net, static inline int compute_score(struct sock *sk, struct net *net,
const unsigned short hnum, const __be32 daddr, const unsigned short hnum, const __be32 daddr,
const int dif, bool exact_dif) const int dif, const int sdif, bool exact_dif)
{ {
int score = -1; int score = -1;
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
...@@ -185,9 +185,13 @@ static inline int compute_score(struct sock *sk, struct net *net, ...@@ -185,9 +185,13 @@ static inline int compute_score(struct sock *sk, struct net *net,
score += 4; score += 4;
} }
if (sk->sk_bound_dev_if || exact_dif) { if (sk->sk_bound_dev_if || exact_dif) {
if (sk->sk_bound_dev_if != dif) bool dev_match = (sk->sk_bound_dev_if == dif ||
sk->sk_bound_dev_if == sdif);
if (exact_dif && !dev_match)
return -1; return -1;
score += 4; if (sk->sk_bound_dev_if && dev_match)
score += 4;
} }
if (sk->sk_incoming_cpu == raw_smp_processor_id()) if (sk->sk_incoming_cpu == raw_smp_processor_id())
score++; score++;
...@@ -208,7 +212,7 @@ struct sock *__inet_lookup_listener(struct net *net, ...@@ -208,7 +212,7 @@ struct sock *__inet_lookup_listener(struct net *net,
struct sk_buff *skb, int doff, struct sk_buff *skb, int doff,
const __be32 saddr, __be16 sport, const __be32 saddr, __be16 sport,
const __be32 daddr, const unsigned short hnum, const __be32 daddr, const unsigned short hnum,
const int dif) const int dif, const int sdif)
{ {
unsigned int hash = inet_lhashfn(net, hnum); unsigned int hash = inet_lhashfn(net, hnum);
struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash]; struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash];
...@@ -218,7 +222,8 @@ struct sock *__inet_lookup_listener(struct net *net, ...@@ -218,7 +222,8 @@ struct sock *__inet_lookup_listener(struct net *net,
u32 phash = 0; u32 phash = 0;
sk_for_each_rcu(sk, &ilb->head) { sk_for_each_rcu(sk, &ilb->head) {
score = compute_score(sk, net, hnum, daddr, dif, exact_dif); score = compute_score(sk, net, hnum, daddr,
dif, sdif, exact_dif);
if (score > hiscore) { if (score > hiscore) {
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
...@@ -268,7 +273,7 @@ struct sock *__inet_lookup_established(struct net *net, ...@@ -268,7 +273,7 @@ struct sock *__inet_lookup_established(struct net *net,
struct inet_hashinfo *hashinfo, struct inet_hashinfo *hashinfo,
const __be32 saddr, const __be16 sport, const __be32 saddr, const __be16 sport,
const __be32 daddr, const u16 hnum, const __be32 daddr, const u16 hnum,
const int dif) const int dif, const int sdif)
{ {
INET_ADDR_COOKIE(acookie, saddr, daddr); INET_ADDR_COOKIE(acookie, saddr, daddr);
const __portpair ports = INET_COMBINED_PORTS(sport, hnum); const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
...@@ -286,11 +291,12 @@ struct sock *__inet_lookup_established(struct net *net, ...@@ -286,11 +291,12 @@ struct sock *__inet_lookup_established(struct net *net,
if (sk->sk_hash != hash) if (sk->sk_hash != hash)
continue; continue;
if (likely(INET_MATCH(sk, net, acookie, if (likely(INET_MATCH(sk, net, acookie,
saddr, daddr, ports, dif))) { saddr, daddr, ports, dif, sdif))) {
if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
goto out; goto out;
if (unlikely(!INET_MATCH(sk, net, acookie, if (unlikely(!INET_MATCH(sk, net, acookie,
saddr, daddr, ports, dif))) { saddr, daddr, ports,
dif, sdif))) {
sock_gen_put(sk); sock_gen_put(sk);
goto begin; goto begin;
} }
...@@ -321,9 +327,10 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row, ...@@ -321,9 +327,10 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
__be32 daddr = inet->inet_rcv_saddr; __be32 daddr = inet->inet_rcv_saddr;
__be32 saddr = inet->inet_daddr; __be32 saddr = inet->inet_daddr;
int dif = sk->sk_bound_dev_if; int dif = sk->sk_bound_dev_if;
struct net *net = sock_net(sk);
int sdif = l3mdev_master_ifindex_by_index(net, dif);
INET_ADDR_COOKIE(acookie, saddr, daddr); INET_ADDR_COOKIE(acookie, saddr, daddr);
const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport); const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
struct net *net = sock_net(sk);
unsigned int hash = inet_ehashfn(net, daddr, lport, unsigned int hash = inet_ehashfn(net, daddr, lport,
saddr, inet->inet_dport); saddr, inet->inet_dport);
struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash); struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
...@@ -339,7 +346,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row, ...@@ -339,7 +346,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
continue; continue;
if (likely(INET_MATCH(sk2, net, acookie, if (likely(INET_MATCH(sk2, net, acookie,
saddr, daddr, ports, dif))) { saddr, daddr, ports, dif, sdif))) {
if (sk2->sk_state == TCP_TIME_WAIT) { if (sk2->sk_state == TCP_TIME_WAIT) {
tw = inet_twsk(sk2); tw = inet_twsk(sk2);
if (twsk_unique(sk, sk2, twp)) if (twsk_unique(sk, sk2, twp))
......
...@@ -122,7 +122,8 @@ void raw_unhash_sk(struct sock *sk) ...@@ -122,7 +122,8 @@ void raw_unhash_sk(struct sock *sk)
EXPORT_SYMBOL_GPL(raw_unhash_sk); EXPORT_SYMBOL_GPL(raw_unhash_sk);
struct sock *__raw_v4_lookup(struct net *net, struct sock *sk, struct sock *__raw_v4_lookup(struct net *net, struct sock *sk,
unsigned short num, __be32 raddr, __be32 laddr, int dif) unsigned short num, __be32 raddr, __be32 laddr,
int dif, int sdif)
{ {
sk_for_each_from(sk) { sk_for_each_from(sk) {
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
...@@ -130,7 +131,8 @@ struct sock *__raw_v4_lookup(struct net *net, struct sock *sk, ...@@ -130,7 +131,8 @@ struct sock *__raw_v4_lookup(struct net *net, struct sock *sk,
if (net_eq(sock_net(sk), net) && inet->inet_num == num && if (net_eq(sock_net(sk), net) && inet->inet_num == num &&
!(inet->inet_daddr && inet->inet_daddr != raddr) && !(inet->inet_daddr && inet->inet_daddr != raddr) &&
!(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) && !(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) &&
!(sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif)) !(sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif &&
sk->sk_bound_dev_if != sdif))
goto found; /* gotcha */ goto found; /* gotcha */
} }
sk = NULL; sk = NULL;
...@@ -171,6 +173,7 @@ static int icmp_filter(const struct sock *sk, const struct sk_buff *skb) ...@@ -171,6 +173,7 @@ static int icmp_filter(const struct sock *sk, const struct sk_buff *skb)
*/ */
static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash) static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
{ {
int sdif = inet_sdif(skb);
struct sock *sk; struct sock *sk;
struct hlist_head *head; struct hlist_head *head;
int delivered = 0; int delivered = 0;
...@@ -184,13 +187,13 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash) ...@@ -184,13 +187,13 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
net = dev_net(skb->dev); net = dev_net(skb->dev);
sk = __raw_v4_lookup(net, __sk_head(head), iph->protocol, sk = __raw_v4_lookup(net, __sk_head(head), iph->protocol,
iph->saddr, iph->daddr, iph->saddr, iph->daddr,
skb->dev->ifindex); skb->dev->ifindex, sdif);
while (sk) { while (sk) {
delivered = 1; delivered = 1;
if ((iph->protocol != IPPROTO_ICMP || !icmp_filter(sk, skb)) && if ((iph->protocol != IPPROTO_ICMP || !icmp_filter(sk, skb)) &&
ip_mc_sf_allow(sk, iph->daddr, iph->saddr, ip_mc_sf_allow(sk, iph->daddr, iph->saddr,
skb->dev->ifindex)) { skb->dev->ifindex, sdif)) {
struct sk_buff *clone = skb_clone(skb, GFP_ATOMIC); struct sk_buff *clone = skb_clone(skb, GFP_ATOMIC);
/* Not releasing hash table! */ /* Not releasing hash table! */
...@@ -199,7 +202,7 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash) ...@@ -199,7 +202,7 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
} }
sk = __raw_v4_lookup(net, sk_next(sk), iph->protocol, sk = __raw_v4_lookup(net, sk_next(sk), iph->protocol,
iph->saddr, iph->daddr, iph->saddr, iph->daddr,
skb->dev->ifindex); skb->dev->ifindex, sdif);
} }
out: out:
read_unlock(&raw_v4_hashinfo.lock); read_unlock(&raw_v4_hashinfo.lock);
...@@ -297,12 +300,15 @@ void raw_icmp_error(struct sk_buff *skb, int protocol, u32 info) ...@@ -297,12 +300,15 @@ void raw_icmp_error(struct sk_buff *skb, int protocol, u32 info)
read_lock(&raw_v4_hashinfo.lock); read_lock(&raw_v4_hashinfo.lock);
raw_sk = sk_head(&raw_v4_hashinfo.ht[hash]); raw_sk = sk_head(&raw_v4_hashinfo.ht[hash]);
if (raw_sk) { if (raw_sk) {
int dif = skb->dev->ifindex;
int sdif = inet_sdif(skb);
iph = (const struct iphdr *)skb->data; iph = (const struct iphdr *)skb->data;
net = dev_net(skb->dev); net = dev_net(skb->dev);
while ((raw_sk = __raw_v4_lookup(net, raw_sk, protocol, while ((raw_sk = __raw_v4_lookup(net, raw_sk, protocol,
iph->daddr, iph->saddr, iph->daddr, iph->saddr,
skb->dev->ifindex)) != NULL) { dif, sdif)) != NULL) {
raw_err(raw_sk, skb, info); raw_err(raw_sk, skb, info);
raw_sk = sk_next(raw_sk); raw_sk = sk_next(raw_sk);
iph = (const struct iphdr *)skb->data; iph = (const struct iphdr *)skb->data;
......
...@@ -46,13 +46,13 @@ static struct sock *raw_lookup(struct net *net, struct sock *from, ...@@ -46,13 +46,13 @@ static struct sock *raw_lookup(struct net *net, struct sock *from,
sk = __raw_v4_lookup(net, from, r->sdiag_raw_protocol, sk = __raw_v4_lookup(net, from, r->sdiag_raw_protocol,
r->id.idiag_dst[0], r->id.idiag_dst[0],
r->id.idiag_src[0], r->id.idiag_src[0],
r->id.idiag_if); r->id.idiag_if, 0);
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
else else
sk = __raw_v6_lookup(net, from, r->sdiag_raw_protocol, sk = __raw_v6_lookup(net, from, r->sdiag_raw_protocol,
(const struct in6_addr *)r->id.idiag_src, (const struct in6_addr *)r->id.idiag_src,
(const struct in6_addr *)r->id.idiag_dst, (const struct in6_addr *)r->id.idiag_dst,
r->id.idiag_if); r->id.idiag_if, 0);
#endif #endif
return sk; return sk;
} }
......
...@@ -383,7 +383,7 @@ void tcp_v4_err(struct sk_buff *icmp_skb, u32 info) ...@@ -383,7 +383,7 @@ void tcp_v4_err(struct sk_buff *icmp_skb, u32 info)
sk = __inet_lookup_established(net, &tcp_hashinfo, iph->daddr, sk = __inet_lookup_established(net, &tcp_hashinfo, iph->daddr,
th->dest, iph->saddr, ntohs(th->source), th->dest, iph->saddr, ntohs(th->source),
inet_iif(icmp_skb)); inet_iif(icmp_skb), 0);
if (!sk) { if (!sk) {
__ICMP_INC_STATS(net, ICMP_MIB_INERRORS); __ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
return; return;
...@@ -659,7 +659,8 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb) ...@@ -659,7 +659,8 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
sk1 = __inet_lookup_listener(net, &tcp_hashinfo, NULL, 0, sk1 = __inet_lookup_listener(net, &tcp_hashinfo, NULL, 0,
ip_hdr(skb)->saddr, ip_hdr(skb)->saddr,
th->source, ip_hdr(skb)->daddr, th->source, ip_hdr(skb)->daddr,
ntohs(th->source), inet_iif(skb)); ntohs(th->source), inet_iif(skb),
tcp_v4_sdif(skb));
/* don't send rst if it can't find key */ /* don't send rst if it can't find key */
if (!sk1) if (!sk1)
goto out; goto out;
...@@ -1523,7 +1524,7 @@ void tcp_v4_early_demux(struct sk_buff *skb) ...@@ -1523,7 +1524,7 @@ void tcp_v4_early_demux(struct sk_buff *skb)
sk = __inet_lookup_established(dev_net(skb->dev), &tcp_hashinfo, sk = __inet_lookup_established(dev_net(skb->dev), &tcp_hashinfo,
iph->saddr, th->source, iph->saddr, th->source,
iph->daddr, ntohs(th->dest), iph->daddr, ntohs(th->dest),
skb->skb_iif); skb->skb_iif, inet_sdif(skb));
if (sk) { if (sk) {
skb->sk = sk; skb->sk = sk;
skb->destructor = sock_edemux; skb->destructor = sock_edemux;
...@@ -1588,6 +1589,7 @@ EXPORT_SYMBOL(tcp_filter); ...@@ -1588,6 +1589,7 @@ EXPORT_SYMBOL(tcp_filter);
int tcp_v4_rcv(struct sk_buff *skb) int tcp_v4_rcv(struct sk_buff *skb)
{ {
struct net *net = dev_net(skb->dev); struct net *net = dev_net(skb->dev);
int sdif = inet_sdif(skb);
const struct iphdr *iph; const struct iphdr *iph;
const struct tcphdr *th; const struct tcphdr *th;
bool refcounted; bool refcounted;
...@@ -1638,7 +1640,7 @@ int tcp_v4_rcv(struct sk_buff *skb) ...@@ -1638,7 +1640,7 @@ int tcp_v4_rcv(struct sk_buff *skb)
lookup: lookup:
sk = __inet_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th), th->source, sk = __inet_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th), th->source,
th->dest, &refcounted); th->dest, sdif, &refcounted);
if (!sk) if (!sk)
goto no_tcp_socket; goto no_tcp_socket;
...@@ -1766,7 +1768,8 @@ int tcp_v4_rcv(struct sk_buff *skb) ...@@ -1766,7 +1768,8 @@ int tcp_v4_rcv(struct sk_buff *skb)
__tcp_hdrlen(th), __tcp_hdrlen(th),
iph->saddr, th->source, iph->saddr, th->source,
iph->daddr, th->dest, iph->daddr, th->dest,
inet_iif(skb)); inet_iif(skb),
sdif);
if (sk2) { if (sk2) {
inet_twsk_deschedule_put(inet_twsk(sk)); inet_twsk_deschedule_put(inet_twsk(sk));
sk = sk2; sk = sk2;
......
...@@ -380,8 +380,8 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum) ...@@ -380,8 +380,8 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum)
static int compute_score(struct sock *sk, struct net *net, static int compute_score(struct sock *sk, struct net *net,
__be32 saddr, __be16 sport, __be32 saddr, __be16 sport,
__be32 daddr, unsigned short hnum, int dif, __be32 daddr, unsigned short hnum,
bool exact_dif) int dif, int sdif, bool exact_dif)
{ {
int score; int score;
struct inet_sock *inet; struct inet_sock *inet;
...@@ -413,10 +413,15 @@ static int compute_score(struct sock *sk, struct net *net, ...@@ -413,10 +413,15 @@ static int compute_score(struct sock *sk, struct net *net,
} }
if (sk->sk_bound_dev_if || exact_dif) { if (sk->sk_bound_dev_if || exact_dif) {
if (sk->sk_bound_dev_if != dif) bool dev_match = (sk->sk_bound_dev_if == dif ||
sk->sk_bound_dev_if == sdif);
if (exact_dif && !dev_match)
return -1; return -1;
score += 4; if (sk->sk_bound_dev_if && dev_match)
score += 4;
} }
if (sk->sk_incoming_cpu == raw_smp_processor_id()) if (sk->sk_incoming_cpu == raw_smp_processor_id())
score++; score++;
return score; return score;
...@@ -436,10 +441,11 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr, ...@@ -436,10 +441,11 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr,
/* called with rcu_read_lock() */ /* called with rcu_read_lock() */
static struct sock *udp4_lib_lookup2(struct net *net, static struct sock *udp4_lib_lookup2(struct net *net,
__be32 saddr, __be16 sport, __be32 saddr, __be16 sport,
__be32 daddr, unsigned int hnum, int dif, bool exact_dif, __be32 daddr, unsigned int hnum,
struct udp_hslot *hslot2, int dif, int sdif, bool exact_dif,
struct sk_buff *skb) struct udp_hslot *hslot2,
struct sk_buff *skb)
{ {
struct sock *sk, *result; struct sock *sk, *result;
int score, badness, matches = 0, reuseport = 0; int score, badness, matches = 0, reuseport = 0;
...@@ -449,7 +455,7 @@ static struct sock *udp4_lib_lookup2(struct net *net, ...@@ -449,7 +455,7 @@ static struct sock *udp4_lib_lookup2(struct net *net,
badness = 0; badness = 0;
udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
score = compute_score(sk, net, saddr, sport, score = compute_score(sk, net, saddr, sport,
daddr, hnum, dif, exact_dif); daddr, hnum, dif, sdif, exact_dif);
if (score > badness) { if (score > badness) {
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
...@@ -477,8 +483,8 @@ static struct sock *udp4_lib_lookup2(struct net *net, ...@@ -477,8 +483,8 @@ static struct sock *udp4_lib_lookup2(struct net *net,
* harder than this. -DaveM * harder than this. -DaveM
*/ */
struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
__be16 sport, __be32 daddr, __be16 dport, __be16 sport, __be32 daddr, __be16 dport, int dif,
int dif, struct udp_table *udptable, struct sk_buff *skb) int sdif, struct udp_table *udptable, struct sk_buff *skb)
{ {
struct sock *sk, *result; struct sock *sk, *result;
unsigned short hnum = ntohs(dport); unsigned short hnum = ntohs(dport);
...@@ -496,7 +502,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, ...@@ -496,7 +502,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
goto begin; goto begin;
result = udp4_lib_lookup2(net, saddr, sport, result = udp4_lib_lookup2(net, saddr, sport,
daddr, hnum, dif, daddr, hnum, dif, sdif,
exact_dif, hslot2, skb); exact_dif, hslot2, skb);
if (!result) { if (!result) {
unsigned int old_slot2 = slot2; unsigned int old_slot2 = slot2;
...@@ -511,7 +517,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, ...@@ -511,7 +517,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
goto begin; goto begin;
result = udp4_lib_lookup2(net, saddr, sport, result = udp4_lib_lookup2(net, saddr, sport,
daddr, hnum, dif, daddr, hnum, dif, sdif,
exact_dif, hslot2, skb); exact_dif, hslot2, skb);
} }
return result; return result;
...@@ -521,7 +527,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, ...@@ -521,7 +527,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
badness = 0; badness = 0;
sk_for_each_rcu(sk, &hslot->head) { sk_for_each_rcu(sk, &hslot->head) {
score = compute_score(sk, net, saddr, sport, score = compute_score(sk, net, saddr, sport,
daddr, hnum, dif, exact_dif); daddr, hnum, dif, sdif, exact_dif);
if (score > badness) { if (score > badness) {
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
...@@ -554,7 +560,7 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, ...@@ -554,7 +560,7 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport, return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport,
iph->daddr, dport, inet_iif(skb), iph->daddr, dport, inet_iif(skb),
udptable, skb); inet_sdif(skb), udptable, skb);
} }
struct sock *udp4_lib_lookup_skb(struct sk_buff *skb, struct sock *udp4_lib_lookup_skb(struct sk_buff *skb,
...@@ -576,7 +582,7 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, ...@@ -576,7 +582,7 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
struct sock *sk; struct sock *sk;
sk = __udp4_lib_lookup(net, saddr, sport, daddr, dport, sk = __udp4_lib_lookup(net, saddr, sport, daddr, dport,
dif, &udp_table, NULL); dif, 0, &udp_table, NULL);
if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL; sk = NULL;
return sk; return sk;
...@@ -587,7 +593,7 @@ EXPORT_SYMBOL_GPL(udp4_lib_lookup); ...@@ -587,7 +593,7 @@ EXPORT_SYMBOL_GPL(udp4_lib_lookup);
static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk, static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
__be16 loc_port, __be32 loc_addr, __be16 loc_port, __be32 loc_addr,
__be16 rmt_port, __be32 rmt_addr, __be16 rmt_port, __be32 rmt_addr,
int dif, unsigned short hnum) int dif, int sdif, unsigned short hnum)
{ {
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
...@@ -597,9 +603,10 @@ static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk, ...@@ -597,9 +603,10 @@ static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
(inet->inet_dport != rmt_port && inet->inet_dport) || (inet->inet_dport != rmt_port && inet->inet_dport) ||
(inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) || (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) ||
ipv6_only_sock(sk) || ipv6_only_sock(sk) ||
(sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif)) (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif &&
sk->sk_bound_dev_if != sdif))
return false; return false;
if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif)) if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif, sdif))
return false; return false;
return true; return true;
} }
...@@ -628,8 +635,8 @@ void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable) ...@@ -628,8 +635,8 @@ void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
struct net *net = dev_net(skb->dev); struct net *net = dev_net(skb->dev);
sk = __udp4_lib_lookup(net, iph->daddr, uh->dest, sk = __udp4_lib_lookup(net, iph->daddr, uh->dest,
iph->saddr, uh->source, skb->dev->ifindex, udptable, iph->saddr, uh->source, skb->dev->ifindex, 0,
NULL); udptable, NULL);
if (!sk) { if (!sk) {
__ICMP_INC_STATS(net, ICMP_MIB_INERRORS); __ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
return; /* No socket for error */ return; /* No socket for error */
...@@ -1953,6 +1960,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb, ...@@ -1953,6 +1960,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10); unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10);
unsigned int offset = offsetof(typeof(*sk), sk_node); unsigned int offset = offsetof(typeof(*sk), sk_node);
int dif = skb->dev->ifindex; int dif = skb->dev->ifindex;
int sdif = inet_sdif(skb);
struct hlist_node *node; struct hlist_node *node;
struct sk_buff *nskb; struct sk_buff *nskb;
...@@ -1967,7 +1975,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb, ...@@ -1967,7 +1975,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) { sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) {
if (!__udp_is_mcast_sock(net, sk, uh->dest, daddr, if (!__udp_is_mcast_sock(net, sk, uh->dest, daddr,
uh->source, saddr, dif, hnum)) uh->source, saddr, dif, sdif, hnum))
continue; continue;
if (!first) { if (!first) {
...@@ -2157,7 +2165,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -2157,7 +2165,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net, static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
__be16 loc_port, __be32 loc_addr, __be16 loc_port, __be32 loc_addr,
__be16 rmt_port, __be32 rmt_addr, __be16 rmt_port, __be32 rmt_addr,
int dif) int dif, int sdif)
{ {
struct sock *sk, *result; struct sock *sk, *result;
unsigned short hnum = ntohs(loc_port); unsigned short hnum = ntohs(loc_port);
...@@ -2171,7 +2179,7 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net, ...@@ -2171,7 +2179,7 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
result = NULL; result = NULL;
sk_for_each_rcu(sk, &hslot->head) { sk_for_each_rcu(sk, &hslot->head) {
if (__udp_is_mcast_sock(net, sk, loc_port, loc_addr, if (__udp_is_mcast_sock(net, sk, loc_port, loc_addr,
rmt_port, rmt_addr, dif, hnum)) { rmt_port, rmt_addr, dif, sdif, hnum)) {
if (result) if (result)
return NULL; return NULL;
result = sk; result = sk;
...@@ -2188,7 +2196,7 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net, ...@@ -2188,7 +2196,7 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
static struct sock *__udp4_lib_demux_lookup(struct net *net, static struct sock *__udp4_lib_demux_lookup(struct net *net,
__be16 loc_port, __be32 loc_addr, __be16 loc_port, __be32 loc_addr,
__be16 rmt_port, __be32 rmt_addr, __be16 rmt_port, __be32 rmt_addr,
int dif) int dif, int sdif)
{ {
unsigned short hnum = ntohs(loc_port); unsigned short hnum = ntohs(loc_port);
unsigned int hash2 = udp4_portaddr_hash(net, loc_addr, hnum); unsigned int hash2 = udp4_portaddr_hash(net, loc_addr, hnum);
...@@ -2200,7 +2208,7 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net, ...@@ -2200,7 +2208,7 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net,
udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
if (INET_MATCH(sk, net, acookie, rmt_addr, if (INET_MATCH(sk, net, acookie, rmt_addr,
loc_addr, ports, dif)) loc_addr, ports, dif, sdif))
return sk; return sk;
/* Only check first socket in chain */ /* Only check first socket in chain */
break; break;
...@@ -2216,6 +2224,7 @@ void udp_v4_early_demux(struct sk_buff *skb) ...@@ -2216,6 +2224,7 @@ void udp_v4_early_demux(struct sk_buff *skb)
struct sock *sk = NULL; struct sock *sk = NULL;
struct dst_entry *dst; struct dst_entry *dst;
int dif = skb->dev->ifindex; int dif = skb->dev->ifindex;
int sdif = inet_sdif(skb);
int ours; int ours;
/* validate the packet */ /* validate the packet */
...@@ -2241,10 +2250,11 @@ void udp_v4_early_demux(struct sk_buff *skb) ...@@ -2241,10 +2250,11 @@ void udp_v4_early_demux(struct sk_buff *skb)
} }
sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr,
uh->source, iph->saddr, dif); uh->source, iph->saddr,
dif, sdif);
} else if (skb->pkt_type == PACKET_HOST) { } else if (skb->pkt_type == PACKET_HOST) {
sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr, sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr,
uh->source, iph->saddr, dif); uh->source, iph->saddr, dif, sdif);
} }
if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt)) if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt))
......
...@@ -45,7 +45,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb, ...@@ -45,7 +45,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb,
sk = __udp4_lib_lookup(net, sk = __udp4_lib_lookup(net,
req->id.idiag_src[0], req->id.idiag_sport, req->id.idiag_src[0], req->id.idiag_sport,
req->id.idiag_dst[0], req->id.idiag_dport, req->id.idiag_dst[0], req->id.idiag_dport,
req->id.idiag_if, tbl, NULL); req->id.idiag_if, 0, tbl, NULL);
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
else if (req->sdiag_family == AF_INET6) else if (req->sdiag_family == AF_INET6)
sk = __udp6_lib_lookup(net, sk = __udp6_lib_lookup(net,
...@@ -53,7 +53,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb, ...@@ -53,7 +53,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb,
req->id.idiag_sport, req->id.idiag_sport,
(struct in6_addr *)req->id.idiag_dst, (struct in6_addr *)req->id.idiag_dst,
req->id.idiag_dport, req->id.idiag_dport,
req->id.idiag_if, tbl, NULL); req->id.idiag_if, 0, tbl, NULL);
#endif #endif
if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL; sk = NULL;
...@@ -182,7 +182,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb, ...@@ -182,7 +182,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
sk = __udp4_lib_lookup(net, sk = __udp4_lib_lookup(net,
req->id.idiag_dst[0], req->id.idiag_dport, req->id.idiag_dst[0], req->id.idiag_dport,
req->id.idiag_src[0], req->id.idiag_sport, req->id.idiag_src[0], req->id.idiag_sport,
req->id.idiag_if, tbl, NULL); req->id.idiag_if, 0, tbl, NULL);
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
else if (req->sdiag_family == AF_INET6) { else if (req->sdiag_family == AF_INET6) {
if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) && if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
...@@ -190,7 +190,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb, ...@@ -190,7 +190,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
sk = __udp4_lib_lookup(net, sk = __udp4_lib_lookup(net,
req->id.idiag_dst[3], req->id.idiag_dport, req->id.idiag_dst[3], req->id.idiag_dport,
req->id.idiag_src[3], req->id.idiag_sport, req->id.idiag_src[3], req->id.idiag_sport,
req->id.idiag_if, tbl, NULL); req->id.idiag_if, 0, tbl, NULL);
else else
sk = __udp6_lib_lookup(net, sk = __udp6_lib_lookup(net,
...@@ -198,7 +198,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb, ...@@ -198,7 +198,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
req->id.idiag_dport, req->id.idiag_dport,
(struct in6_addr *)req->id.idiag_src, (struct in6_addr *)req->id.idiag_src,
req->id.idiag_sport, req->id.idiag_sport,
req->id.idiag_if, tbl, NULL); req->id.idiag_if, 0, tbl, NULL);
} }
#endif #endif
else { else {
......
...@@ -56,7 +56,7 @@ struct sock *__inet6_lookup_established(struct net *net, ...@@ -56,7 +56,7 @@ struct sock *__inet6_lookup_established(struct net *net,
const __be16 sport, const __be16 sport,
const struct in6_addr *daddr, const struct in6_addr *daddr,
const u16 hnum, const u16 hnum,
const int dif) const int dif, const int sdif)
{ {
struct sock *sk; struct sock *sk;
const struct hlist_nulls_node *node; const struct hlist_nulls_node *node;
...@@ -73,12 +73,12 @@ struct sock *__inet6_lookup_established(struct net *net, ...@@ -73,12 +73,12 @@ struct sock *__inet6_lookup_established(struct net *net,
sk_nulls_for_each_rcu(sk, node, &head->chain) { sk_nulls_for_each_rcu(sk, node, &head->chain) {
if (sk->sk_hash != hash) if (sk->sk_hash != hash)
continue; continue;
if (!INET6_MATCH(sk, net, saddr, daddr, ports, dif)) if (!INET6_MATCH(sk, net, saddr, daddr, ports, dif, sdif))
continue; continue;
if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
goto out; goto out;
if (unlikely(!INET6_MATCH(sk, net, saddr, daddr, ports, dif))) { if (unlikely(!INET6_MATCH(sk, net, saddr, daddr, ports, dif, sdif))) {
sock_gen_put(sk); sock_gen_put(sk);
goto begin; goto begin;
} }
...@@ -96,7 +96,7 @@ EXPORT_SYMBOL(__inet6_lookup_established); ...@@ -96,7 +96,7 @@ EXPORT_SYMBOL(__inet6_lookup_established);
static inline int compute_score(struct sock *sk, struct net *net, static inline int compute_score(struct sock *sk, struct net *net,
const unsigned short hnum, const unsigned short hnum,
const struct in6_addr *daddr, const struct in6_addr *daddr,
const int dif, bool exact_dif) const int dif, const int sdif, bool exact_dif)
{ {
int score = -1; int score = -1;
...@@ -110,9 +110,13 @@ static inline int compute_score(struct sock *sk, struct net *net, ...@@ -110,9 +110,13 @@ static inline int compute_score(struct sock *sk, struct net *net,
score++; score++;
} }
if (sk->sk_bound_dev_if || exact_dif) { if (sk->sk_bound_dev_if || exact_dif) {
if (sk->sk_bound_dev_if != dif) bool dev_match = (sk->sk_bound_dev_if == dif ||
sk->sk_bound_dev_if == sdif);
if (exact_dif && !dev_match)
return -1; return -1;
score++; if (sk->sk_bound_dev_if && dev_match)
score++;
} }
if (sk->sk_incoming_cpu == raw_smp_processor_id()) if (sk->sk_incoming_cpu == raw_smp_processor_id())
score++; score++;
...@@ -126,7 +130,7 @@ struct sock *inet6_lookup_listener(struct net *net, ...@@ -126,7 +130,7 @@ struct sock *inet6_lookup_listener(struct net *net,
struct sk_buff *skb, int doff, struct sk_buff *skb, int doff,
const struct in6_addr *saddr, const struct in6_addr *saddr,
const __be16 sport, const struct in6_addr *daddr, const __be16 sport, const struct in6_addr *daddr,
const unsigned short hnum, const int dif) const unsigned short hnum, const int dif, const int sdif)
{ {
unsigned int hash = inet_lhashfn(net, hnum); unsigned int hash = inet_lhashfn(net, hnum);
struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash]; struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash];
...@@ -136,7 +140,7 @@ struct sock *inet6_lookup_listener(struct net *net, ...@@ -136,7 +140,7 @@ struct sock *inet6_lookup_listener(struct net *net,
u32 phash = 0; u32 phash = 0;
sk_for_each(sk, &ilb->head) { sk_for_each(sk, &ilb->head) {
score = compute_score(sk, net, hnum, daddr, dif, exact_dif); score = compute_score(sk, net, hnum, daddr, dif, sdif, exact_dif);
if (score > hiscore) { if (score > hiscore) {
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
...@@ -171,7 +175,7 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo, ...@@ -171,7 +175,7 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
bool refcounted; bool refcounted;
sk = __inet6_lookup(net, hashinfo, skb, doff, saddr, sport, daddr, sk = __inet6_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
ntohs(dport), dif, &refcounted); ntohs(dport), dif, 0, &refcounted);
if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt)) if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL; sk = NULL;
return sk; return sk;
...@@ -187,8 +191,9 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row, ...@@ -187,8 +191,9 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr; const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr;
const struct in6_addr *saddr = &sk->sk_v6_daddr; const struct in6_addr *saddr = &sk->sk_v6_daddr;
const int dif = sk->sk_bound_dev_if; const int dif = sk->sk_bound_dev_if;
const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
const int sdif = l3mdev_master_ifindex_by_index(net, dif);
const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
const unsigned int hash = inet6_ehashfn(net, daddr, lport, saddr, const unsigned int hash = inet6_ehashfn(net, daddr, lport, saddr,
inet->inet_dport); inet->inet_dport);
struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash); struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
...@@ -203,7 +208,8 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row, ...@@ -203,7 +208,8 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
if (sk2->sk_hash != hash) if (sk2->sk_hash != hash)
continue; continue;
if (likely(INET6_MATCH(sk2, net, saddr, daddr, ports, dif))) { if (likely(INET6_MATCH(sk2, net, saddr, daddr, ports,
dif, sdif))) {
if (sk2->sk_state == TCP_TIME_WAIT) { if (sk2->sk_state == TCP_TIME_WAIT) {
tw = inet_twsk(sk2); tw = inet_twsk(sk2);
if (twsk_unique(sk, sk2, twp)) if (twsk_unique(sk, sk2, twp))
......
...@@ -72,7 +72,7 @@ EXPORT_SYMBOL_GPL(raw_v6_hashinfo); ...@@ -72,7 +72,7 @@ EXPORT_SYMBOL_GPL(raw_v6_hashinfo);
struct sock *__raw_v6_lookup(struct net *net, struct sock *sk, struct sock *__raw_v6_lookup(struct net *net, struct sock *sk,
unsigned short num, const struct in6_addr *loc_addr, unsigned short num, const struct in6_addr *loc_addr,
const struct in6_addr *rmt_addr, int dif) const struct in6_addr *rmt_addr, int dif, int sdif)
{ {
bool is_multicast = ipv6_addr_is_multicast(loc_addr); bool is_multicast = ipv6_addr_is_multicast(loc_addr);
...@@ -86,7 +86,9 @@ struct sock *__raw_v6_lookup(struct net *net, struct sock *sk, ...@@ -86,7 +86,9 @@ struct sock *__raw_v6_lookup(struct net *net, struct sock *sk,
!ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr)) !ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr))
continue; continue;
if (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif) if (sk->sk_bound_dev_if &&
sk->sk_bound_dev_if != dif &&
sk->sk_bound_dev_if != sdif)
continue; continue;
if (!ipv6_addr_any(&sk->sk_v6_rcv_saddr)) { if (!ipv6_addr_any(&sk->sk_v6_rcv_saddr)) {
...@@ -178,7 +180,8 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr) ...@@ -178,7 +180,8 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
goto out; goto out;
net = dev_net(skb->dev); net = dev_net(skb->dev);
sk = __raw_v6_lookup(net, sk, nexthdr, daddr, saddr, inet6_iif(skb)); sk = __raw_v6_lookup(net, sk, nexthdr, daddr, saddr,
inet6_iif(skb), inet6_sdif(skb));
while (sk) { while (sk) {
int filtered; int filtered;
...@@ -222,7 +225,7 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr) ...@@ -222,7 +225,7 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
} }
} }
sk = __raw_v6_lookup(net, sk_next(sk), nexthdr, daddr, saddr, sk = __raw_v6_lookup(net, sk_next(sk), nexthdr, daddr, saddr,
inet6_iif(skb)); inet6_iif(skb), inet6_sdif(skb));
} }
out: out:
read_unlock(&raw_v6_hashinfo.lock); read_unlock(&raw_v6_hashinfo.lock);
...@@ -378,7 +381,7 @@ void raw6_icmp_error(struct sk_buff *skb, int nexthdr, ...@@ -378,7 +381,7 @@ void raw6_icmp_error(struct sk_buff *skb, int nexthdr,
net = dev_net(skb->dev); net = dev_net(skb->dev);
while ((sk = __raw_v6_lookup(net, sk, nexthdr, saddr, daddr, while ((sk = __raw_v6_lookup(net, sk, nexthdr, saddr, daddr,
inet6_iif(skb)))) { inet6_iif(skb), inet6_iif(skb)))) {
rawv6_err(sk, skb, NULL, type, code, rawv6_err(sk, skb, NULL, type, code,
inner_offset, info); inner_offset, info);
sk = sk_next(sk); sk = sk_next(sk);
......
...@@ -350,7 +350,7 @@ static void tcp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, ...@@ -350,7 +350,7 @@ static void tcp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
sk = __inet6_lookup_established(net, &tcp_hashinfo, sk = __inet6_lookup_established(net, &tcp_hashinfo,
&hdr->daddr, th->dest, &hdr->daddr, th->dest,
&hdr->saddr, ntohs(th->source), &hdr->saddr, ntohs(th->source),
skb->dev->ifindex); skb->dev->ifindex, inet6_sdif(skb));
if (!sk) { if (!sk) {
__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev), __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
...@@ -918,7 +918,8 @@ static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb) ...@@ -918,7 +918,8 @@ static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb)
&tcp_hashinfo, NULL, 0, &tcp_hashinfo, NULL, 0,
&ipv6h->saddr, &ipv6h->saddr,
th->source, &ipv6h->daddr, th->source, &ipv6h->daddr,
ntohs(th->source), tcp_v6_iif(skb)); ntohs(th->source), tcp_v6_iif(skb),
tcp_v6_sdif(skb));
if (!sk1) if (!sk1)
goto out; goto out;
...@@ -1397,6 +1398,7 @@ static void tcp_v6_fill_cb(struct sk_buff *skb, const struct ipv6hdr *hdr, ...@@ -1397,6 +1398,7 @@ static void tcp_v6_fill_cb(struct sk_buff *skb, const struct ipv6hdr *hdr,
static int tcp_v6_rcv(struct sk_buff *skb) static int tcp_v6_rcv(struct sk_buff *skb)
{ {
int sdif = inet6_sdif(skb);
const struct tcphdr *th; const struct tcphdr *th;
const struct ipv6hdr *hdr; const struct ipv6hdr *hdr;
bool refcounted; bool refcounted;
...@@ -1430,7 +1432,7 @@ static int tcp_v6_rcv(struct sk_buff *skb) ...@@ -1430,7 +1432,7 @@ static int tcp_v6_rcv(struct sk_buff *skb)
lookup: lookup:
sk = __inet6_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th), sk = __inet6_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th),
th->source, th->dest, inet6_iif(skb), th->source, th->dest, inet6_iif(skb), sdif,
&refcounted); &refcounted);
if (!sk) if (!sk)
goto no_tcp_socket; goto no_tcp_socket;
...@@ -1563,7 +1565,8 @@ static int tcp_v6_rcv(struct sk_buff *skb) ...@@ -1563,7 +1565,8 @@ static int tcp_v6_rcv(struct sk_buff *skb)
skb, __tcp_hdrlen(th), skb, __tcp_hdrlen(th),
&ipv6_hdr(skb)->saddr, th->source, &ipv6_hdr(skb)->saddr, th->source,
&ipv6_hdr(skb)->daddr, &ipv6_hdr(skb)->daddr,
ntohs(th->dest), tcp_v6_iif(skb)); ntohs(th->dest), tcp_v6_iif(skb),
sdif);
if (sk2) { if (sk2) {
struct inet_timewait_sock *tw = inet_twsk(sk); struct inet_timewait_sock *tw = inet_twsk(sk);
inet_twsk_deschedule_put(tw); inet_twsk_deschedule_put(tw);
...@@ -1610,7 +1613,7 @@ static void tcp_v6_early_demux(struct sk_buff *skb) ...@@ -1610,7 +1613,7 @@ static void tcp_v6_early_demux(struct sk_buff *skb)
sk = __inet6_lookup_established(dev_net(skb->dev), &tcp_hashinfo, sk = __inet6_lookup_established(dev_net(skb->dev), &tcp_hashinfo,
&hdr->saddr, th->source, &hdr->saddr, th->source,
&hdr->daddr, ntohs(th->dest), &hdr->daddr, ntohs(th->dest),
inet6_iif(skb)); inet6_iif(skb), inet6_sdif(skb));
if (sk) { if (sk) {
skb->sk = sk; skb->sk = sk;
skb->destructor = sock_edemux; skb->destructor = sock_edemux;
......
...@@ -129,7 +129,7 @@ static void udp_v6_rehash(struct sock *sk) ...@@ -129,7 +129,7 @@ static void udp_v6_rehash(struct sock *sk)
static int compute_score(struct sock *sk, struct net *net, static int compute_score(struct sock *sk, struct net *net,
const struct in6_addr *saddr, __be16 sport, const struct in6_addr *saddr, __be16 sport,
const struct in6_addr *daddr, unsigned short hnum, const struct in6_addr *daddr, unsigned short hnum,
int dif, bool exact_dif) int dif, int sdif, bool exact_dif)
{ {
int score; int score;
struct inet_sock *inet; struct inet_sock *inet;
...@@ -161,9 +161,13 @@ static int compute_score(struct sock *sk, struct net *net, ...@@ -161,9 +161,13 @@ static int compute_score(struct sock *sk, struct net *net,
} }
if (sk->sk_bound_dev_if || exact_dif) { if (sk->sk_bound_dev_if || exact_dif) {
if (sk->sk_bound_dev_if != dif) bool dev_match = (sk->sk_bound_dev_if == dif ||
sk->sk_bound_dev_if == sdif);
if (exact_dif && !dev_match)
return -1; return -1;
score++; if (sk->sk_bound_dev_if && dev_match)
score++;
} }
if (sk->sk_incoming_cpu == raw_smp_processor_id()) if (sk->sk_incoming_cpu == raw_smp_processor_id())
...@@ -175,9 +179,9 @@ static int compute_score(struct sock *sk, struct net *net, ...@@ -175,9 +179,9 @@ static int compute_score(struct sock *sk, struct net *net,
/* called with rcu_read_lock() */ /* called with rcu_read_lock() */
static struct sock *udp6_lib_lookup2(struct net *net, static struct sock *udp6_lib_lookup2(struct net *net,
const struct in6_addr *saddr, __be16 sport, const struct in6_addr *saddr, __be16 sport,
const struct in6_addr *daddr, unsigned int hnum, int dif, const struct in6_addr *daddr, unsigned int hnum,
bool exact_dif, struct udp_hslot *hslot2, int dif, int sdif, bool exact_dif,
struct sk_buff *skb) struct udp_hslot *hslot2, struct sk_buff *skb)
{ {
struct sock *sk, *result; struct sock *sk, *result;
int score, badness, matches = 0, reuseport = 0; int score, badness, matches = 0, reuseport = 0;
...@@ -187,7 +191,7 @@ static struct sock *udp6_lib_lookup2(struct net *net, ...@@ -187,7 +191,7 @@ static struct sock *udp6_lib_lookup2(struct net *net,
badness = -1; badness = -1;
udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
score = compute_score(sk, net, saddr, sport, score = compute_score(sk, net, saddr, sport,
daddr, hnum, dif, exact_dif); daddr, hnum, dif, sdif, exact_dif);
if (score > badness) { if (score > badness) {
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
...@@ -214,10 +218,10 @@ static struct sock *udp6_lib_lookup2(struct net *net, ...@@ -214,10 +218,10 @@ static struct sock *udp6_lib_lookup2(struct net *net,
/* rcu_read_lock() must be held */ /* rcu_read_lock() must be held */
struct sock *__udp6_lib_lookup(struct net *net, struct sock *__udp6_lib_lookup(struct net *net,
const struct in6_addr *saddr, __be16 sport, const struct in6_addr *saddr, __be16 sport,
const struct in6_addr *daddr, __be16 dport, const struct in6_addr *daddr, __be16 dport,
int dif, struct udp_table *udptable, int dif, int sdif, struct udp_table *udptable,
struct sk_buff *skb) struct sk_buff *skb)
{ {
struct sock *sk, *result; struct sock *sk, *result;
unsigned short hnum = ntohs(dport); unsigned short hnum = ntohs(dport);
...@@ -235,7 +239,7 @@ struct sock *__udp6_lib_lookup(struct net *net, ...@@ -235,7 +239,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
goto begin; goto begin;
result = udp6_lib_lookup2(net, saddr, sport, result = udp6_lib_lookup2(net, saddr, sport,
daddr, hnum, dif, exact_dif, daddr, hnum, dif, sdif, exact_dif,
hslot2, skb); hslot2, skb);
if (!result) { if (!result) {
unsigned int old_slot2 = slot2; unsigned int old_slot2 = slot2;
...@@ -250,7 +254,7 @@ struct sock *__udp6_lib_lookup(struct net *net, ...@@ -250,7 +254,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
goto begin; goto begin;
result = udp6_lib_lookup2(net, saddr, sport, result = udp6_lib_lookup2(net, saddr, sport,
daddr, hnum, dif, daddr, hnum, dif, sdif,
exact_dif, hslot2, exact_dif, hslot2,
skb); skb);
} }
...@@ -261,7 +265,7 @@ struct sock *__udp6_lib_lookup(struct net *net, ...@@ -261,7 +265,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
badness = -1; badness = -1;
sk_for_each_rcu(sk, &hslot->head) { sk_for_each_rcu(sk, &hslot->head) {
score = compute_score(sk, net, saddr, sport, daddr, hnum, dif, score = compute_score(sk, net, saddr, sport, daddr, hnum, dif,
exact_dif); sdif, exact_dif);
if (score > badness) { if (score > badness) {
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
...@@ -294,7 +298,7 @@ static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb, ...@@ -294,7 +298,7 @@ static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb,
return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport, return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport,
&iph->daddr, dport, inet6_iif(skb), &iph->daddr, dport, inet6_iif(skb),
udptable, skb); inet6_sdif(skb), udptable, skb);
} }
struct sock *udp6_lib_lookup_skb(struct sk_buff *skb, struct sock *udp6_lib_lookup_skb(struct sk_buff *skb,
...@@ -304,7 +308,7 @@ struct sock *udp6_lib_lookup_skb(struct sk_buff *skb, ...@@ -304,7 +308,7 @@ struct sock *udp6_lib_lookup_skb(struct sk_buff *skb,
return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport, return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport,
&iph->daddr, dport, inet6_iif(skb), &iph->daddr, dport, inet6_iif(skb),
&udp_table, skb); inet6_sdif(skb), &udp_table, skb);
} }
EXPORT_SYMBOL_GPL(udp6_lib_lookup_skb); EXPORT_SYMBOL_GPL(udp6_lib_lookup_skb);
...@@ -320,7 +324,7 @@ struct sock *udp6_lib_lookup(struct net *net, const struct in6_addr *saddr, __be ...@@ -320,7 +324,7 @@ struct sock *udp6_lib_lookup(struct net *net, const struct in6_addr *saddr, __be
struct sock *sk; struct sock *sk;
sk = __udp6_lib_lookup(net, saddr, sport, daddr, dport, sk = __udp6_lib_lookup(net, saddr, sport, daddr, dport,
dif, &udp_table, NULL); dif, 0, &udp_table, NULL);
if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL; sk = NULL;
return sk; return sk;
...@@ -501,7 +505,7 @@ void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt, ...@@ -501,7 +505,7 @@ void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
struct net *net = dev_net(skb->dev); struct net *net = dev_net(skb->dev);
sk = __udp6_lib_lookup(net, daddr, uh->dest, saddr, uh->source, sk = __udp6_lib_lookup(net, daddr, uh->dest, saddr, uh->source,
inet6_iif(skb), udptable, skb); inet6_iif(skb), 0, udptable, skb);
if (!sk) { if (!sk) {
__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev), __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
ICMP6_MIB_INERRORS); ICMP6_MIB_INERRORS);
...@@ -893,7 +897,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -893,7 +897,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
static struct sock *__udp6_lib_demux_lookup(struct net *net, static struct sock *__udp6_lib_demux_lookup(struct net *net,
__be16 loc_port, const struct in6_addr *loc_addr, __be16 loc_port, const struct in6_addr *loc_addr,
__be16 rmt_port, const struct in6_addr *rmt_addr, __be16 rmt_port, const struct in6_addr *rmt_addr,
int dif) int dif, int sdif)
{ {
unsigned short hnum = ntohs(loc_port); unsigned short hnum = ntohs(loc_port);
unsigned int hash2 = udp6_portaddr_hash(net, loc_addr, hnum); unsigned int hash2 = udp6_portaddr_hash(net, loc_addr, hnum);
...@@ -904,7 +908,7 @@ static struct sock *__udp6_lib_demux_lookup(struct net *net, ...@@ -904,7 +908,7 @@ static struct sock *__udp6_lib_demux_lookup(struct net *net,
udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
if (sk->sk_state == TCP_ESTABLISHED && if (sk->sk_state == TCP_ESTABLISHED &&
INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif)) INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif, sdif))
return sk; return sk;
/* Only check first socket in chain */ /* Only check first socket in chain */
break; break;
...@@ -919,6 +923,7 @@ static void udp_v6_early_demux(struct sk_buff *skb) ...@@ -919,6 +923,7 @@ static void udp_v6_early_demux(struct sk_buff *skb)
struct sock *sk; struct sock *sk;
struct dst_entry *dst; struct dst_entry *dst;
int dif = skb->dev->ifindex; int dif = skb->dev->ifindex;
int sdif = inet6_sdif(skb);
if (!pskb_may_pull(skb, skb_transport_offset(skb) + if (!pskb_may_pull(skb, skb_transport_offset(skb) +
sizeof(struct udphdr))) sizeof(struct udphdr)))
...@@ -930,7 +935,7 @@ static void udp_v6_early_demux(struct sk_buff *skb) ...@@ -930,7 +935,7 @@ static void udp_v6_early_demux(struct sk_buff *skb)
sk = __udp6_lib_demux_lookup(net, uh->dest, sk = __udp6_lib_demux_lookup(net, uh->dest,
&ipv6_hdr(skb)->daddr, &ipv6_hdr(skb)->daddr,
uh->source, &ipv6_hdr(skb)->saddr, uh->source, &ipv6_hdr(skb)->saddr,
dif); dif, sdif);
else else
return; return;
......
...@@ -125,7 +125,7 @@ nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, void *hp, ...@@ -125,7 +125,7 @@ nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, void *hp,
__tcp_hdrlen(tcph), __tcp_hdrlen(tcph),
saddr, sport, saddr, sport,
daddr, dport, daddr, dport,
in->ifindex); in->ifindex, 0);
if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL; sk = NULL;
...@@ -195,7 +195,7 @@ nf_tproxy_get_sock_v6(struct net *net, struct sk_buff *skb, int thoff, void *hp, ...@@ -195,7 +195,7 @@ nf_tproxy_get_sock_v6(struct net *net, struct sk_buff *skb, int thoff, void *hp,
thoff + __tcp_hdrlen(tcph), thoff + __tcp_hdrlen(tcph),
saddr, sport, saddr, sport,
daddr, ntohs(dport), daddr, ntohs(dport),
in->ifindex); in->ifindex, 0);
if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL; sk = NULL;
...@@ -208,7 +208,7 @@ nf_tproxy_get_sock_v6(struct net *net, struct sk_buff *skb, int thoff, void *hp, ...@@ -208,7 +208,7 @@ nf_tproxy_get_sock_v6(struct net *net, struct sk_buff *skb, int thoff, void *hp,
case NFT_LOOKUP_ESTABLISHED: case NFT_LOOKUP_ESTABLISHED:
sk = __inet6_lookup_established(net, &tcp_hashinfo, sk = __inet6_lookup_established(net, &tcp_hashinfo,
saddr, sport, daddr, ntohs(dport), saddr, sport, daddr, ntohs(dport),
in->ifindex); in->ifindex, 0);
break; break;
default: default:
BUG(); BUG();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment