Commit 8dd014ad authored by David Stevens's avatar David Stevens Committed by Michael S. Tsirkin

vhost-net: mergeable buffers support

This adds support for mergeable buffers in vhost-net: this is needed
for older guests without indirect buffer support, as well
as for zero copy with some devices.

Includes changes by Michael S. Tsirkin to make the
patch as low risk as possible (i.e., close to no changes
when feature is disabled).
Signed-off-by: default avatarDavid Stevens <dlstevens@us.ibm.com>
Signed-off-by: default avatarMichael S. Tsirkin <mst@redhat.com>
parent 9e3d1957
...@@ -74,6 +74,22 @@ static int move_iovec_hdr(struct iovec *from, struct iovec *to, ...@@ -74,6 +74,22 @@ static int move_iovec_hdr(struct iovec *from, struct iovec *to,
} }
return seg; return seg;
} }
/* Copy iovec entries for len bytes from iovec. */
static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
size_t len, int iovcount)
{
int seg = 0;
size_t size;
while (len && seg < iovcount) {
size = min(from->iov_len, len);
to->iov_base = from->iov_base;
to->iov_len = size;
len -= size;
++from;
++to;
++seg;
}
}
/* Caller must have TX VQ lock */ /* Caller must have TX VQ lock */
static void tx_poll_stop(struct vhost_net *net) static void tx_poll_stop(struct vhost_net *net)
...@@ -129,7 +145,7 @@ static void handle_tx(struct vhost_net *net) ...@@ -129,7 +145,7 @@ static void handle_tx(struct vhost_net *net)
if (wmem < sock->sk->sk_sndbuf / 2) if (wmem < sock->sk->sk_sndbuf / 2)
tx_poll_stop(net); tx_poll_stop(net);
hdr_size = vq->hdr_size; hdr_size = vq->vhost_hlen;
for (;;) { for (;;) {
head = vhost_get_vq_desc(&net->dev, vq, vq->iov, head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
...@@ -172,7 +188,7 @@ static void handle_tx(struct vhost_net *net) ...@@ -172,7 +188,7 @@ static void handle_tx(struct vhost_net *net)
/* TODO: Check specific error and bomb out unless ENOBUFS? */ /* TODO: Check specific error and bomb out unless ENOBUFS? */
err = sock->ops->sendmsg(NULL, sock, &msg, len); err = sock->ops->sendmsg(NULL, sock, &msg, len);
if (unlikely(err < 0)) { if (unlikely(err < 0)) {
vhost_discard_vq_desc(vq); vhost_discard_vq_desc(vq, 1);
tx_poll_start(net, sock); tx_poll_start(net, sock);
break; break;
} }
...@@ -191,9 +207,82 @@ static void handle_tx(struct vhost_net *net) ...@@ -191,9 +207,82 @@ static void handle_tx(struct vhost_net *net)
unuse_mm(net->dev.mm); unuse_mm(net->dev.mm);
} }
static int peek_head_len(struct sock *sk)
{
struct sk_buff *head;
int len = 0;
lock_sock(sk);
head = skb_peek(&sk->sk_receive_queue);
if (head)
len = head->len;
release_sock(sk);
return len;
}
/* This is a multi-buffer version of vhost_get_desc, that works if
* vq has read descriptors only.
* @vq - the relevant virtqueue
* @datalen - data length we'll be reading
* @iovcount - returned count of io vectors we fill
* @log - vhost log
* @log_num - log offset
* returns number of buffer heads allocated, negative on error
*/
static int get_rx_bufs(struct vhost_virtqueue *vq,
struct vring_used_elem *heads,
int datalen,
unsigned *iovcount,
struct vhost_log *log,
unsigned *log_num)
{
unsigned int out, in;
int seg = 0;
int headcount = 0;
unsigned d;
int r, nlogs = 0;
while (datalen > 0) {
if (unlikely(headcount >= VHOST_NET_MAX_SG)) {
r = -ENOBUFS;
goto err;
}
d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
ARRAY_SIZE(vq->iov) - seg, &out,
&in, log, log_num);
if (d == vq->num) {
r = 0;
goto err;
}
if (unlikely(out || in <= 0)) {
vq_err(vq, "unexpected descriptor format for RX: "
"out %d, in %d\n", out, in);
r = -EINVAL;
goto err;
}
if (unlikely(log)) {
nlogs += *log_num;
log += *log_num;
}
heads[headcount].id = d;
heads[headcount].len = iov_length(vq->iov + seg, in);
datalen -= heads[headcount].len;
++headcount;
seg += in;
}
heads[headcount - 1].len += datalen;
*iovcount = seg;
if (unlikely(log))
*log_num = nlogs;
return headcount;
err:
vhost_discard_vq_desc(vq, headcount);
return r;
}
/* Expects to be always run from workqueue - which acts as /* Expects to be always run from workqueue - which acts as
* read-size critical section for our kind of RCU. */ * read-size critical section for our kind of RCU. */
static void handle_rx(struct vhost_net *net) static void handle_rx_big(struct vhost_net *net)
{ {
struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
unsigned out, in, log, s; unsigned out, in, log, s;
...@@ -223,7 +312,7 @@ static void handle_rx(struct vhost_net *net) ...@@ -223,7 +312,7 @@ static void handle_rx(struct vhost_net *net)
use_mm(net->dev.mm); use_mm(net->dev.mm);
mutex_lock(&vq->mutex); mutex_lock(&vq->mutex);
vhost_disable_notify(vq); vhost_disable_notify(vq);
hdr_size = vq->hdr_size; hdr_size = vq->vhost_hlen;
vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ? vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
vq->log : NULL; vq->log : NULL;
...@@ -270,14 +359,14 @@ static void handle_rx(struct vhost_net *net) ...@@ -270,14 +359,14 @@ static void handle_rx(struct vhost_net *net)
len, MSG_DONTWAIT | MSG_TRUNC); len, MSG_DONTWAIT | MSG_TRUNC);
/* TODO: Check specific error and bomb out unless EAGAIN? */ /* TODO: Check specific error and bomb out unless EAGAIN? */
if (err < 0) { if (err < 0) {
vhost_discard_vq_desc(vq); vhost_discard_vq_desc(vq, 1);
break; break;
} }
/* TODO: Should check and handle checksum. */ /* TODO: Should check and handle checksum. */
if (err > len) { if (err > len) {
pr_debug("Discarded truncated rx packet: " pr_debug("Discarded truncated rx packet: "
" len %d > %zd\n", err, len); " len %d > %zd\n", err, len);
vhost_discard_vq_desc(vq); vhost_discard_vq_desc(vq, 1);
continue; continue;
} }
len = err; len = err;
...@@ -302,6 +391,123 @@ static void handle_rx(struct vhost_net *net) ...@@ -302,6 +391,123 @@ static void handle_rx(struct vhost_net *net)
unuse_mm(net->dev.mm); unuse_mm(net->dev.mm);
} }
/* Expects to be always run from workqueue - which acts as
* read-size critical section for our kind of RCU. */
static void handle_rx_mergeable(struct vhost_net *net)
{
struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
unsigned uninitialized_var(in), log;
struct vhost_log *vq_log;
struct msghdr msg = {
.msg_name = NULL,
.msg_namelen = 0,
.msg_control = NULL, /* FIXME: get and handle RX aux data. */
.msg_controllen = 0,
.msg_iov = vq->iov,
.msg_flags = MSG_DONTWAIT,
};
struct virtio_net_hdr_mrg_rxbuf hdr = {
.hdr.flags = 0,
.hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE
};
size_t total_len = 0;
int err, headcount;
size_t vhost_hlen, sock_hlen;
size_t vhost_len, sock_len;
struct socket *sock = rcu_dereference(vq->private_data);
if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue))
return;
use_mm(net->dev.mm);
mutex_lock(&vq->mutex);
vhost_disable_notify(vq);
vhost_hlen = vq->vhost_hlen;
sock_hlen = vq->sock_hlen;
vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
vq->log : NULL;
while ((sock_len = peek_head_len(sock->sk))) {
sock_len += sock_hlen;
vhost_len = sock_len + vhost_hlen;
headcount = get_rx_bufs(vq, vq->heads, vhost_len,
&in, vq_log, &log);
/* On error, stop handling until the next kick. */
if (unlikely(headcount < 0))
break;
/* OK, now we need to know about added descriptors. */
if (!headcount) {
if (unlikely(vhost_enable_notify(vq))) {
/* They have slipped one in as we were
* doing that: check again. */
vhost_disable_notify(vq);
continue;
}
/* Nothing new? Wait for eventfd to tell us
* they refilled. */
break;
}
/* We don't need to be notified again. */
if (unlikely((vhost_hlen)))
/* Skip header. TODO: support TSO. */
move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in);
else
/* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF:
* needed because sendmsg can modify msg_iov. */
copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in);
msg.msg_iovlen = in;
err = sock->ops->recvmsg(NULL, sock, &msg,
sock_len, MSG_DONTWAIT | MSG_TRUNC);
/* Userspace might have consumed the packet meanwhile:
* it's not supposed to do this usually, but might be hard
* to prevent. Discard data we got (if any) and keep going. */
if (unlikely(err != sock_len)) {
pr_debug("Discarded rx packet: "
" len %d, expected %zd\n", err, sock_len);
vhost_discard_vq_desc(vq, headcount);
continue;
}
if (unlikely(vhost_hlen) &&
memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0,
vhost_hlen)) {
vq_err(vq, "Unable to write vnet_hdr at addr %p\n",
vq->iov->iov_base);
break;
}
/* TODO: Should check and handle checksum. */
if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF) &&
memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount,
offsetof(typeof(hdr), num_buffers),
sizeof hdr.num_buffers)) {
vq_err(vq, "Failed num_buffers write");
vhost_discard_vq_desc(vq, headcount);
break;
}
vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
headcount);
if (unlikely(vq_log))
vhost_log_write(vq, vq_log, log, vhost_len);
total_len += vhost_len;
if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
vhost_poll_queue(&vq->poll);
break;
}
}
mutex_unlock(&vq->mutex);
unuse_mm(net->dev.mm);
}
static void handle_rx(struct vhost_net *net)
{
if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF))
handle_rx_mergeable(net);
else
handle_rx_big(net);
}
static void handle_tx_kick(struct vhost_work *work) static void handle_tx_kick(struct vhost_work *work)
{ {
struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
...@@ -577,9 +783,21 @@ static long vhost_net_reset_owner(struct vhost_net *n) ...@@ -577,9 +783,21 @@ static long vhost_net_reset_owner(struct vhost_net *n)
static int vhost_net_set_features(struct vhost_net *n, u64 features) static int vhost_net_set_features(struct vhost_net *n, u64 features)
{ {
size_t hdr_size = features & (1 << VHOST_NET_F_VIRTIO_NET_HDR) ? size_t vhost_hlen, sock_hlen, hdr_len;
sizeof(struct virtio_net_hdr) : 0;
int i; int i;
hdr_len = (features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ?
sizeof(struct virtio_net_hdr_mrg_rxbuf) :
sizeof(struct virtio_net_hdr);
if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
/* vhost provides vnet_hdr */
vhost_hlen = hdr_len;
sock_hlen = 0;
} else {
/* socket provides vnet_hdr */
vhost_hlen = 0;
sock_hlen = hdr_len;
}
mutex_lock(&n->dev.mutex); mutex_lock(&n->dev.mutex);
if ((features & (1 << VHOST_F_LOG_ALL)) && if ((features & (1 << VHOST_F_LOG_ALL)) &&
!vhost_log_access_ok(&n->dev)) { !vhost_log_access_ok(&n->dev)) {
...@@ -590,7 +808,8 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) ...@@ -590,7 +808,8 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
smp_wmb(); smp_wmb();
for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
mutex_lock(&n->vqs[i].mutex); mutex_lock(&n->vqs[i].mutex);
n->vqs[i].hdr_size = hdr_size; n->vqs[i].vhost_hlen = vhost_hlen;
n->vqs[i].sock_hlen = sock_hlen;
mutex_unlock(&n->vqs[i].mutex); mutex_unlock(&n->vqs[i].mutex);
} }
vhost_net_flush(n); vhost_net_flush(n);
......
...@@ -149,7 +149,8 @@ static void vhost_vq_reset(struct vhost_dev *dev, ...@@ -149,7 +149,8 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->used_flags = 0; vq->used_flags = 0;
vq->log_used = false; vq->log_used = false;
vq->log_addr = -1ull; vq->log_addr = -1ull;
vq->hdr_size = 0; vq->vhost_hlen = 0;
vq->sock_hlen = 0;
vq->private_data = NULL; vq->private_data = NULL;
vq->log_base = NULL; vq->log_base = NULL;
vq->error_ctx = NULL; vq->error_ctx = NULL;
...@@ -1101,9 +1102,9 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, ...@@ -1101,9 +1102,9 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
} }
/* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
void vhost_discard_vq_desc(struct vhost_virtqueue *vq) void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
{ {
vq->last_avail_idx--; vq->last_avail_idx -= n;
} }
/* After we've used one of their buffers, we tell them about it. We'll then /* After we've used one of their buffers, we tell them about it. We'll then
...@@ -1148,6 +1149,67 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) ...@@ -1148,6 +1149,67 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
return 0; return 0;
} }
static int __vhost_add_used_n(struct vhost_virtqueue *vq,
struct vring_used_elem *heads,
unsigned count)
{
struct vring_used_elem __user *used;
int start;
start = vq->last_used_idx % vq->num;
used = vq->used->ring + start;
if (copy_to_user(used, heads, count * sizeof *used)) {
vq_err(vq, "Failed to write used");
return -EFAULT;
}
if (unlikely(vq->log_used)) {
/* Make sure data is seen before log. */
smp_wmb();
/* Log used ring entry write. */
log_write(vq->log_base,
vq->log_addr +
((void __user *)used - (void __user *)vq->used),
count * sizeof *used);
}
vq->last_used_idx += count;
return 0;
}
/* After we've used one of their buffers, we tell them about it. We'll then
* want to notify the guest, using eventfd. */
int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
unsigned count)
{
int start, n, r;
start = vq->last_used_idx % vq->num;
n = vq->num - start;
if (n < count) {
r = __vhost_add_used_n(vq, heads, n);
if (r < 0)
return r;
heads += n;
count -= n;
}
r = __vhost_add_used_n(vq, heads, count);
/* Make sure buffer is written before we update index. */
smp_wmb();
if (put_user(vq->last_used_idx, &vq->used->idx)) {
vq_err(vq, "Failed to increment used idx");
return -EFAULT;
}
if (unlikely(vq->log_used)) {
/* Log used index update. */
log_write(vq->log_base,
vq->log_addr + offsetof(struct vring_used, idx),
sizeof vq->used->idx);
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
return r;
}
/* This actually signals the guest, using eventfd. */ /* This actually signals the guest, using eventfd. */
void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
{ {
...@@ -1182,6 +1244,15 @@ void vhost_add_used_and_signal(struct vhost_dev *dev, ...@@ -1182,6 +1244,15 @@ void vhost_add_used_and_signal(struct vhost_dev *dev,
vhost_signal(dev, vq); vhost_signal(dev, vq);
} }
/* multi-buffer version of vhost_add_used_and_signal */
void vhost_add_used_and_signal_n(struct vhost_dev *dev,
struct vhost_virtqueue *vq,
struct vring_used_elem *heads, unsigned count)
{
vhost_add_used_n(vq, heads, count);
vhost_signal(dev, vq);
}
/* OK, now we need to know about added descriptors. */ /* OK, now we need to know about added descriptors. */
bool vhost_enable_notify(struct vhost_virtqueue *vq) bool vhost_enable_notify(struct vhost_virtqueue *vq)
{ {
...@@ -1206,7 +1277,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq) ...@@ -1206,7 +1277,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq)
return false; return false;
} }
return avail_idx != vq->last_avail_idx; return avail_idx != vq->avail_idx;
} }
/* We don't need to be notified again. */ /* We don't need to be notified again. */
......
...@@ -96,7 +96,9 @@ struct vhost_virtqueue { ...@@ -96,7 +96,9 @@ struct vhost_virtqueue {
struct iovec indirect[VHOST_NET_MAX_SG]; struct iovec indirect[VHOST_NET_MAX_SG];
struct iovec iov[VHOST_NET_MAX_SG]; struct iovec iov[VHOST_NET_MAX_SG];
struct iovec hdr[VHOST_NET_MAX_SG]; struct iovec hdr[VHOST_NET_MAX_SG];
size_t hdr_size; size_t vhost_hlen;
size_t sock_hlen;
struct vring_used_elem heads[VHOST_NET_MAX_SG];
/* We use a kind of RCU to access private pointer. /* We use a kind of RCU to access private pointer.
* All readers access it from worker, which makes it possible to * All readers access it from worker, which makes it possible to
* flush the vhost_work instead of synchronize_rcu. Therefore readers do * flush the vhost_work instead of synchronize_rcu. Therefore readers do
...@@ -139,12 +141,16 @@ int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, ...@@ -139,12 +141,16 @@ int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *,
struct iovec iov[], unsigned int iov_count, struct iovec iov[], unsigned int iov_count,
unsigned int *out_num, unsigned int *in_num, unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num); struct vhost_log *log, unsigned int *log_num);
void vhost_discard_vq_desc(struct vhost_virtqueue *); void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len);
void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *); int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads,
unsigned count);
void vhost_add_used_and_signal(struct vhost_dev *, struct vhost_virtqueue *, void vhost_add_used_and_signal(struct vhost_dev *, struct vhost_virtqueue *,
unsigned int head, int len); unsigned int id, int len);
void vhost_add_used_and_signal_n(struct vhost_dev *, struct vhost_virtqueue *,
struct vring_used_elem *heads, unsigned count);
void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *);
void vhost_disable_notify(struct vhost_virtqueue *); void vhost_disable_notify(struct vhost_virtqueue *);
bool vhost_enable_notify(struct vhost_virtqueue *); bool vhost_enable_notify(struct vhost_virtqueue *);
...@@ -161,7 +167,8 @@ enum { ...@@ -161,7 +167,8 @@ enum {
VHOST_FEATURES = (1 << VIRTIO_F_NOTIFY_ON_EMPTY) | VHOST_FEATURES = (1 << VIRTIO_F_NOTIFY_ON_EMPTY) |
(1 << VIRTIO_RING_F_INDIRECT_DESC) | (1 << VIRTIO_RING_F_INDIRECT_DESC) |
(1 << VHOST_F_LOG_ALL) | (1 << VHOST_F_LOG_ALL) |
(1 << VHOST_NET_F_VIRTIO_NET_HDR), (1 << VHOST_NET_F_VIRTIO_NET_HDR) |
(1 << VIRTIO_NET_F_MRG_RXBUF),
}; };
static inline int vhost_has_feature(struct vhost_dev *dev, int bit) static inline int vhost_has_feature(struct vhost_dev *dev, int bit)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment