Commit 1e240e8d authored by Christoph Hellwig's avatar Christoph Hellwig Committed by Jason Gunthorpe

memremap: move dev_pagemap callbacks into a separate structure

The dev_pagemap is a growing too many callbacks.  Move them into a
separate ops structure so that they are not duplicated for multiple
instances, and an attacker can't easily overwrite them.
Signed-off-by: default avatarChristoph Hellwig <hch@lst.de>
Reviewed-by: default avatarLogan Gunthorpe <logang@deltatee.com>
Reviewed-by: default avatarJason Gunthorpe <jgg@mellanox.com>
Reviewed-by: default avatarDan Williams <dan.j.williams@intel.com>
Tested-by: default avatarDan Williams <dan.j.williams@intel.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 3ed2dcdf
...@@ -36,9 +36,8 @@ static void dev_dax_percpu_exit(struct percpu_ref *ref) ...@@ -36,9 +36,8 @@ static void dev_dax_percpu_exit(struct percpu_ref *ref)
percpu_ref_exit(ref); percpu_ref_exit(ref);
} }
static void dev_dax_percpu_kill(struct percpu_ref *data) static void dev_dax_percpu_kill(struct percpu_ref *ref)
{ {
struct percpu_ref *ref = data;
struct dev_dax *dev_dax = ref_to_dev_dax(ref); struct dev_dax *dev_dax = ref_to_dev_dax(ref);
dev_dbg(&dev_dax->dev, "%s\n", __func__); dev_dbg(&dev_dax->dev, "%s\n", __func__);
...@@ -442,6 +441,11 @@ static void dev_dax_kill(void *dev_dax) ...@@ -442,6 +441,11 @@ static void dev_dax_kill(void *dev_dax)
kill_dev_dax(dev_dax); kill_dev_dax(dev_dax);
} }
static const struct dev_pagemap_ops dev_dax_pagemap_ops = {
.kill = dev_dax_percpu_kill,
.cleanup = dev_dax_percpu_exit,
};
int dev_dax_probe(struct device *dev) int dev_dax_probe(struct device *dev)
{ {
struct dev_dax *dev_dax = to_dev_dax(dev); struct dev_dax *dev_dax = to_dev_dax(dev);
...@@ -466,9 +470,8 @@ int dev_dax_probe(struct device *dev) ...@@ -466,9 +470,8 @@ int dev_dax_probe(struct device *dev)
return rc; return rc;
dev_dax->pgmap.ref = &dev_dax->ref; dev_dax->pgmap.ref = &dev_dax->ref;
dev_dax->pgmap.kill = dev_dax_percpu_kill;
dev_dax->pgmap.cleanup = dev_dax_percpu_exit;
dev_dax->pgmap.type = MEMORY_DEVICE_DEVDAX; dev_dax->pgmap.type = MEMORY_DEVICE_DEVDAX;
dev_dax->pgmap.ops = &dev_dax_pagemap_ops;
addr = devm_memremap_pages(dev, &dev_dax->pgmap); addr = devm_memremap_pages(dev, &dev_dax->pgmap);
if (IS_ERR(addr)) if (IS_ERR(addr))
return PTR_ERR(addr); return PTR_ERR(addr);
......
...@@ -16,7 +16,7 @@ struct dev_dax *__dax_pmem_probe(struct device *dev, enum dev_dax_subsys subsys) ...@@ -16,7 +16,7 @@ struct dev_dax *__dax_pmem_probe(struct device *dev, enum dev_dax_subsys subsys)
struct dev_dax *dev_dax; struct dev_dax *dev_dax;
struct nd_namespace_io *nsio; struct nd_namespace_io *nsio;
struct dax_region *dax_region; struct dax_region *dax_region;
struct dev_pagemap pgmap = { 0 }; struct dev_pagemap pgmap = { };
struct nd_namespace_common *ndns; struct nd_namespace_common *ndns;
struct nd_dax *nd_dax = to_nd_dax(dev); struct nd_dax *nd_dax = to_nd_dax(dev);
struct nd_pfn *nd_pfn = &nd_dax->nd_pfn; struct nd_pfn *nd_pfn = &nd_dax->nd_pfn;
......
...@@ -303,7 +303,7 @@ static const struct attribute_group *pmem_attribute_groups[] = { ...@@ -303,7 +303,7 @@ static const struct attribute_group *pmem_attribute_groups[] = {
NULL, NULL,
}; };
static void __pmem_release_queue(struct percpu_ref *ref) static void pmem_pagemap_cleanup(struct percpu_ref *ref)
{ {
struct request_queue *q; struct request_queue *q;
...@@ -313,10 +313,10 @@ static void __pmem_release_queue(struct percpu_ref *ref) ...@@ -313,10 +313,10 @@ static void __pmem_release_queue(struct percpu_ref *ref)
static void pmem_release_queue(void *ref) static void pmem_release_queue(void *ref)
{ {
__pmem_release_queue(ref); pmem_pagemap_cleanup(ref);
} }
static void pmem_freeze_queue(struct percpu_ref *ref) static void pmem_pagemap_kill(struct percpu_ref *ref)
{ {
struct request_queue *q; struct request_queue *q;
...@@ -339,19 +339,24 @@ static void pmem_release_pgmap_ops(void *__pgmap) ...@@ -339,19 +339,24 @@ static void pmem_release_pgmap_ops(void *__pgmap)
dev_pagemap_put_ops(); dev_pagemap_put_ops();
} }
static void fsdax_pagefree(struct page *page, void *data) static void pmem_pagemap_page_free(struct page *page, void *data)
{ {
wake_up_var(&page->_refcount); wake_up_var(&page->_refcount);
} }
static const struct dev_pagemap_ops fsdax_pagemap_ops = {
.page_free = pmem_pagemap_page_free,
.kill = pmem_pagemap_kill,
.cleanup = pmem_pagemap_cleanup,
};
static int setup_pagemap_fsdax(struct device *dev, struct dev_pagemap *pgmap) static int setup_pagemap_fsdax(struct device *dev, struct dev_pagemap *pgmap)
{ {
dev_pagemap_get_ops(); dev_pagemap_get_ops();
if (devm_add_action_or_reset(dev, pmem_release_pgmap_ops, pgmap)) if (devm_add_action_or_reset(dev, pmem_release_pgmap_ops, pgmap))
return -ENOMEM; return -ENOMEM;
pgmap->type = MEMORY_DEVICE_FS_DAX; pgmap->type = MEMORY_DEVICE_FS_DAX;
pgmap->page_free = fsdax_pagefree; pgmap->ops = &fsdax_pagemap_ops;
return 0; return 0;
} }
...@@ -409,8 +414,6 @@ static int pmem_attach_disk(struct device *dev, ...@@ -409,8 +414,6 @@ static int pmem_attach_disk(struct device *dev,
pmem->pfn_flags = PFN_DEV; pmem->pfn_flags = PFN_DEV;
pmem->pgmap.ref = &q->q_usage_counter; pmem->pgmap.ref = &q->q_usage_counter;
pmem->pgmap.kill = pmem_freeze_queue;
pmem->pgmap.cleanup = __pmem_release_queue;
if (is_nd_pfn(dev)) { if (is_nd_pfn(dev)) {
if (setup_pagemap_fsdax(dev, &pmem->pgmap)) if (setup_pagemap_fsdax(dev, &pmem->pgmap))
return -ENOMEM; return -ENOMEM;
......
...@@ -153,6 +153,11 @@ static int pci_p2pdma_setup(struct pci_dev *pdev) ...@@ -153,6 +153,11 @@ static int pci_p2pdma_setup(struct pci_dev *pdev)
return error; return error;
} }
static const struct dev_pagemap_ops pci_p2pdma_pagemap_ops = {
.kill = pci_p2pdma_percpu_kill,
.cleanup = pci_p2pdma_percpu_cleanup,
};
/** /**
* pci_p2pdma_add_resource - add memory for use as p2p memory * pci_p2pdma_add_resource - add memory for use as p2p memory
* @pdev: the device to add the memory to * @pdev: the device to add the memory to
...@@ -208,8 +213,7 @@ int pci_p2pdma_add_resource(struct pci_dev *pdev, int bar, size_t size, ...@@ -208,8 +213,7 @@ int pci_p2pdma_add_resource(struct pci_dev *pdev, int bar, size_t size,
pgmap->type = MEMORY_DEVICE_PCI_P2PDMA; pgmap->type = MEMORY_DEVICE_PCI_P2PDMA;
pgmap->pci_p2pdma_bus_offset = pci_bus_address(pdev, bar) - pgmap->pci_p2pdma_bus_offset = pci_bus_address(pdev, bar) -
pci_resource_start(pdev, bar); pci_resource_start(pdev, bar);
pgmap->kill = pci_p2pdma_percpu_kill; pgmap->ops = &pci_p2pdma_pagemap_ops;
pgmap->cleanup = pci_p2pdma_percpu_cleanup;
addr = devm_memremap_pages(&pdev->dev, pgmap); addr = devm_memremap_pages(&pdev->dev, pgmap);
if (IS_ERR(addr)) { if (IS_ERR(addr)) {
......
...@@ -63,41 +63,45 @@ enum memory_type { ...@@ -63,41 +63,45 @@ enum memory_type {
MEMORY_DEVICE_PCI_P2PDMA, MEMORY_DEVICE_PCI_P2PDMA,
}; };
/* struct dev_pagemap_ops {
* Additional notes about MEMORY_DEVICE_PRIVATE may be found in /*
* include/linux/hmm.h and Documentation/vm/hmm.rst. There is also a brief * Called once the page refcount reaches 1. (ZONE_DEVICE pages never
* explanation in include/linux/memory_hotplug.h. * reach 0 refcount unless there is a refcount bug. This allows the
* * device driver to implement its own memory management.)
* The page_free() callback is called once the page refcount reaches 1 */
* (ZONE_DEVICE pages never reach 0 refcount unless there is a refcount bug. void (*page_free)(struct page *page, void *data);
* This allows the device driver to implement its own memory management.)
*/ /*
typedef void (*dev_page_free_t)(struct page *page, void *data); * Transition the refcount in struct dev_pagemap to the dead state.
*/
void (*kill)(struct percpu_ref *ref);
/*
* Wait for refcount in struct dev_pagemap to be idle and reap it.
*/
void (*cleanup)(struct percpu_ref *ref);
};
/** /**
* struct dev_pagemap - metadata for ZONE_DEVICE mappings * struct dev_pagemap - metadata for ZONE_DEVICE mappings
* @page_free: free page callback when page refcount reaches 1
* @altmap: pre-allocated/reserved memory for vmemmap allocations * @altmap: pre-allocated/reserved memory for vmemmap allocations
* @res: physical address range covered by @ref * @res: physical address range covered by @ref
* @ref: reference count that pins the devm_memremap_pages() mapping * @ref: reference count that pins the devm_memremap_pages() mapping
* @kill: callback to transition @ref to the dead state
* @cleanup: callback to wait for @ref to be idle and reap it
* @dev: host device of the mapping for debug * @dev: host device of the mapping for debug
* @data: private data pointer for page_free() * @data: private data pointer for page_free()
* @type: memory type: see MEMORY_* in memory_hotplug.h * @type: memory type: see MEMORY_* in memory_hotplug.h
* @ops: method table
*/ */
struct dev_pagemap { struct dev_pagemap {
dev_page_free_t page_free;
struct vmem_altmap altmap; struct vmem_altmap altmap;
bool altmap_valid; bool altmap_valid;
struct resource res; struct resource res;
struct percpu_ref *ref; struct percpu_ref *ref;
void (*kill)(struct percpu_ref *ref);
void (*cleanup)(struct percpu_ref *ref);
struct device *dev; struct device *dev;
void *data; void *data;
enum memory_type type; enum memory_type type;
u64 pci_p2pdma_bus_offset; u64 pci_p2pdma_bus_offset;
const struct dev_pagemap_ops *ops;
}; };
#ifdef CONFIG_ZONE_DEVICE #ifdef CONFIG_ZONE_DEVICE
......
...@@ -92,10 +92,10 @@ static void devm_memremap_pages_release(void *data) ...@@ -92,10 +92,10 @@ static void devm_memremap_pages_release(void *data)
unsigned long pfn; unsigned long pfn;
int nid; int nid;
pgmap->kill(pgmap->ref); pgmap->ops->kill(pgmap->ref);
for_each_device_pfn(pfn, pgmap) for_each_device_pfn(pfn, pgmap)
put_page(pfn_to_page(pfn)); put_page(pfn_to_page(pfn));
pgmap->cleanup(pgmap->ref); pgmap->ops->cleanup(pgmap->ref);
/* pages are dead and unused, undo the arch mapping */ /* pages are dead and unused, undo the arch mapping */
align_start = res->start & ~(SECTION_SIZE - 1); align_start = res->start & ~(SECTION_SIZE - 1);
...@@ -128,8 +128,8 @@ static void devm_memremap_pages_release(void *data) ...@@ -128,8 +128,8 @@ static void devm_memremap_pages_release(void *data)
* @pgmap: pointer to a struct dev_pagemap * @pgmap: pointer to a struct dev_pagemap
* *
* Notes: * Notes:
* 1/ At a minimum the res, ref and type members of @pgmap must be initialized * 1/ At a minimum the res, ref and type and ops members of @pgmap must be
* by the caller before passing it to this function * initialized by the caller before passing it to this function
* *
* 2/ The altmap field may optionally be initialized, in which case altmap_valid * 2/ The altmap field may optionally be initialized, in which case altmap_valid
* must be set to true * must be set to true
...@@ -179,7 +179,8 @@ void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap) ...@@ -179,7 +179,8 @@ void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
break; break;
} }
if (!pgmap->ref || !pgmap->kill || !pgmap->cleanup) { if (!pgmap->ref || !pgmap->ops || !pgmap->ops->kill ||
!pgmap->ops->cleanup) {
WARN(1, "Missing reference count teardown definition\n"); WARN(1, "Missing reference count teardown definition\n");
return ERR_PTR(-EINVAL); return ERR_PTR(-EINVAL);
} }
...@@ -293,9 +294,8 @@ void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap) ...@@ -293,9 +294,8 @@ void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
err_pfn_remap: err_pfn_remap:
pgmap_array_delete(res); pgmap_array_delete(res);
err_array: err_array:
pgmap->kill(pgmap->ref); pgmap->ops->kill(pgmap->ref);
pgmap->cleanup(pgmap->ref); pgmap->ops->cleanup(pgmap->ref);
return ERR_PTR(error); return ERR_PTR(error);
} }
EXPORT_SYMBOL_GPL(devm_memremap_pages); EXPORT_SYMBOL_GPL(devm_memremap_pages);
...@@ -388,7 +388,7 @@ void __put_devmap_managed_page(struct page *page) ...@@ -388,7 +388,7 @@ void __put_devmap_managed_page(struct page *page)
mem_cgroup_uncharge(page); mem_cgroup_uncharge(page);
page->pgmap->page_free(page, page->pgmap->data); page->pgmap->ops->page_free(page, page->pgmap->data);
} else if (!count) } else if (!count)
__put_page(page); __put_page(page);
} }
......
...@@ -1384,6 +1384,12 @@ static void hmm_devmem_free(struct page *page, void *data) ...@@ -1384,6 +1384,12 @@ static void hmm_devmem_free(struct page *page, void *data)
devmem->ops->free(devmem, page); devmem->ops->free(devmem, page);
} }
static const struct dev_pagemap_ops hmm_pagemap_ops = {
.page_free = hmm_devmem_free,
.kill = hmm_devmem_ref_kill,
.cleanup = hmm_devmem_ref_exit,
};
/* /*
* hmm_devmem_add() - hotplug ZONE_DEVICE memory for device memory * hmm_devmem_add() - hotplug ZONE_DEVICE memory for device memory
* *
...@@ -1438,12 +1444,10 @@ struct hmm_devmem *hmm_devmem_add(const struct hmm_devmem_ops *ops, ...@@ -1438,12 +1444,10 @@ struct hmm_devmem *hmm_devmem_add(const struct hmm_devmem_ops *ops,
devmem->pagemap.type = MEMORY_DEVICE_PRIVATE; devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
devmem->pagemap.res = *devmem->resource; devmem->pagemap.res = *devmem->resource;
devmem->pagemap.page_free = hmm_devmem_free; devmem->pagemap.ops = &hmm_pagemap_ops;
devmem->pagemap.altmap_valid = false; devmem->pagemap.altmap_valid = false;
devmem->pagemap.ref = &devmem->ref; devmem->pagemap.ref = &devmem->ref;
devmem->pagemap.data = devmem; devmem->pagemap.data = devmem;
devmem->pagemap.kill = hmm_devmem_ref_kill;
devmem->pagemap.cleanup = hmm_devmem_ref_exit;
result = devm_memremap_pages(devmem->device, &devmem->pagemap); result = devm_memremap_pages(devmem->device, &devmem->pagemap);
if (IS_ERR(result)) if (IS_ERR(result))
......
...@@ -100,9 +100,10 @@ static void nfit_test_kill(void *_pgmap) ...@@ -100,9 +100,10 @@ static void nfit_test_kill(void *_pgmap)
{ {
struct dev_pagemap *pgmap = _pgmap; struct dev_pagemap *pgmap = _pgmap;
WARN_ON(!pgmap || !pgmap->ref || !pgmap->kill || !pgmap->cleanup); WARN_ON(!pgmap || !pgmap->ref || !pgmap->ops || !pgmap->ops->kill ||
pgmap->kill(pgmap->ref); !pgmap->ops->cleanup);
pgmap->cleanup(pgmap->ref); pgmap->ops->kill(pgmap->ref);
pgmap->ops->cleanup(pgmap->ref);
} }
void *__wrap_devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap) void *__wrap_devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment