Commit 5f179793 authored by Linus Torvalds's avatar Linus Torvalds

Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost

Pull vhost/virtio fixes from Michael Tsirkin:
 "A couple of last-minute fixes"

* tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost:
  vhost/vsock: fix use-after-free in network stack callers
  virtio/s390: fix race in ccw_io_helper()
  virtio/s390: avoid race on vcdev->config
  vhost/vsock: fix reset orphans race with close timeout
parents b8bf4692 834e772c
...@@ -56,6 +56,7 @@ struct virtio_ccw_device { ...@@ -56,6 +56,7 @@ struct virtio_ccw_device {
unsigned int revision; /* Transport revision */ unsigned int revision; /* Transport revision */
wait_queue_head_t wait_q; wait_queue_head_t wait_q;
spinlock_t lock; spinlock_t lock;
struct mutex io_lock; /* Serializes I/O requests */
struct list_head virtqueues; struct list_head virtqueues;
unsigned long indicators; unsigned long indicators;
unsigned long indicators2; unsigned long indicators2;
...@@ -296,6 +297,7 @@ static int ccw_io_helper(struct virtio_ccw_device *vcdev, ...@@ -296,6 +297,7 @@ static int ccw_io_helper(struct virtio_ccw_device *vcdev,
unsigned long flags; unsigned long flags;
int flag = intparm & VIRTIO_CCW_INTPARM_MASK; int flag = intparm & VIRTIO_CCW_INTPARM_MASK;
mutex_lock(&vcdev->io_lock);
do { do {
spin_lock_irqsave(get_ccwdev_lock(vcdev->cdev), flags); spin_lock_irqsave(get_ccwdev_lock(vcdev->cdev), flags);
ret = ccw_device_start(vcdev->cdev, ccw, intparm, 0, 0); ret = ccw_device_start(vcdev->cdev, ccw, intparm, 0, 0);
...@@ -308,7 +310,9 @@ static int ccw_io_helper(struct virtio_ccw_device *vcdev, ...@@ -308,7 +310,9 @@ static int ccw_io_helper(struct virtio_ccw_device *vcdev,
cpu_relax(); cpu_relax();
} while (ret == -EBUSY); } while (ret == -EBUSY);
wait_event(vcdev->wait_q, doing_io(vcdev, flag) == 0); wait_event(vcdev->wait_q, doing_io(vcdev, flag) == 0);
return ret ? ret : vcdev->err; ret = ret ? ret : vcdev->err;
mutex_unlock(&vcdev->io_lock);
return ret;
} }
static void virtio_ccw_drop_indicator(struct virtio_ccw_device *vcdev, static void virtio_ccw_drop_indicator(struct virtio_ccw_device *vcdev,
...@@ -828,6 +832,7 @@ static void virtio_ccw_get_config(struct virtio_device *vdev, ...@@ -828,6 +832,7 @@ static void virtio_ccw_get_config(struct virtio_device *vdev,
int ret; int ret;
struct ccw1 *ccw; struct ccw1 *ccw;
void *config_area; void *config_area;
unsigned long flags;
ccw = kzalloc(sizeof(*ccw), GFP_DMA | GFP_KERNEL); ccw = kzalloc(sizeof(*ccw), GFP_DMA | GFP_KERNEL);
if (!ccw) if (!ccw)
...@@ -846,11 +851,13 @@ static void virtio_ccw_get_config(struct virtio_device *vdev, ...@@ -846,11 +851,13 @@ static void virtio_ccw_get_config(struct virtio_device *vdev,
if (ret) if (ret)
goto out_free; goto out_free;
spin_lock_irqsave(&vcdev->lock, flags);
memcpy(vcdev->config, config_area, offset + len); memcpy(vcdev->config, config_area, offset + len);
if (buf)
memcpy(buf, &vcdev->config[offset], len);
if (vcdev->config_ready < offset + len) if (vcdev->config_ready < offset + len)
vcdev->config_ready = offset + len; vcdev->config_ready = offset + len;
spin_unlock_irqrestore(&vcdev->lock, flags);
if (buf)
memcpy(buf, config_area + offset, len);
out_free: out_free:
kfree(config_area); kfree(config_area);
...@@ -864,6 +871,7 @@ static void virtio_ccw_set_config(struct virtio_device *vdev, ...@@ -864,6 +871,7 @@ static void virtio_ccw_set_config(struct virtio_device *vdev,
struct virtio_ccw_device *vcdev = to_vc_device(vdev); struct virtio_ccw_device *vcdev = to_vc_device(vdev);
struct ccw1 *ccw; struct ccw1 *ccw;
void *config_area; void *config_area;
unsigned long flags;
ccw = kzalloc(sizeof(*ccw), GFP_DMA | GFP_KERNEL); ccw = kzalloc(sizeof(*ccw), GFP_DMA | GFP_KERNEL);
if (!ccw) if (!ccw)
...@@ -876,9 +884,11 @@ static void virtio_ccw_set_config(struct virtio_device *vdev, ...@@ -876,9 +884,11 @@ static void virtio_ccw_set_config(struct virtio_device *vdev,
/* Make sure we don't overwrite fields. */ /* Make sure we don't overwrite fields. */
if (vcdev->config_ready < offset) if (vcdev->config_ready < offset)
virtio_ccw_get_config(vdev, 0, NULL, offset); virtio_ccw_get_config(vdev, 0, NULL, offset);
spin_lock_irqsave(&vcdev->lock, flags);
memcpy(&vcdev->config[offset], buf, len); memcpy(&vcdev->config[offset], buf, len);
/* Write the config area to the host. */ /* Write the config area to the host. */
memcpy(config_area, vcdev->config, sizeof(vcdev->config)); memcpy(config_area, vcdev->config, sizeof(vcdev->config));
spin_unlock_irqrestore(&vcdev->lock, flags);
ccw->cmd_code = CCW_CMD_WRITE_CONF; ccw->cmd_code = CCW_CMD_WRITE_CONF;
ccw->flags = 0; ccw->flags = 0;
ccw->count = offset + len; ccw->count = offset + len;
...@@ -1247,6 +1257,7 @@ static int virtio_ccw_online(struct ccw_device *cdev) ...@@ -1247,6 +1257,7 @@ static int virtio_ccw_online(struct ccw_device *cdev)
init_waitqueue_head(&vcdev->wait_q); init_waitqueue_head(&vcdev->wait_q);
INIT_LIST_HEAD(&vcdev->virtqueues); INIT_LIST_HEAD(&vcdev->virtqueues);
spin_lock_init(&vcdev->lock); spin_lock_init(&vcdev->lock);
mutex_init(&vcdev->io_lock);
spin_lock_irqsave(get_ccwdev_lock(cdev), flags); spin_lock_irqsave(get_ccwdev_lock(cdev), flags);
dev_set_drvdata(&cdev->dev, vcdev); dev_set_drvdata(&cdev->dev, vcdev);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <net/sock.h> #include <net/sock.h>
#include <linux/virtio_vsock.h> #include <linux/virtio_vsock.h>
#include <linux/vhost.h> #include <linux/vhost.h>
#include <linux/hashtable.h>
#include <net/af_vsock.h> #include <net/af_vsock.h>
#include "vhost.h" #include "vhost.h"
...@@ -27,14 +28,14 @@ enum { ...@@ -27,14 +28,14 @@ enum {
/* Used to track all the vhost_vsock instances on the system. */ /* Used to track all the vhost_vsock instances on the system. */
static DEFINE_SPINLOCK(vhost_vsock_lock); static DEFINE_SPINLOCK(vhost_vsock_lock);
static LIST_HEAD(vhost_vsock_list); static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
struct vhost_vsock { struct vhost_vsock {
struct vhost_dev dev; struct vhost_dev dev;
struct vhost_virtqueue vqs[2]; struct vhost_virtqueue vqs[2];
/* Link to global vhost_vsock_list, protected by vhost_vsock_lock */ /* Link to global vhost_vsock_hash, writes use vhost_vsock_lock */
struct list_head list; struct hlist_node hash;
struct vhost_work send_pkt_work; struct vhost_work send_pkt_work;
spinlock_t send_pkt_list_lock; spinlock_t send_pkt_list_lock;
...@@ -50,11 +51,14 @@ static u32 vhost_transport_get_local_cid(void) ...@@ -50,11 +51,14 @@ static u32 vhost_transport_get_local_cid(void)
return VHOST_VSOCK_DEFAULT_HOST_CID; return VHOST_VSOCK_DEFAULT_HOST_CID;
} }
static struct vhost_vsock *__vhost_vsock_get(u32 guest_cid) /* Callers that dereference the return value must hold vhost_vsock_lock or the
* RCU read lock.
*/
static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
{ {
struct vhost_vsock *vsock; struct vhost_vsock *vsock;
list_for_each_entry(vsock, &vhost_vsock_list, list) { hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
u32 other_cid = vsock->guest_cid; u32 other_cid = vsock->guest_cid;
/* Skip instances that have no CID yet */ /* Skip instances that have no CID yet */
...@@ -69,17 +73,6 @@ static struct vhost_vsock *__vhost_vsock_get(u32 guest_cid) ...@@ -69,17 +73,6 @@ static struct vhost_vsock *__vhost_vsock_get(u32 guest_cid)
return NULL; return NULL;
} }
static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
{
struct vhost_vsock *vsock;
spin_lock_bh(&vhost_vsock_lock);
vsock = __vhost_vsock_get(guest_cid);
spin_unlock_bh(&vhost_vsock_lock);
return vsock;
}
static void static void
vhost_transport_do_send_pkt(struct vhost_vsock *vsock, vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
struct vhost_virtqueue *vq) struct vhost_virtqueue *vq)
...@@ -210,9 +203,12 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt) ...@@ -210,9 +203,12 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
struct vhost_vsock *vsock; struct vhost_vsock *vsock;
int len = pkt->len; int len = pkt->len;
rcu_read_lock();
/* Find the vhost_vsock according to guest context id */ /* Find the vhost_vsock according to guest context id */
vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid)); vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
if (!vsock) { if (!vsock) {
rcu_read_unlock();
virtio_transport_free_pkt(pkt); virtio_transport_free_pkt(pkt);
return -ENODEV; return -ENODEV;
} }
...@@ -225,6 +221,8 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt) ...@@ -225,6 +221,8 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
spin_unlock_bh(&vsock->send_pkt_list_lock); spin_unlock_bh(&vsock->send_pkt_list_lock);
vhost_work_queue(&vsock->dev, &vsock->send_pkt_work); vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
rcu_read_unlock();
return len; return len;
} }
...@@ -234,12 +232,15 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk) ...@@ -234,12 +232,15 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk)
struct vhost_vsock *vsock; struct vhost_vsock *vsock;
struct virtio_vsock_pkt *pkt, *n; struct virtio_vsock_pkt *pkt, *n;
int cnt = 0; int cnt = 0;
int ret = -ENODEV;
LIST_HEAD(freeme); LIST_HEAD(freeme);
rcu_read_lock();
/* Find the vhost_vsock according to guest context id */ /* Find the vhost_vsock according to guest context id */
vsock = vhost_vsock_get(vsk->remote_addr.svm_cid); vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
if (!vsock) if (!vsock)
return -ENODEV; goto out;
spin_lock_bh(&vsock->send_pkt_list_lock); spin_lock_bh(&vsock->send_pkt_list_lock);
list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) { list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
...@@ -265,7 +266,10 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk) ...@@ -265,7 +266,10 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk)
vhost_poll_queue(&tx_vq->poll); vhost_poll_queue(&tx_vq->poll);
} }
return 0; ret = 0;
out:
rcu_read_unlock();
return ret;
} }
static struct virtio_vsock_pkt * static struct virtio_vsock_pkt *
...@@ -533,10 +537,6 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file) ...@@ -533,10 +537,6 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
spin_lock_init(&vsock->send_pkt_list_lock); spin_lock_init(&vsock->send_pkt_list_lock);
INIT_LIST_HEAD(&vsock->send_pkt_list); INIT_LIST_HEAD(&vsock->send_pkt_list);
vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work); vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
spin_lock_bh(&vhost_vsock_lock);
list_add_tail(&vsock->list, &vhost_vsock_list);
spin_unlock_bh(&vhost_vsock_lock);
return 0; return 0;
out: out:
...@@ -563,13 +563,21 @@ static void vhost_vsock_reset_orphans(struct sock *sk) ...@@ -563,13 +563,21 @@ static void vhost_vsock_reset_orphans(struct sock *sk)
* executing. * executing.
*/ */
if (!vhost_vsock_get(vsk->remote_addr.svm_cid)) { /* If the peer is still valid, no need to reset connection */
sock_set_flag(sk, SOCK_DONE); if (vhost_vsock_get(vsk->remote_addr.svm_cid))
vsk->peer_shutdown = SHUTDOWN_MASK; return;
sk->sk_state = SS_UNCONNECTED;
sk->sk_err = ECONNRESET; /* If the close timeout is pending, let it expire. This avoids races
sk->sk_error_report(sk); * with the timeout callback.
} */
if (vsk->close_work_scheduled)
return;
sock_set_flag(sk, SOCK_DONE);
vsk->peer_shutdown = SHUTDOWN_MASK;
sk->sk_state = SS_UNCONNECTED;
sk->sk_err = ECONNRESET;
sk->sk_error_report(sk);
} }
static int vhost_vsock_dev_release(struct inode *inode, struct file *file) static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
...@@ -577,9 +585,13 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file) ...@@ -577,9 +585,13 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
struct vhost_vsock *vsock = file->private_data; struct vhost_vsock *vsock = file->private_data;
spin_lock_bh(&vhost_vsock_lock); spin_lock_bh(&vhost_vsock_lock);
list_del(&vsock->list); if (vsock->guest_cid)
hash_del_rcu(&vsock->hash);
spin_unlock_bh(&vhost_vsock_lock); spin_unlock_bh(&vhost_vsock_lock);
/* Wait for other CPUs to finish using vsock */
synchronize_rcu();
/* Iterating over all connections for all CIDs to find orphans is /* Iterating over all connections for all CIDs to find orphans is
* inefficient. Room for improvement here. */ * inefficient. Room for improvement here. */
vsock_for_each_connected_socket(vhost_vsock_reset_orphans); vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
...@@ -620,12 +632,17 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid) ...@@ -620,12 +632,17 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
/* Refuse if CID is already in use */ /* Refuse if CID is already in use */
spin_lock_bh(&vhost_vsock_lock); spin_lock_bh(&vhost_vsock_lock);
other = __vhost_vsock_get(guest_cid); other = vhost_vsock_get(guest_cid);
if (other && other != vsock) { if (other && other != vsock) {
spin_unlock_bh(&vhost_vsock_lock); spin_unlock_bh(&vhost_vsock_lock);
return -EADDRINUSE; return -EADDRINUSE;
} }
if (vsock->guest_cid)
hash_del_rcu(&vsock->hash);
vsock->guest_cid = guest_cid; vsock->guest_cid = guest_cid;
hash_add_rcu(vhost_vsock_hash, &vsock->hash, guest_cid);
spin_unlock_bh(&vhost_vsock_lock); spin_unlock_bh(&vhost_vsock_lock);
return 0; return 0;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment