Commit 227b60f5 authored by Stephen Hemminger's avatar Stephen Hemminger Committed by David S. Miller

[INET]: local port range robustness

Expansion of original idea from Denis V. Lunev <den@openvz.org>

Add robustness and locking to the local_port_range sysctl.
1. Enforce that low < high when setting.
2. Use seqlock to ensure atomic update.

The locking might seem like overkill, but there are
cases where sysadmin might want to change value in the
middle of a DoS attack.
Signed-off-by: default avatarStephen Hemminger <shemminger@linux-foundation.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 06393009
...@@ -1866,13 +1866,14 @@ static int cma_alloc_port(struct idr *ps, struct rdma_id_private *id_priv, ...@@ -1866,13 +1866,14 @@ static int cma_alloc_port(struct idr *ps, struct rdma_id_private *id_priv,
static int cma_alloc_any_port(struct idr *ps, struct rdma_id_private *id_priv) static int cma_alloc_any_port(struct idr *ps, struct rdma_id_private *id_priv)
{ {
struct rdma_bind_list *bind_list; struct rdma_bind_list *bind_list;
int port, ret; int port, ret, low, high;
bind_list = kzalloc(sizeof *bind_list, GFP_KERNEL); bind_list = kzalloc(sizeof *bind_list, GFP_KERNEL);
if (!bind_list) if (!bind_list)
return -ENOMEM; return -ENOMEM;
retry: retry:
/* FIXME: add proper port randomization per like inet_csk_get_port */
do { do {
ret = idr_get_new_above(ps, bind_list, next_port, &port); ret = idr_get_new_above(ps, bind_list, next_port, &port);
} while ((ret == -EAGAIN) && idr_pre_get(ps, GFP_KERNEL)); } while ((ret == -EAGAIN) && idr_pre_get(ps, GFP_KERNEL));
...@@ -1880,18 +1881,19 @@ static int cma_alloc_any_port(struct idr *ps, struct rdma_id_private *id_priv) ...@@ -1880,18 +1881,19 @@ static int cma_alloc_any_port(struct idr *ps, struct rdma_id_private *id_priv)
if (ret) if (ret)
goto err1; goto err1;
if (port > sysctl_local_port_range[1]) { inet_get_local_port_range(&low, &high);
if (next_port != sysctl_local_port_range[0]) { if (port > high) {
if (next_port != low) {
idr_remove(ps, port); idr_remove(ps, port);
next_port = sysctl_local_port_range[0]; next_port = low;
goto retry; goto retry;
} }
ret = -EADDRNOTAVAIL; ret = -EADDRNOTAVAIL;
goto err2; goto err2;
} }
if (port == sysctl_local_port_range[1]) if (port == high)
next_port = sysctl_local_port_range[0]; next_port = low;
else else
next_port = port + 1; next_port = port + 1;
...@@ -2769,12 +2771,12 @@ static void cma_remove_one(struct ib_device *device) ...@@ -2769,12 +2771,12 @@ static void cma_remove_one(struct ib_device *device)
static int cma_init(void) static int cma_init(void)
{ {
int ret; int ret, low, high;
get_random_bytes(&next_port, sizeof next_port); get_random_bytes(&next_port, sizeof next_port);
next_port = ((unsigned int) next_port % inet_get_local_port_range(&low, &high);
(sysctl_local_port_range[1] - sysctl_local_port_range[0])) + next_port = ((unsigned int) next_port % (high - low)) + low;
sysctl_local_port_range[0];
cma_wq = create_singlethread_workqueue("rdma_cm"); cma_wq = create_singlethread_workqueue("rdma_cm");
if (!cma_wq) if (!cma_wq)
return -ENOMEM; return -ENOMEM;
......
...@@ -171,7 +171,8 @@ extern unsigned long snmp_fold_field(void *mib[], int offt); ...@@ -171,7 +171,8 @@ extern unsigned long snmp_fold_field(void *mib[], int offt);
extern int snmp_mib_init(void *ptr[2], size_t mibsize, size_t mibalign); extern int snmp_mib_init(void *ptr[2], size_t mibsize, size_t mibalign);
extern void snmp_mib_free(void *ptr[2]); extern void snmp_mib_free(void *ptr[2]);
extern int sysctl_local_port_range[2]; extern void inet_get_local_port_range(int *low, int *high);
extern int sysctl_ip_default_ttl; extern int sysctl_ip_default_ttl;
extern int sysctl_ip_nonlocal_bind; extern int sysctl_ip_nonlocal_bind;
......
...@@ -33,6 +33,19 @@ EXPORT_SYMBOL(inet_csk_timer_bug_msg); ...@@ -33,6 +33,19 @@ EXPORT_SYMBOL(inet_csk_timer_bug_msg);
* This array holds the first and last local port number. * This array holds the first and last local port number.
*/ */
int sysctl_local_port_range[2] = { 32768, 61000 }; int sysctl_local_port_range[2] = { 32768, 61000 };
DEFINE_SEQLOCK(sysctl_port_range_lock);
void inet_get_local_port_range(int *low, int *high)
{
unsigned seq;
do {
seq = read_seqbegin(&sysctl_port_range_lock);
*low = sysctl_local_port_range[0];
*high = sysctl_local_port_range[1];
} while (read_seqretry(&sysctl_port_range_lock, seq));
}
EXPORT_SYMBOL(inet_get_local_port_range);
int inet_csk_bind_conflict(const struct sock *sk, int inet_csk_bind_conflict(const struct sock *sk,
const struct inet_bind_bucket *tb) const struct inet_bind_bucket *tb)
...@@ -77,10 +90,11 @@ int inet_csk_get_port(struct inet_hashinfo *hashinfo, ...@@ -77,10 +90,11 @@ int inet_csk_get_port(struct inet_hashinfo *hashinfo,
local_bh_disable(); local_bh_disable();
if (!snum) { if (!snum) {
int low = sysctl_local_port_range[0]; int remaining, rover, low, high;
int high = sysctl_local_port_range[1];
int remaining = (high - low) + 1; inet_get_local_port_range(&low, &high);
int rover = net_random() % (high - low) + low; remaining = high - low;
rover = net_random() % remaining + low;
do { do {
head = &hashinfo->bhash[inet_bhashfn(rover, hashinfo->bhash_size)]; head = &hashinfo->bhash[inet_bhashfn(rover, hashinfo->bhash_size)];
......
...@@ -279,19 +279,18 @@ int inet_hash_connect(struct inet_timewait_death_row *death_row, ...@@ -279,19 +279,18 @@ int inet_hash_connect(struct inet_timewait_death_row *death_row,
int ret; int ret;
if (!snum) { if (!snum) {
int low = sysctl_local_port_range[0]; int i, remaining, low, high, port;
int high = sysctl_local_port_range[1];
int range = high - low;
int i;
int port;
static u32 hint; static u32 hint;
u32 offset = hint + inet_sk_port_offset(sk); u32 offset = hint + inet_sk_port_offset(sk);
struct hlist_node *node; struct hlist_node *node;
struct inet_timewait_sock *tw = NULL; struct inet_timewait_sock *tw = NULL;
inet_get_local_port_range(&low, &high);
remaining = high - low;
local_bh_disable(); local_bh_disable();
for (i = 1; i <= range; i++) { for (i = 1; i <= remaining; i++) {
port = low + (i + offset) % range; port = low + (i + offset) % remaining;
head = &hinfo->bhash[inet_bhashfn(port, hinfo->bhash_size)]; head = &hinfo->bhash[inet_bhashfn(port, hinfo->bhash_size)];
spin_lock(&head->lock); spin_lock(&head->lock);
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <linux/sysctl.h> #include <linux/sysctl.h>
#include <linux/igmp.h> #include <linux/igmp.h>
#include <linux/inetdevice.h> #include <linux/inetdevice.h>
#include <linux/seqlock.h>
#include <net/snmp.h> #include <net/snmp.h>
#include <net/icmp.h> #include <net/icmp.h>
#include <net/ip.h> #include <net/ip.h>
...@@ -89,6 +90,74 @@ static int ipv4_sysctl_forward_strategy(ctl_table *table, ...@@ -89,6 +90,74 @@ static int ipv4_sysctl_forward_strategy(ctl_table *table,
return 1; return 1;
} }
extern seqlock_t sysctl_port_range_lock;
extern int sysctl_local_port_range[2];
/* Update system visible IP port range */
static void set_local_port_range(int range[2])
{
write_seqlock(&sysctl_port_range_lock);
sysctl_local_port_range[0] = range[0];
sysctl_local_port_range[1] = range[1];
write_sequnlock(&sysctl_port_range_lock);
}
/* Validate changes from /proc interface. */
static int ipv4_local_port_range(ctl_table *table, int write, struct file *filp,
void __user *buffer,
size_t *lenp, loff_t *ppos)
{
int ret;
int range[2] = { sysctl_local_port_range[0],
sysctl_local_port_range[1] };
ctl_table tmp = {
.data = &range,
.maxlen = sizeof(range),
.mode = table->mode,
.extra1 = &ip_local_port_range_min,
.extra2 = &ip_local_port_range_max,
};
ret = proc_dointvec_minmax(&tmp, write, filp, buffer, lenp, ppos);
if (write && ret == 0) {
if (range[1] <= range[0])
ret = -EINVAL;
else
set_local_port_range(range);
}
return ret;
}
/* Validate changes from sysctl interface. */
static int ipv4_sysctl_local_port_range(ctl_table *table, int __user *name,
int nlen, void __user *oldval,
size_t __user *oldlenp,
void __user *newval, size_t newlen)
{
int ret;
int range[2] = { sysctl_local_port_range[0],
sysctl_local_port_range[1] };
ctl_table tmp = {
.data = &range,
.maxlen = sizeof(range),
.mode = table->mode,
.extra1 = &ip_local_port_range_min,
.extra2 = &ip_local_port_range_max,
};
ret = sysctl_intvec(&tmp, name, nlen, oldval, oldlenp, newval, newlen);
if (ret == 0 && newval && newlen) {
if (range[1] <= range[0])
ret = -EINVAL;
else
set_local_port_range(range);
}
return ret;
}
static int proc_tcp_congestion_control(ctl_table *ctl, int write, struct file * filp, static int proc_tcp_congestion_control(ctl_table *ctl, int write, struct file * filp,
void __user *buffer, size_t *lenp, loff_t *ppos) void __user *buffer, size_t *lenp, loff_t *ppos)
{ {
...@@ -427,10 +496,8 @@ ctl_table ipv4_table[] = { ...@@ -427,10 +496,8 @@ ctl_table ipv4_table[] = {
.data = &sysctl_local_port_range, .data = &sysctl_local_port_range,
.maxlen = sizeof(sysctl_local_port_range), .maxlen = sizeof(sysctl_local_port_range),
.mode = 0644, .mode = 0644,
.proc_handler = &proc_dointvec_minmax, .proc_handler = &ipv4_local_port_range,
.strategy = &sysctl_intvec, .strategy = &ipv4_sysctl_local_port_range,
.extra1 = ip_local_port_range_min,
.extra2 = ip_local_port_range_max
}, },
{ {
.ctl_name = NET_IPV4_ICMP_ECHO_IGNORE_ALL, .ctl_name = NET_IPV4_ICMP_ECHO_IGNORE_ALL,
......
...@@ -2470,6 +2470,5 @@ EXPORT_SYMBOL(tcp_v4_syn_recv_sock); ...@@ -2470,6 +2470,5 @@ EXPORT_SYMBOL(tcp_v4_syn_recv_sock);
EXPORT_SYMBOL(tcp_proc_register); EXPORT_SYMBOL(tcp_proc_register);
EXPORT_SYMBOL(tcp_proc_unregister); EXPORT_SYMBOL(tcp_proc_unregister);
#endif #endif
EXPORT_SYMBOL(sysctl_local_port_range);
EXPORT_SYMBOL(sysctl_tcp_low_latency); EXPORT_SYMBOL(sysctl_tcp_low_latency);
...@@ -147,11 +147,11 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -147,11 +147,11 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
write_lock_bh(&udp_hash_lock); write_lock_bh(&udp_hash_lock);
if (!snum) { if (!snum) {
int i; int i, low, high;
int low = sysctl_local_port_range[0];
int high = sysctl_local_port_range[1];
unsigned rover, best, best_size_so_far; unsigned rover, best, best_size_so_far;
inet_get_local_port_range(&low, &high);
best_size_so_far = UINT_MAX; best_size_so_far = UINT_MAX;
best = rover = net_random() % (high - low) + low; best = rover = net_random() % (high - low) + low;
......
...@@ -254,18 +254,18 @@ int inet6_hash_connect(struct inet_timewait_death_row *death_row, ...@@ -254,18 +254,18 @@ int inet6_hash_connect(struct inet_timewait_death_row *death_row,
int ret; int ret;
if (snum == 0) { if (snum == 0) {
const int low = sysctl_local_port_range[0]; int i, port, low, high, remaining;
const int high = sysctl_local_port_range[1];
const int range = high - low;
int i, port;
static u32 hint; static u32 hint;
const u32 offset = hint + inet6_sk_port_offset(sk); const u32 offset = hint + inet6_sk_port_offset(sk);
struct hlist_node *node; struct hlist_node *node;
struct inet_timewait_sock *tw = NULL; struct inet_timewait_sock *tw = NULL;
inet_get_local_port_range(&low, &high);
remaining = high - low;
local_bh_disable(); local_bh_disable();
for (i = 1; i <= range; i++) { for (i = 1; i <= remaining; i++) {
port = low + (i + offset) % range; port = low + (i + offset) % remaining;
head = &hinfo->bhash[inet_bhashfn(port, hinfo->bhash_size)]; head = &hinfo->bhash[inet_bhashfn(port, hinfo->bhash_size)];
spin_lock(&head->lock); spin_lock(&head->lock);
......
...@@ -5315,11 +5315,12 @@ static long sctp_get_port_local(struct sock *sk, union sctp_addr *addr) ...@@ -5315,11 +5315,12 @@ static long sctp_get_port_local(struct sock *sk, union sctp_addr *addr)
if (snum == 0) { if (snum == 0) {
/* Search for an available port. */ /* Search for an available port. */
unsigned int low = sysctl_local_port_range[0]; int low, high, remaining, index;
unsigned int high = sysctl_local_port_range[1]; unsigned int rover;
unsigned int remaining = (high - low) + 1;
unsigned int rover = net_random() % remaining + low; inet_get_local_port_range(&low, &high);
int index; remaining = (high - low) + 1;
rover = net_random() % remaining + low;
do { do {
rover++; rover++;
......
...@@ -47,7 +47,7 @@ ...@@ -47,7 +47,7 @@
#include <linux/netfilter_ipv6.h> #include <linux/netfilter_ipv6.h>
#include <linux/tty.h> #include <linux/tty.h>
#include <net/icmp.h> #include <net/icmp.h>
#include <net/ip.h> /* for sysctl_local_port_range[] */ #include <net/ip.h> /* for local_port_range[] */
#include <net/tcp.h> /* struct or_callable used in sock_rcv_skb */ #include <net/tcp.h> /* struct or_callable used in sock_rcv_skb */
#include <asm/uaccess.h> #include <asm/uaccess.h>
#include <asm/ioctls.h> #include <asm/ioctls.h>
...@@ -3232,8 +3232,6 @@ static int selinux_socket_post_create(struct socket *sock, int family, ...@@ -3232,8 +3232,6 @@ static int selinux_socket_post_create(struct socket *sock, int family,
/* Range of port numbers used to automatically bind. /* Range of port numbers used to automatically bind.
Need to determine whether we should perform a name_bind Need to determine whether we should perform a name_bind
permission check between the socket and the port number. */ permission check between the socket and the port number. */
#define ip_local_port_range_0 sysctl_local_port_range[0]
#define ip_local_port_range_1 sysctl_local_port_range[1]
static int selinux_socket_bind(struct socket *sock, struct sockaddr *address, int addrlen) static int selinux_socket_bind(struct socket *sock, struct sockaddr *address, int addrlen)
{ {
...@@ -3276,10 +3274,16 @@ static int selinux_socket_bind(struct socket *sock, struct sockaddr *address, in ...@@ -3276,10 +3274,16 @@ static int selinux_socket_bind(struct socket *sock, struct sockaddr *address, in
addrp = (char *)&addr6->sin6_addr.s6_addr; addrp = (char *)&addr6->sin6_addr.s6_addr;
} }
if (snum&&(snum < max(PROT_SOCK,ip_local_port_range_0) || if (snum) {
snum > ip_local_port_range_1)) { int low, high;
err = security_port_sid(sk->sk_family, sk->sk_type,
sk->sk_protocol, snum, &sid); inet_get_local_port_range(&low, &high);
if (snum < max(PROT_SOCK, low) || snum > high) {
err = security_port_sid(sk->sk_family,
sk->sk_type,
sk->sk_protocol, snum,
&sid);
if (err) if (err)
goto out; goto out;
AVC_AUDIT_DATA_INIT(&ad,NET); AVC_AUDIT_DATA_INIT(&ad,NET);
...@@ -3291,6 +3295,7 @@ static int selinux_socket_bind(struct socket *sock, struct sockaddr *address, in ...@@ -3291,6 +3295,7 @@ static int selinux_socket_bind(struct socket *sock, struct sockaddr *address, in
if (err) if (err)
goto out; goto out;
} }
}
switch(isec->sclass) { switch(isec->sclass) {
case SECCLASS_TCP_SOCKET: case SECCLASS_TCP_SOCKET:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment