Commit 3679d585 authored by Andrey Ignatov's avatar Andrey Ignatov Committed by Daniel Borkmann

net: Introduce __inet_bind() and __inet6_bind

Refactor `bind()` code to make it ready to be called from BPF helper
function `bpf_bind()` (will be added soon). Implementation of
`inet_bind()` and `inet6_bind()` is separated into `__inet_bind()` and
`__inet6_bind()` correspondingly. These function can be used from both
`sk_prot->bind` and `bpf_bind()` contexts.

New functions have two additional arguments.

`force_bind_address_no_port` forces binding to IP only w/o checking
`inet_sock.bind_address_no_port` field. It'll allow to bind local end of
a connection to desired IP in `bpf_bind()` w/o changing
`bind_address_no_port` field of a socket. It's useful since `bpf_bind()`
can return an error and we'd need to restore original value of
`bind_address_no_port` in that case if we changed this before calling to
the helper.

`with_lock` specifies whether to lock socket when working with `struct
sk` or not. The argument is set to `true` for `sk_prot->bind`, i.e. old
behavior is preserved. But it will be set to `false` for `bpf_bind()`
use-case. The reason is all call-sites, where `bpf_bind()` will be
called, already hold that socket lock.
Signed-off-by: default avatarAndrey Ignatov <rdna@fb.com>
Acked-by: default avatarAlexei Starovoitov <ast@kernel.org>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parent e50b0a6f
...@@ -32,6 +32,8 @@ int inet_shutdown(struct socket *sock, int how); ...@@ -32,6 +32,8 @@ int inet_shutdown(struct socket *sock, int how);
int inet_listen(struct socket *sock, int backlog); int inet_listen(struct socket *sock, int backlog);
void inet_sock_destruct(struct sock *sk); void inet_sock_destruct(struct sock *sk);
int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len); int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len);
int __inet_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len,
bool force_bind_address_no_port, bool with_lock);
int inet_getname(struct socket *sock, struct sockaddr *uaddr, int inet_getname(struct socket *sock, struct sockaddr *uaddr,
int peer); int peer);
int inet_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg); int inet_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg);
......
...@@ -1066,6 +1066,8 @@ void ipv6_local_error(struct sock *sk, int err, struct flowi6 *fl6, u32 info); ...@@ -1066,6 +1066,8 @@ void ipv6_local_error(struct sock *sk, int err, struct flowi6 *fl6, u32 info);
void ipv6_local_rxpmtu(struct sock *sk, struct flowi6 *fl6, u32 mtu); void ipv6_local_rxpmtu(struct sock *sk, struct flowi6 *fl6, u32 mtu);
int inet6_release(struct socket *sock); int inet6_release(struct socket *sock);
int __inet6_bind(struct sock *sock, struct sockaddr *uaddr, int addr_len,
bool force_bind_address_no_port, bool with_lock);
int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len); int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len);
int inet6_getname(struct socket *sock, struct sockaddr *uaddr, int inet6_getname(struct socket *sock, struct sockaddr *uaddr,
int peer); int peer);
......
...@@ -432,30 +432,37 @@ EXPORT_SYMBOL(inet_release); ...@@ -432,30 +432,37 @@ EXPORT_SYMBOL(inet_release);
int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
{ {
struct sockaddr_in *addr = (struct sockaddr_in *)uaddr;
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct inet_sock *inet = inet_sk(sk);
struct net *net = sock_net(sk);
unsigned short snum;
int chk_addr_ret;
u32 tb_id = RT_TABLE_LOCAL;
int err; int err;
/* If the socket has its own bind function then use it. (RAW) */ /* If the socket has its own bind function then use it. (RAW) */
if (sk->sk_prot->bind) { if (sk->sk_prot->bind) {
err = sk->sk_prot->bind(sk, uaddr, addr_len); return sk->sk_prot->bind(sk, uaddr, addr_len);
goto out;
} }
err = -EINVAL;
if (addr_len < sizeof(struct sockaddr_in)) if (addr_len < sizeof(struct sockaddr_in))
goto out; return -EINVAL;
/* BPF prog is run before any checks are done so that if the prog /* BPF prog is run before any checks are done so that if the prog
* changes context in a wrong way it will be caught. * changes context in a wrong way it will be caught.
*/ */
err = BPF_CGROUP_RUN_PROG_INET4_BIND(sk, uaddr); err = BPF_CGROUP_RUN_PROG_INET4_BIND(sk, uaddr);
if (err) if (err)
goto out; return err;
return __inet_bind(sk, uaddr, addr_len, false, true);
}
EXPORT_SYMBOL(inet_bind);
int __inet_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len,
bool force_bind_address_no_port, bool with_lock)
{
struct sockaddr_in *addr = (struct sockaddr_in *)uaddr;
struct inet_sock *inet = inet_sk(sk);
struct net *net = sock_net(sk);
unsigned short snum;
int chk_addr_ret;
u32 tb_id = RT_TABLE_LOCAL;
int err;
if (addr->sin_family != AF_INET) { if (addr->sin_family != AF_INET) {
/* Compatibility games : accept AF_UNSPEC (mapped to AF_INET) /* Compatibility games : accept AF_UNSPEC (mapped to AF_INET)
...@@ -499,6 +506,7 @@ int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -499,6 +506,7 @@ int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
* would be illegal to use them (multicast/broadcast) in * would be illegal to use them (multicast/broadcast) in
* which case the sending device address is used. * which case the sending device address is used.
*/ */
if (with_lock)
lock_sock(sk); lock_sock(sk);
/* Check these errors (active socket, double bind). */ /* Check these errors (active socket, double bind). */
...@@ -511,7 +519,8 @@ int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -511,7 +519,8 @@ int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
inet->inet_saddr = 0; /* Use device */ inet->inet_saddr = 0; /* Use device */
/* Make sure we are allowed to bind here. */ /* Make sure we are allowed to bind here. */
if ((snum || !inet->bind_address_no_port) && if ((snum || !(inet->bind_address_no_port ||
force_bind_address_no_port)) &&
sk->sk_prot->get_port(sk, snum)) { sk->sk_prot->get_port(sk, snum)) {
inet->inet_saddr = inet->inet_rcv_saddr = 0; inet->inet_saddr = inet->inet_rcv_saddr = 0;
err = -EADDRINUSE; err = -EADDRINUSE;
...@@ -528,11 +537,11 @@ int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -528,11 +537,11 @@ int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
sk_dst_reset(sk); sk_dst_reset(sk);
err = 0; err = 0;
out_release_sock: out_release_sock:
if (with_lock)
release_sock(sk); release_sock(sk);
out: out:
return err; return err;
} }
EXPORT_SYMBOL(inet_bind);
int inet_dgram_connect(struct socket *sock, struct sockaddr *uaddr, int inet_dgram_connect(struct socket *sock, struct sockaddr *uaddr,
int addr_len, int flags) int addr_len, int flags)
......
...@@ -277,15 +277,7 @@ static int inet6_create(struct net *net, struct socket *sock, int protocol, ...@@ -277,15 +277,7 @@ static int inet6_create(struct net *net, struct socket *sock, int protocol,
/* bind for INET6 API */ /* bind for INET6 API */
int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
{ {
struct sockaddr_in6 *addr = (struct sockaddr_in6 *)uaddr;
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct inet_sock *inet = inet_sk(sk);
struct ipv6_pinfo *np = inet6_sk(sk);
struct net *net = sock_net(sk);
__be32 v4addr = 0;
unsigned short snum;
bool saved_ipv6only;
int addr_type = 0;
int err = 0; int err = 0;
/* If the socket has its own bind function then use it. */ /* If the socket has its own bind function then use it. */
...@@ -302,11 +294,28 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -302,11 +294,28 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
if (err) if (err)
return err; return err;
return __inet6_bind(sk, uaddr, addr_len, false, true);
}
EXPORT_SYMBOL(inet6_bind);
int __inet6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len,
bool force_bind_address_no_port, bool with_lock)
{
struct sockaddr_in6 *addr = (struct sockaddr_in6 *)uaddr;
struct inet_sock *inet = inet_sk(sk);
struct ipv6_pinfo *np = inet6_sk(sk);
struct net *net = sock_net(sk);
__be32 v4addr = 0;
unsigned short snum;
bool saved_ipv6only;
int addr_type = 0;
int err = 0;
if (addr->sin6_family != AF_INET6) if (addr->sin6_family != AF_INET6)
return -EAFNOSUPPORT; return -EAFNOSUPPORT;
addr_type = ipv6_addr_type(&addr->sin6_addr); addr_type = ipv6_addr_type(&addr->sin6_addr);
if ((addr_type & IPV6_ADDR_MULTICAST) && sock->type == SOCK_STREAM) if ((addr_type & IPV6_ADDR_MULTICAST) && sk->sk_type == SOCK_STREAM)
return -EINVAL; return -EINVAL;
snum = ntohs(addr->sin6_port); snum = ntohs(addr->sin6_port);
...@@ -314,6 +323,7 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -314,6 +323,7 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
!ns_capable(net->user_ns, CAP_NET_BIND_SERVICE)) !ns_capable(net->user_ns, CAP_NET_BIND_SERVICE))
return -EACCES; return -EACCES;
if (with_lock)
lock_sock(sk); lock_sock(sk);
/* Check these errors (active socket, double bind). */ /* Check these errors (active socket, double bind). */
...@@ -402,7 +412,8 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -402,7 +412,8 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
sk->sk_ipv6only = 1; sk->sk_ipv6only = 1;
/* Make sure we are allowed to bind here. */ /* Make sure we are allowed to bind here. */
if ((snum || !inet->bind_address_no_port) && if ((snum || !(inet->bind_address_no_port ||
force_bind_address_no_port)) &&
sk->sk_prot->get_port(sk, snum)) { sk->sk_prot->get_port(sk, snum)) {
sk->sk_ipv6only = saved_ipv6only; sk->sk_ipv6only = saved_ipv6only;
inet_reset_saddr(sk); inet_reset_saddr(sk);
...@@ -418,13 +429,13 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -418,13 +429,13 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
inet->inet_dport = 0; inet->inet_dport = 0;
inet->inet_daddr = 0; inet->inet_daddr = 0;
out: out:
if (with_lock)
release_sock(sk); release_sock(sk);
return err; return err;
out_unlock: out_unlock:
rcu_read_unlock(); rcu_read_unlock();
goto out; goto out;
} }
EXPORT_SYMBOL(inet6_bind);
int inet6_release(struct socket *sock) int inet6_release(struct socket *sock)
{ {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment