Commit 25030a7f authored by Gerrit Renker's avatar Gerrit Renker Committed by David S. Miller

[UDP]: Unify UDPv4 and UDPv6 ->get_port()

This patch creates one common function which is called by
udp_v4_get_port() and udp_v6_get_port(). As a result,
  * duplicated code is removed
  * udp_port_rover and local port lookup can now be removed from udp.h
  * further savings follow since the same function will be used by UDP-Litev4
    and UDP-Litev6

In contrast to the patch sent in response to Yoshifujis comments
(fixed by this variant), the code below also removes the
EXPORT_SYMBOL(udp_port_rover), since udp_port_rover can now remain
local to net/ipv4/udp.c.
Signed-off-by: default avatarGerrit Renker <gerrit@erg.abdn.ac.uk>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 6a28ec8c
...@@ -30,25 +30,9 @@ ...@@ -30,25 +30,9 @@
#define UDP_HTABLE_SIZE 128 #define UDP_HTABLE_SIZE 128
/* udp.c: This needs to be shared by v4 and v6 because the lookup
* and hashing code needs to work with different AF's yet
* the port space is shared.
*/
extern struct hlist_head udp_hash[UDP_HTABLE_SIZE]; extern struct hlist_head udp_hash[UDP_HTABLE_SIZE];
extern rwlock_t udp_hash_lock; extern rwlock_t udp_hash_lock;
extern int udp_port_rover;
static inline int udp_lport_inuse(u16 num)
{
struct sock *sk;
struct hlist_node *node;
sk_for_each(sk, node, &udp_hash[num & (UDP_HTABLE_SIZE - 1)])
if (inet_sk(sk)->num == num)
return 1;
return 0;
}
/* Note: this must match 'valbool' in sock_setsockopt */ /* Note: this must match 'valbool' in sock_setsockopt */
#define UDP_CSUM_NOXMIT 1 #define UDP_CSUM_NOXMIT 1
...@@ -63,6 +47,8 @@ extern struct proto udp_prot; ...@@ -63,6 +47,8 @@ extern struct proto udp_prot;
struct sk_buff; struct sk_buff;
extern int udp_get_port(struct sock *sk, unsigned short snum,
int (*saddr_cmp)(struct sock *, struct sock *));
extern void udp_err(struct sk_buff *, u32); extern void udp_err(struct sk_buff *, u32);
extern int udp_sendmsg(struct kiocb *iocb, struct sock *sk, extern int udp_sendmsg(struct kiocb *iocb, struct sock *sk,
......
...@@ -118,14 +118,34 @@ DEFINE_SNMP_STAT(struct udp_mib, udp_statistics) __read_mostly; ...@@ -118,14 +118,34 @@ DEFINE_SNMP_STAT(struct udp_mib, udp_statistics) __read_mostly;
struct hlist_head udp_hash[UDP_HTABLE_SIZE]; struct hlist_head udp_hash[UDP_HTABLE_SIZE];
DEFINE_RWLOCK(udp_hash_lock); DEFINE_RWLOCK(udp_hash_lock);
/* Shared by v4/v6 udp. */ /* Shared by v4/v6 udp_get_port */
int udp_port_rover; int udp_port_rover;
static int udp_v4_get_port(struct sock *sk, unsigned short snum) static inline int udp_lport_inuse(u16 num)
{
struct sock *sk;
struct hlist_node *node;
sk_for_each(sk, node, &udp_hash[num & (UDP_HTABLE_SIZE - 1)])
if (inet_sk(sk)->num == num)
return 1;
return 0;
}
/**
* udp_get_port - common port lookup for IPv4 and IPv6
*
* @sk: socket struct in question
* @snum: port number to look up
* @saddr_comp: AF-dependent comparison of bound local IP addresses
*/
int udp_get_port(struct sock *sk, unsigned short snum,
int (*saddr_cmp)(struct sock *sk1, struct sock *sk2))
{ {
struct hlist_node *node; struct hlist_node *node;
struct hlist_head *head;
struct sock *sk2; struct sock *sk2;
struct inet_sock *inet = inet_sk(sk); int error = 1;
write_lock_bh(&udp_hash_lock); write_lock_bh(&udp_hash_lock);
if (snum == 0) { if (snum == 0) {
...@@ -137,11 +157,10 @@ static int udp_v4_get_port(struct sock *sk, unsigned short snum) ...@@ -137,11 +157,10 @@ static int udp_v4_get_port(struct sock *sk, unsigned short snum)
best_size_so_far = 32767; best_size_so_far = 32767;
best = result = udp_port_rover; best = result = udp_port_rover;
for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) { for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) {
struct hlist_head *list;
int size; int size;
list = &udp_hash[result & (UDP_HTABLE_SIZE - 1)]; head = &udp_hash[result & (UDP_HTABLE_SIZE - 1)];
if (hlist_empty(list)) { if (hlist_empty(head)) {
if (result > sysctl_local_port_range[1]) if (result > sysctl_local_port_range[1])
result = sysctl_local_port_range[0] + result = sysctl_local_port_range[0] +
((result - sysctl_local_port_range[0]) & ((result - sysctl_local_port_range[0]) &
...@@ -149,12 +168,11 @@ static int udp_v4_get_port(struct sock *sk, unsigned short snum) ...@@ -149,12 +168,11 @@ static int udp_v4_get_port(struct sock *sk, unsigned short snum)
goto gotit; goto gotit;
} }
size = 0; size = 0;
sk_for_each(sk2, node, list) sk_for_each(sk2, node, head)
if (++size >= best_size_so_far) if (++size < best_size_so_far) {
goto next;
best_size_so_far = size; best_size_so_far = size;
best = result; best = result;
next:; }
} }
result = best; result = best;
for(i = 0; i < (1 << 16) / UDP_HTABLE_SIZE; i++, result += UDP_HTABLE_SIZE) { for(i = 0; i < (1 << 16) / UDP_HTABLE_SIZE; i++, result += UDP_HTABLE_SIZE) {
...@@ -170,38 +188,44 @@ static int udp_v4_get_port(struct sock *sk, unsigned short snum) ...@@ -170,38 +188,44 @@ static int udp_v4_get_port(struct sock *sk, unsigned short snum)
gotit: gotit:
udp_port_rover = snum = result; udp_port_rover = snum = result;
} else { } else {
sk_for_each(sk2, node, head = &udp_hash[snum & (UDP_HTABLE_SIZE - 1)];
&udp_hash[snum & (UDP_HTABLE_SIZE - 1)]) {
struct inet_sock *inet2 = inet_sk(sk2);
if (inet2->num == snum && sk_for_each(sk2, node, head)
if (inet_sk(sk2)->num == snum &&
sk2 != sk && sk2 != sk &&
!ipv6_only_sock(sk2) && (!sk2->sk_reuse || !sk->sk_reuse) &&
(!sk2->sk_bound_dev_if || (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if
!sk->sk_bound_dev_if || || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && (*saddr_cmp)(sk, sk2) )
(!inet2->rcv_saddr ||
!inet->rcv_saddr ||
inet2->rcv_saddr == inet->rcv_saddr) &&
(!sk2->sk_reuse || !sk->sk_reuse))
goto fail; goto fail;
} }
} inet_sk(sk)->num = snum;
inet->num = snum;
if (sk_unhashed(sk)) { if (sk_unhashed(sk)) {
struct hlist_head *h = &udp_hash[snum & (UDP_HTABLE_SIZE - 1)]; head = &udp_hash[snum & (UDP_HTABLE_SIZE - 1)];
sk_add_node(sk, head);
sk_add_node(sk, h);
sock_prot_inc_use(sk->sk_prot); sock_prot_inc_use(sk->sk_prot);
} }
write_unlock_bh(&udp_hash_lock); error = 0;
return 0;
fail: fail:
write_unlock_bh(&udp_hash_lock); write_unlock_bh(&udp_hash_lock);
return 1; return error;
} }
static inline int ipv4_rcv_saddr_equal(struct sock *sk1, struct sock *sk2)
{
struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
return ( !ipv6_only_sock(sk2) &&
(!inet1->rcv_saddr || !inet2->rcv_saddr ||
inet1->rcv_saddr == inet2->rcv_saddr ));
}
static inline int udp_v4_get_port(struct sock *sk, unsigned short snum)
{
return udp_get_port(sk, snum, ipv4_rcv_saddr_equal);
}
static void udp_v4_hash(struct sock *sk) static void udp_v4_hash(struct sock *sk)
{ {
BUG(); BUG();
...@@ -1596,7 +1620,7 @@ EXPORT_SYMBOL(udp_disconnect); ...@@ -1596,7 +1620,7 @@ EXPORT_SYMBOL(udp_disconnect);
EXPORT_SYMBOL(udp_hash); EXPORT_SYMBOL(udp_hash);
EXPORT_SYMBOL(udp_hash_lock); EXPORT_SYMBOL(udp_hash_lock);
EXPORT_SYMBOL(udp_ioctl); EXPORT_SYMBOL(udp_ioctl);
EXPORT_SYMBOL(udp_port_rover); EXPORT_SYMBOL(udp_get_port);
EXPORT_SYMBOL(udp_prot); EXPORT_SYMBOL(udp_prot);
EXPORT_SYMBOL(udp_sendmsg); EXPORT_SYMBOL(udp_sendmsg);
EXPORT_SYMBOL(udp_poll); EXPORT_SYMBOL(udp_poll);
......
...@@ -61,81 +61,9 @@ ...@@ -61,81 +61,9 @@
DEFINE_SNMP_STAT(struct udp_mib, udp_stats_in6) __read_mostly; DEFINE_SNMP_STAT(struct udp_mib, udp_stats_in6) __read_mostly;
/* Grrr, addr_type already calculated by caller, but I don't want static inline int udp_v6_get_port(struct sock *sk, unsigned short snum)
* to add some silly "cookie" argument to this method just for that.
*/
static int udp_v6_get_port(struct sock *sk, unsigned short snum)
{ {
struct sock *sk2; return udp_get_port(sk, snum, ipv6_rcv_saddr_equal);
struct hlist_node *node;
write_lock_bh(&udp_hash_lock);
if (snum == 0) {
int best_size_so_far, best, result, i;
if (udp_port_rover > sysctl_local_port_range[1] ||
udp_port_rover < sysctl_local_port_range[0])
udp_port_rover = sysctl_local_port_range[0];
best_size_so_far = 32767;
best = result = udp_port_rover;
for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) {
int size;
struct hlist_head *list;
list = &udp_hash[result & (UDP_HTABLE_SIZE - 1)];
if (hlist_empty(list)) {
if (result > sysctl_local_port_range[1])
result = sysctl_local_port_range[0] +
((result - sysctl_local_port_range[0]) &
(UDP_HTABLE_SIZE - 1));
goto gotit;
}
size = 0;
sk_for_each(sk2, node, list)
if (++size >= best_size_so_far)
goto next;
best_size_so_far = size;
best = result;
next:;
}
result = best;
for(i = 0; i < (1 << 16) / UDP_HTABLE_SIZE; i++, result += UDP_HTABLE_SIZE) {
if (result > sysctl_local_port_range[1])
result = sysctl_local_port_range[0]
+ ((result - sysctl_local_port_range[0]) &
(UDP_HTABLE_SIZE - 1));
if (!udp_lport_inuse(result))
break;
}
if (i >= (1 << 16) / UDP_HTABLE_SIZE)
goto fail;
gotit:
udp_port_rover = snum = result;
} else {
sk_for_each(sk2, node,
&udp_hash[snum & (UDP_HTABLE_SIZE - 1)]) {
if (inet_sk(sk2)->num == snum &&
sk2 != sk &&
(!sk2->sk_bound_dev_if ||
!sk->sk_bound_dev_if ||
sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
(!sk2->sk_reuse || !sk->sk_reuse) &&
ipv6_rcv_saddr_equal(sk, sk2))
goto fail;
}
}
inet_sk(sk)->num = snum;
if (sk_unhashed(sk)) {
sk_add_node(sk, &udp_hash[snum & (UDP_HTABLE_SIZE - 1)]);
sock_prot_inc_use(sk->sk_prot);
}
write_unlock_bh(&udp_hash_lock);
return 0;
fail:
write_unlock_bh(&udp_hash_lock);
return 1;
} }
static void udp_v6_hash(struct sock *sk) static void udp_v6_hash(struct sock *sk)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment