Commit c3bea3d2 authored by Jason Gunthorpe's avatar Jason Gunthorpe Committed by Doug Ledford

RDMA/uverbs: Use the iterator for ib_uverbs_unmarshall_recv()

This has a very complicated memory layout, with two flex arrays. Use
the iterator API to make reading it clearer.
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
Signed-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarDoug Ledford <dledford@redhat.com>
parent 335708c7
...@@ -150,6 +150,17 @@ static int uverbs_request_next(struct uverbs_req_iter *iter, void *val, ...@@ -150,6 +150,17 @@ static int uverbs_request_next(struct uverbs_req_iter *iter, void *val,
return 0; return 0;
} }
static const void __user *uverbs_request_next_ptr(struct uverbs_req_iter *iter,
size_t len)
{
const void __user *res = iter->cur;
if (iter->cur + len > iter->end)
return ERR_PTR(-ENOSPC);
iter->cur += len;
return res;
}
static int uverbs_request_finish(struct uverbs_req_iter *iter) static int uverbs_request_finish(struct uverbs_req_iter *iter)
{ {
if (!ib_is_buffer_cleared(iter->cur, iter->end - iter->cur)) if (!ib_is_buffer_cleared(iter->cur, iter->end - iter->cur))
...@@ -2073,16 +2084,23 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs, ...@@ -2073,16 +2084,23 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs,
int is_ud; int is_ud;
int ret, ret2; int ret, ret2;
size_t next_size; size_t next_size;
const struct ib_sge __user *sgls;
const void __user *wqes;
struct uverbs_req_iter iter;
if (copy_from_user(&cmd, buf, sizeof cmd)) ret = uverbs_request_start(attrs, &iter, &cmd, sizeof(cmd));
return -EFAULT; if (ret)
return ret;
if (in_len < sizeof cmd + cmd.wqe_size * cmd.wr_count + wqes = uverbs_request_next_ptr(&iter, cmd.wqe_size * cmd.wr_count);
cmd.sge_count * sizeof (struct ib_uverbs_sge)) if (IS_ERR(wqes))
return -EINVAL; return PTR_ERR(wqes);
sgls = uverbs_request_next_ptr(
if (cmd.wqe_size < sizeof (struct ib_uverbs_send_wr)) &iter, cmd.sge_count * sizeof(struct ib_uverbs_sge));
return -EINVAL; if (IS_ERR(sgls))
return PTR_ERR(sgls);
ret = uverbs_request_finish(&iter);
if (ret)
return ret;
user_wr = kmalloc(cmd.wqe_size, GFP_KERNEL); user_wr = kmalloc(cmd.wqe_size, GFP_KERNEL);
if (!user_wr) if (!user_wr)
...@@ -2096,8 +2114,7 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs, ...@@ -2096,8 +2114,7 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs,
sg_ind = 0; sg_ind = 0;
last = NULL; last = NULL;
for (i = 0; i < cmd.wr_count; ++i) { for (i = 0; i < cmd.wr_count; ++i) {
if (copy_from_user(user_wr, if (copy_from_user(user_wr, wqes + i * cmd.wqe_size,
buf + sizeof cmd + i * cmd.wqe_size,
cmd.wqe_size)) { cmd.wqe_size)) {
ret = -EFAULT; ret = -EFAULT;
goto out_put; goto out_put;
...@@ -2205,11 +2222,9 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs, ...@@ -2205,11 +2222,9 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs,
if (next->num_sge) { if (next->num_sge) {
next->sg_list = (void *) next + next->sg_list = (void *) next +
ALIGN(next_size, sizeof(struct ib_sge)); ALIGN(next_size, sizeof(struct ib_sge));
if (copy_from_user(next->sg_list, if (copy_from_user(next->sg_list, sgls + sg_ind,
buf + sizeof cmd + next->num_sge *
cmd.wr_count * cmd.wqe_size + sizeof(struct ib_sge))) {
sg_ind * sizeof (struct ib_sge),
next->num_sge * sizeof (struct ib_sge))) {
ret = -EFAULT; ret = -EFAULT;
goto out_put; goto out_put;
} }
...@@ -2248,25 +2263,32 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs, ...@@ -2248,25 +2263,32 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs,
return ret; return ret;
} }
static struct ib_recv_wr *ib_uverbs_unmarshall_recv(const char __user *buf, static struct ib_recv_wr *
int in_len, ib_uverbs_unmarshall_recv(struct uverbs_req_iter *iter, u32 wr_count,
u32 wr_count, u32 wqe_size, u32 sge_count)
u32 sge_count,
u32 wqe_size)
{ {
struct ib_uverbs_recv_wr *user_wr; struct ib_uverbs_recv_wr *user_wr;
struct ib_recv_wr *wr = NULL, *last, *next; struct ib_recv_wr *wr = NULL, *last, *next;
int sg_ind; int sg_ind;
int i; int i;
int ret; int ret;
const struct ib_sge __user *sgls;
if (in_len < wqe_size * wr_count + const void __user *wqes;
sge_count * sizeof (struct ib_uverbs_sge))
return ERR_PTR(-EINVAL);
if (wqe_size < sizeof (struct ib_uverbs_recv_wr)) if (wqe_size < sizeof (struct ib_uverbs_recv_wr))
return ERR_PTR(-EINVAL); return ERR_PTR(-EINVAL);
wqes = uverbs_request_next_ptr(iter, wqe_size * wr_count);
if (IS_ERR(wqes))
return ERR_CAST(wqes);
sgls = uverbs_request_next_ptr(
iter, sge_count * sizeof(struct ib_uverbs_sge));
if (IS_ERR(sgls))
return ERR_CAST(sgls);
ret = uverbs_request_finish(iter);
if (ret)
return ERR_PTR(ret);
user_wr = kmalloc(wqe_size, GFP_KERNEL); user_wr = kmalloc(wqe_size, GFP_KERNEL);
if (!user_wr) if (!user_wr)
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
...@@ -2274,7 +2296,7 @@ static struct ib_recv_wr *ib_uverbs_unmarshall_recv(const char __user *buf, ...@@ -2274,7 +2296,7 @@ static struct ib_recv_wr *ib_uverbs_unmarshall_recv(const char __user *buf,
sg_ind = 0; sg_ind = 0;
last = NULL; last = NULL;
for (i = 0; i < wr_count; ++i) { for (i = 0; i < wr_count; ++i) {
if (copy_from_user(user_wr, buf + i * wqe_size, if (copy_from_user(user_wr, wqes + i * wqe_size,
wqe_size)) { wqe_size)) {
ret = -EFAULT; ret = -EFAULT;
goto err; goto err;
...@@ -2313,10 +2335,9 @@ static struct ib_recv_wr *ib_uverbs_unmarshall_recv(const char __user *buf, ...@@ -2313,10 +2335,9 @@ static struct ib_recv_wr *ib_uverbs_unmarshall_recv(const char __user *buf,
if (next->num_sge) { if (next->num_sge) {
next->sg_list = (void *) next + next->sg_list = (void *) next +
ALIGN(sizeof *next, sizeof (struct ib_sge)); ALIGN(sizeof *next, sizeof (struct ib_sge));
if (copy_from_user(next->sg_list, if (copy_from_user(next->sg_list, sgls + sg_ind,
buf + wr_count * wqe_size + next->num_sge *
sg_ind * sizeof (struct ib_sge), sizeof(struct ib_sge))) {
next->num_sge * sizeof (struct ib_sge))) {
ret = -EFAULT; ret = -EFAULT;
goto err; goto err;
} }
...@@ -2349,13 +2370,14 @@ static int ib_uverbs_post_recv(struct uverbs_attr_bundle *attrs, ...@@ -2349,13 +2370,14 @@ static int ib_uverbs_post_recv(struct uverbs_attr_bundle *attrs,
const struct ib_recv_wr *bad_wr; const struct ib_recv_wr *bad_wr;
struct ib_qp *qp; struct ib_qp *qp;
int ret, ret2; int ret, ret2;
struct uverbs_req_iter iter;
if (copy_from_user(&cmd, buf, sizeof cmd)) ret = uverbs_request_start(attrs, &iter, &cmd, sizeof(cmd));
return -EFAULT; if (ret)
return ret;
wr = ib_uverbs_unmarshall_recv(buf + sizeof cmd, wr = ib_uverbs_unmarshall_recv(&iter, cmd.wr_count, cmd.wqe_size,
in_len - sizeof cmd, cmd.wr_count, cmd.sge_count);
cmd.sge_count, cmd.wqe_size);
if (IS_ERR(wr)) if (IS_ERR(wr))
return PTR_ERR(wr); return PTR_ERR(wr);
...@@ -2400,13 +2422,14 @@ static int ib_uverbs_post_srq_recv(struct uverbs_attr_bundle *attrs, ...@@ -2400,13 +2422,14 @@ static int ib_uverbs_post_srq_recv(struct uverbs_attr_bundle *attrs,
const struct ib_recv_wr *bad_wr; const struct ib_recv_wr *bad_wr;
struct ib_srq *srq; struct ib_srq *srq;
int ret, ret2; int ret, ret2;
struct uverbs_req_iter iter;
if (copy_from_user(&cmd, buf, sizeof cmd)) ret = uverbs_request_start(attrs, &iter, &cmd, sizeof(cmd));
return -EFAULT; if (ret)
return ret;
wr = ib_uverbs_unmarshall_recv(buf + sizeof cmd, wr = ib_uverbs_unmarshall_recv(&iter, cmd.wr_count, cmd.wqe_size,
in_len - sizeof cmd, cmd.wr_count, cmd.sge_count);
cmd.sge_count, cmd.wqe_size);
if (IS_ERR(wr)) if (IS_ERR(wr))
return PTR_ERR(wr); return PTR_ERR(wr);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment