Commit b9d92bc3 authored by Stephen Hemminger's avatar Stephen Hemminger Committed by David S. Miller

[IPV4/IPV6]: inetsw using RCU.

parent cd65eaf5
...@@ -80,11 +80,9 @@ struct inet_protosw { ...@@ -80,11 +80,9 @@ struct inet_protosw {
extern struct inet_protocol *inet_protocol_base; extern struct inet_protocol *inet_protocol_base;
extern struct inet_protocol *inet_protos[MAX_INET_PROTOS]; extern struct inet_protocol *inet_protos[MAX_INET_PROTOS];
extern struct list_head inetsw[SOCK_MAX];
#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE) #if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
extern struct inet6_protocol *inet6_protos[MAX_INET_PROTOS]; extern struct inet6_protocol *inet6_protos[MAX_INET_PROTOS];
extern struct list_head inetsw6[SOCK_MAX];
#endif #endif
extern int inet_add_protocol(struct inet_protocol *prot, unsigned char num); extern int inet_add_protocol(struct inet_protocol *prot, unsigned char num);
......
...@@ -94,7 +94,6 @@ ...@@ -94,7 +94,6 @@
#include <linux/inet.h> #include <linux/inet.h>
#include <linux/igmp.h> #include <linux/igmp.h>
#include <linux/netdevice.h> #include <linux/netdevice.h>
#include <linux/brlock.h>
#include <net/ip.h> #include <net/ip.h>
#include <net/protocol.h> #include <net/protocol.h>
#include <net/arp.h> #include <net/arp.h>
...@@ -129,7 +128,8 @@ static kmem_cache_t *raw4_sk_cachep; ...@@ -129,7 +128,8 @@ static kmem_cache_t *raw4_sk_cachep;
/* The inetsw table contains everything that inet_create needs to /* The inetsw table contains everything that inet_create needs to
* build a new socket. * build a new socket.
*/ */
struct list_head inetsw[SOCK_MAX]; static struct list_head inetsw[SOCK_MAX];
static spinlock_t inetsw_lock = SPIN_LOCK_UNLOCKED;
/* New destruction routine */ /* New destruction routine */
...@@ -337,8 +337,8 @@ static int inet_create(struct socket *sock, int protocol) ...@@ -337,8 +337,8 @@ static int inet_create(struct socket *sock, int protocol)
/* Look for the requested type/protocol pair. */ /* Look for the requested type/protocol pair. */
answer = NULL; answer = NULL;
br_read_lock_bh(BR_NETPROTO_LOCK); rcu_read_lock();
list_for_each(p, &inetsw[sock->type]) { list_for_each_rcu(p, &inetsw[sock->type]) {
answer = list_entry(p, struct inet_protosw, list); answer = list_entry(p, struct inet_protosw, list);
/* Check the non-wild match. */ /* Check the non-wild match. */
...@@ -356,7 +356,6 @@ static int inet_create(struct socket *sock, int protocol) ...@@ -356,7 +356,6 @@ static int inet_create(struct socket *sock, int protocol)
} }
answer = NULL; answer = NULL;
} }
br_read_unlock_bh(BR_NETPROTO_LOCK);
err = -ESOCKTNOSUPPORT; err = -ESOCKTNOSUPPORT;
if (!answer) if (!answer)
...@@ -373,6 +372,7 @@ static int inet_create(struct socket *sock, int protocol) ...@@ -373,6 +372,7 @@ static int inet_create(struct socket *sock, int protocol)
sk->no_check = answer->no_check; sk->no_check = answer->no_check;
if (INET_PROTOSW_REUSE & answer->flags) if (INET_PROTOSW_REUSE & answer->flags)
sk->reuse = 1; sk->reuse = 1;
rcu_read_unlock();
inet = inet_sk(sk); inet = inet_sk(sk);
...@@ -427,6 +427,7 @@ static int inet_create(struct socket *sock, int protocol) ...@@ -427,6 +427,7 @@ static int inet_create(struct socket *sock, int protocol)
out: out:
return err; return err;
out_sk_free: out_sk_free:
rcu_read_unlock();
sk_free(sk); sk_free(sk);
goto out; goto out;
} }
...@@ -979,7 +980,7 @@ void inet_register_protosw(struct inet_protosw *p) ...@@ -979,7 +980,7 @@ void inet_register_protosw(struct inet_protosw *p)
int protocol = p->protocol; int protocol = p->protocol;
struct list_head *last_perm; struct list_head *last_perm;
br_write_lock_bh(BR_NETPROTO_LOCK); spin_lock_bh(&inetsw_lock);
if (p->type > SOCK_MAX) if (p->type > SOCK_MAX)
goto out_illegal; goto out_illegal;
...@@ -1008,9 +1009,12 @@ void inet_register_protosw(struct inet_protosw *p) ...@@ -1008,9 +1009,12 @@ void inet_register_protosw(struct inet_protosw *p)
* non-permanent entry. This means that when we remove this entry, the * non-permanent entry. This means that when we remove this entry, the
* system automatically returns to the old behavior. * system automatically returns to the old behavior.
*/ */
list_add(&p->list, last_perm); list_add_rcu(&p->list, last_perm);
out: out:
br_write_unlock_bh(BR_NETPROTO_LOCK); spin_unlock_bh(&inetsw_lock);
synchronize_kernel();
return; return;
out_permanent: out_permanent:
...@@ -1032,9 +1036,11 @@ void inet_unregister_protosw(struct inet_protosw *p) ...@@ -1032,9 +1036,11 @@ void inet_unregister_protosw(struct inet_protosw *p)
"Attempt to unregister permanent protocol %d.\n", "Attempt to unregister permanent protocol %d.\n",
p->protocol); p->protocol);
} else { } else {
br_write_lock_bh(BR_NETPROTO_LOCK); spin_lock_bh(&inetsw_lock);
list_del(&p->list); list_del_rcu(&p->list);
br_write_unlock_bh(BR_NETPROTO_LOCK); spin_unlock_bh(&inetsw_lock);
synchronize_kernel();
} }
} }
......
...@@ -695,15 +695,12 @@ static void icmp_unreach(struct sk_buff *skb) ...@@ -695,15 +695,12 @@ static void icmp_unreach(struct sk_buff *skb)
} }
read_unlock(&raw_v4_lock); read_unlock(&raw_v4_lock);
/* rcu_read_lock();
* This can't change while we are doing it.
* Callers have obtained BR_NETPROTO_LOCK so
* we are OK.
*/
ipprot = inet_protos[hash]; ipprot = inet_protos[hash];
smp_read_barrier_depends();
if (ipprot && ipprot->err_handler) if (ipprot && ipprot->err_handler)
ipprot->err_handler(skb, info); ipprot->err_handler(skb, info);
rcu_read_unlock();
out: out:
return; return;
......
...@@ -215,6 +215,7 @@ static inline int ip_local_deliver_finish(struct sk_buff *skb) ...@@ -215,6 +215,7 @@ static inline int ip_local_deliver_finish(struct sk_buff *skb)
/* Point into the IP datagram, just past the header. */ /* Point into the IP datagram, just past the header. */
skb->h.raw = skb->data; skb->h.raw = skb->data;
rcu_read_lock();
{ {
/* Note: See raw.c and net/raw.h, RAWV4_HTABLE_SIZE==MAX_INET_PROTOS */ /* Note: See raw.c and net/raw.h, RAWV4_HTABLE_SIZE==MAX_INET_PROTOS */
int protocol = skb->nh.iph->protocol; int protocol = skb->nh.iph->protocol;
...@@ -235,10 +236,11 @@ static inline int ip_local_deliver_finish(struct sk_buff *skb) ...@@ -235,10 +236,11 @@ static inline int ip_local_deliver_finish(struct sk_buff *skb)
if ((ipprot = inet_protos[hash]) != NULL) { if ((ipprot = inet_protos[hash]) != NULL) {
int ret; int ret;
smp_read_barrier_depends();
if (!ipprot->no_policy && if (!ipprot->no_policy &&
!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) { !xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) {
kfree_skb(skb); kfree_skb(skb);
return 0; goto out;
} }
ret = ipprot->handler(skb); ret = ipprot->handler(skb);
if (ret < 0) { if (ret < 0) {
...@@ -258,6 +260,8 @@ static inline int ip_local_deliver_finish(struct sk_buff *skb) ...@@ -258,6 +260,8 @@ static inline int ip_local_deliver_finish(struct sk_buff *skb)
kfree_skb(skb); kfree_skb(skb);
} }
} }
out:
rcu_read_unlock();
return 0; return 0;
} }
......
...@@ -37,7 +37,6 @@ ...@@ -37,7 +37,6 @@
#include <linux/inet.h> #include <linux/inet.h>
#include <linux/netdevice.h> #include <linux/netdevice.h>
#include <linux/timer.h> #include <linux/timer.h>
#include <linux/brlock.h>
#include <net/ip.h> #include <net/ip.h>
#include <net/protocol.h> #include <net/protocol.h>
#include <net/tcp.h> #include <net/tcp.h>
...@@ -49,6 +48,7 @@ ...@@ -49,6 +48,7 @@
#include <linux/igmp.h> #include <linux/igmp.h>
struct inet_protocol *inet_protos[MAX_INET_PROTOS]; struct inet_protocol *inet_protos[MAX_INET_PROTOS];
static spinlock_t inet_proto_lock = SPIN_LOCK_UNLOCKED;
/* /*
* Add a protocol handler to the hash tables * Add a protocol handler to the hash tables
...@@ -60,16 +60,14 @@ int inet_add_protocol(struct inet_protocol *prot, unsigned char protocol) ...@@ -60,16 +60,14 @@ int inet_add_protocol(struct inet_protocol *prot, unsigned char protocol)
hash = protocol & (MAX_INET_PROTOS - 1); hash = protocol & (MAX_INET_PROTOS - 1);
br_write_lock_bh(BR_NETPROTO_LOCK); spin_lock_bh(&inet_proto_lock);
if (inet_protos[hash]) { if (inet_protos[hash]) {
ret = -1; ret = -1;
} else { } else {
inet_protos[hash] = prot; inet_protos[hash] = prot;
ret = 0; ret = 0;
} }
spin_unlock_bh(&inet_proto_lock);
br_write_unlock_bh(BR_NETPROTO_LOCK);
return ret; return ret;
} }
...@@ -84,16 +82,15 @@ int inet_del_protocol(struct inet_protocol *prot, unsigned char protocol) ...@@ -84,16 +82,15 @@ int inet_del_protocol(struct inet_protocol *prot, unsigned char protocol)
hash = protocol & (MAX_INET_PROTOS - 1); hash = protocol & (MAX_INET_PROTOS - 1);
br_write_lock_bh(BR_NETPROTO_LOCK); spin_lock_bh(&inet_proto_lock);
if (inet_protos[hash] == prot) { if (inet_protos[hash] == prot) {
inet_protos[hash] = NULL; inet_protos[hash] = NULL;
ret = 0; ret = 0;
} else { } else {
ret = -1; ret = -1;
} }
spin_unlock_bh(&inet_proto_lock);
br_write_unlock_bh(BR_NETPROTO_LOCK);
return ret; return ret;
} }
...@@ -45,7 +45,6 @@ ...@@ -45,7 +45,6 @@
#include <linux/inet.h> #include <linux/inet.h>
#include <linux/netdevice.h> #include <linux/netdevice.h>
#include <linux/icmpv6.h> #include <linux/icmpv6.h>
#include <linux/brlock.h>
#include <linux/smp_lock.h> #include <linux/smp_lock.h>
#include <net/ip.h> #include <net/ip.h>
...@@ -102,7 +101,8 @@ kmem_cache_t *raw6_sk_cachep; ...@@ -102,7 +101,8 @@ kmem_cache_t *raw6_sk_cachep;
/* The inetsw table contains everything that inet_create needs to /* The inetsw table contains everything that inet_create needs to
* build a new socket. * build a new socket.
*/ */
struct list_head inetsw6[SOCK_MAX]; static struct list_head inetsw6[SOCK_MAX];
static spinlock_t inetsw6_lock = SPIN_LOCK_UNLOCKED;
static void inet6_sock_destruct(struct sock *sk) static void inet6_sock_destruct(struct sock *sk)
{ {
...@@ -162,8 +162,8 @@ static int inet6_create(struct socket *sock, int protocol) ...@@ -162,8 +162,8 @@ static int inet6_create(struct socket *sock, int protocol)
/* Look for the requested type/protocol pair. */ /* Look for the requested type/protocol pair. */
answer = NULL; answer = NULL;
br_read_lock_bh(BR_NETPROTO_LOCK); rcu_read_lock();
list_for_each(p, &inetsw6[sock->type]) { list_for_each_rcu(p, &inetsw6[sock->type]) {
answer = list_entry(p, struct inet_protosw, list); answer = list_entry(p, struct inet_protosw, list);
/* Check the non-wild match. */ /* Check the non-wild match. */
...@@ -181,7 +181,6 @@ static int inet6_create(struct socket *sock, int protocol) ...@@ -181,7 +181,6 @@ static int inet6_create(struct socket *sock, int protocol)
} }
answer = NULL; answer = NULL;
} }
br_read_unlock_bh(BR_NETPROTO_LOCK);
if (!answer) if (!answer)
goto free_and_badtype; goto free_and_badtype;
...@@ -198,6 +197,7 @@ static int inet6_create(struct socket *sock, int protocol) ...@@ -198,6 +197,7 @@ static int inet6_create(struct socket *sock, int protocol)
sk->no_check = answer->no_check; sk->no_check = answer->no_check;
if (INET_PROTOSW_REUSE & answer->flags) if (INET_PROTOSW_REUSE & answer->flags)
sk->reuse = 1; sk->reuse = 1;
rcu_read_unlock();
inet = inet_sk(sk); inet = inet_sk(sk);
...@@ -260,12 +260,15 @@ static int inet6_create(struct socket *sock, int protocol) ...@@ -260,12 +260,15 @@ static int inet6_create(struct socket *sock, int protocol)
return 0; return 0;
free_and_badtype: free_and_badtype:
rcu_read_unlock();
sk_free(sk); sk_free(sk);
return -ESOCKTNOSUPPORT; return -ESOCKTNOSUPPORT;
free_and_badperm: free_and_badperm:
rcu_read_unlock();
sk_free(sk); sk_free(sk);
return -EPERM; return -EPERM;
free_and_noproto: free_and_noproto:
rcu_read_unlock();
sk_free(sk); sk_free(sk);
return -EPROTONOSUPPORT; return -EPROTONOSUPPORT;
do_oom: do_oom:
...@@ -574,7 +577,7 @@ inet6_register_protosw(struct inet_protosw *p) ...@@ -574,7 +577,7 @@ inet6_register_protosw(struct inet_protosw *p)
int protocol = p->protocol; int protocol = p->protocol;
struct list_head *last_perm; struct list_head *last_perm;
br_write_lock_bh(BR_NETPROTO_LOCK); spin_lock_bh(&inetsw6_lock);
if (p->type > SOCK_MAX) if (p->type > SOCK_MAX)
goto out_illegal; goto out_illegal;
...@@ -603,9 +606,9 @@ inet6_register_protosw(struct inet_protosw *p) ...@@ -603,9 +606,9 @@ inet6_register_protosw(struct inet_protosw *p)
* non-permanent entry. This means that when we remove this entry, the * non-permanent entry. This means that when we remove this entry, the
* system automatically returns to the old behavior. * system automatically returns to the old behavior.
*/ */
list_add(&p->list, last_perm); list_add_rcu(&p->list, last_perm);
out: out:
br_write_unlock_bh(BR_NETPROTO_LOCK); spin_unlock_bh(&inetsw6_lock);
return; return;
out_permanent: out_permanent:
...@@ -623,7 +626,17 @@ inet6_register_protosw(struct inet_protosw *p) ...@@ -623,7 +626,17 @@ inet6_register_protosw(struct inet_protosw *p)
void void
inet6_unregister_protosw(struct inet_protosw *p) inet6_unregister_protosw(struct inet_protosw *p)
{ {
inet_unregister_protosw(p); if (INET_PROTOSW_PERMANENT & p->flags) {
printk(KERN_ERR
"Attempt to unregister permanent protocol %d.\n",
p->protocol);
} else {
spin_lock_bh(&inetsw6_lock);
list_del_rcu(&p->list);
spin_unlock_bh(&inetsw6_lock);
synchronize_kernel();
}
} }
int int
......
...@@ -456,9 +456,12 @@ static void icmpv6_notify(struct sk_buff *skb, int type, int code, u32 info) ...@@ -456,9 +456,12 @@ static void icmpv6_notify(struct sk_buff *skb, int type, int code, u32 info)
hash = nexthdr & (MAX_INET_PROTOS - 1); hash = nexthdr & (MAX_INET_PROTOS - 1);
rcu_read_lock();
ipprot = inet6_protos[hash]; ipprot = inet6_protos[hash];
smp_read_barrier_depends();
if (ipprot && ipprot->err_handler) if (ipprot && ipprot->err_handler)
ipprot->err_handler(skb, NULL, type, code, inner_offset, info); ipprot->err_handler(skb, NULL, type, code, inner_offset, info);
rcu_read_unlock();
read_lock(&raw_v6_lock); read_lock(&raw_v6_lock);
if ((sk = raw_v6_htable[hash]) != NULL) { if ((sk = raw_v6_htable[hash]) != NULL) {
......
...@@ -152,6 +152,7 @@ static inline int ip6_input_finish(struct sk_buff *skb) ...@@ -152,6 +152,7 @@ static inline int ip6_input_finish(struct sk_buff *skb)
skb->h.raw += (skb->h.raw[1]+1)<<3; skb->h.raw += (skb->h.raw[1]+1)<<3;
} }
rcu_read_lock();
resubmit: resubmit:
if (!pskb_pull(skb, skb->h.raw - skb->data)) if (!pskb_pull(skb, skb->h.raw - skb->data))
goto discard; goto discard;
...@@ -165,6 +166,7 @@ static inline int ip6_input_finish(struct sk_buff *skb) ...@@ -165,6 +166,7 @@ static inline int ip6_input_finish(struct sk_buff *skb)
if ((ipprot = inet6_protos[hash]) != NULL) { if ((ipprot = inet6_protos[hash]) != NULL) {
int ret; int ret;
smp_read_barrier_depends();
if (ipprot->flags & INET6_PROTO_FINAL) { if (ipprot->flags & INET6_PROTO_FINAL) {
if (!cksum_sub && skb->ip_summed == CHECKSUM_HW) { if (!cksum_sub && skb->ip_summed == CHECKSUM_HW) {
skb->csum = csum_sub(skb->csum, skb->csum = csum_sub(skb->csum,
...@@ -173,10 +175,8 @@ static inline int ip6_input_finish(struct sk_buff *skb) ...@@ -173,10 +175,8 @@ static inline int ip6_input_finish(struct sk_buff *skb)
} }
} }
if (!(ipprot->flags & INET6_PROTO_NOPOLICY) && if (!(ipprot->flags & INET6_PROTO_NOPOLICY) &&
!xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb)) { !xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb))
kfree_skb(skb); goto discard;
return 0;
}
ret = ipprot->handler(&skb, &nhoff); ret = ipprot->handler(&skb, &nhoff);
if (ret > 0) if (ret > 0)
...@@ -194,10 +194,11 @@ static inline int ip6_input_finish(struct sk_buff *skb) ...@@ -194,10 +194,11 @@ static inline int ip6_input_finish(struct sk_buff *skb)
kfree_skb(skb); kfree_skb(skb);
} }
} }
rcu_read_unlock();
return 0; return 0;
discard: discard:
rcu_read_unlock();
kfree_skb(skb); kfree_skb(skb);
return 0; return 0;
} }
......
...@@ -32,7 +32,6 @@ ...@@ -32,7 +32,6 @@
#include <linux/in6.h> #include <linux/in6.h>
#include <linux/netdevice.h> #include <linux/netdevice.h>
#include <linux/if_arp.h> #include <linux/if_arp.h>
#include <linux/brlock.h>
#include <net/sock.h> #include <net/sock.h>
#include <net/snmp.h> #include <net/snmp.h>
...@@ -41,12 +40,14 @@ ...@@ -41,12 +40,14 @@
#include <net/protocol.h> #include <net/protocol.h>
struct inet6_protocol *inet6_protos[MAX_INET_PROTOS]; struct inet6_protocol *inet6_protos[MAX_INET_PROTOS];
static spinlock_t inet6_proto_lock = SPIN_LOCK_UNLOCKED;
int inet6_add_protocol(struct inet6_protocol *prot, unsigned char protocol) int inet6_add_protocol(struct inet6_protocol *prot, unsigned char protocol)
{ {
int ret, hash = protocol & (MAX_INET_PROTOS - 1); int ret, hash = protocol & (MAX_INET_PROTOS - 1);
br_write_lock_bh(BR_NETPROTO_LOCK); spin_lock_bh(&inet6_proto_lock);
if (inet6_protos[hash]) { if (inet6_protos[hash]) {
ret = -1; ret = -1;
...@@ -55,7 +56,7 @@ int inet6_add_protocol(struct inet6_protocol *prot, unsigned char protocol) ...@@ -55,7 +56,7 @@ int inet6_add_protocol(struct inet6_protocol *prot, unsigned char protocol)
ret = 0; ret = 0;
} }
br_write_unlock_bh(BR_NETPROTO_LOCK); spin_unlock_bh(&inet6_proto_lock);
return ret; return ret;
} }
...@@ -68,7 +69,7 @@ int inet6_del_protocol(struct inet6_protocol *prot, unsigned char protocol) ...@@ -68,7 +69,7 @@ int inet6_del_protocol(struct inet6_protocol *prot, unsigned char protocol)
{ {
int ret, hash = protocol & (MAX_INET_PROTOS - 1); int ret, hash = protocol & (MAX_INET_PROTOS - 1);
br_write_lock_bh(BR_NETPROTO_LOCK); spin_lock_bh(&inet6_proto_lock);
if (inet6_protos[hash] != prot) { if (inet6_protos[hash] != prot) {
ret = -1; ret = -1;
...@@ -77,7 +78,7 @@ int inet6_del_protocol(struct inet6_protocol *prot, unsigned char protocol) ...@@ -77,7 +78,7 @@ int inet6_del_protocol(struct inet6_protocol *prot, unsigned char protocol)
ret = 0; ret = 0;
} }
br_write_unlock_bh(BR_NETPROTO_LOCK); spin_unlock_bh(&inet6_proto_lock);
return ret; return ret;
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment