Commit 89be8aa5 authored by Benjamin Tissoires's avatar Benjamin Tissoires

HID: bpf: actually free hdev memory after attaching a HID-BPF program

Turns out that I got my reference counts wrong and each successful
bus_find_device() actually calls get_device(), and we need to manually
call put_device().

Ensure each bus_find_device() gets a matching put_device() when releasing
the bpf programs and fix all the error paths.

Cc: <stable@vger.kernel.org>
Fixes: f5c27da4 ("HID: initial BPF implementation")
Link: https://lore.kernel.org/r/20240124-b4-hid-bpf-fixes-v2-2-052520b1e5e6@kernel.orgSigned-off-by: default avatarBenjamin Tissoires <bentiss@kernel.org>
parent 7cdd2108
...@@ -292,7 +292,7 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags) ...@@ -292,7 +292,7 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
struct hid_device *hdev; struct hid_device *hdev;
struct bpf_prog *prog; struct bpf_prog *prog;
struct device *dev; struct device *dev;
int fd; int err, fd;
if (!hid_bpf_ops) if (!hid_bpf_ops)
return -EINVAL; return -EINVAL;
...@@ -311,14 +311,24 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags) ...@@ -311,14 +311,24 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
* on errors or when it'll be detached * on errors or when it'll be detached
*/ */
prog = bpf_prog_get(prog_fd); prog = bpf_prog_get(prog_fd);
if (IS_ERR(prog)) if (IS_ERR(prog)) {
return PTR_ERR(prog); err = PTR_ERR(prog);
goto out_dev_put;
}
fd = do_hid_bpf_attach_prog(hdev, prog_fd, prog, flags); fd = do_hid_bpf_attach_prog(hdev, prog_fd, prog, flags);
if (fd < 0) if (fd < 0) {
bpf_prog_put(prog); err = fd;
goto out_prog_put;
}
return fd; return fd;
out_prog_put:
bpf_prog_put(prog);
out_dev_put:
put_device(dev);
return err;
} }
/** /**
...@@ -345,8 +355,10 @@ hid_bpf_allocate_context(unsigned int hid_id) ...@@ -345,8 +355,10 @@ hid_bpf_allocate_context(unsigned int hid_id)
hdev = to_hid_device(dev); hdev = to_hid_device(dev);
ctx_kern = kzalloc(sizeof(*ctx_kern), GFP_KERNEL); ctx_kern = kzalloc(sizeof(*ctx_kern), GFP_KERNEL);
if (!ctx_kern) if (!ctx_kern) {
put_device(dev);
return NULL; return NULL;
}
ctx_kern->ctx.hid = hdev; ctx_kern->ctx.hid = hdev;
...@@ -363,10 +375,15 @@ noinline void ...@@ -363,10 +375,15 @@ noinline void
hid_bpf_release_context(struct hid_bpf_ctx *ctx) hid_bpf_release_context(struct hid_bpf_ctx *ctx)
{ {
struct hid_bpf_ctx_kern *ctx_kern; struct hid_bpf_ctx_kern *ctx_kern;
struct hid_device *hid;
ctx_kern = container_of(ctx, struct hid_bpf_ctx_kern, ctx); ctx_kern = container_of(ctx, struct hid_bpf_ctx_kern, ctx);
hid = (struct hid_device *)ctx_kern->ctx.hid; /* ignore const */
kfree(ctx_kern); kfree(ctx_kern);
/* get_device() is called by bus_find_device() */
put_device(&hid->dev);
} }
/** /**
......
...@@ -196,6 +196,7 @@ static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx) ...@@ -196,6 +196,7 @@ static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
static void hid_bpf_release_progs(struct work_struct *work) static void hid_bpf_release_progs(struct work_struct *work)
{ {
int i, j, n, map_fd = -1; int i, j, n, map_fd = -1;
bool hdev_destroyed;
if (!jmp_table.map) if (!jmp_table.map)
return; return;
...@@ -220,6 +221,12 @@ static void hid_bpf_release_progs(struct work_struct *work) ...@@ -220,6 +221,12 @@ static void hid_bpf_release_progs(struct work_struct *work)
if (entry->hdev) { if (entry->hdev) {
hdev = entry->hdev; hdev = entry->hdev;
type = entry->type; type = entry->type;
/*
* hdev is still valid, even if we are called after hid_destroy_device():
* when hid_bpf_attach() gets called, it takes a ref on the dev through
* bus_find_device()
*/
hdev_destroyed = hdev->bpf.destroyed;
hid_bpf_populate_hdev(hdev, type); hid_bpf_populate_hdev(hdev, type);
...@@ -232,12 +239,19 @@ static void hid_bpf_release_progs(struct work_struct *work) ...@@ -232,12 +239,19 @@ static void hid_bpf_release_progs(struct work_struct *work)
if (test_bit(next->idx, jmp_table.enabled)) if (test_bit(next->idx, jmp_table.enabled))
continue; continue;
if (next->hdev == hdev && next->type == type) if (next->hdev == hdev && next->type == type) {
/*
* clear the hdev reference and decrement the device ref
* that was taken during bus_find_device() while calling
* hid_bpf_attach()
*/
next->hdev = NULL; next->hdev = NULL;
put_device(&hdev->dev);
}
} }
/* if type was rdesc fixup, reconnect device */ /* if type was rdesc fixup and the device is not gone, reconnect device */
if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP) if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP && !hdev_destroyed)
hid_bpf_reconnect(hdev); hid_bpf_reconnect(hdev);
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment