Commit 6a2c0962 authored by Stefano Garzarella's avatar Stefano Garzarella Committed by David S. Miller

vsock: prevent transport modules unloading

This patch adds 'module' member in the 'struct vsock_transport'
in order to get/put the transport module. This prevents the
module unloading while sockets are assigned to it.

We increase the module refcnt when a socket is assigned to a
transport, and we decrease the module refcnt when the socket
is destructed.
Reviewed-by: default avatarStefan Hajnoczi <stefanha@redhat.com>
Reviewed-by: default avatarJorgen Hansen <jhansen@vmware.com>
Signed-off-by: default avatarStefano Garzarella <sgarzare@redhat.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent b1bba80a
...@@ -386,6 +386,8 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock) ...@@ -386,6 +386,8 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
static struct virtio_transport vhost_transport = { static struct virtio_transport vhost_transport = {
.transport = { .transport = {
.module = THIS_MODULE,
.get_local_cid = vhost_transport_get_local_cid, .get_local_cid = vhost_transport_get_local_cid,
.init = virtio_transport_do_socket_init, .init = virtio_transport_do_socket_init,
......
...@@ -100,6 +100,8 @@ struct vsock_transport_send_notify_data { ...@@ -100,6 +100,8 @@ struct vsock_transport_send_notify_data {
#define VSOCK_TRANSPORT_F_DGRAM 0x00000004 #define VSOCK_TRANSPORT_F_DGRAM 0x00000004
struct vsock_transport { struct vsock_transport {
struct module *module;
/* Initialize/tear-down socket. */ /* Initialize/tear-down socket. */
int (*init)(struct vsock_sock *, struct vsock_sock *); int (*init)(struct vsock_sock *, struct vsock_sock *);
void (*destruct)(struct vsock_sock *); void (*destruct)(struct vsock_sock *);
......
...@@ -380,6 +380,16 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected) ...@@ -380,6 +380,16 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected)
} }
EXPORT_SYMBOL_GPL(vsock_enqueue_accept); EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
static void vsock_deassign_transport(struct vsock_sock *vsk)
{
if (!vsk->transport)
return;
vsk->transport->destruct(vsk);
module_put(vsk->transport->module);
vsk->transport = NULL;
}
/* Assign a transport to a socket and call the .init transport callback. /* Assign a transport to a socket and call the .init transport callback.
* *
* Note: for stream socket this must be called when vsk->remote_addr is set * Note: for stream socket this must be called when vsk->remote_addr is set
...@@ -418,10 +428,13 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) ...@@ -418,10 +428,13 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
return 0; return 0;
vsk->transport->release(vsk); vsk->transport->release(vsk);
vsk->transport->destruct(vsk); vsock_deassign_transport(vsk);
} }
if (!new_transport) /* We increase the module refcnt to prevent the transport unloading
* while there are open sockets assigned to it.
*/
if (!new_transport || !try_module_get(new_transport->module))
return -ENODEV; return -ENODEV;
vsk->transport = new_transport; vsk->transport = new_transport;
...@@ -741,8 +754,7 @@ static void vsock_sk_destruct(struct sock *sk) ...@@ -741,8 +754,7 @@ static void vsock_sk_destruct(struct sock *sk)
{ {
struct vsock_sock *vsk = vsock_sk(sk); struct vsock_sock *vsk = vsock_sk(sk);
if (vsk->transport) vsock_deassign_transport(vsk);
vsk->transport->destruct(vsk);
/* When clearing these addresses, there's no need to set the family and /* When clearing these addresses, there's no need to set the family and
* possibly register the address family with the kernel. * possibly register the address family with the kernel.
......
...@@ -857,6 +857,8 @@ int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written, ...@@ -857,6 +857,8 @@ int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
} }
static struct vsock_transport hvs_transport = { static struct vsock_transport hvs_transport = {
.module = THIS_MODULE,
.get_local_cid = hvs_get_local_cid, .get_local_cid = hvs_get_local_cid,
.init = hvs_sock_init, .init = hvs_sock_init,
......
...@@ -462,6 +462,8 @@ static void virtio_vsock_rx_done(struct virtqueue *vq) ...@@ -462,6 +462,8 @@ static void virtio_vsock_rx_done(struct virtqueue *vq)
static struct virtio_transport virtio_transport = { static struct virtio_transport virtio_transport = {
.transport = { .transport = {
.module = THIS_MODULE,
.get_local_cid = virtio_transport_get_local_cid, .get_local_cid = virtio_transport_get_local_cid,
.init = virtio_transport_do_socket_init, .init = virtio_transport_do_socket_init,
......
...@@ -2020,6 +2020,7 @@ static u32 vmci_transport_get_local_cid(void) ...@@ -2020,6 +2020,7 @@ static u32 vmci_transport_get_local_cid(void)
} }
static struct vsock_transport vmci_transport = { static struct vsock_transport vmci_transport = {
.module = THIS_MODULE,
.init = vmci_transport_socket_init, .init = vmci_transport_socket_init,
.destruct = vmci_transport_destruct, .destruct = vmci_transport_destruct,
.release = vmci_transport_release, .release = vmci_transport_release,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment