Commit fb74c277 authored by David Ahern's avatar David Ahern Committed by David S. Miller

net: ipv4: add second dif to udp socket lookups

Add a second device index, sdif, to udp socket lookups. sdif is the
index for ingress devices enslaved to an l3mdev. It allows the lookups
to consider the enslaved device as well as the L3 domain when searching
for a socket.

Early demux lookups are handled in the next patch as part of INET_MATCH
changes.
Signed-off-by: default avatarDavid Ahern <dsahern@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 46d4b68f
...@@ -78,6 +78,16 @@ struct ipcm_cookie { ...@@ -78,6 +78,16 @@ struct ipcm_cookie {
#define IPCB(skb) ((struct inet_skb_parm*)((skb)->cb)) #define IPCB(skb) ((struct inet_skb_parm*)((skb)->cb))
#define PKTINFO_SKB_CB(skb) ((struct in_pktinfo *)((skb)->cb)) #define PKTINFO_SKB_CB(skb) ((struct in_pktinfo *)((skb)->cb))
/* return enslaved device index if relevant */
static inline int inet_sdif(struct sk_buff *skb)
{
#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
if (skb && ipv4_l3mdev_skb(IPCB(skb)->flags))
return IPCB(skb)->iif;
#endif
return 0;
}
struct ip_ra_chain { struct ip_ra_chain {
struct ip_ra_chain __rcu *next; struct ip_ra_chain __rcu *next;
struct sock *sk; struct sock *sk;
......
...@@ -287,7 +287,7 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname, ...@@ -287,7 +287,7 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
__be32 daddr, __be16 dport, int dif); __be32 daddr, __be16 dport, int dif);
struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
__be32 daddr, __be16 dport, int dif, __be32 daddr, __be16 dport, int dif, int sdif,
struct udp_table *tbl, struct sk_buff *skb); struct udp_table *tbl, struct sk_buff *skb);
struct sock *udp4_lib_lookup_skb(struct sk_buff *skb, struct sock *udp4_lib_lookup_skb(struct sk_buff *skb,
__be16 sport, __be16 dport); __be16 sport, __be16 dport);
......
...@@ -380,8 +380,8 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum) ...@@ -380,8 +380,8 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum)
static int compute_score(struct sock *sk, struct net *net, static int compute_score(struct sock *sk, struct net *net,
__be32 saddr, __be16 sport, __be32 saddr, __be16 sport,
__be32 daddr, unsigned short hnum, int dif, __be32 daddr, unsigned short hnum,
bool exact_dif) int dif, int sdif, bool exact_dif)
{ {
int score; int score;
struct inet_sock *inet; struct inet_sock *inet;
...@@ -413,10 +413,15 @@ static int compute_score(struct sock *sk, struct net *net, ...@@ -413,10 +413,15 @@ static int compute_score(struct sock *sk, struct net *net,
} }
if (sk->sk_bound_dev_if || exact_dif) { if (sk->sk_bound_dev_if || exact_dif) {
if (sk->sk_bound_dev_if != dif) bool dev_match = (sk->sk_bound_dev_if == dif ||
sk->sk_bound_dev_if == sdif);
if (exact_dif && !dev_match)
return -1; return -1;
if (sk->sk_bound_dev_if && dev_match)
score += 4; score += 4;
} }
if (sk->sk_incoming_cpu == raw_smp_processor_id()) if (sk->sk_incoming_cpu == raw_smp_processor_id())
score++; score++;
return score; return score;
...@@ -437,7 +442,8 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr, ...@@ -437,7 +442,8 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr,
/* called with rcu_read_lock() */ /* called with rcu_read_lock() */
static struct sock *udp4_lib_lookup2(struct net *net, static struct sock *udp4_lib_lookup2(struct net *net,
__be32 saddr, __be16 sport, __be32 saddr, __be16 sport,
__be32 daddr, unsigned int hnum, int dif, bool exact_dif, __be32 daddr, unsigned int hnum,
int dif, int sdif, bool exact_dif,
struct udp_hslot *hslot2, struct udp_hslot *hslot2,
struct sk_buff *skb) struct sk_buff *skb)
{ {
...@@ -449,7 +455,7 @@ static struct sock *udp4_lib_lookup2(struct net *net, ...@@ -449,7 +455,7 @@ static struct sock *udp4_lib_lookup2(struct net *net,
badness = 0; badness = 0;
udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
score = compute_score(sk, net, saddr, sport, score = compute_score(sk, net, saddr, sport,
daddr, hnum, dif, exact_dif); daddr, hnum, dif, sdif, exact_dif);
if (score > badness) { if (score > badness) {
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
...@@ -477,8 +483,8 @@ static struct sock *udp4_lib_lookup2(struct net *net, ...@@ -477,8 +483,8 @@ static struct sock *udp4_lib_lookup2(struct net *net,
* harder than this. -DaveM * harder than this. -DaveM
*/ */
struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
__be16 sport, __be32 daddr, __be16 dport, __be16 sport, __be32 daddr, __be16 dport, int dif,
int dif, struct udp_table *udptable, struct sk_buff *skb) int sdif, struct udp_table *udptable, struct sk_buff *skb)
{ {
struct sock *sk, *result; struct sock *sk, *result;
unsigned short hnum = ntohs(dport); unsigned short hnum = ntohs(dport);
...@@ -496,7 +502,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, ...@@ -496,7 +502,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
goto begin; goto begin;
result = udp4_lib_lookup2(net, saddr, sport, result = udp4_lib_lookup2(net, saddr, sport,
daddr, hnum, dif, daddr, hnum, dif, sdif,
exact_dif, hslot2, skb); exact_dif, hslot2, skb);
if (!result) { if (!result) {
unsigned int old_slot2 = slot2; unsigned int old_slot2 = slot2;
...@@ -511,7 +517,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, ...@@ -511,7 +517,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
goto begin; goto begin;
result = udp4_lib_lookup2(net, saddr, sport, result = udp4_lib_lookup2(net, saddr, sport,
daddr, hnum, dif, daddr, hnum, dif, sdif,
exact_dif, hslot2, skb); exact_dif, hslot2, skb);
} }
return result; return result;
...@@ -521,7 +527,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, ...@@ -521,7 +527,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
badness = 0; badness = 0;
sk_for_each_rcu(sk, &hslot->head) { sk_for_each_rcu(sk, &hslot->head) {
score = compute_score(sk, net, saddr, sport, score = compute_score(sk, net, saddr, sport,
daddr, hnum, dif, exact_dif); daddr, hnum, dif, sdif, exact_dif);
if (score > badness) { if (score > badness) {
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
...@@ -554,7 +560,7 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, ...@@ -554,7 +560,7 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport, return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport,
iph->daddr, dport, inet_iif(skb), iph->daddr, dport, inet_iif(skb),
udptable, skb); inet_sdif(skb), udptable, skb);
} }
struct sock *udp4_lib_lookup_skb(struct sk_buff *skb, struct sock *udp4_lib_lookup_skb(struct sk_buff *skb,
...@@ -576,7 +582,7 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, ...@@ -576,7 +582,7 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
struct sock *sk; struct sock *sk;
sk = __udp4_lib_lookup(net, saddr, sport, daddr, dport, sk = __udp4_lib_lookup(net, saddr, sport, daddr, dport,
dif, &udp_table, NULL); dif, 0, &udp_table, NULL);
if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
sk = NULL; sk = NULL;
return sk; return sk;
...@@ -587,7 +593,7 @@ EXPORT_SYMBOL_GPL(udp4_lib_lookup); ...@@ -587,7 +593,7 @@ EXPORT_SYMBOL_GPL(udp4_lib_lookup);
static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk, static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
__be16 loc_port, __be32 loc_addr, __be16 loc_port, __be32 loc_addr,
__be16 rmt_port, __be32 rmt_addr, __be16 rmt_port, __be32 rmt_addr,
int dif, unsigned short hnum) int dif, int sdif, unsigned short hnum)
{ {
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
...@@ -597,7 +603,8 @@ static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk, ...@@ -597,7 +603,8 @@ static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
(inet->inet_dport != rmt_port && inet->inet_dport) || (inet->inet_dport != rmt_port && inet->inet_dport) ||
(inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) || (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) ||
ipv6_only_sock(sk) || ipv6_only_sock(sk) ||
(sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif)) (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif &&
sk->sk_bound_dev_if != sdif))
return false; return false;
if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif)) if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif))
return false; return false;
...@@ -628,8 +635,8 @@ void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable) ...@@ -628,8 +635,8 @@ void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
struct net *net = dev_net(skb->dev); struct net *net = dev_net(skb->dev);
sk = __udp4_lib_lookup(net, iph->daddr, uh->dest, sk = __udp4_lib_lookup(net, iph->daddr, uh->dest,
iph->saddr, uh->source, skb->dev->ifindex, udptable, iph->saddr, uh->source, skb->dev->ifindex, 0,
NULL); udptable, NULL);
if (!sk) { if (!sk) {
__ICMP_INC_STATS(net, ICMP_MIB_INERRORS); __ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
return; /* No socket for error */ return; /* No socket for error */
...@@ -1953,6 +1960,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb, ...@@ -1953,6 +1960,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10); unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10);
unsigned int offset = offsetof(typeof(*sk), sk_node); unsigned int offset = offsetof(typeof(*sk), sk_node);
int dif = skb->dev->ifindex; int dif = skb->dev->ifindex;
int sdif = inet_sdif(skb);
struct hlist_node *node; struct hlist_node *node;
struct sk_buff *nskb; struct sk_buff *nskb;
...@@ -1967,7 +1975,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb, ...@@ -1967,7 +1975,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) { sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) {
if (!__udp_is_mcast_sock(net, sk, uh->dest, daddr, if (!__udp_is_mcast_sock(net, sk, uh->dest, daddr,
uh->source, saddr, dif, hnum)) uh->source, saddr, dif, sdif, hnum))
continue; continue;
if (!first) { if (!first) {
...@@ -2157,7 +2165,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -2157,7 +2165,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net, static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
__be16 loc_port, __be32 loc_addr, __be16 loc_port, __be32 loc_addr,
__be16 rmt_port, __be32 rmt_addr, __be16 rmt_port, __be32 rmt_addr,
int dif) int dif, int sdif)
{ {
struct sock *sk, *result; struct sock *sk, *result;
unsigned short hnum = ntohs(loc_port); unsigned short hnum = ntohs(loc_port);
...@@ -2171,7 +2179,7 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net, ...@@ -2171,7 +2179,7 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
result = NULL; result = NULL;
sk_for_each_rcu(sk, &hslot->head) { sk_for_each_rcu(sk, &hslot->head) {
if (__udp_is_mcast_sock(net, sk, loc_port, loc_addr, if (__udp_is_mcast_sock(net, sk, loc_port, loc_addr,
rmt_port, rmt_addr, dif, hnum)) { rmt_port, rmt_addr, dif, sdif, hnum)) {
if (result) if (result)
return NULL; return NULL;
result = sk; result = sk;
...@@ -2216,6 +2224,7 @@ void udp_v4_early_demux(struct sk_buff *skb) ...@@ -2216,6 +2224,7 @@ void udp_v4_early_demux(struct sk_buff *skb)
struct sock *sk = NULL; struct sock *sk = NULL;
struct dst_entry *dst; struct dst_entry *dst;
int dif = skb->dev->ifindex; int dif = skb->dev->ifindex;
int sdif = inet_sdif(skb);
int ours; int ours;
/* validate the packet */ /* validate the packet */
...@@ -2241,7 +2250,8 @@ void udp_v4_early_demux(struct sk_buff *skb) ...@@ -2241,7 +2250,8 @@ void udp_v4_early_demux(struct sk_buff *skb)
} }
sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr,
uh->source, iph->saddr, dif); uh->source, iph->saddr,
dif, sdif);
} else if (skb->pkt_type == PACKET_HOST) { } else if (skb->pkt_type == PACKET_HOST) {
sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr, sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr,
uh->source, iph->saddr, dif); uh->source, iph->saddr, dif);
......
...@@ -45,7 +45,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb, ...@@ -45,7 +45,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb,
sk = __udp4_lib_lookup(net, sk = __udp4_lib_lookup(net,
req->id.idiag_src[0], req->id.idiag_sport, req->id.idiag_src[0], req->id.idiag_sport,
req->id.idiag_dst[0], req->id.idiag_dport, req->id.idiag_dst[0], req->id.idiag_dport,
req->id.idiag_if, tbl, NULL); req->id.idiag_if, 0, tbl, NULL);
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
else if (req->sdiag_family == AF_INET6) else if (req->sdiag_family == AF_INET6)
sk = __udp6_lib_lookup(net, sk = __udp6_lib_lookup(net,
...@@ -182,7 +182,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb, ...@@ -182,7 +182,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
sk = __udp4_lib_lookup(net, sk = __udp4_lib_lookup(net,
req->id.idiag_dst[0], req->id.idiag_dport, req->id.idiag_dst[0], req->id.idiag_dport,
req->id.idiag_src[0], req->id.idiag_sport, req->id.idiag_src[0], req->id.idiag_sport,
req->id.idiag_if, tbl, NULL); req->id.idiag_if, 0, tbl, NULL);
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
else if (req->sdiag_family == AF_INET6) { else if (req->sdiag_family == AF_INET6) {
if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) && if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
...@@ -190,7 +190,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb, ...@@ -190,7 +190,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb,
sk = __udp4_lib_lookup(net, sk = __udp4_lib_lookup(net,
req->id.idiag_dst[3], req->id.idiag_dport, req->id.idiag_dst[3], req->id.idiag_dport,
req->id.idiag_src[3], req->id.idiag_sport, req->id.idiag_src[3], req->id.idiag_sport,
req->id.idiag_if, tbl, NULL); req->id.idiag_if, 0, tbl, NULL);
else else
sk = __udp6_lib_lookup(net, sk = __udp6_lib_lookup(net,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment