Commit 6b1e6cc7 authored by Jason Wang's avatar Jason Wang Committed by Michael S. Tsirkin

vhost: new device IOTLB API

This patch tries to implement an device IOTLB for vhost. This could be
used with userspace(qemu) implementation of DMA remapping
to emulate an IOMMU for the guest.

The idea is simple, cache the translation in a software device IOTLB
(which is implemented as an interval tree) in vhost and use vhost_net
file descriptor for reporting IOTLB miss and IOTLB
update/invalidation. When vhost meets an IOTLB miss, the fault
address, size and access can be read from the file. After userspace
finishes the translation, it writes the translated address to the
vhost_net file to update the device IOTLB.

When device IOTLB is enabled by setting VIRTIO_F_IOMMU_PLATFORM all vq
addresses set by ioctl are treated as iova instead of virtual address and
the accessing can only be done through IOTLB instead of direct userspace
memory access. Before each round or vq processing, all vq metadata is
prefetched in device IOTLB to make sure no translation fault happens
during vq processing.

In most cases, virtqueues are contiguous even in virtual address space.
The IOTLB translation for virtqueue itself may make it a little
slower. We might add fast path cache on top of this patch.
Signed-off-by: default avatarJason Wang <jasowang@redhat.com>
[mst: use virtio feature bit: VHOST_F_DEVICE_IOTLB -> VIRTIO_F_IOMMU_PLATFORM ]
[mst: fix build warnings ]
Signed-off-by: default avatarMichael S. Tsirkin <mst@redhat.com>
[ weiyj.lk: missing unlock on error ]
Signed-off-by: default avatarWei Yongjun <weiyj.lk@gmail.com>
parent b2fbd8b0
...@@ -61,7 +61,8 @@ MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;" ...@@ -61,7 +61,8 @@ MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
enum { enum {
VHOST_NET_FEATURES = VHOST_FEATURES | VHOST_NET_FEATURES = VHOST_FEATURES |
(1ULL << VHOST_NET_F_VIRTIO_NET_HDR) | (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
(1ULL << VIRTIO_NET_F_MRG_RXBUF) (1ULL << VIRTIO_NET_F_MRG_RXBUF) |
(1ULL << VIRTIO_F_IOMMU_PLATFORM)
}; };
enum { enum {
...@@ -308,7 +309,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net, ...@@ -308,7 +309,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
{ {
unsigned long uninitialized_var(endtime); unsigned long uninitialized_var(endtime);
int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
out_num, in_num, NULL, NULL); out_num, in_num, NULL, NULL);
if (r == vq->num && vq->busyloop_timeout) { if (r == vq->num && vq->busyloop_timeout) {
preempt_disable(); preempt_disable();
...@@ -318,7 +319,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net, ...@@ -318,7 +319,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
cpu_relax_lowlatency(); cpu_relax_lowlatency();
preempt_enable(); preempt_enable();
r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
out_num, in_num, NULL, NULL); out_num, in_num, NULL, NULL);
} }
return r; return r;
...@@ -351,6 +352,9 @@ static void handle_tx(struct vhost_net *net) ...@@ -351,6 +352,9 @@ static void handle_tx(struct vhost_net *net)
if (!sock) if (!sock)
goto out; goto out;
if (!vq_iotlb_prefetch(vq))
goto out;
vhost_disable_notify(&net->dev, vq); vhost_disable_notify(&net->dev, vq);
hdr_size = nvq->vhost_hlen; hdr_size = nvq->vhost_hlen;
...@@ -612,6 +616,10 @@ static void handle_rx(struct vhost_net *net) ...@@ -612,6 +616,10 @@ static void handle_rx(struct vhost_net *net)
sock = vq->private_data; sock = vq->private_data;
if (!sock) if (!sock)
goto out; goto out;
if (!vq_iotlb_prefetch(vq))
goto out;
vhost_disable_notify(&net->dev, vq); vhost_disable_notify(&net->dev, vq);
vhost_hlen = nvq->vhost_hlen; vhost_hlen = nvq->vhost_hlen;
...@@ -1080,10 +1088,14 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) ...@@ -1080,10 +1088,14 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
} }
mutex_lock(&n->dev.mutex); mutex_lock(&n->dev.mutex);
if ((features & (1 << VHOST_F_LOG_ALL)) && if ((features & (1 << VHOST_F_LOG_ALL)) &&
!vhost_log_access_ok(&n->dev)) { !vhost_log_access_ok(&n->dev))
mutex_unlock(&n->dev.mutex); goto out_unlock;
return -EFAULT;
if ((features & (1ULL << VIRTIO_F_IOMMU_PLATFORM))) {
if (vhost_init_device_iotlb(&n->dev, true))
goto out_unlock;
} }
for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
mutex_lock(&n->vqs[i].vq.mutex); mutex_lock(&n->vqs[i].vq.mutex);
n->vqs[i].vq.acked_features = features; n->vqs[i].vq.acked_features = features;
...@@ -1093,6 +1105,10 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) ...@@ -1093,6 +1105,10 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
} }
mutex_unlock(&n->dev.mutex); mutex_unlock(&n->dev.mutex);
return 0; return 0;
out_unlock:
mutex_unlock(&n->dev.mutex);
return -EFAULT;
} }
static long vhost_net_set_owner(struct vhost_net *n) static long vhost_net_set_owner(struct vhost_net *n)
...@@ -1166,9 +1182,40 @@ static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl, ...@@ -1166,9 +1182,40 @@ static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
} }
#endif #endif
static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
{
struct file *file = iocb->ki_filp;
struct vhost_net *n = file->private_data;
struct vhost_dev *dev = &n->dev;
int noblock = file->f_flags & O_NONBLOCK;
return vhost_chr_read_iter(dev, to, noblock);
}
static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb,
struct iov_iter *from)
{
struct file *file = iocb->ki_filp;
struct vhost_net *n = file->private_data;
struct vhost_dev *dev = &n->dev;
return vhost_chr_write_iter(dev, from);
}
static unsigned int vhost_net_chr_poll(struct file *file, poll_table *wait)
{
struct vhost_net *n = file->private_data;
struct vhost_dev *dev = &n->dev;
return vhost_chr_poll(file, dev, wait);
}
static const struct file_operations vhost_net_fops = { static const struct file_operations vhost_net_fops = {
.owner = THIS_MODULE, .owner = THIS_MODULE,
.release = vhost_net_release, .release = vhost_net_release,
.read_iter = vhost_net_chr_read_iter,
.write_iter = vhost_net_chr_write_iter,
.poll = vhost_net_chr_poll,
.unlocked_ioctl = vhost_net_ioctl, .unlocked_ioctl = vhost_net_ioctl,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_ioctl = vhost_net_compat_ioctl, .compat_ioctl = vhost_net_compat_ioctl,
......
...@@ -35,6 +35,10 @@ static ushort max_mem_regions = 64; ...@@ -35,6 +35,10 @@ static ushort max_mem_regions = 64;
module_param(max_mem_regions, ushort, 0444); module_param(max_mem_regions, ushort, 0444);
MODULE_PARM_DESC(max_mem_regions, MODULE_PARM_DESC(max_mem_regions,
"Maximum number of memory regions in memory map. (default: 64)"); "Maximum number of memory regions in memory map. (default: 64)");
static int max_iotlb_entries = 2048;
module_param(max_iotlb_entries, int, 0444);
MODULE_PARM_DESC(max_iotlb_entries,
"Maximum number of iotlb entries. (default: 2048)");
enum { enum {
VHOST_MEMORY_F_LOG = 0x1, VHOST_MEMORY_F_LOG = 0x1,
...@@ -306,6 +310,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, ...@@ -306,6 +310,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vhost_disable_cross_endian(vq); vhost_disable_cross_endian(vq);
vq->busyloop_timeout = 0; vq->busyloop_timeout = 0;
vq->umem = NULL; vq->umem = NULL;
vq->iotlb = NULL;
} }
static int vhost_worker(void *data) static int vhost_worker(void *data)
...@@ -400,9 +405,14 @@ void vhost_dev_init(struct vhost_dev *dev, ...@@ -400,9 +405,14 @@ void vhost_dev_init(struct vhost_dev *dev,
dev->log_ctx = NULL; dev->log_ctx = NULL;
dev->log_file = NULL; dev->log_file = NULL;
dev->umem = NULL; dev->umem = NULL;
dev->iotlb = NULL;
dev->mm = NULL; dev->mm = NULL;
dev->worker = NULL; dev->worker = NULL;
init_llist_head(&dev->work_list); init_llist_head(&dev->work_list);
init_waitqueue_head(&dev->wait);
INIT_LIST_HEAD(&dev->read_list);
INIT_LIST_HEAD(&dev->pending_list);
spin_lock_init(&dev->iotlb_lock);
for (i = 0; i < dev->nvqs; ++i) { for (i = 0; i < dev->nvqs; ++i) {
...@@ -550,6 +560,15 @@ void vhost_dev_stop(struct vhost_dev *dev) ...@@ -550,6 +560,15 @@ void vhost_dev_stop(struct vhost_dev *dev)
} }
EXPORT_SYMBOL_GPL(vhost_dev_stop); EXPORT_SYMBOL_GPL(vhost_dev_stop);
static void vhost_umem_free(struct vhost_umem *umem,
struct vhost_umem_node *node)
{
vhost_umem_interval_tree_remove(node, &umem->umem_tree);
list_del(&node->link);
kfree(node);
umem->numem--;
}
static void vhost_umem_clean(struct vhost_umem *umem) static void vhost_umem_clean(struct vhost_umem *umem)
{ {
struct vhost_umem_node *node, *tmp; struct vhost_umem_node *node, *tmp;
...@@ -557,14 +576,31 @@ static void vhost_umem_clean(struct vhost_umem *umem) ...@@ -557,14 +576,31 @@ static void vhost_umem_clean(struct vhost_umem *umem)
if (!umem) if (!umem)
return; return;
list_for_each_entry_safe(node, tmp, &umem->umem_list, link) { list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
vhost_umem_interval_tree_remove(node, &umem->umem_tree); vhost_umem_free(umem, node);
list_del(&node->link);
kvfree(node);
}
kvfree(umem); kvfree(umem);
} }
static void vhost_clear_msg(struct vhost_dev *dev)
{
struct vhost_msg_node *node, *n;
spin_lock(&dev->iotlb_lock);
list_for_each_entry_safe(node, n, &dev->read_list, node) {
list_del(&node->node);
kfree(node);
}
list_for_each_entry_safe(node, n, &dev->pending_list, node) {
list_del(&node->node);
kfree(node);
}
spin_unlock(&dev->iotlb_lock);
}
/* Caller should have device mutex if and only if locked is set */ /* Caller should have device mutex if and only if locked is set */
void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
{ {
...@@ -593,6 +629,10 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) ...@@ -593,6 +629,10 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
/* No one will access memory at this point */ /* No one will access memory at this point */
vhost_umem_clean(dev->umem); vhost_umem_clean(dev->umem);
dev->umem = NULL; dev->umem = NULL;
vhost_umem_clean(dev->iotlb);
dev->iotlb = NULL;
vhost_clear_msg(dev);
wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
WARN_ON(!llist_empty(&dev->work_list)); WARN_ON(!llist_empty(&dev->work_list));
if (dev->worker) { if (dev->worker) {
kthread_stop(dev->worker); kthread_stop(dev->worker);
...@@ -668,28 +708,381 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem, ...@@ -668,28 +708,381 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
return 1; return 1;
} }
#define vhost_put_user(vq, x, ptr) __put_user(x, ptr) static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
struct iovec iov[], int iov_size, int access);
static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to, static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to,
const void *from, unsigned size) const void *from, unsigned size)
{ {
return __copy_to_user(to, from, size); int ret;
}
#define vhost_get_user(vq, x, ptr) __get_user(x, ptr) if (!vq->iotlb)
return __copy_to_user(to, from, size);
else {
/* This function should be called after iotlb
* prefetch, which means we're sure that all vq
* could be access through iotlb. So -EAGAIN should
* not happen in this case.
*/
/* TODO: more fast path */
struct iov_iter t;
ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
ARRAY_SIZE(vq->iotlb_iov),
VHOST_ACCESS_WO);
if (ret < 0)
goto out;
iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
ret = copy_to_iter(from, size, &t);
if (ret == size)
ret = 0;
}
out:
return ret;
}
static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to, static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
void *from, unsigned size) void *from, unsigned size)
{ {
return __copy_from_user(to, from, size); int ret;
if (!vq->iotlb)
return __copy_from_user(to, from, size);
else {
/* This function should be called after iotlb
* prefetch, which means we're sure that vq
* could be access through iotlb. So -EAGAIN should
* not happen in this case.
*/
/* TODO: more fast path */
struct iov_iter f;
ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
ARRAY_SIZE(vq->iotlb_iov),
VHOST_ACCESS_RO);
if (ret < 0) {
vq_err(vq, "IOTLB translation failure: uaddr "
"%p size 0x%llx\n", from,
(unsigned long long) size);
goto out;
}
iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
ret = copy_from_iter(to, size, &f);
if (ret == size)
ret = 0;
}
out:
return ret;
}
static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
void *addr, unsigned size)
{
int ret;
/* This function should be called after iotlb
* prefetch, which means we're sure that vq
* could be access through iotlb. So -EAGAIN should
* not happen in this case.
*/
/* TODO: more fast path */
ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
ARRAY_SIZE(vq->iotlb_iov),
VHOST_ACCESS_RO);
if (ret < 0) {
vq_err(vq, "IOTLB translation failure: uaddr "
"%p size 0x%llx\n", addr,
(unsigned long long) size);
return NULL;
}
if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
vq_err(vq, "Non atomic userspace memory access: uaddr "
"%p size 0x%llx\n", addr,
(unsigned long long) size);
return NULL;
}
return vq->iotlb_iov[0].iov_base;
}
#define vhost_put_user(vq, x, ptr) \
({ \
int ret = -EFAULT; \
if (!vq->iotlb) { \
ret = __put_user(x, ptr); \
} else { \
__typeof__(ptr) to = \
(__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
if (to != NULL) \
ret = __put_user(x, to); \
else \
ret = -EFAULT; \
} \
ret; \
})
#define vhost_get_user(vq, x, ptr) \
({ \
int ret; \
if (!vq->iotlb) { \
ret = __get_user(x, ptr); \
} else { \
__typeof__(ptr) from = \
(__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
if (from != NULL) \
ret = __get_user(x, from); \
else \
ret = -EFAULT; \
} \
ret; \
})
static void vhost_dev_lock_vqs(struct vhost_dev *d)
{
int i = 0;
for (i = 0; i < d->nvqs; ++i)
mutex_lock(&d->vqs[i]->mutex);
}
static void vhost_dev_unlock_vqs(struct vhost_dev *d)
{
int i = 0;
for (i = 0; i < d->nvqs; ++i)
mutex_unlock(&d->vqs[i]->mutex);
}
static int vhost_new_umem_range(struct vhost_umem *umem,
u64 start, u64 size, u64 end,
u64 userspace_addr, int perm)
{
struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
if (!node)
return -ENOMEM;
if (umem->numem == max_iotlb_entries) {
tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
vhost_umem_free(umem, tmp);
}
node->start = start;
node->size = size;
node->last = end;
node->userspace_addr = userspace_addr;
node->perm = perm;
INIT_LIST_HEAD(&node->link);
list_add_tail(&node->link, &umem->umem_list);
vhost_umem_interval_tree_insert(node, &umem->umem_tree);
umem->numem++;
return 0;
}
static void vhost_del_umem_range(struct vhost_umem *umem,
u64 start, u64 end)
{
struct vhost_umem_node *node;
while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
start, end)))
vhost_umem_free(umem, node);
}
static void vhost_iotlb_notify_vq(struct vhost_dev *d,
struct vhost_iotlb_msg *msg)
{
struct vhost_msg_node *node, *n;
spin_lock(&d->iotlb_lock);
list_for_each_entry_safe(node, n, &d->pending_list, node) {
struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
if (msg->iova <= vq_msg->iova &&
msg->iova + msg->size - 1 > vq_msg->iova &&
vq_msg->type == VHOST_IOTLB_MISS) {
vhost_poll_queue(&node->vq->poll);
list_del(&node->node);
kfree(node);
}
}
spin_unlock(&d->iotlb_lock);
}
static int umem_access_ok(u64 uaddr, u64 size, int access)
{
unsigned long a = uaddr;
if ((access & VHOST_ACCESS_RO) &&
!access_ok(VERIFY_READ, (void __user *)a, size))
return -EFAULT;
if ((access & VHOST_ACCESS_WO) &&
!access_ok(VERIFY_WRITE, (void __user *)a, size))
return -EFAULT;
return 0;
}
int vhost_process_iotlb_msg(struct vhost_dev *dev,
struct vhost_iotlb_msg *msg)
{
int ret = 0;
vhost_dev_lock_vqs(dev);
switch (msg->type) {
case VHOST_IOTLB_UPDATE:
if (!dev->iotlb) {
ret = -EFAULT;
break;
}
if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
ret = -EFAULT;
break;
}
if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
msg->iova + msg->size - 1,
msg->uaddr, msg->perm)) {
ret = -ENOMEM;
break;
}
vhost_iotlb_notify_vq(dev, msg);
break;
case VHOST_IOTLB_INVALIDATE:
vhost_del_umem_range(dev->iotlb, msg->iova,
msg->iova + msg->size - 1);
break;
default:
ret = -EINVAL;
break;
}
vhost_dev_unlock_vqs(dev);
return ret;
}
ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
struct iov_iter *from)
{
struct vhost_msg_node node;
unsigned size = sizeof(struct vhost_msg);
size_t ret;
int err;
if (iov_iter_count(from) < size)
return 0;
ret = copy_from_iter(&node.msg, size, from);
if (ret != size)
goto done;
switch (node.msg.type) {
case VHOST_IOTLB_MSG:
err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
if (err)
ret = err;
break;
default:
ret = -EINVAL;
break;
}
done:
return ret;
}
EXPORT_SYMBOL(vhost_chr_write_iter);
unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
poll_table *wait)
{
unsigned int mask = 0;
poll_wait(file, &dev->wait, wait);
if (!list_empty(&dev->read_list))
mask |= POLLIN | POLLRDNORM;
return mask;
}
EXPORT_SYMBOL(vhost_chr_poll);
ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
int noblock)
{
DEFINE_WAIT(wait);
struct vhost_msg_node *node;
ssize_t ret = 0;
unsigned size = sizeof(struct vhost_msg);
if (iov_iter_count(to) < size)
return 0;
while (1) {
if (!noblock)
prepare_to_wait(&dev->wait, &wait,
TASK_INTERRUPTIBLE);
node = vhost_dequeue_msg(dev, &dev->read_list);
if (node)
break;
if (noblock) {
ret = -EAGAIN;
break;
}
if (signal_pending(current)) {
ret = -ERESTARTSYS;
break;
}
if (!dev->iotlb) {
ret = -EBADFD;
break;
}
schedule();
}
if (!noblock)
finish_wait(&dev->wait, &wait);
if (node) {
ret = copy_to_iter(&node->msg, size, to);
if (ret != size || node->msg.type != VHOST_IOTLB_MISS) {
kfree(node);
return ret;
}
vhost_enqueue_msg(dev, &dev->pending_list, node);
}
return ret;
}
EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
{
struct vhost_dev *dev = vq->dev;
struct vhost_msg_node *node;
struct vhost_iotlb_msg *msg;
node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
if (!node)
return -ENOMEM;
msg = &node->msg.iotlb;
msg->type = VHOST_IOTLB_MISS;
msg->iova = iova;
msg->perm = access;
vhost_enqueue_msg(dev, &dev->read_list, node);
return 0;
} }
static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
struct vring_desc __user *desc, struct vring_desc __user *desc,
struct vring_avail __user *avail, struct vring_avail __user *avail,
struct vring_used __user *used) struct vring_used __user *used)
{ {
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
return access_ok(VERIFY_READ, desc, num * sizeof *desc) && return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
access_ok(VERIFY_READ, avail, access_ok(VERIFY_READ, avail,
sizeof *avail + num * sizeof *avail->ring + s) && sizeof *avail + num * sizeof *avail->ring + s) &&
...@@ -697,6 +1090,54 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, ...@@ -697,6 +1090,54 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
sizeof *used + num * sizeof *used->ring + s); sizeof *used + num * sizeof *used->ring + s);
} }
static int iotlb_access_ok(struct vhost_virtqueue *vq,
int access, u64 addr, u64 len)
{
const struct vhost_umem_node *node;
struct vhost_umem *umem = vq->iotlb;
u64 s = 0, size;
while (len > s) {
node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
addr,
addr + len - 1);
if (node == NULL || node->start > addr) {
vhost_iotlb_miss(vq, addr, access);
return false;
} else if (!(node->perm & access)) {
/* Report the possible access violation by
* request another translation from userspace.
*/
return false;
}
size = node->size - addr + node->start;
s += size;
addr += size;
}
return true;
}
int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
{
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
unsigned int num = vq->num;
if (!vq->iotlb)
return 1;
return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
num * sizeof *vq->desc) &&
iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
sizeof *vq->avail +
num * sizeof *vq->avail->ring + s) &&
iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
sizeof *vq->used +
num * sizeof *vq->used->ring + s);
}
EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
/* Can we log writes? */ /* Can we log writes? */
/* Caller should have device mutex but not vq mutex */ /* Caller should have device mutex but not vq mutex */
int vhost_log_access_ok(struct vhost_dev *dev) int vhost_log_access_ok(struct vhost_dev *dev)
...@@ -723,16 +1164,35 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq, ...@@ -723,16 +1164,35 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
/* Caller should have vq mutex and device mutex */ /* Caller should have vq mutex and device mutex */
int vhost_vq_access_ok(struct vhost_virtqueue *vq) int vhost_vq_access_ok(struct vhost_virtqueue *vq)
{ {
if (vq->iotlb) {
/* When device IOTLB was used, the access validation
* will be validated during prefetching.
*/
return 1;
}
return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) &&
vq_log_access_ok(vq, vq->log_base); vq_log_access_ok(vq, vq->log_base);
} }
EXPORT_SYMBOL_GPL(vhost_vq_access_ok); EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
static struct vhost_umem *vhost_umem_alloc(void)
{
struct vhost_umem *umem = vhost_kvzalloc(sizeof(*umem));
if (!umem)
return NULL;
umem->umem_tree = RB_ROOT;
umem->numem = 0;
INIT_LIST_HEAD(&umem->umem_list);
return umem;
}
static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
{ {
struct vhost_memory mem, *newmem; struct vhost_memory mem, *newmem;
struct vhost_memory_region *region; struct vhost_memory_region *region;
struct vhost_umem_node *node;
struct vhost_umem *newumem, *oldumem; struct vhost_umem *newumem, *oldumem;
unsigned long size = offsetof(struct vhost_memory, regions); unsigned long size = offsetof(struct vhost_memory, regions);
int i; int i;
...@@ -754,28 +1214,23 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) ...@@ -754,28 +1214,23 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
return -EFAULT; return -EFAULT;
} }
newumem = vhost_kvzalloc(sizeof(*newumem)); newumem = vhost_umem_alloc();
if (!newumem) { if (!newumem) {
kvfree(newmem); kvfree(newmem);
return -ENOMEM; return -ENOMEM;
} }
newumem->umem_tree = RB_ROOT;
INIT_LIST_HEAD(&newumem->umem_list);
for (region = newmem->regions; for (region = newmem->regions;
region < newmem->regions + mem.nregions; region < newmem->regions + mem.nregions;
region++) { region++) {
node = vhost_kvzalloc(sizeof(*node)); if (vhost_new_umem_range(newumem,
if (!node) region->guest_phys_addr,
region->memory_size,
region->guest_phys_addr +
region->memory_size - 1,
region->userspace_addr,
VHOST_ACCESS_RW))
goto err; goto err;
node->start = region->guest_phys_addr;
node->size = region->memory_size;
node->last = node->start + node->size - 1;
node->userspace_addr = region->userspace_addr;
INIT_LIST_HEAD(&node->link);
list_add_tail(&node->link, &newumem->umem_list);
vhost_umem_interval_tree_insert(node, &newumem->umem_tree);
} }
if (!memory_access_ok(d, newumem, 0)) if (!memory_access_ok(d, newumem, 0))
...@@ -1019,6 +1474,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) ...@@ -1019,6 +1474,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
} }
EXPORT_SYMBOL_GPL(vhost_vring_ioctl); EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
{
struct vhost_umem *niotlb, *oiotlb;
int i;
niotlb = vhost_umem_alloc();
if (!niotlb)
return -ENOMEM;
oiotlb = d->iotlb;
d->iotlb = niotlb;
for (i = 0; i < d->nvqs; ++i) {
mutex_lock(&d->vqs[i]->mutex);
d->vqs[i]->iotlb = niotlb;
mutex_unlock(&d->vqs[i]->mutex);
}
vhost_umem_clean(oiotlb);
return 0;
}
EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
/* Caller must have device mutex */ /* Caller must have device mutex */
long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
{ {
...@@ -1233,15 +1712,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) ...@@ -1233,15 +1712,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
if (r) if (r)
goto err; goto err;
vq->signalled_used_valid = false; vq->signalled_used_valid = false;
if (!access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { if (!vq->iotlb &&
!access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
r = -EFAULT; r = -EFAULT;
goto err; goto err;
} }
r = vhost_get_user(vq, last_used_idx, &vq->used->idx); r = vhost_get_user(vq, last_used_idx, &vq->used->idx);
if (r) if (r) {
vq_err(vq, "Can't access used idx at %p\n",
&vq->used->idx);
goto err; goto err;
}
vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
return 0; return 0;
err: err:
vq->is_le = is_le; vq->is_le = is_le;
return r; return r;
...@@ -1249,10 +1733,11 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) ...@@ -1249,10 +1733,11 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
EXPORT_SYMBOL_GPL(vhost_vq_init_access); EXPORT_SYMBOL_GPL(vhost_vq_init_access);
static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
struct iovec iov[], int iov_size) struct iovec iov[], int iov_size, int access)
{ {
const struct vhost_umem_node *node; const struct vhost_umem_node *node;
struct vhost_umem *umem = vq->umem; struct vhost_dev *dev = vq->dev;
struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
struct iovec *_iov; struct iovec *_iov;
u64 s = 0; u64 s = 0;
int ret = 0; int ret = 0;
...@@ -1263,12 +1748,21 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, ...@@ -1263,12 +1748,21 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
ret = -ENOBUFS; ret = -ENOBUFS;
break; break;
} }
node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
addr, addr + len - 1); addr, addr + len - 1);
if (node == NULL || node->start > addr) { if (node == NULL || node->start > addr) {
ret = -EFAULT; if (umem != dev->iotlb) {
ret = -EFAULT;
break;
}
ret = -EAGAIN;
break;
} else if (!(node->perm & access)) {
ret = -EPERM;
break; break;
} }
_iov = iov + ret; _iov = iov + ret;
size = node->size - addr + node->start; size = node->size - addr + node->start;
_iov->iov_len = min((u64)len - s, size); _iov->iov_len = min((u64)len - s, size);
...@@ -1279,6 +1773,8 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, ...@@ -1279,6 +1773,8 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
++ret; ++ret;
} }
if (ret == -EAGAIN)
vhost_iotlb_miss(vq, addr, access);
return ret; return ret;
} }
...@@ -1313,7 +1809,7 @@ static int get_indirect(struct vhost_virtqueue *vq, ...@@ -1313,7 +1809,7 @@ static int get_indirect(struct vhost_virtqueue *vq,
unsigned int i = 0, count, found = 0; unsigned int i = 0, count, found = 0;
u32 len = vhost32_to_cpu(vq, indirect->len); u32 len = vhost32_to_cpu(vq, indirect->len);
struct iov_iter from; struct iov_iter from;
int ret; int ret, access;
/* Sanity check */ /* Sanity check */
if (unlikely(len % sizeof desc)) { if (unlikely(len % sizeof desc)) {
...@@ -1325,9 +1821,10 @@ static int get_indirect(struct vhost_virtqueue *vq, ...@@ -1325,9 +1821,10 @@ static int get_indirect(struct vhost_virtqueue *vq,
} }
ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect, ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
UIO_MAXIOV); UIO_MAXIOV, VHOST_ACCESS_RO);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d in indirect.\n", ret); if (ret != -EAGAIN)
vq_err(vq, "Translation failure %d in indirect.\n", ret);
return ret; return ret;
} }
iov_iter_init(&from, READ, vq->indirect, ret, len); iov_iter_init(&from, READ, vq->indirect, ret, len);
...@@ -1365,16 +1862,22 @@ static int get_indirect(struct vhost_virtqueue *vq, ...@@ -1365,16 +1862,22 @@ static int get_indirect(struct vhost_virtqueue *vq,
return -EINVAL; return -EINVAL;
} }
if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
access = VHOST_ACCESS_WO;
else
access = VHOST_ACCESS_RO;
ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
vhost32_to_cpu(vq, desc.len), iov + iov_count, vhost32_to_cpu(vq, desc.len), iov + iov_count,
iov_size - iov_count); iov_size - iov_count, access);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d indirect idx %d\n", if (ret != -EAGAIN)
ret, i); vq_err(vq, "Translation failure %d indirect idx %d\n",
ret, i);
return ret; return ret;
} }
/* If this is an input descriptor, increment that count. */ /* If this is an input descriptor, increment that count. */
if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { if (access == VHOST_ACCESS_WO) {
*in_num += ret; *in_num += ret;
if (unlikely(log)) { if (unlikely(log)) {
log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
...@@ -1413,7 +1916,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, ...@@ -1413,7 +1916,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
u16 last_avail_idx; u16 last_avail_idx;
__virtio16 avail_idx; __virtio16 avail_idx;
__virtio16 ring_head; __virtio16 ring_head;
int ret; int ret, access;
/* Check it isn't doing very strange things with descriptor numbers. */ /* Check it isn't doing very strange things with descriptor numbers. */
last_avail_idx = vq->last_avail_idx; last_avail_idx = vq->last_avail_idx;
...@@ -1487,22 +1990,28 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, ...@@ -1487,22 +1990,28 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
out_num, in_num, out_num, in_num,
log, log_num, &desc); log, log_num, &desc);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
vq_err(vq, "Failure detected " if (ret != -EAGAIN)
"in indirect descriptor at idx %d\n", i); vq_err(vq, "Failure detected "
"in indirect descriptor at idx %d\n", i);
return ret; return ret;
} }
continue; continue;
} }
if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
access = VHOST_ACCESS_WO;
else
access = VHOST_ACCESS_RO;
ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
vhost32_to_cpu(vq, desc.len), iov + iov_count, vhost32_to_cpu(vq, desc.len), iov + iov_count,
iov_size - iov_count); iov_size - iov_count, access);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d descriptor idx %d\n", if (ret != -EAGAIN)
ret, i); vq_err(vq, "Translation failure %d descriptor idx %d\n",
ret, i);
return ret; return ret;
} }
if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { if (access == VHOST_ACCESS_WO) {
/* If this is an input descriptor, /* If this is an input descriptor,
* increment that count. */ * increment that count. */
*in_num += ret; *in_num += ret;
...@@ -1768,6 +2277,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) ...@@ -1768,6 +2277,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
} }
EXPORT_SYMBOL_GPL(vhost_disable_notify); EXPORT_SYMBOL_GPL(vhost_disable_notify);
/* Create a new message. */
struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
{
struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
if (!node)
return NULL;
node->vq = vq;
node->msg.type = type;
return node;
}
EXPORT_SYMBOL_GPL(vhost_new_msg);
void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
struct vhost_msg_node *node)
{
spin_lock(&dev->iotlb_lock);
list_add_tail(&node->node, head);
spin_unlock(&dev->iotlb_lock);
wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
}
EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
struct list_head *head)
{
struct vhost_msg_node *node = NULL;
spin_lock(&dev->iotlb_lock);
if (!list_empty(head)) {
node = list_first_entry(head, struct vhost_msg_node,
node);
list_del(&node->node);
}
spin_unlock(&dev->iotlb_lock);
return node;
}
EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
static int __init vhost_init(void) static int __init vhost_init(void)
{ {
return 0; return 0;
......
...@@ -65,13 +65,15 @@ struct vhost_umem_node { ...@@ -65,13 +65,15 @@ struct vhost_umem_node {
__u64 last; __u64 last;
__u64 size; __u64 size;
__u64 userspace_addr; __u64 userspace_addr;
__u64 flags_padding; __u32 perm;
__u32 flags_padding;
__u64 __subtree_last; __u64 __subtree_last;
}; };
struct vhost_umem { struct vhost_umem {
struct rb_root umem_tree; struct rb_root umem_tree;
struct list_head umem_list; struct list_head umem_list;
int numem;
}; };
/* The virtqueue structure describes a queue attached to a device. */ /* The virtqueue structure describes a queue attached to a device. */
...@@ -119,10 +121,12 @@ struct vhost_virtqueue { ...@@ -119,10 +121,12 @@ struct vhost_virtqueue {
u64 log_addr; u64 log_addr;
struct iovec iov[UIO_MAXIOV]; struct iovec iov[UIO_MAXIOV];
struct iovec iotlb_iov[64];
struct iovec *indirect; struct iovec *indirect;
struct vring_used_elem *heads; struct vring_used_elem *heads;
/* Protected by virtqueue mutex. */ /* Protected by virtqueue mutex. */
struct vhost_umem *umem; struct vhost_umem *umem;
struct vhost_umem *iotlb;
void *private_data; void *private_data;
u64 acked_features; u64 acked_features;
/* Log write descriptors */ /* Log write descriptors */
...@@ -139,6 +143,12 @@ struct vhost_virtqueue { ...@@ -139,6 +143,12 @@ struct vhost_virtqueue {
u32 busyloop_timeout; u32 busyloop_timeout;
}; };
struct vhost_msg_node {
struct vhost_msg msg;
struct vhost_virtqueue *vq;
struct list_head node;
};
struct vhost_dev { struct vhost_dev {
struct mm_struct *mm; struct mm_struct *mm;
struct mutex mutex; struct mutex mutex;
...@@ -149,6 +159,11 @@ struct vhost_dev { ...@@ -149,6 +159,11 @@ struct vhost_dev {
struct llist_head work_list; struct llist_head work_list;
struct task_struct *worker; struct task_struct *worker;
struct vhost_umem *umem; struct vhost_umem *umem;
struct vhost_umem *iotlb;
spinlock_t iotlb_lock;
struct list_head read_list;
struct list_head pending_list;
wait_queue_head_t wait;
}; };
void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs); void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
...@@ -185,6 +200,21 @@ bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *); ...@@ -185,6 +200,21 @@ bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
unsigned int log_num, u64 len); unsigned int log_num, u64 len);
int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
void vhost_enqueue_msg(struct vhost_dev *dev,
struct list_head *head,
struct vhost_msg_node *node);
struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
struct list_head *head);
unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
poll_table *wait);
ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
int noblock);
ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
struct iov_iter *from);
int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled);
#define vq_err(vq, fmt, ...) do { \ #define vq_err(vq, fmt, ...) do { \
pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \ pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \
......
...@@ -47,6 +47,32 @@ struct vhost_vring_addr { ...@@ -47,6 +47,32 @@ struct vhost_vring_addr {
__u64 log_guest_addr; __u64 log_guest_addr;
}; };
/* no alignment requirement */
struct vhost_iotlb_msg {
__u64 iova;
__u64 size;
__u64 uaddr;
#define VHOST_ACCESS_RO 0x1
#define VHOST_ACCESS_WO 0x2
#define VHOST_ACCESS_RW 0x3
__u8 perm;
#define VHOST_IOTLB_MISS 1
#define VHOST_IOTLB_UPDATE 2
#define VHOST_IOTLB_INVALIDATE 3
#define VHOST_IOTLB_ACCESS_FAIL 4
__u8 type;
};
#define VHOST_IOTLB_MSG 0x1
struct vhost_msg {
int type;
union {
struct vhost_iotlb_msg iotlb;
__u8 padding[64];
};
};
struct vhost_memory_region { struct vhost_memory_region {
__u64 guest_phys_addr; __u64 guest_phys_addr;
__u64 memory_size; /* bytes */ __u64 memory_size; /* bytes */
...@@ -146,6 +172,8 @@ struct vhost_memory { ...@@ -146,6 +172,8 @@ struct vhost_memory {
#define VHOST_F_LOG_ALL 26 #define VHOST_F_LOG_ALL 26
/* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */ /* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */
#define VHOST_NET_F_VIRTIO_NET_HDR 27 #define VHOST_NET_F_VIRTIO_NET_HDR 27
/* Vhost have device IOTLB */
#define VHOST_F_DEVICE_IOTLB 63
/* VHOST_SCSI specific definitions */ /* VHOST_SCSI specific definitions */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment