Commit d5a8ac28 authored by Sowmini Varadhan's avatar Sowmini Varadhan Committed by David S. Miller

RDS-TCP: Make RDS-TCP work correctly when it is set up in a netns other than init_net

Open the sockets calling sock_create_kern() with the correct struct net
pointer, and use that struct net pointer when verifying the
address passed to rds_bind().
Signed-off-by: default avatarSowmini Varadhan <sowmini.varadhan@oracle.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 1ebd08a7
...@@ -185,7 +185,8 @@ int rds_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -185,7 +185,8 @@ int rds_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
ret = 0; ret = 0;
goto out; goto out;
} }
trans = rds_trans_get_preferred(sin->sin_addr.s_addr); trans = rds_trans_get_preferred(sock_net(sock->sk),
sin->sin_addr.s_addr);
if (!trans) { if (!trans) {
ret = -EADDRNOTAVAIL; ret = -EADDRNOTAVAIL;
rds_remove_bound(rs); rds_remove_bound(rs);
......
...@@ -117,7 +117,8 @@ static void rds_conn_reset(struct rds_connection *conn) ...@@ -117,7 +117,8 @@ static void rds_conn_reset(struct rds_connection *conn)
* For now they are not garbage collected once they're created. They * For now they are not garbage collected once they're created. They
* are torn down as the module is removed, if ever. * are torn down as the module is removed, if ever.
*/ */
static struct rds_connection *__rds_conn_create(__be32 laddr, __be32 faddr, static struct rds_connection *__rds_conn_create(struct net *net,
__be32 laddr, __be32 faddr,
struct rds_transport *trans, gfp_t gfp, struct rds_transport *trans, gfp_t gfp,
int is_outgoing) int is_outgoing)
{ {
...@@ -157,6 +158,7 @@ static struct rds_connection *__rds_conn_create(__be32 laddr, __be32 faddr, ...@@ -157,6 +158,7 @@ static struct rds_connection *__rds_conn_create(__be32 laddr, __be32 faddr,
conn->c_faddr = faddr; conn->c_faddr = faddr;
spin_lock_init(&conn->c_lock); spin_lock_init(&conn->c_lock);
conn->c_next_tx_seq = 1; conn->c_next_tx_seq = 1;
rds_conn_net_set(conn, net);
init_waitqueue_head(&conn->c_waitq); init_waitqueue_head(&conn->c_waitq);
INIT_LIST_HEAD(&conn->c_send_queue); INIT_LIST_HEAD(&conn->c_send_queue);
...@@ -174,7 +176,7 @@ static struct rds_connection *__rds_conn_create(__be32 laddr, __be32 faddr, ...@@ -174,7 +176,7 @@ static struct rds_connection *__rds_conn_create(__be32 laddr, __be32 faddr,
* can bind to the destination address then we'd rather the messages * can bind to the destination address then we'd rather the messages
* flow through loopback rather than either transport. * flow through loopback rather than either transport.
*/ */
loop_trans = rds_trans_get_preferred(faddr); loop_trans = rds_trans_get_preferred(net, faddr);
if (loop_trans) { if (loop_trans) {
rds_trans_put(loop_trans); rds_trans_put(loop_trans);
conn->c_loopback = 1; conn->c_loopback = 1;
...@@ -260,17 +262,19 @@ static struct rds_connection *__rds_conn_create(__be32 laddr, __be32 faddr, ...@@ -260,17 +262,19 @@ static struct rds_connection *__rds_conn_create(__be32 laddr, __be32 faddr,
return conn; return conn;
} }
struct rds_connection *rds_conn_create(__be32 laddr, __be32 faddr, struct rds_connection *rds_conn_create(struct net *net,
__be32 laddr, __be32 faddr,
struct rds_transport *trans, gfp_t gfp) struct rds_transport *trans, gfp_t gfp)
{ {
return __rds_conn_create(laddr, faddr, trans, gfp, 0); return __rds_conn_create(net, laddr, faddr, trans, gfp, 0);
} }
EXPORT_SYMBOL_GPL(rds_conn_create); EXPORT_SYMBOL_GPL(rds_conn_create);
struct rds_connection *rds_conn_create_outgoing(__be32 laddr, __be32 faddr, struct rds_connection *rds_conn_create_outgoing(struct net *net,
__be32 laddr, __be32 faddr,
struct rds_transport *trans, gfp_t gfp) struct rds_transport *trans, gfp_t gfp)
{ {
return __rds_conn_create(laddr, faddr, trans, gfp, 1); return __rds_conn_create(net, laddr, faddr, trans, gfp, 1);
} }
EXPORT_SYMBOL_GPL(rds_conn_create_outgoing); EXPORT_SYMBOL_GPL(rds_conn_create_outgoing);
......
...@@ -317,7 +317,7 @@ static void rds_ib_ic_info(struct socket *sock, unsigned int len, ...@@ -317,7 +317,7 @@ static void rds_ib_ic_info(struct socket *sock, unsigned int len,
* allowed to influence which paths have priority. We could call userspace * allowed to influence which paths have priority. We could call userspace
* asserting this policy "routing". * asserting this policy "routing".
*/ */
static int rds_ib_laddr_check(__be32 addr) static int rds_ib_laddr_check(struct net *net, __be32 addr)
{ {
int ret; int ret;
struct rdma_cm_id *cm_id; struct rdma_cm_id *cm_id;
......
...@@ -448,8 +448,9 @@ int rds_ib_cm_handle_connect(struct rdma_cm_id *cm_id, ...@@ -448,8 +448,9 @@ int rds_ib_cm_handle_connect(struct rdma_cm_id *cm_id,
(unsigned long long)be64_to_cpu(lguid), (unsigned long long)be64_to_cpu(lguid),
(unsigned long long)be64_to_cpu(fguid)); (unsigned long long)be64_to_cpu(fguid));
conn = rds_conn_create(dp->dp_daddr, dp->dp_saddr, &rds_ib_transport, /* RDS/IB is not currently netns aware, thus init_net */
GFP_KERNEL); conn = rds_conn_create(&init_net, dp->dp_daddr, dp->dp_saddr,
&rds_ib_transport, GFP_KERNEL);
if (IS_ERR(conn)) { if (IS_ERR(conn)) {
rdsdebug("rds_conn_create failed (%ld)\n", PTR_ERR(conn)); rdsdebug("rds_conn_create failed (%ld)\n", PTR_ERR(conn));
conn = NULL; conn = NULL;
......
...@@ -218,7 +218,7 @@ static void rds_iw_ic_info(struct socket *sock, unsigned int len, ...@@ -218,7 +218,7 @@ static void rds_iw_ic_info(struct socket *sock, unsigned int len,
* allowed to influence which paths have priority. We could call userspace * allowed to influence which paths have priority. We could call userspace
* asserting this policy "routing". * asserting this policy "routing".
*/ */
static int rds_iw_laddr_check(__be32 addr) static int rds_iw_laddr_check(struct net *net, __be32 addr)
{ {
int ret; int ret;
struct rdma_cm_id *cm_id; struct rdma_cm_id *cm_id;
......
...@@ -398,8 +398,9 @@ int rds_iw_cm_handle_connect(struct rdma_cm_id *cm_id, ...@@ -398,8 +398,9 @@ int rds_iw_cm_handle_connect(struct rdma_cm_id *cm_id,
&dp->dp_saddr, &dp->dp_daddr, &dp->dp_saddr, &dp->dp_daddr,
RDS_PROTOCOL_MAJOR(version), RDS_PROTOCOL_MINOR(version)); RDS_PROTOCOL_MAJOR(version), RDS_PROTOCOL_MINOR(version));
conn = rds_conn_create(dp->dp_daddr, dp->dp_saddr, &rds_iw_transport, /* RDS/IW is not currently netns aware, thus init_net */
GFP_KERNEL); conn = rds_conn_create(&init_net, dp->dp_daddr, dp->dp_saddr,
&rds_iw_transport, GFP_KERNEL);
if (IS_ERR(conn)) { if (IS_ERR(conn)) {
rdsdebug("rds_conn_create failed (%ld)\n", PTR_ERR(conn)); rdsdebug("rds_conn_create failed (%ld)\n", PTR_ERR(conn));
conn = NULL; conn = NULL;
......
...@@ -128,8 +128,21 @@ struct rds_connection { ...@@ -128,8 +128,21 @@ struct rds_connection {
/* Protocol version */ /* Protocol version */
unsigned int c_version; unsigned int c_version;
possible_net_t c_net;
}; };
static inline
struct net *rds_conn_net(struct rds_connection *conn)
{
return read_pnet(&conn->c_net);
}
static inline
void rds_conn_net_set(struct rds_connection *conn, struct net *net)
{
write_pnet(&conn->c_net, net);
}
#define RDS_FLAG_CONG_BITMAP 0x01 #define RDS_FLAG_CONG_BITMAP 0x01
#define RDS_FLAG_ACK_REQUIRED 0x02 #define RDS_FLAG_ACK_REQUIRED 0x02
#define RDS_FLAG_RETRANSMITTED 0x04 #define RDS_FLAG_RETRANSMITTED 0x04
...@@ -417,7 +430,7 @@ struct rds_transport { ...@@ -417,7 +430,7 @@ struct rds_transport {
unsigned int t_prefer_loopback:1; unsigned int t_prefer_loopback:1;
unsigned int t_type; unsigned int t_type;
int (*laddr_check)(__be32 addr); int (*laddr_check)(struct net *net, __be32 addr);
int (*conn_alloc)(struct rds_connection *conn, gfp_t gfp); int (*conn_alloc)(struct rds_connection *conn, gfp_t gfp);
void (*conn_free)(void *data); void (*conn_free)(void *data);
int (*conn_connect)(struct rds_connection *conn); int (*conn_connect)(struct rds_connection *conn);
...@@ -608,9 +621,11 @@ struct rds_message *rds_cong_update_alloc(struct rds_connection *conn); ...@@ -608,9 +621,11 @@ struct rds_message *rds_cong_update_alloc(struct rds_connection *conn);
/* conn.c */ /* conn.c */
int rds_conn_init(void); int rds_conn_init(void);
void rds_conn_exit(void); void rds_conn_exit(void);
struct rds_connection *rds_conn_create(__be32 laddr, __be32 faddr, struct rds_connection *rds_conn_create(struct net *net,
__be32 laddr, __be32 faddr,
struct rds_transport *trans, gfp_t gfp); struct rds_transport *trans, gfp_t gfp);
struct rds_connection *rds_conn_create_outgoing(__be32 laddr, __be32 faddr, struct rds_connection *rds_conn_create_outgoing(struct net *net,
__be32 laddr, __be32 faddr,
struct rds_transport *trans, gfp_t gfp); struct rds_transport *trans, gfp_t gfp);
void rds_conn_shutdown(struct rds_connection *conn); void rds_conn_shutdown(struct rds_connection *conn);
void rds_conn_destroy(struct rds_connection *conn); void rds_conn_destroy(struct rds_connection *conn);
...@@ -795,7 +810,7 @@ void rds_connect_complete(struct rds_connection *conn); ...@@ -795,7 +810,7 @@ void rds_connect_complete(struct rds_connection *conn);
/* transport.c */ /* transport.c */
int rds_trans_register(struct rds_transport *trans); int rds_trans_register(struct rds_transport *trans);
void rds_trans_unregister(struct rds_transport *trans); void rds_trans_unregister(struct rds_transport *trans);
struct rds_transport *rds_trans_get_preferred(__be32 addr); struct rds_transport *rds_trans_get_preferred(struct net *net, __be32 addr);
void rds_trans_put(struct rds_transport *trans); void rds_trans_put(struct rds_transport *trans);
unsigned int rds_trans_stats_info_copy(struct rds_info_iterator *iter, unsigned int rds_trans_stats_info_copy(struct rds_info_iterator *iter,
unsigned int avail); unsigned int avail);
......
...@@ -1023,7 +1023,8 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len) ...@@ -1023,7 +1023,8 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len)
if (rs->rs_conn && rs->rs_conn->c_faddr == daddr) if (rs->rs_conn && rs->rs_conn->c_faddr == daddr)
conn = rs->rs_conn; conn = rs->rs_conn;
else { else {
conn = rds_conn_create_outgoing(rs->rs_bound_addr, daddr, conn = rds_conn_create_outgoing(sock_net(sock->sk),
rs->rs_bound_addr, daddr,
rs->rs_transport, rs->rs_transport,
sock->sk->sk_allocation); sock->sk->sk_allocation);
if (IS_ERR(conn)) { if (IS_ERR(conn)) {
......
...@@ -189,9 +189,9 @@ static void rds_tcp_tc_info(struct socket *sock, unsigned int len, ...@@ -189,9 +189,9 @@ static void rds_tcp_tc_info(struct socket *sock, unsigned int len,
spin_unlock_irqrestore(&rds_tcp_tc_list_lock, flags); spin_unlock_irqrestore(&rds_tcp_tc_list_lock, flags);
} }
static int rds_tcp_laddr_check(__be32 addr) static int rds_tcp_laddr_check(struct net *net, __be32 addr)
{ {
if (inet_addr_type(&init_net, addr) == RTN_LOCAL) if (inet_addr_type(net, addr) == RTN_LOCAL)
return 0; return 0;
return -EADDRNOTAVAIL; return -EADDRNOTAVAIL;
} }
......
...@@ -79,7 +79,8 @@ int rds_tcp_conn_connect(struct rds_connection *conn) ...@@ -79,7 +79,8 @@ int rds_tcp_conn_connect(struct rds_connection *conn)
struct sockaddr_in src, dest; struct sockaddr_in src, dest;
int ret; int ret;
ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock); ret = sock_create_kern(rds_conn_net(conn), PF_INET,
SOCK_STREAM, IPPROTO_TCP, &sock);
if (ret < 0) if (ret < 0)
goto out; goto out;
......
...@@ -85,8 +85,9 @@ static int rds_tcp_accept_one(struct socket *sock) ...@@ -85,8 +85,9 @@ static int rds_tcp_accept_one(struct socket *sock)
struct inet_sock *inet; struct inet_sock *inet;
struct rds_tcp_connection *rs_tcp; struct rds_tcp_connection *rs_tcp;
ret = sock_create_lite(sock->sk->sk_family, sock->sk->sk_type, ret = sock_create_kern(sock_net(sock->sk), sock->sk->sk_family,
sock->sk->sk_protocol, &new_sock); sock->sk->sk_type, sock->sk->sk_protocol,
&new_sock);
if (ret) if (ret)
goto out; goto out;
...@@ -108,7 +109,8 @@ static int rds_tcp_accept_one(struct socket *sock) ...@@ -108,7 +109,8 @@ static int rds_tcp_accept_one(struct socket *sock)
&inet->inet_saddr, ntohs(inet->inet_sport), &inet->inet_saddr, ntohs(inet->inet_sport),
&inet->inet_daddr, ntohs(inet->inet_dport)); &inet->inet_daddr, ntohs(inet->inet_dport));
conn = rds_conn_create(inet->inet_saddr, inet->inet_daddr, conn = rds_conn_create(sock_net(sock->sk),
inet->inet_saddr, inet->inet_daddr,
&rds_tcp_transport, GFP_KERNEL); &rds_tcp_transport, GFP_KERNEL);
if (IS_ERR(conn)) { if (IS_ERR(conn)) {
ret = PTR_ERR(conn); ret = PTR_ERR(conn);
...@@ -187,7 +189,13 @@ int rds_tcp_listen_init(void) ...@@ -187,7 +189,13 @@ int rds_tcp_listen_init(void)
struct socket *sock = NULL; struct socket *sock = NULL;
int ret; int ret;
ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock); /* MUST call sock_create_kern directly so that we avoid get_net()
* in sk_alloc(). Doing a get_net() will result in cleanup_net()
* never getting invoked, which will leave sock and other things
* in limbo.
*/
ret = sock_create_kern(current->nsproxy->net_ns, PF_INET,
SOCK_STREAM, IPPROTO_TCP, &sock);
if (ret < 0) if (ret < 0)
goto out; goto out;
......
...@@ -77,7 +77,7 @@ void rds_trans_put(struct rds_transport *trans) ...@@ -77,7 +77,7 @@ void rds_trans_put(struct rds_transport *trans)
module_put(trans->t_owner); module_put(trans->t_owner);
} }
struct rds_transport *rds_trans_get_preferred(__be32 addr) struct rds_transport *rds_trans_get_preferred(struct net *net, __be32 addr)
{ {
struct rds_transport *ret = NULL; struct rds_transport *ret = NULL;
struct rds_transport *trans; struct rds_transport *trans;
...@@ -90,7 +90,7 @@ struct rds_transport *rds_trans_get_preferred(__be32 addr) ...@@ -90,7 +90,7 @@ struct rds_transport *rds_trans_get_preferred(__be32 addr)
for (i = 0; i < RDS_TRANS_COUNT; i++) { for (i = 0; i < RDS_TRANS_COUNT; i++) {
trans = transports[i]; trans = transports[i];
if (trans && (trans->laddr_check(addr) == 0) && if (trans && (trans->laddr_check(net, addr) == 0) &&
(!trans->t_owner || try_module_get(trans->t_owner))) { (!trans->t_owner || try_module_get(trans->t_owner))) {
ret = trans; ret = trans;
break; break;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment