Commit 2bf338f2 authored by Ohad Sharabi's avatar Ohad Sharabi Committed by Oded Gabbay

habanalabs: make some MMU functions common

Some MMU functions can be used by different versions of our MMUs, so
move them to be common.
Signed-off-by: default avatarOhad Sharabi <osharabi@habana.ai>
Reviewed-by: default avatarOded Gabbay <ogabbay@kernel.org>
Signed-off-by: default avatarOded Gabbay <ogabbay@kernel.org>
parent d280d595
...@@ -402,8 +402,11 @@ enum hl_device_hw_state { ...@@ -402,8 +402,11 @@ enum hl_device_hw_state {
* @hop4_mask: mask to get the PTE address in hop 4. * @hop4_mask: mask to get the PTE address in hop 4.
* @hop5_mask: mask to get the PTE address in hop 5. * @hop5_mask: mask to get the PTE address in hop 5.
* @last_mask: mask to get the bit indicating this is the last hop. * @last_mask: mask to get the bit indicating this is the last hop.
* @pgt_size: size for page tables.
* @page_size: default page size used to allocate memory. * @page_size: default page size used to allocate memory.
* @num_hops: The amount of hops supported by the translation table. * @num_hops: The amount of hops supported by the translation table.
* @hop_table_size: HOP table size.
* @hop0_tables_total_size: total size for all HOP0 tables.
* @host_resident: Should the MMU page table reside in host memory or in the * @host_resident: Should the MMU page table reside in host memory or in the
* device DRAM. * device DRAM.
*/ */
...@@ -423,8 +426,11 @@ struct hl_mmu_properties { ...@@ -423,8 +426,11 @@ struct hl_mmu_properties {
u64 hop4_mask; u64 hop4_mask;
u64 hop5_mask; u64 hop5_mask;
u64 last_mask; u64 last_mask;
u64 pgt_size;
u32 page_size; u32 page_size;
u32 num_hops; u32 num_hops;
u32 hop_table_size;
u32 hop0_tables_total_size;
u8 host_resident; u8 host_resident;
}; };
...@@ -3015,6 +3021,9 @@ int hl_mmu_unmap_contiguous(struct hl_ctx *ctx, u64 virt_addr, u32 size); ...@@ -3015,6 +3021,9 @@ int hl_mmu_unmap_contiguous(struct hl_ctx *ctx, u64 virt_addr, u32 size);
int hl_mmu_invalidate_cache(struct hl_device *hdev, bool is_hard, u32 flags); int hl_mmu_invalidate_cache(struct hl_device *hdev, bool is_hard, u32 flags);
int hl_mmu_invalidate_cache_range(struct hl_device *hdev, bool is_hard, int hl_mmu_invalidate_cache_range(struct hl_device *hdev, bool is_hard,
u32 flags, u32 asid, u64 va, u64 size); u32 flags, u32 asid, u64 va, u64 size);
u64 hl_mmu_get_next_hop_addr(struct hl_ctx *ctx, u64 curr_pte);
u64 hl_mmu_get_hop_pte_phys_addr(struct hl_ctx *ctx, struct hl_mmu_properties *mmu_prop,
u8 hop_idx, u64 hop_addr, u64 virt_addr);
void hl_mmu_swap_out(struct hl_ctx *ctx); void hl_mmu_swap_out(struct hl_ctx *ctx);
void hl_mmu_swap_in(struct hl_ctx *ctx); void hl_mmu_swap_in(struct hl_ctx *ctx);
int hl_mmu_if_set_funcs(struct hl_device *hdev); int hl_mmu_if_set_funcs(struct hl_device *hdev);
......
...@@ -662,3 +662,58 @@ int hl_mmu_invalidate_cache_range(struct hl_device *hdev, bool is_hard, ...@@ -662,3 +662,58 @@ int hl_mmu_invalidate_cache_range(struct hl_device *hdev, bool is_hard,
return rc; return rc;
} }
u64 hl_mmu_get_next_hop_addr(struct hl_ctx *ctx, u64 curr_pte)
{
return (curr_pte & PAGE_PRESENT_MASK) ? (curr_pte & HOP_PHYS_ADDR_MASK) : ULLONG_MAX;
}
/**
* hl_mmu_get_hop_pte_phys_addr() - extract PTE address from HOP
* @ctx: pointer to the context structure to initialize.
* @hop_idx: HOP index.
* @hop_addr: HOP address.
* @virt_addr: virtual address fro the translation.
*
* @return the matching PTE value on success, otherwise U64_MAX.
*/
u64 hl_mmu_get_hop_pte_phys_addr(struct hl_ctx *ctx, struct hl_mmu_properties *mmu_prop,
u8 hop_idx, u64 hop_addr, u64 virt_addr)
{
u64 mask, shift;
if (hop_idx >= mmu_prop->num_hops) {
dev_err_ratelimited(ctx->hdev->dev, "Invalid hop index %d\n", hop_idx);
return U64_MAX;
}
/* currently max number of HOPs is 6 */
switch (hop_idx) {
case 0:
mask = mmu_prop->hop0_mask;
shift = mmu_prop->hop0_shift;
break;
case 1:
mask = mmu_prop->hop1_mask;
shift = mmu_prop->hop1_shift;
break;
case 2:
mask = mmu_prop->hop2_mask;
shift = mmu_prop->hop2_shift;
break;
case 3:
mask = mmu_prop->hop3_mask;
shift = mmu_prop->hop3_shift;
break;
case 4:
mask = mmu_prop->hop4_mask;
shift = mmu_prop->hop4_shift;
break;
default:
mask = mmu_prop->hop5_mask;
shift = mmu_prop->hop5_shift;
break;
}
return hop_addr + ctx->hdev->asic_prop.mmu_pte_size * ((virt_addr & mask) >> shift);
}
...@@ -217,18 +217,10 @@ static inline u64 get_hop4_pte_addr(struct hl_ctx *ctx, ...@@ -217,18 +217,10 @@ static inline u64 get_hop4_pte_addr(struct hl_ctx *ctx,
mmu_prop->hop4_shift); mmu_prop->hop4_shift);
} }
static inline u64 get_next_hop_addr(struct hl_ctx *ctx, u64 curr_pte)
{
if (curr_pte & PAGE_PRESENT_MASK)
return curr_pte & HOP_PHYS_ADDR_MASK;
else
return ULLONG_MAX;
}
static inline u64 get_alloc_next_hop_addr(struct hl_ctx *ctx, u64 curr_pte, static inline u64 get_alloc_next_hop_addr(struct hl_ctx *ctx, u64 curr_pte,
bool *is_new_hop) bool *is_new_hop)
{ {
u64 hop_addr = get_next_hop_addr(ctx, curr_pte); u64 hop_addr = hl_mmu_get_next_hop_addr(ctx, curr_pte);
if (hop_addr == ULLONG_MAX) { if (hop_addr == ULLONG_MAX) {
hop_addr = alloc_hop(ctx); hop_addr = alloc_hop(ctx);
...@@ -546,7 +538,7 @@ static int _hl_mmu_v1_unmap(struct hl_ctx *ctx, ...@@ -546,7 +538,7 @@ static int _hl_mmu_v1_unmap(struct hl_ctx *ctx,
curr_pte = *(u64 *) (uintptr_t) hop0_pte_addr; curr_pte = *(u64 *) (uintptr_t) hop0_pte_addr;
hop1_addr = get_next_hop_addr(ctx, curr_pte); hop1_addr = hl_mmu_get_next_hop_addr(ctx, curr_pte);
if (hop1_addr == ULLONG_MAX) if (hop1_addr == ULLONG_MAX)
goto not_mapped; goto not_mapped;
...@@ -555,7 +547,7 @@ static int _hl_mmu_v1_unmap(struct hl_ctx *ctx, ...@@ -555,7 +547,7 @@ static int _hl_mmu_v1_unmap(struct hl_ctx *ctx,
curr_pte = *(u64 *) (uintptr_t) hop1_pte_addr; curr_pte = *(u64 *) (uintptr_t) hop1_pte_addr;
hop2_addr = get_next_hop_addr(ctx, curr_pte); hop2_addr = hl_mmu_get_next_hop_addr(ctx, curr_pte);
if (hop2_addr == ULLONG_MAX) if (hop2_addr == ULLONG_MAX)
goto not_mapped; goto not_mapped;
...@@ -564,7 +556,7 @@ static int _hl_mmu_v1_unmap(struct hl_ctx *ctx, ...@@ -564,7 +556,7 @@ static int _hl_mmu_v1_unmap(struct hl_ctx *ctx,
curr_pte = *(u64 *) (uintptr_t) hop2_pte_addr; curr_pte = *(u64 *) (uintptr_t) hop2_pte_addr;
hop3_addr = get_next_hop_addr(ctx, curr_pte); hop3_addr = hl_mmu_get_next_hop_addr(ctx, curr_pte);
if (hop3_addr == ULLONG_MAX) if (hop3_addr == ULLONG_MAX)
goto not_mapped; goto not_mapped;
...@@ -582,7 +574,7 @@ static int _hl_mmu_v1_unmap(struct hl_ctx *ctx, ...@@ -582,7 +574,7 @@ static int _hl_mmu_v1_unmap(struct hl_ctx *ctx,
} }
if (!is_huge) { if (!is_huge) {
hop4_addr = get_next_hop_addr(ctx, curr_pte); hop4_addr = hl_mmu_get_next_hop_addr(ctx, curr_pte);
if (hop4_addr == ULLONG_MAX) if (hop4_addr == ULLONG_MAX)
goto not_mapped; goto not_mapped;
...@@ -845,27 +837,6 @@ static void hl_mmu_v1_swap_in(struct hl_ctx *ctx) ...@@ -845,27 +837,6 @@ static void hl_mmu_v1_swap_in(struct hl_ctx *ctx)
} }
static inline u64 get_hop_pte_addr(struct hl_ctx *ctx,
struct hl_mmu_properties *mmu_prop,
int hop_num, u64 hop_addr, u64 virt_addr)
{
switch (hop_num) {
case 0:
return get_hop0_pte_addr(ctx, mmu_prop, hop_addr, virt_addr);
case 1:
return get_hop1_pte_addr(ctx, mmu_prop, hop_addr, virt_addr);
case 2:
return get_hop2_pte_addr(ctx, mmu_prop, hop_addr, virt_addr);
case 3:
return get_hop3_pte_addr(ctx, mmu_prop, hop_addr, virt_addr);
case 4:
return get_hop4_pte_addr(ctx, mmu_prop, hop_addr, virt_addr);
default:
break;
}
return U64_MAX;
}
static int hl_mmu_v1_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr, static int hl_mmu_v1_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr,
struct hl_mmu_hop_info *hops) struct hl_mmu_hop_info *hops)
{ {
...@@ -906,7 +877,7 @@ static int hl_mmu_v1_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr, ...@@ -906,7 +877,7 @@ static int hl_mmu_v1_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr,
hops->hop_info[0].hop_addr = get_phys_hop0_addr(ctx); hops->hop_info[0].hop_addr = get_phys_hop0_addr(ctx);
hops->hop_info[0].hop_pte_addr = hops->hop_info[0].hop_pte_addr =
get_hop_pte_addr(ctx, mmu_prop, 0, hl_mmu_get_hop_pte_phys_addr(ctx, mmu_prop, 0,
hops->hop_info[0].hop_addr, virt_addr); hops->hop_info[0].hop_addr, virt_addr);
hops->hop_info[0].hop_pte_val = hops->hop_info[0].hop_pte_val =
hdev->asic_funcs->read_pte(hdev, hdev->asic_funcs->read_pte(hdev,
...@@ -914,13 +885,13 @@ static int hl_mmu_v1_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr, ...@@ -914,13 +885,13 @@ static int hl_mmu_v1_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr,
for (i = 1 ; i < used_hops ; i++) { for (i = 1 ; i < used_hops ; i++) {
hops->hop_info[i].hop_addr = hops->hop_info[i].hop_addr =
get_next_hop_addr(ctx, hl_mmu_get_next_hop_addr(ctx,
hops->hop_info[i - 1].hop_pte_val); hops->hop_info[i - 1].hop_pte_val);
if (hops->hop_info[i].hop_addr == ULLONG_MAX) if (hops->hop_info[i].hop_addr == ULLONG_MAX)
return -EFAULT; return -EFAULT;
hops->hop_info[i].hop_pte_addr = hops->hop_info[i].hop_pte_addr =
get_hop_pte_addr(ctx, mmu_prop, i, hl_mmu_get_hop_pte_phys_addr(ctx, mmu_prop, i,
hops->hop_info[i].hop_addr, hops->hop_info[i].hop_addr,
virt_addr); virt_addr);
hops->hop_info[i].hop_pte_val = hops->hop_info[i].hop_pte_val =
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment