diff options
Diffstat (limited to 'drivers/iommu/generic_pt')
23 files changed, 5810 insertions, 0 deletions
diff --git a/drivers/iommu/generic_pt/.kunitconfig b/drivers/iommu/generic_pt/.kunitconfig new file mode 100644 index 000000000000..52ac9e661ffd --- /dev/null +++ b/drivers/iommu/generic_pt/.kunitconfig @@ -0,0 +1,14 @@ +CONFIG_KUNIT=y +CONFIG_GENERIC_PT=y +CONFIG_DEBUG_GENERIC_PT=y +CONFIG_IOMMU_PT=y +CONFIG_IOMMU_PT_AMDV1=y +CONFIG_IOMMU_PT_VTDSS=y +CONFIG_IOMMU_PT_X86_64=y +CONFIG_IOMMU_PT_KUNIT_TEST=y + +CONFIG_IOMMUFD=y +CONFIG_DEBUG_KERNEL=y +CONFIG_FAULT_INJECTION=y +CONFIG_RUNTIME_TESTING_MENU=y +CONFIG_IOMMUFD_TEST=y diff --git a/drivers/iommu/generic_pt/Kconfig b/drivers/iommu/generic_pt/Kconfig new file mode 100644 index 000000000000..ce4fb4786914 --- /dev/null +++ b/drivers/iommu/generic_pt/Kconfig @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: GPL-2.0-only + +menuconfig GENERIC_PT + bool "Generic Radix Page Table" if COMPILE_TEST + help + Generic library for building radix tree page tables. + + Generic PT provides a set of HW page table formats and a common + set of APIs to work with them. + +if GENERIC_PT +config DEBUG_GENERIC_PT + bool "Extra debugging checks for GENERIC_PT" + help + Enable extra run time debugging checks for GENERIC_PT code. This + incurs a runtime cost and should not be enabled for production + kernels. + + The kunit tests require this to be enabled to get full coverage. + +config IOMMU_PT + tristate "IOMMU Page Tables" + select IOMMU_API + depends on IOMMU_SUPPORT + depends on GENERIC_PT + help + Generic library for building IOMMU page tables + + IOMMU_PT provides an implementation of the page table operations + related to struct iommu_domain using GENERIC_PT. It provides a single + implementation of the page table operations that can be shared by + multiple drivers. + +if IOMMU_PT +config IOMMU_PT_AMDV1 + tristate "IOMMU page table for 64-bit AMD IOMMU v1" + depends on !GENERIC_ATOMIC64 # for cmpxchg64 + help + iommu_domain implementation for the AMD v1 page table. AMDv1 is the + "host" page table. It supports granular page sizes of almost every + power of 2 and decodes the full 64-bit IOVA space. + + Selected automatically by an IOMMU driver that uses this format. + +config IOMMU_PT_VTDSS + tristate "IOMMU page table for Intel VT-d Second Stage" + depends on !GENERIC_ATOMIC64 # for cmpxchg64 + help + iommu_domain implementation for the Intel VT-d's 64 bit 3/4/5 + level Second Stage page table. It is similar to the X86_64 format with + 4K/2M/1G page sizes. + + Selected automatically by an IOMMU driver that uses this format. + +config IOMMU_PT_X86_64 + tristate "IOMMU page table for x86 64-bit, 4/5 levels" + depends on !GENERIC_ATOMIC64 # for cmpxchg64 + help + iommu_domain implementation for the x86 64-bit 4/5 level page table. + It supports 4K/2M/1G page sizes and can decode a sign-extended + portion of the 64-bit IOVA space. + + Selected automatically by an IOMMU driver that uses this format. + +config IOMMU_PT_KUNIT_TEST + tristate "IOMMU Page Table KUnit Test" if !KUNIT_ALL_TESTS + depends on KUNIT + depends on IOMMU_PT_AMDV1 || !IOMMU_PT_AMDV1 + depends on IOMMU_PT_X86_64 || !IOMMU_PT_X86_64 + depends on IOMMU_PT_VTDSS || !IOMMU_PT_VTDSS + default KUNIT_ALL_TESTS + help + Enable kunit tests for GENERIC_PT and IOMMU_PT that covers all the + enabled page table formats. The test covers most of the GENERIC_PT + functions provided by the page table format, as well as covering the + iommu_domain related functions. + +endif +endif diff --git a/drivers/iommu/generic_pt/fmt/Makefile b/drivers/iommu/generic_pt/fmt/Makefile new file mode 100644 index 000000000000..976b49ec97dc --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/Makefile @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: GPL-2.0 + +iommu_pt_fmt-$(CONFIG_IOMMU_PT_AMDV1) += amdv1 +iommu_pt_fmt-$(CONFIG_IOMMUFD_TEST) += mock + +iommu_pt_fmt-$(CONFIG_IOMMU_PT_VTDSS) += vtdss + +iommu_pt_fmt-$(CONFIG_IOMMU_PT_X86_64) += x86_64 + +IOMMU_PT_KUNIT_TEST := +define create_format +obj-$(2) += iommu_$(1).o +iommu_pt_kunit_test-y += kunit_iommu_$(1).o +CFLAGS_kunit_iommu_$(1).o += -DGENERIC_PT_KUNIT=1 +IOMMU_PT_KUNIT_TEST := iommu_pt_kunit_test.o + +endef + +$(eval $(foreach fmt,$(iommu_pt_fmt-y),$(call create_format,$(fmt),y))) +$(eval $(foreach fmt,$(iommu_pt_fmt-m),$(call create_format,$(fmt),m))) + +# The kunit objects are constructed by compiling the main source +# with -DGENERIC_PT_KUNIT +$(obj)/kunit_iommu_%.o: $(src)/iommu_%.c FORCE + $(call rule_mkdir) + $(call if_changed_dep,cc_o_c) + +obj-$(CONFIG_IOMMU_PT_KUNIT_TEST) += $(IOMMU_PT_KUNIT_TEST) diff --git a/drivers/iommu/generic_pt/fmt/amdv1.h b/drivers/iommu/generic_pt/fmt/amdv1.h new file mode 100644 index 000000000000..aa8e1a8ec95f --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/amdv1.h @@ -0,0 +1,411 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + * AMD IOMMU v1 page table + * + * This is described in Section "2.2.3 I/O Page Tables for Host Translations" + * of the "AMD I/O Virtualization Technology (IOMMU) Specification" + * + * Note the level numbering here matches the core code, so level 0 is the same + * as mode 1. + * + */ +#ifndef __GENERIC_PT_FMT_AMDV1_H +#define __GENERIC_PT_FMT_AMDV1_H + +#include "defs_amdv1.h" +#include "../pt_defs.h" + +#include <asm/page.h> +#include <linux/bitfield.h> +#include <linux/container_of.h> +#include <linux/mem_encrypt.h> +#include <linux/minmax.h> +#include <linux/sizes.h> +#include <linux/string.h> + +enum { + PT_ITEM_WORD_SIZE = sizeof(u64), + /* + * The IOMMUFD selftest uses the AMDv1 format with some alterations It + * uses a 2k page size to test cases where the CPU page size is not the + * same. + */ +#ifdef AMDV1_IOMMUFD_SELFTEST + PT_MAX_VA_ADDRESS_LG2 = 56, + PT_MAX_OUTPUT_ADDRESS_LG2 = 51, + PT_MAX_TOP_LEVEL = 4, + PT_GRANULE_LG2SZ = 11, +#else + PT_MAX_VA_ADDRESS_LG2 = 64, + PT_MAX_OUTPUT_ADDRESS_LG2 = 52, + PT_MAX_TOP_LEVEL = 5, + PT_GRANULE_LG2SZ = 12, +#endif + PT_TABLEMEM_LG2SZ = 12, + + /* The DTE only has these bits for the top phyiscal address */ + PT_TOP_PHYS_MASK = GENMASK_ULL(51, 12), +}; + +/* PTE bits */ +enum { + AMDV1PT_FMT_PR = BIT(0), + AMDV1PT_FMT_D = BIT(6), + AMDV1PT_FMT_NEXT_LEVEL = GENMASK_ULL(11, 9), + AMDV1PT_FMT_OA = GENMASK_ULL(51, 12), + AMDV1PT_FMT_FC = BIT_ULL(60), + AMDV1PT_FMT_IR = BIT_ULL(61), + AMDV1PT_FMT_IW = BIT_ULL(62), +}; + +/* + * gcc 13 has a bug where it thinks the output of FIELD_GET() is an enum, make + * these defines to avoid it. + */ +#define AMDV1PT_FMT_NL_DEFAULT 0 +#define AMDV1PT_FMT_NL_SIZE 7 + +static inline pt_oaddr_t amdv1pt_table_pa(const struct pt_state *pts) +{ + u64 entry = pts->entry; + + if (pts_feature(pts, PT_FEAT_AMDV1_ENCRYPT_TABLES)) + entry = __sme_clr(entry); + return oalog2_mul(FIELD_GET(AMDV1PT_FMT_OA, entry), PT_GRANULE_LG2SZ); +} +#define pt_table_pa amdv1pt_table_pa + +/* Returns the oa for the start of the contiguous entry */ +static inline pt_oaddr_t amdv1pt_entry_oa(const struct pt_state *pts) +{ + u64 entry = pts->entry; + pt_oaddr_t oa; + + if (pts_feature(pts, PT_FEAT_AMDV1_ENCRYPT_TABLES)) + entry = __sme_clr(entry); + oa = FIELD_GET(AMDV1PT_FMT_OA, entry); + + if (FIELD_GET(AMDV1PT_FMT_NEXT_LEVEL, entry) == AMDV1PT_FMT_NL_SIZE) { + unsigned int sz_bits = oaffz(oa); + + oa = oalog2_set_mod(oa, 0, sz_bits); + } else if (PT_WARN_ON(FIELD_GET(AMDV1PT_FMT_NEXT_LEVEL, entry) != + AMDV1PT_FMT_NL_DEFAULT)) + return 0; + return oalog2_mul(oa, PT_GRANULE_LG2SZ); +} +#define pt_entry_oa amdv1pt_entry_oa + +static inline bool amdv1pt_can_have_leaf(const struct pt_state *pts) +{ + /* + * Table 15: Page Table Level Parameters + * The top most level cannot have translation entries + */ + return pts->level < PT_MAX_TOP_LEVEL; +} +#define pt_can_have_leaf amdv1pt_can_have_leaf + +/* Body in pt_fmt_defaults.h */ +static inline unsigned int pt_table_item_lg2sz(const struct pt_state *pts); + +static inline unsigned int +amdv1pt_entry_num_contig_lg2(const struct pt_state *pts) +{ + u32 code; + + if (FIELD_GET(AMDV1PT_FMT_NEXT_LEVEL, pts->entry) == + AMDV1PT_FMT_NL_DEFAULT) + return ilog2(1); + + PT_WARN_ON(FIELD_GET(AMDV1PT_FMT_NEXT_LEVEL, pts->entry) != + AMDV1PT_FMT_NL_SIZE); + + /* + * The contiguous size is encoded in the length of a string of 1's in + * the low bits of the OA. Reverse the equation: + * code = log2_to_int(num_contig_lg2 + item_lg2sz - + * PT_GRANULE_LG2SZ - 1) - 1 + * Which can be expressed as: + * num_contig_lg2 = oalog2_ffz(code) + 1 - + * item_lg2sz - PT_GRANULE_LG2SZ + * + * Assume the bit layout is correct and remove the masking. Reorganize + * the equation to move all the arithmetic before the ffz. + */ + code = pts->entry >> (__bf_shf(AMDV1PT_FMT_OA) - 1 + + pt_table_item_lg2sz(pts) - PT_GRANULE_LG2SZ); + return ffz_t(u32, code); +} +#define pt_entry_num_contig_lg2 amdv1pt_entry_num_contig_lg2 + +static inline unsigned int amdv1pt_num_items_lg2(const struct pt_state *pts) +{ + /* + * Top entry covers bits [63:57] only, this is handled through + * max_vasz_lg2. + */ + if (PT_WARN_ON(pts->level == 5)) + return 7; + return PT_TABLEMEM_LG2SZ - ilog2(sizeof(u64)); +} +#define pt_num_items_lg2 amdv1pt_num_items_lg2 + +static inline pt_vaddr_t amdv1pt_possible_sizes(const struct pt_state *pts) +{ + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + + if (!amdv1pt_can_have_leaf(pts)) + return 0; + + /* + * Table 14: Example Page Size Encodings + * Address bits 51:32 can be used to encode page sizes greater than 4 + * Gbytes. Address bits 63:52 are zero-extended. + * + * 512GB Pages are not supported due to a hardware bug. + * Otherwise every power of two size is supported. + */ + return GENMASK_ULL(min(51, isz_lg2 + amdv1pt_num_items_lg2(pts) - 1), + isz_lg2) & ~SZ_512G; +} +#define pt_possible_sizes amdv1pt_possible_sizes + +static inline enum pt_entry_type amdv1pt_load_entry_raw(struct pt_state *pts) +{ + const u64 *tablep = pt_cur_table(pts, u64) + pts->index; + unsigned int next_level; + u64 entry; + + pts->entry = entry = READ_ONCE(*tablep); + if (!(entry & AMDV1PT_FMT_PR)) + return PT_ENTRY_EMPTY; + + next_level = FIELD_GET(AMDV1PT_FMT_NEXT_LEVEL, pts->entry); + if (pts->level == 0 || next_level == AMDV1PT_FMT_NL_DEFAULT || + next_level == AMDV1PT_FMT_NL_SIZE) + return PT_ENTRY_OA; + return PT_ENTRY_TABLE; +} +#define pt_load_entry_raw amdv1pt_load_entry_raw + +static inline void +amdv1pt_install_leaf_entry(struct pt_state *pts, pt_oaddr_t oa, + unsigned int oasz_lg2, + const struct pt_write_attrs *attrs) +{ + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + u64 *tablep = pt_cur_table(pts, u64) + pts->index; + u64 entry; + + if (!pt_check_install_leaf_args(pts, oa, oasz_lg2)) + return; + + entry = AMDV1PT_FMT_PR | + FIELD_PREP(AMDV1PT_FMT_OA, log2_div(oa, PT_GRANULE_LG2SZ)) | + attrs->descriptor_bits; + + if (oasz_lg2 == isz_lg2) { + entry |= FIELD_PREP(AMDV1PT_FMT_NEXT_LEVEL, + AMDV1PT_FMT_NL_DEFAULT); + WRITE_ONCE(*tablep, entry); + } else { + unsigned int num_contig_lg2 = oasz_lg2 - isz_lg2; + u64 *end = tablep + log2_to_int(num_contig_lg2); + + entry |= FIELD_PREP(AMDV1PT_FMT_NEXT_LEVEL, + AMDV1PT_FMT_NL_SIZE) | + FIELD_PREP(AMDV1PT_FMT_OA, + oalog2_to_int(oasz_lg2 - PT_GRANULE_LG2SZ - + 1) - + 1); + + /* See amdv1pt_clear_entries() */ + if (num_contig_lg2 <= ilog2(32)) { + for (; tablep != end; tablep++) + WRITE_ONCE(*tablep, entry); + } else { + memset64(tablep, entry, log2_to_int(num_contig_lg2)); + } + } + pts->entry = entry; +} +#define pt_install_leaf_entry amdv1pt_install_leaf_entry + +static inline bool amdv1pt_install_table(struct pt_state *pts, + pt_oaddr_t table_pa, + const struct pt_write_attrs *attrs) +{ + u64 entry; + + /* + * IR and IW are ANDed from the table levels along with the PTE. We + * always control permissions from the PTE, so always set IR and IW for + * tables. + */ + entry = AMDV1PT_FMT_PR | + FIELD_PREP(AMDV1PT_FMT_NEXT_LEVEL, pts->level) | + FIELD_PREP(AMDV1PT_FMT_OA, + log2_div(table_pa, PT_GRANULE_LG2SZ)) | + AMDV1PT_FMT_IR | AMDV1PT_FMT_IW; + if (pts_feature(pts, PT_FEAT_AMDV1_ENCRYPT_TABLES)) + entry = __sme_set(entry); + return pt_table_install64(pts, entry); +} +#define pt_install_table amdv1pt_install_table + +static inline void amdv1pt_attr_from_entry(const struct pt_state *pts, + struct pt_write_attrs *attrs) +{ + attrs->descriptor_bits = + pts->entry & (AMDV1PT_FMT_FC | AMDV1PT_FMT_IR | AMDV1PT_FMT_IW); +} +#define pt_attr_from_entry amdv1pt_attr_from_entry + +static inline void amdv1pt_clear_entries(struct pt_state *pts, + unsigned int num_contig_lg2) +{ + u64 *tablep = pt_cur_table(pts, u64) + pts->index; + u64 *end = tablep + log2_to_int(num_contig_lg2); + + /* + * gcc generates rep stos for the io-pgtable code, and this difference + * can show in microbenchmarks with larger contiguous page sizes. + * rep is slower for small cases. + */ + if (num_contig_lg2 <= ilog2(32)) { + for (; tablep != end; tablep++) + WRITE_ONCE(*tablep, 0); + } else { + memset64(tablep, 0, log2_to_int(num_contig_lg2)); + } +} +#define pt_clear_entries amdv1pt_clear_entries + +static inline bool amdv1pt_entry_is_write_dirty(const struct pt_state *pts) +{ + unsigned int num_contig_lg2 = amdv1pt_entry_num_contig_lg2(pts); + u64 *tablep = pt_cur_table(pts, u64) + + log2_set_mod(pts->index, 0, num_contig_lg2); + u64 *end = tablep + log2_to_int(num_contig_lg2); + + for (; tablep != end; tablep++) + if (READ_ONCE(*tablep) & AMDV1PT_FMT_D) + return true; + return false; +} +#define pt_entry_is_write_dirty amdv1pt_entry_is_write_dirty + +static inline void amdv1pt_entry_make_write_clean(struct pt_state *pts) +{ + unsigned int num_contig_lg2 = amdv1pt_entry_num_contig_lg2(pts); + u64 *tablep = pt_cur_table(pts, u64) + + log2_set_mod(pts->index, 0, num_contig_lg2); + u64 *end = tablep + log2_to_int(num_contig_lg2); + + for (; tablep != end; tablep++) + WRITE_ONCE(*tablep, READ_ONCE(*tablep) & ~(u64)AMDV1PT_FMT_D); +} +#define pt_entry_make_write_clean amdv1pt_entry_make_write_clean + +static inline bool amdv1pt_entry_make_write_dirty(struct pt_state *pts) +{ + u64 *tablep = pt_cur_table(pts, u64) + pts->index; + u64 new = pts->entry | AMDV1PT_FMT_D; + + return try_cmpxchg64(tablep, &pts->entry, new); +} +#define pt_entry_make_write_dirty amdv1pt_entry_make_write_dirty + +/* --- iommu */ +#include <linux/generic_pt/iommu.h> +#include <linux/iommu.h> + +#define pt_iommu_table pt_iommu_amdv1 + +/* The common struct is in the per-format common struct */ +static inline struct pt_common *common_from_iommu(struct pt_iommu *iommu_table) +{ + return &container_of(iommu_table, struct pt_iommu_amdv1, iommu) + ->amdpt.common; +} + +static inline struct pt_iommu *iommu_from_common(struct pt_common *common) +{ + return &container_of(common, struct pt_iommu_amdv1, amdpt.common)->iommu; +} + +static inline int amdv1pt_iommu_set_prot(struct pt_common *common, + struct pt_write_attrs *attrs, + unsigned int iommu_prot) +{ + u64 pte = 0; + + if (pt_feature(common, PT_FEAT_AMDV1_FORCE_COHERENCE)) + pte |= AMDV1PT_FMT_FC; + if (iommu_prot & IOMMU_READ) + pte |= AMDV1PT_FMT_IR; + if (iommu_prot & IOMMU_WRITE) + pte |= AMDV1PT_FMT_IW; + + /* + * Ideally we'd have an IOMMU_ENCRYPTED flag set by higher levels to + * control this. For now if the tables use sme_set then so do the ptes. + */ + if (pt_feature(common, PT_FEAT_AMDV1_ENCRYPT_TABLES)) + pte = __sme_set(pte); + + attrs->descriptor_bits = pte; + return 0; +} +#define pt_iommu_set_prot amdv1pt_iommu_set_prot + +static inline int amdv1pt_iommu_fmt_init(struct pt_iommu_amdv1 *iommu_table, + const struct pt_iommu_amdv1_cfg *cfg) +{ + struct pt_amdv1 *table = &iommu_table->amdpt; + unsigned int max_vasz_lg2 = PT_MAX_VA_ADDRESS_LG2; + + if (cfg->starting_level == 0 || cfg->starting_level > PT_MAX_TOP_LEVEL) + return -EINVAL; + + if (!pt_feature(&table->common, PT_FEAT_DYNAMIC_TOP) && + cfg->starting_level != PT_MAX_TOP_LEVEL) + max_vasz_lg2 = PT_GRANULE_LG2SZ + + (PT_TABLEMEM_LG2SZ - ilog2(sizeof(u64))) * + (cfg->starting_level + 1); + + table->common.max_vasz_lg2 = + min(max_vasz_lg2, cfg->common.hw_max_vasz_lg2); + table->common.max_oasz_lg2 = + min(PT_MAX_OUTPUT_ADDRESS_LG2, cfg->common.hw_max_oasz_lg2); + pt_top_set_level(&table->common, cfg->starting_level); + return 0; +} +#define pt_iommu_fmt_init amdv1pt_iommu_fmt_init + +#ifndef PT_FMT_VARIANT +static inline void +amdv1pt_iommu_fmt_hw_info(struct pt_iommu_amdv1 *table, + const struct pt_range *top_range, + struct pt_iommu_amdv1_hw_info *info) +{ + info->host_pt_root = virt_to_phys(top_range->top_table); + PT_WARN_ON(info->host_pt_root & ~PT_TOP_PHYS_MASK); + info->mode = top_range->top_level + 1; +} +#define pt_iommu_fmt_hw_info amdv1pt_iommu_fmt_hw_info +#endif + +#if defined(GENERIC_PT_KUNIT) +static const struct pt_iommu_amdv1_cfg amdv1_kunit_fmt_cfgs[] = { + /* Matches what io_pgtable does */ + [0] = { .starting_level = 2 }, +}; +#define kunit_fmt_cfgs amdv1_kunit_fmt_cfgs +enum { KUNIT_FMT_FEATURES = 0 }; +#endif + +#endif diff --git a/drivers/iommu/generic_pt/fmt/defs_amdv1.h b/drivers/iommu/generic_pt/fmt/defs_amdv1.h new file mode 100644 index 000000000000..0b9614ca6d10 --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/defs_amdv1.h @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + */ +#ifndef __GENERIC_PT_FMT_DEFS_AMDV1_H +#define __GENERIC_PT_FMT_DEFS_AMDV1_H + +#include <linux/generic_pt/common.h> +#include <linux/types.h> + +typedef u64 pt_vaddr_t; +typedef u64 pt_oaddr_t; + +struct amdv1pt_write_attrs { + u64 descriptor_bits; + gfp_t gfp; +}; +#define pt_write_attrs amdv1pt_write_attrs + +#endif diff --git a/drivers/iommu/generic_pt/fmt/defs_vtdss.h b/drivers/iommu/generic_pt/fmt/defs_vtdss.h new file mode 100644 index 000000000000..4a239bcaae2a --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/defs_vtdss.h @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES + * + */ +#ifndef __GENERIC_PT_FMT_DEFS_VTDSS_H +#define __GENERIC_PT_FMT_DEFS_VTDSS_H + +#include <linux/generic_pt/common.h> +#include <linux/types.h> + +typedef u64 pt_vaddr_t; +typedef u64 pt_oaddr_t; + +struct vtdss_pt_write_attrs { + u64 descriptor_bits; + gfp_t gfp; +}; +#define pt_write_attrs vtdss_pt_write_attrs + +#endif diff --git a/drivers/iommu/generic_pt/fmt/defs_x86_64.h b/drivers/iommu/generic_pt/fmt/defs_x86_64.h new file mode 100644 index 000000000000..6f589e1f55d3 --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/defs_x86_64.h @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + */ +#ifndef __GENERIC_PT_FMT_DEFS_X86_64_H +#define __GENERIC_PT_FMT_DEFS_X86_64_H + +#include <linux/generic_pt/common.h> +#include <linux/types.h> + +typedef u64 pt_vaddr_t; +typedef u64 pt_oaddr_t; + +struct x86_64_pt_write_attrs { + u64 descriptor_bits; + gfp_t gfp; +}; +#define pt_write_attrs x86_64_pt_write_attrs + +#endif diff --git a/drivers/iommu/generic_pt/fmt/iommu_amdv1.c b/drivers/iommu/generic_pt/fmt/iommu_amdv1.c new file mode 100644 index 000000000000..72a2337d0c55 --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/iommu_amdv1.c @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + */ +#define PT_FMT amdv1 +#define PT_SUPPORTED_FEATURES \ + (BIT(PT_FEAT_FULL_VA) | BIT(PT_FEAT_DYNAMIC_TOP) | \ + BIT(PT_FEAT_FLUSH_RANGE) | BIT(PT_FEAT_FLUSH_RANGE_NO_GAPS) | \ + BIT(PT_FEAT_AMDV1_ENCRYPT_TABLES) | \ + BIT(PT_FEAT_AMDV1_FORCE_COHERENCE)) +#define PT_FORCE_ENABLED_FEATURES \ + (BIT(PT_FEAT_DYNAMIC_TOP) | BIT(PT_FEAT_AMDV1_ENCRYPT_TABLES) | \ + BIT(PT_FEAT_AMDV1_FORCE_COHERENCE)) + +#include "iommu_template.h" diff --git a/drivers/iommu/generic_pt/fmt/iommu_mock.c b/drivers/iommu/generic_pt/fmt/iommu_mock.c new file mode 100644 index 000000000000..74e597cba9d9 --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/iommu_mock.c @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + */ +#define AMDV1_IOMMUFD_SELFTEST 1 +#define PT_FMT amdv1 +#define PT_FMT_VARIANT mock +#define PT_SUPPORTED_FEATURES 0 + +#include "iommu_template.h" diff --git a/drivers/iommu/generic_pt/fmt/iommu_template.h b/drivers/iommu/generic_pt/fmt/iommu_template.h new file mode 100644 index 000000000000..d28e86abdf2e --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/iommu_template.h @@ -0,0 +1,48 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + * Template to build the iommu module and kunit from the format and + * implementation headers. + * + * The format should have: + * #define PT_FMT <name> + * #define PT_SUPPORTED_FEATURES (BIT(PT_FEAT_xx) | BIT(PT_FEAT_yy)) + * And optionally: + * #define PT_FORCE_ENABLED_FEATURES .. + * #define PT_FMT_VARIANT <suffix> + */ +#include <linux/args.h> +#include <linux/stringify.h> + +#ifdef PT_FMT_VARIANT +#define PTPFX_RAW \ + CONCATENATE(CONCATENATE(PT_FMT, _), PT_FMT_VARIANT) +#else +#define PTPFX_RAW PT_FMT +#endif + +#define PTPFX CONCATENATE(PTPFX_RAW, _) + +#define _PT_FMT_H PT_FMT.h +#define PT_FMT_H __stringify(_PT_FMT_H) + +#define _PT_DEFS_H CONCATENATE(defs_, _PT_FMT_H) +#define PT_DEFS_H __stringify(_PT_DEFS_H) + +#include <linux/generic_pt/common.h> +#include PT_DEFS_H +#include "../pt_defs.h" +#include PT_FMT_H +#include "../pt_common.h" + +#ifndef GENERIC_PT_KUNIT +#include "../iommu_pt.h" +#else +/* + * The makefile will compile the .c file twice, once with GENERIC_PT_KUNIT set + * which means we are building the kunit modle. + */ +#include "../kunit_generic_pt.h" +#include "../kunit_iommu_pt.h" +#endif diff --git a/drivers/iommu/generic_pt/fmt/iommu_vtdss.c b/drivers/iommu/generic_pt/fmt/iommu_vtdss.c new file mode 100644 index 000000000000..f551711e2a33 --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/iommu_vtdss.c @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES + */ +#define PT_FMT vtdss +#define PT_SUPPORTED_FEATURES \ + (BIT(PT_FEAT_FLUSH_RANGE) | BIT(PT_FEAT_VTDSS_FORCE_COHERENCE) | \ + BIT(PT_FEAT_VTDSS_FORCE_WRITEABLE) | BIT(PT_FEAT_DMA_INCOHERENT)) + +#include "iommu_template.h" diff --git a/drivers/iommu/generic_pt/fmt/iommu_x86_64.c b/drivers/iommu/generic_pt/fmt/iommu_x86_64.c new file mode 100644 index 000000000000..5472660c2d71 --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/iommu_x86_64.c @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + */ +#define PT_FMT x86_64 +#define PT_SUPPORTED_FEATURES \ + (BIT(PT_FEAT_SIGN_EXTEND) | BIT(PT_FEAT_FLUSH_RANGE) | \ + BIT(PT_FEAT_FLUSH_RANGE_NO_GAPS) | \ + BIT(PT_FEAT_X86_64_AMD_ENCRYPT_TABLES) | BIT(PT_FEAT_DMA_INCOHERENT)) + +#include "iommu_template.h" diff --git a/drivers/iommu/generic_pt/fmt/vtdss.h b/drivers/iommu/generic_pt/fmt/vtdss.h new file mode 100644 index 000000000000..f5f8981edde7 --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/vtdss.h @@ -0,0 +1,285 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES + * + * Intel VT-d Second Stange 5/4 level page table + * + * This is described in + * Section "3.7 Second-Stage Translation" + * Section "9.8 Second-Stage Paging Entries" + * + * Of the "Intel Virtualization Technology for Directed I/O Architecture + * Specification". + * + * The named levels in the spec map to the pts->level as: + * Table/SS-PTE - 0 + * Directory/SS-PDE - 1 + * Directory Ptr/SS-PDPTE - 2 + * PML4/SS-PML4E - 3 + * PML5/SS-PML5E - 4 + */ +#ifndef __GENERIC_PT_FMT_VTDSS_H +#define __GENERIC_PT_FMT_VTDSS_H + +#include "defs_vtdss.h" +#include "../pt_defs.h" + +#include <linux/bitfield.h> +#include <linux/container_of.h> +#include <linux/log2.h> + +enum { + PT_MAX_OUTPUT_ADDRESS_LG2 = 52, + PT_MAX_VA_ADDRESS_LG2 = 57, + PT_ITEM_WORD_SIZE = sizeof(u64), + PT_MAX_TOP_LEVEL = 4, + PT_GRANULE_LG2SZ = 12, + PT_TABLEMEM_LG2SZ = 12, + + /* SSPTPTR is 4k aligned and limited by HAW */ + PT_TOP_PHYS_MASK = GENMASK_ULL(63, 12), +}; + +/* Shared descriptor bits */ +enum { + VTDSS_FMT_R = BIT(0), + VTDSS_FMT_W = BIT(1), + VTDSS_FMT_A = BIT(8), + VTDSS_FMT_D = BIT(9), + VTDSS_FMT_SNP = BIT(11), + VTDSS_FMT_OA = GENMASK_ULL(51, 12), +}; + +/* PDPTE/PDE */ +enum { + VTDSS_FMT_PS = BIT(7), +}; + +#define common_to_vtdss_pt(common_ptr) \ + container_of_const(common_ptr, struct pt_vtdss, common) +#define to_vtdss_pt(pts) common_to_vtdss_pt((pts)->range->common) + +static inline pt_oaddr_t vtdss_pt_table_pa(const struct pt_state *pts) +{ + return oalog2_mul(FIELD_GET(VTDSS_FMT_OA, pts->entry), + PT_TABLEMEM_LG2SZ); +} +#define pt_table_pa vtdss_pt_table_pa + +static inline pt_oaddr_t vtdss_pt_entry_oa(const struct pt_state *pts) +{ + return oalog2_mul(FIELD_GET(VTDSS_FMT_OA, pts->entry), + PT_GRANULE_LG2SZ); +} +#define pt_entry_oa vtdss_pt_entry_oa + +static inline bool vtdss_pt_can_have_leaf(const struct pt_state *pts) +{ + return pts->level <= 2; +} +#define pt_can_have_leaf vtdss_pt_can_have_leaf + +static inline unsigned int vtdss_pt_num_items_lg2(const struct pt_state *pts) +{ + return PT_TABLEMEM_LG2SZ - ilog2(sizeof(u64)); +} +#define pt_num_items_lg2 vtdss_pt_num_items_lg2 + +static inline enum pt_entry_type vtdss_pt_load_entry_raw(struct pt_state *pts) +{ + const u64 *tablep = pt_cur_table(pts, u64); + u64 entry; + + pts->entry = entry = READ_ONCE(tablep[pts->index]); + if (!entry) + return PT_ENTRY_EMPTY; + if (pts->level == 0 || + (vtdss_pt_can_have_leaf(pts) && (pts->entry & VTDSS_FMT_PS))) + return PT_ENTRY_OA; + return PT_ENTRY_TABLE; +} +#define pt_load_entry_raw vtdss_pt_load_entry_raw + +static inline void +vtdss_pt_install_leaf_entry(struct pt_state *pts, pt_oaddr_t oa, + unsigned int oasz_lg2, + const struct pt_write_attrs *attrs) +{ + u64 *tablep = pt_cur_table(pts, u64); + u64 entry; + + if (!pt_check_install_leaf_args(pts, oa, oasz_lg2)) + return; + + entry = FIELD_PREP(VTDSS_FMT_OA, log2_div(oa, PT_GRANULE_LG2SZ)) | + attrs->descriptor_bits; + if (pts->level != 0) + entry |= VTDSS_FMT_PS; + + WRITE_ONCE(tablep[pts->index], entry); + pts->entry = entry; +} +#define pt_install_leaf_entry vtdss_pt_install_leaf_entry + +static inline bool vtdss_pt_install_table(struct pt_state *pts, + pt_oaddr_t table_pa, + const struct pt_write_attrs *attrs) +{ + u64 entry; + + entry = VTDSS_FMT_R | VTDSS_FMT_W | + FIELD_PREP(VTDSS_FMT_OA, log2_div(table_pa, PT_GRANULE_LG2SZ)); + return pt_table_install64(pts, entry); +} +#define pt_install_table vtdss_pt_install_table + +static inline void vtdss_pt_attr_from_entry(const struct pt_state *pts, + struct pt_write_attrs *attrs) +{ + attrs->descriptor_bits = pts->entry & + (VTDSS_FMT_R | VTDSS_FMT_W | VTDSS_FMT_SNP); +} +#define pt_attr_from_entry vtdss_pt_attr_from_entry + +static inline bool vtdss_pt_entry_is_write_dirty(const struct pt_state *pts) +{ + u64 *tablep = pt_cur_table(pts, u64) + pts->index; + + return READ_ONCE(*tablep) & VTDSS_FMT_D; +} +#define pt_entry_is_write_dirty vtdss_pt_entry_is_write_dirty + +static inline void vtdss_pt_entry_make_write_clean(struct pt_state *pts) +{ + u64 *tablep = pt_cur_table(pts, u64) + pts->index; + + WRITE_ONCE(*tablep, READ_ONCE(*tablep) & ~(u64)VTDSS_FMT_D); +} +#define pt_entry_make_write_clean vtdss_pt_entry_make_write_clean + +static inline bool vtdss_pt_entry_make_write_dirty(struct pt_state *pts) +{ + u64 *tablep = pt_cur_table(pts, u64) + pts->index; + u64 new = pts->entry | VTDSS_FMT_D; + + return try_cmpxchg64(tablep, &pts->entry, new); +} +#define pt_entry_make_write_dirty vtdss_pt_entry_make_write_dirty + +static inline unsigned int vtdss_pt_max_sw_bit(struct pt_common *common) +{ + return 10; +} +#define pt_max_sw_bit vtdss_pt_max_sw_bit + +static inline u64 vtdss_pt_sw_bit(unsigned int bitnr) +{ + if (__builtin_constant_p(bitnr) && bitnr > 10) + BUILD_BUG(); + + /* Bits marked Ignored in the specification */ + switch (bitnr) { + case 0: + return BIT(10); + case 1 ... 9: + return BIT_ULL((bitnr - 1) + 52); + case 10: + return BIT_ULL(63); + /* Some bits in 9-3 are available in some entries */ + default: + PT_WARN_ON(true); + return 0; + } +} +#define pt_sw_bit vtdss_pt_sw_bit + +/* --- iommu */ +#include <linux/generic_pt/iommu.h> +#include <linux/iommu.h> + +#define pt_iommu_table pt_iommu_vtdss + +/* The common struct is in the per-format common struct */ +static inline struct pt_common *common_from_iommu(struct pt_iommu *iommu_table) +{ + return &container_of(iommu_table, struct pt_iommu_table, iommu) + ->vtdss_pt.common; +} + +static inline struct pt_iommu *iommu_from_common(struct pt_common *common) +{ + return &container_of(common, struct pt_iommu_table, vtdss_pt.common) + ->iommu; +} + +static inline int vtdss_pt_iommu_set_prot(struct pt_common *common, + struct pt_write_attrs *attrs, + unsigned int iommu_prot) +{ + u64 pte = 0; + + /* + * VTDSS does not have a present bit, so we tell if any entry is present + * by checking for R or W. + */ + if (!(iommu_prot & (IOMMU_READ | IOMMU_WRITE))) + return -EINVAL; + + if (iommu_prot & IOMMU_READ) + pte |= VTDSS_FMT_R; + if (iommu_prot & IOMMU_WRITE) + pte |= VTDSS_FMT_W; + if (pt_feature(common, PT_FEAT_VTDSS_FORCE_COHERENCE)) + pte |= VTDSS_FMT_SNP; + + if (pt_feature(common, PT_FEAT_VTDSS_FORCE_WRITEABLE) && + !(iommu_prot & IOMMU_WRITE)) { + pr_err_ratelimited( + "Read-only mapping is disallowed on the domain which serves as the parent in a nested configuration, due to HW errata (ERRATA_772415_SPR17)\n"); + return -EINVAL; + } + + attrs->descriptor_bits = pte; + return 0; +} +#define pt_iommu_set_prot vtdss_pt_iommu_set_prot + +static inline int vtdss_pt_iommu_fmt_init(struct pt_iommu_vtdss *iommu_table, + const struct pt_iommu_vtdss_cfg *cfg) +{ + struct pt_vtdss *table = &iommu_table->vtdss_pt; + + if (cfg->top_level > 4 || cfg->top_level < 2) + return -EOPNOTSUPP; + + pt_top_set_level(&table->common, cfg->top_level); + return 0; +} +#define pt_iommu_fmt_init vtdss_pt_iommu_fmt_init + +static inline void +vtdss_pt_iommu_fmt_hw_info(struct pt_iommu_vtdss *table, + const struct pt_range *top_range, + struct pt_iommu_vtdss_hw_info *info) +{ + info->ssptptr = virt_to_phys(top_range->top_table); + PT_WARN_ON(info->ssptptr & ~PT_TOP_PHYS_MASK); + /* + * top_level = 2 = 3 level table aw=1 + * top_level = 3 = 4 level table aw=2 + * top_level = 4 = 5 level table aw=3 + */ + info->aw = top_range->top_level - 1; +} +#define pt_iommu_fmt_hw_info vtdss_pt_iommu_fmt_hw_info + +#if defined(GENERIC_PT_KUNIT) +static const struct pt_iommu_vtdss_cfg vtdss_kunit_fmt_cfgs[] = { + [0] = { .common.hw_max_vasz_lg2 = 39, .top_level = 2}, + [1] = { .common.hw_max_vasz_lg2 = 48, .top_level = 3}, + [2] = { .common.hw_max_vasz_lg2 = 57, .top_level = 4}, +}; +#define kunit_fmt_cfgs vtdss_kunit_fmt_cfgs +enum { KUNIT_FMT_FEATURES = BIT(PT_FEAT_VTDSS_FORCE_WRITEABLE) }; +#endif +#endif diff --git a/drivers/iommu/generic_pt/fmt/x86_64.h b/drivers/iommu/generic_pt/fmt/x86_64.h new file mode 100644 index 000000000000..210748d9d6e8 --- /dev/null +++ b/drivers/iommu/generic_pt/fmt/x86_64.h @@ -0,0 +1,279 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + * x86 page table. Supports the 4 and 5 level variations. + * + * The 4 and 5 level version is described in: + * Section "4.4 4-Level Paging and 5-Level Paging" of the Intel Software + * Developer's Manual Volume 3 + * + * Section "9.7 First-Stage Paging Entries" of the "Intel Virtualization + * Technology for Directed I/O Architecture Specification" + * + * Section "2.2.6 I/O Page Tables for Guest Translations" of the "AMD I/O + * Virtualization Technology (IOMMU) Specification" + * + * It is used by x86 CPUs, AMD and VT-d IOMMU HW. + * + * Note the 3 level format is very similar and almost implemented here. The + * reserved/ignored layout is different and there are functional bit + * differences. + * + * This format uses PT_FEAT_SIGN_EXTEND to have a upper/non-canonical/lower + * split. PT_FEAT_SIGN_EXTEND is optional as AMD IOMMU sometimes uses non-sign + * extended addressing with this page table format. + * + * The named levels in the spec map to the pts->level as: + * Table/PTE - 0 + * Directory/PDE - 1 + * Directory Ptr/PDPTE - 2 + * PML4/PML4E - 3 + * PML5/PML5E - 4 + */ +#ifndef __GENERIC_PT_FMT_X86_64_H +#define __GENERIC_PT_FMT_X86_64_H + +#include "defs_x86_64.h" +#include "../pt_defs.h" + +#include <linux/bitfield.h> +#include <linux/container_of.h> +#include <linux/log2.h> +#include <linux/mem_encrypt.h> + +enum { + PT_MAX_OUTPUT_ADDRESS_LG2 = 52, + PT_MAX_VA_ADDRESS_LG2 = 57, + PT_ITEM_WORD_SIZE = sizeof(u64), + PT_MAX_TOP_LEVEL = 4, + PT_GRANULE_LG2SZ = 12, + PT_TABLEMEM_LG2SZ = 12, + + /* + * For AMD the GCR3 Base only has these bits. For VT-d FSPTPTR is 4k + * aligned and is limited by the architected HAW + */ + PT_TOP_PHYS_MASK = GENMASK_ULL(51, 12), +}; + +/* Shared descriptor bits */ +enum { + X86_64_FMT_P = BIT(0), + X86_64_FMT_RW = BIT(1), + X86_64_FMT_U = BIT(2), + X86_64_FMT_A = BIT(5), + X86_64_FMT_D = BIT(6), + X86_64_FMT_OA = GENMASK_ULL(51, 12), + X86_64_FMT_XD = BIT_ULL(63), +}; + +/* PDPTE/PDE */ +enum { + X86_64_FMT_PS = BIT(7), +}; + +static inline pt_oaddr_t x86_64_pt_table_pa(const struct pt_state *pts) +{ + u64 entry = pts->entry; + + if (pts_feature(pts, PT_FEAT_X86_64_AMD_ENCRYPT_TABLES)) + entry = __sme_clr(entry); + return oalog2_mul(FIELD_GET(X86_64_FMT_OA, entry), + PT_TABLEMEM_LG2SZ); +} +#define pt_table_pa x86_64_pt_table_pa + +static inline pt_oaddr_t x86_64_pt_entry_oa(const struct pt_state *pts) +{ + u64 entry = pts->entry; + + if (pts_feature(pts, PT_FEAT_X86_64_AMD_ENCRYPT_TABLES)) + entry = __sme_clr(entry); + return oalog2_mul(FIELD_GET(X86_64_FMT_OA, entry), + PT_GRANULE_LG2SZ); +} +#define pt_entry_oa x86_64_pt_entry_oa + +static inline bool x86_64_pt_can_have_leaf(const struct pt_state *pts) +{ + return pts->level <= 2; +} +#define pt_can_have_leaf x86_64_pt_can_have_leaf + +static inline unsigned int x86_64_pt_num_items_lg2(const struct pt_state *pts) +{ + return PT_TABLEMEM_LG2SZ - ilog2(sizeof(u64)); +} +#define pt_num_items_lg2 x86_64_pt_num_items_lg2 + +static inline enum pt_entry_type x86_64_pt_load_entry_raw(struct pt_state *pts) +{ + const u64 *tablep = pt_cur_table(pts, u64); + u64 entry; + + pts->entry = entry = READ_ONCE(tablep[pts->index]); + if (!(entry & X86_64_FMT_P)) + return PT_ENTRY_EMPTY; + if (pts->level == 0 || + (x86_64_pt_can_have_leaf(pts) && (entry & X86_64_FMT_PS))) + return PT_ENTRY_OA; + return PT_ENTRY_TABLE; +} +#define pt_load_entry_raw x86_64_pt_load_entry_raw + +static inline void +x86_64_pt_install_leaf_entry(struct pt_state *pts, pt_oaddr_t oa, + unsigned int oasz_lg2, + const struct pt_write_attrs *attrs) +{ + u64 *tablep = pt_cur_table(pts, u64); + u64 entry; + + if (!pt_check_install_leaf_args(pts, oa, oasz_lg2)) + return; + + entry = X86_64_FMT_P | + FIELD_PREP(X86_64_FMT_OA, log2_div(oa, PT_GRANULE_LG2SZ)) | + attrs->descriptor_bits; + if (pts->level != 0) + entry |= X86_64_FMT_PS; + + WRITE_ONCE(tablep[pts->index], entry); + pts->entry = entry; +} +#define pt_install_leaf_entry x86_64_pt_install_leaf_entry + +static inline bool x86_64_pt_install_table(struct pt_state *pts, + pt_oaddr_t table_pa, + const struct pt_write_attrs *attrs) +{ + u64 entry; + + entry = X86_64_FMT_P | X86_64_FMT_RW | X86_64_FMT_U | X86_64_FMT_A | + FIELD_PREP(X86_64_FMT_OA, log2_div(table_pa, PT_GRANULE_LG2SZ)); + if (pts_feature(pts, PT_FEAT_X86_64_AMD_ENCRYPT_TABLES)) + entry = __sme_set(entry); + return pt_table_install64(pts, entry); +} +#define pt_install_table x86_64_pt_install_table + +static inline void x86_64_pt_attr_from_entry(const struct pt_state *pts, + struct pt_write_attrs *attrs) +{ + attrs->descriptor_bits = pts->entry & + (X86_64_FMT_RW | X86_64_FMT_U | X86_64_FMT_A | + X86_64_FMT_D | X86_64_FMT_XD); +} +#define pt_attr_from_entry x86_64_pt_attr_from_entry + +static inline unsigned int x86_64_pt_max_sw_bit(struct pt_common *common) +{ + return 12; +} +#define pt_max_sw_bit x86_64_pt_max_sw_bit + +static inline u64 x86_64_pt_sw_bit(unsigned int bitnr) +{ + if (__builtin_constant_p(bitnr) && bitnr > 12) + BUILD_BUG(); + + /* Bits marked Ignored/AVL in the specification */ + switch (bitnr) { + case 0: + return BIT(9); + case 1: + return BIT(11); + case 2 ... 12: + return BIT_ULL((bitnr - 2) + 52); + /* Some bits in 8,6,4,3 are available in some entries */ + default: + PT_WARN_ON(true); + return 0; + } +} +#define pt_sw_bit x86_64_pt_sw_bit + +/* --- iommu */ +#include <linux/generic_pt/iommu.h> +#include <linux/iommu.h> + +#define pt_iommu_table pt_iommu_x86_64 + +/* The common struct is in the per-format common struct */ +static inline struct pt_common *common_from_iommu(struct pt_iommu *iommu_table) +{ + return &container_of(iommu_table, struct pt_iommu_table, iommu) + ->x86_64_pt.common; +} + +static inline struct pt_iommu *iommu_from_common(struct pt_common *common) +{ + return &container_of(common, struct pt_iommu_table, x86_64_pt.common) + ->iommu; +} + +static inline int x86_64_pt_iommu_set_prot(struct pt_common *common, + struct pt_write_attrs *attrs, + unsigned int iommu_prot) +{ + u64 pte; + + pte = X86_64_FMT_U | X86_64_FMT_A; + if (iommu_prot & IOMMU_WRITE) + pte |= X86_64_FMT_RW | X86_64_FMT_D; + + /* + * Ideally we'd have an IOMMU_ENCRYPTED flag set by higher levels to + * control this. For now if the tables use sme_set then so do the ptes. + */ + if (pt_feature(common, PT_FEAT_X86_64_AMD_ENCRYPT_TABLES)) + pte = __sme_set(pte); + + attrs->descriptor_bits = pte; + return 0; +} +#define pt_iommu_set_prot x86_64_pt_iommu_set_prot + +static inline int +x86_64_pt_iommu_fmt_init(struct pt_iommu_x86_64 *iommu_table, + const struct pt_iommu_x86_64_cfg *cfg) +{ + struct pt_x86_64 *table = &iommu_table->x86_64_pt; + + if (cfg->top_level < 3 || cfg->top_level > 4) + return -EOPNOTSUPP; + + pt_top_set_level(&table->common, cfg->top_level); + + table->common.max_oasz_lg2 = + min(PT_MAX_OUTPUT_ADDRESS_LG2, cfg->common.hw_max_oasz_lg2); + return 0; +} +#define pt_iommu_fmt_init x86_64_pt_iommu_fmt_init + +static inline void +x86_64_pt_iommu_fmt_hw_info(struct pt_iommu_x86_64 *table, + const struct pt_range *top_range, + struct pt_iommu_x86_64_hw_info *info) +{ + info->gcr3_pt = virt_to_phys(top_range->top_table); + PT_WARN_ON(info->gcr3_pt & ~PT_TOP_PHYS_MASK); + info->levels = top_range->top_level + 1; +} +#define pt_iommu_fmt_hw_info x86_64_pt_iommu_fmt_hw_info + +#if defined(GENERIC_PT_KUNIT) +static const struct pt_iommu_x86_64_cfg x86_64_kunit_fmt_cfgs[] = { + [0] = { .common.features = BIT(PT_FEAT_SIGN_EXTEND), + .common.hw_max_vasz_lg2 = 48, .top_level = 3 }, + [1] = { .common.features = BIT(PT_FEAT_SIGN_EXTEND), + .common.hw_max_vasz_lg2 = 57, .top_level = 4 }, + /* AMD IOMMU PASID 0 formats with no SIGN_EXTEND */ + [2] = { .common.hw_max_vasz_lg2 = 47, .top_level = 3 }, + [3] = { .common.hw_max_vasz_lg2 = 56, .top_level = 4}, +}; +#define kunit_fmt_cfgs x86_64_kunit_fmt_cfgs +enum { KUNIT_FMT_FEATURES = BIT(PT_FEAT_SIGN_EXTEND)}; +#endif +#endif diff --git a/drivers/iommu/generic_pt/iommu_pt.h b/drivers/iommu/generic_pt/iommu_pt.h new file mode 100644 index 000000000000..97aeda1ad01c --- /dev/null +++ b/drivers/iommu/generic_pt/iommu_pt.h @@ -0,0 +1,1289 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + * "Templated C code" for implementing the iommu operations for page tables. + * This is compiled multiple times, over all the page table formats to pick up + * the per-format definitions. + */ +#ifndef __GENERIC_PT_IOMMU_PT_H +#define __GENERIC_PT_IOMMU_PT_H + +#include "pt_iter.h" + +#include <linux/export.h> +#include <linux/iommu.h> +#include "../iommu-pages.h" +#include <linux/cleanup.h> +#include <linux/dma-mapping.h> + +enum { + SW_BIT_CACHE_FLUSH_DONE = 0, +}; + +static void flush_writes_range(const struct pt_state *pts, + unsigned int start_index, unsigned int end_index) +{ + if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT)) + iommu_pages_flush_incoherent( + iommu_from_common(pts->range->common)->iommu_device, + pts->table, start_index * PT_ITEM_WORD_SIZE, + (end_index - start_index) * PT_ITEM_WORD_SIZE); +} + +static void flush_writes_item(const struct pt_state *pts) +{ + if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT)) + iommu_pages_flush_incoherent( + iommu_from_common(pts->range->common)->iommu_device, + pts->table, pts->index * PT_ITEM_WORD_SIZE, + PT_ITEM_WORD_SIZE); +} + +static void gather_range_pages(struct iommu_iotlb_gather *iotlb_gather, + struct pt_iommu *iommu_table, pt_vaddr_t iova, + pt_vaddr_t len, + struct iommu_pages_list *free_list) +{ + struct pt_common *common = common_from_iommu(iommu_table); + + if (pt_feature(common, PT_FEAT_DMA_INCOHERENT)) + iommu_pages_stop_incoherent_list(free_list, + iommu_table->iommu_device); + + if (pt_feature(common, PT_FEAT_FLUSH_RANGE_NO_GAPS) && + iommu_iotlb_gather_is_disjoint(iotlb_gather, iova, len)) { + iommu_iotlb_sync(&iommu_table->domain, iotlb_gather); + /* + * Note that the sync frees the gather's free list, so we must + * not have any pages on that list that are covered by iova/len + */ + } else if (pt_feature(common, PT_FEAT_FLUSH_RANGE)) { + iommu_iotlb_gather_add_range(iotlb_gather, iova, len); + } + + iommu_pages_list_splice(free_list, &iotlb_gather->freelist); +} + +#define DOMAIN_NS(op) CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), op) + +static int make_range_ul(struct pt_common *common, struct pt_range *range, + unsigned long iova, unsigned long len) +{ + unsigned long last; + + if (unlikely(len == 0)) + return -EINVAL; + + if (check_add_overflow(iova, len - 1, &last)) + return -EOVERFLOW; + + *range = pt_make_range(common, iova, last); + if (sizeof(iova) > sizeof(range->va)) { + if (unlikely(range->va != iova || range->last_va != last)) + return -EOVERFLOW; + } + return 0; +} + +static __maybe_unused int make_range_u64(struct pt_common *common, + struct pt_range *range, u64 iova, + u64 len) +{ + if (unlikely(iova > ULONG_MAX || len > ULONG_MAX)) + return -EOVERFLOW; + return make_range_ul(common, range, iova, len); +} + +/* + * Some APIs use unsigned long, while othersuse dma_addr_t as the type. Dispatch + * to the correct validation based on the type. + */ +#define make_range_no_check(common, range, iova, len) \ + ({ \ + int ret; \ + if (sizeof(iova) > sizeof(unsigned long) || \ + sizeof(len) > sizeof(unsigned long)) \ + ret = make_range_u64(common, range, iova, len); \ + else \ + ret = make_range_ul(common, range, iova, len); \ + ret; \ + }) + +#define make_range(common, range, iova, len) \ + ({ \ + int ret = make_range_no_check(common, range, iova, len); \ + if (!ret) \ + ret = pt_check_range(range); \ + ret; \ + }) + +static inline unsigned int compute_best_pgsize(struct pt_state *pts, + pt_oaddr_t oa) +{ + struct pt_iommu *iommu_table = iommu_from_common(pts->range->common); + + if (!pt_can_have_leaf(pts)) + return 0; + + /* + * The page size is limited by the domain's bitmap. This allows the core + * code to reduce the supported page sizes by changing the bitmap. + */ + return pt_compute_best_pgsize(pt_possible_sizes(pts) & + iommu_table->domain.pgsize_bitmap, + pts->range->va, pts->range->last_va, oa); +} + +static __always_inline int __do_iova_to_phys(struct pt_range *range, void *arg, + unsigned int level, + struct pt_table_p *table, + pt_level_fn_t descend_fn) +{ + struct pt_state pts = pt_init(range, level, table); + pt_oaddr_t *res = arg; + + switch (pt_load_single_entry(&pts)) { + case PT_ENTRY_EMPTY: + return -ENOENT; + case PT_ENTRY_TABLE: + return pt_descend(&pts, arg, descend_fn); + case PT_ENTRY_OA: + *res = pt_entry_oa_exact(&pts); + return 0; + } + return -ENOENT; +} +PT_MAKE_LEVELS(__iova_to_phys, __do_iova_to_phys); + +/** + * iova_to_phys() - Return the output address for the given IOVA + * @domain: Table to query + * @iova: IO virtual address to query + * + * Determine the output address from the given IOVA. @iova may have any + * alignment, the returned physical will be adjusted with any sub page offset. + * + * Context: The caller must hold a read range lock that includes @iova. + * + * Return: 0 if there is no translation for the given iova. + */ +phys_addr_t DOMAIN_NS(iova_to_phys)(struct iommu_domain *domain, + dma_addr_t iova) +{ + struct pt_iommu *iommu_table = + container_of(domain, struct pt_iommu, domain); + struct pt_range range; + pt_oaddr_t res; + int ret; + + ret = make_range(common_from_iommu(iommu_table), &range, iova, 1); + if (ret) + return ret; + + ret = pt_walk_range(&range, __iova_to_phys, &res); + /* PHYS_ADDR_MAX would be a better error code */ + if (ret) + return 0; + return res; +} +EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(iova_to_phys), "GENERIC_PT_IOMMU"); + +struct pt_iommu_dirty_args { + struct iommu_dirty_bitmap *dirty; + unsigned int flags; +}; + +static void record_dirty(struct pt_state *pts, + struct pt_iommu_dirty_args *dirty, + unsigned int num_contig_lg2) +{ + pt_vaddr_t dirty_len; + + if (num_contig_lg2 != ilog2(1)) { + unsigned int index = pts->index; + unsigned int end_index = log2_set_mod_max_t( + unsigned int, pts->index, num_contig_lg2); + + /* Adjust for being contained inside a contiguous page */ + end_index = min(end_index, pts->end_index); + dirty_len = (end_index - index) * + log2_to_int(pt_table_item_lg2sz(pts)); + } else { + dirty_len = log2_to_int(pt_table_item_lg2sz(pts)); + } + + if (dirty->dirty->bitmap) + iova_bitmap_set(dirty->dirty->bitmap, pts->range->va, + dirty_len); + + if (!(dirty->flags & IOMMU_DIRTY_NO_CLEAR)) { + /* + * No write log required because DMA incoherence and atomic + * dirty tracking bits can't work together + */ + pt_entry_make_write_clean(pts); + iommu_iotlb_gather_add_range(dirty->dirty->gather, + pts->range->va, dirty_len); + } +} + +static inline int __read_and_clear_dirty(struct pt_range *range, void *arg, + unsigned int level, + struct pt_table_p *table) +{ + struct pt_state pts = pt_init(range, level, table); + struct pt_iommu_dirty_args *dirty = arg; + int ret; + + for_each_pt_level_entry(&pts) { + if (pts.type == PT_ENTRY_TABLE) { + ret = pt_descend(&pts, arg, __read_and_clear_dirty); + if (ret) + return ret; + continue; + } + if (pts.type == PT_ENTRY_OA && pt_entry_is_write_dirty(&pts)) + record_dirty(&pts, dirty, + pt_entry_num_contig_lg2(&pts)); + } + return 0; +} + +/** + * read_and_clear_dirty() - Manipulate the HW set write dirty state + * @domain: Domain to manipulate + * @iova: IO virtual address to start + * @size: Length of the IOVA + * @flags: A bitmap of IOMMU_DIRTY_NO_CLEAR + * @dirty: Place to store the dirty bits + * + * Iterate over all the entries in the mapped range and record their write dirty + * status in iommu_dirty_bitmap. If IOMMU_DIRTY_NO_CLEAR is not specified then + * the entries will be left dirty, otherwise they are returned to being not + * write dirty. + * + * Context: The caller must hold a read range lock that includes @iova. + * + * Returns: -ERRNO on failure, 0 on success. + */ +int DOMAIN_NS(read_and_clear_dirty)(struct iommu_domain *domain, + unsigned long iova, size_t size, + unsigned long flags, + struct iommu_dirty_bitmap *dirty) +{ + struct pt_iommu *iommu_table = + container_of(domain, struct pt_iommu, domain); + struct pt_iommu_dirty_args dirty_args = { + .dirty = dirty, + .flags = flags, + }; + struct pt_range range; + int ret; + +#if !IS_ENABLED(CONFIG_IOMMUFD_DRIVER) || !defined(pt_entry_is_write_dirty) + return -EOPNOTSUPP; +#endif + + ret = make_range(common_from_iommu(iommu_table), &range, iova, size); + if (ret) + return ret; + + ret = pt_walk_range(&range, __read_and_clear_dirty, &dirty_args); + PT_WARN_ON(ret); + return ret; +} +EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(read_and_clear_dirty), "GENERIC_PT_IOMMU"); + +static inline int __set_dirty(struct pt_range *range, void *arg, + unsigned int level, struct pt_table_p *table) +{ + struct pt_state pts = pt_init(range, level, table); + + switch (pt_load_single_entry(&pts)) { + case PT_ENTRY_EMPTY: + return -ENOENT; + case PT_ENTRY_TABLE: + return pt_descend(&pts, arg, __set_dirty); + case PT_ENTRY_OA: + if (!pt_entry_make_write_dirty(&pts)) + return -EAGAIN; + return 0; + } + return -ENOENT; +} + +static int __maybe_unused NS(set_dirty)(struct pt_iommu *iommu_table, + dma_addr_t iova) +{ + struct pt_range range; + int ret; + + ret = make_range(common_from_iommu(iommu_table), &range, iova, 1); + if (ret) + return ret; + + /* + * Note: There is no locking here yet, if the test suite races this it + * can crash. It should use RCU locking eventually. + */ + return pt_walk_range(&range, __set_dirty, NULL); +} + +struct pt_iommu_collect_args { + struct iommu_pages_list free_list; + /* Fail if any OAs are within the range */ + u8 check_mapped : 1; +}; + +static int __collect_tables(struct pt_range *range, void *arg, + unsigned int level, struct pt_table_p *table) +{ + struct pt_state pts = pt_init(range, level, table); + struct pt_iommu_collect_args *collect = arg; + int ret; + + if (!collect->check_mapped && !pt_can_have_table(&pts)) + return 0; + + for_each_pt_level_entry(&pts) { + if (pts.type == PT_ENTRY_TABLE) { + iommu_pages_list_add(&collect->free_list, pts.table_lower); + ret = pt_descend(&pts, arg, __collect_tables); + if (ret) + return ret; + continue; + } + if (pts.type == PT_ENTRY_OA && collect->check_mapped) + return -EADDRINUSE; + } + return 0; +} + +enum alloc_mode {ALLOC_NORMAL, ALLOC_DEFER_COHERENT_FLUSH}; + +/* Allocate a table, the empty table will be ready to be installed. */ +static inline struct pt_table_p *_table_alloc(struct pt_common *common, + size_t lg2sz, gfp_t gfp, + enum alloc_mode mode) +{ + struct pt_iommu *iommu_table = iommu_from_common(common); + struct pt_table_p *table_mem; + + table_mem = iommu_alloc_pages_node_sz(iommu_table->nid, gfp, + log2_to_int(lg2sz)); + if (pt_feature(common, PT_FEAT_DMA_INCOHERENT) && + mode == ALLOC_NORMAL) { + int ret = iommu_pages_start_incoherent( + table_mem, iommu_table->iommu_device); + if (ret) { + iommu_free_pages(table_mem); + return ERR_PTR(ret); + } + } + return table_mem; +} + +static inline struct pt_table_p *table_alloc_top(struct pt_common *common, + uintptr_t top_of_table, + gfp_t gfp, + enum alloc_mode mode) +{ + /* + * Top doesn't need the free list or otherwise, so it technically + * doesn't need to use iommu pages. Use the API anyhow as the top is + * usually not smaller than PAGE_SIZE to keep things simple. + */ + return _table_alloc(common, pt_top_memsize_lg2(common, top_of_table), + gfp, mode); +} + +/* Allocate an interior table */ +static inline struct pt_table_p *table_alloc(const struct pt_state *parent_pts, + gfp_t gfp, enum alloc_mode mode) +{ + struct pt_state child_pts = + pt_init(parent_pts->range, parent_pts->level - 1, NULL); + + return _table_alloc(parent_pts->range->common, + pt_num_items_lg2(&child_pts) + + ilog2(PT_ITEM_WORD_SIZE), + gfp, mode); +} + +static inline int pt_iommu_new_table(struct pt_state *pts, + struct pt_write_attrs *attrs) +{ + struct pt_table_p *table_mem; + phys_addr_t phys; + + /* Given PA/VA/length can't be represented */ + if (PT_WARN_ON(!pt_can_have_table(pts))) + return -ENXIO; + + table_mem = table_alloc(pts, attrs->gfp, ALLOC_NORMAL); + if (IS_ERR(table_mem)) + return PTR_ERR(table_mem); + + phys = virt_to_phys(table_mem); + if (!pt_install_table(pts, phys, attrs)) { + iommu_pages_free_incoherent( + table_mem, + iommu_from_common(pts->range->common)->iommu_device); + return -EAGAIN; + } + + if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT)) { + flush_writes_item(pts); + pt_set_sw_bit_release(pts, SW_BIT_CACHE_FLUSH_DONE); + } + + if (IS_ENABLED(CONFIG_DEBUG_GENERIC_PT)) { + /* + * The underlying table can't store the physical table address. + * This happens when kunit testing tables outside their normal + * environment where a CPU might be limited. + */ + pt_load_single_entry(pts); + if (PT_WARN_ON(pt_table_pa(pts) != phys)) { + pt_clear_entries(pts, ilog2(1)); + iommu_pages_free_incoherent( + table_mem, iommu_from_common(pts->range->common) + ->iommu_device); + return -EINVAL; + } + } + + pts->table_lower = table_mem; + return 0; +} + +struct pt_iommu_map_args { + struct iommu_iotlb_gather *iotlb_gather; + struct pt_write_attrs attrs; + pt_oaddr_t oa; + unsigned int leaf_pgsize_lg2; + unsigned int leaf_level; +}; + +/* + * This will recursively check any tables in the block to validate they are + * empty and then free them through the gather. + */ +static int clear_contig(const struct pt_state *start_pts, + struct iommu_iotlb_gather *iotlb_gather, + unsigned int step, unsigned int pgsize_lg2) +{ + struct pt_iommu *iommu_table = + iommu_from_common(start_pts->range->common); + struct pt_range range = *start_pts->range; + struct pt_state pts = + pt_init(&range, start_pts->level, start_pts->table); + struct pt_iommu_collect_args collect = { .check_mapped = true }; + int ret; + + pts.index = start_pts->index; + pts.end_index = start_pts->index + step; + for (; _pt_iter_load(&pts); pt_next_entry(&pts)) { + if (pts.type == PT_ENTRY_TABLE) { + collect.free_list = + IOMMU_PAGES_LIST_INIT(collect.free_list); + ret = pt_walk_descend_all(&pts, __collect_tables, + &collect); + if (ret) + return ret; + + /* + * The table item must be cleared before we can update + * the gather + */ + pt_clear_entries(&pts, ilog2(1)); + flush_writes_item(&pts); + + iommu_pages_list_add(&collect.free_list, + pt_table_ptr(&pts)); + gather_range_pages( + iotlb_gather, iommu_table, range.va, + log2_to_int(pt_table_item_lg2sz(&pts)), + &collect.free_list); + } else if (pts.type != PT_ENTRY_EMPTY) { + return -EADDRINUSE; + } + } + return 0; +} + +static int __map_range_leaf(struct pt_range *range, void *arg, + unsigned int level, struct pt_table_p *table) +{ + struct pt_state pts = pt_init(range, level, table); + struct pt_iommu_map_args *map = arg; + unsigned int leaf_pgsize_lg2 = map->leaf_pgsize_lg2; + unsigned int start_index; + pt_oaddr_t oa = map->oa; + unsigned int step; + bool need_contig; + int ret = 0; + + PT_WARN_ON(map->leaf_level != level); + PT_WARN_ON(!pt_can_have_leaf(&pts)); + + step = log2_to_int_t(unsigned int, + leaf_pgsize_lg2 - pt_table_item_lg2sz(&pts)); + need_contig = leaf_pgsize_lg2 != pt_table_item_lg2sz(&pts); + + _pt_iter_first(&pts); + start_index = pts.index; + do { + pts.type = pt_load_entry_raw(&pts); + if (pts.type != PT_ENTRY_EMPTY || need_contig) { + if (pts.index != start_index) + pt_index_to_va(&pts); + ret = clear_contig(&pts, map->iotlb_gather, step, + leaf_pgsize_lg2); + if (ret) + break; + } + + if (IS_ENABLED(CONFIG_DEBUG_GENERIC_PT)) { + pt_index_to_va(&pts); + PT_WARN_ON(compute_best_pgsize(&pts, oa) != + leaf_pgsize_lg2); + } + pt_install_leaf_entry(&pts, oa, leaf_pgsize_lg2, &map->attrs); + + oa += log2_to_int(leaf_pgsize_lg2); + pts.index += step; + } while (pts.index < pts.end_index); + + flush_writes_range(&pts, start_index, pts.index); + + map->oa = oa; + return ret; +} + +static int __map_range(struct pt_range *range, void *arg, unsigned int level, + struct pt_table_p *table) +{ + struct pt_state pts = pt_init(range, level, table); + struct pt_iommu_map_args *map = arg; + int ret; + + PT_WARN_ON(map->leaf_level == level); + PT_WARN_ON(!pt_can_have_table(&pts)); + + _pt_iter_first(&pts); + + /* Descend to a child table */ + do { + pts.type = pt_load_entry_raw(&pts); + + if (pts.type != PT_ENTRY_TABLE) { + if (pts.type != PT_ENTRY_EMPTY) + return -EADDRINUSE; + ret = pt_iommu_new_table(&pts, &map->attrs); + if (ret) { + /* + * Racing with another thread installing a table + */ + if (ret == -EAGAIN) + continue; + return ret; + } + } else { + pts.table_lower = pt_table_ptr(&pts); + /* + * Racing with a shared pt_iommu_new_table()? The other + * thread is still flushing the cache, so we have to + * also flush it to ensure that when our thread's map + * completes all the table items leading to our mapping + * are visible. + * + * This requires the pt_set_bit_release() to be a + * release of the cache flush so that this can acquire + * visibility at the iommu. + */ + if (pts_feature(&pts, PT_FEAT_DMA_INCOHERENT) && + !pt_test_sw_bit_acquire(&pts, + SW_BIT_CACHE_FLUSH_DONE)) + flush_writes_item(&pts); + } + + /* + * The already present table can possibly be shared with another + * concurrent map. + */ + if (map->leaf_level == level - 1) + ret = pt_descend(&pts, arg, __map_range_leaf); + else + ret = pt_descend(&pts, arg, __map_range); + if (ret) + return ret; + + pts.index++; + pt_index_to_va(&pts); + if (pts.index >= pts.end_index) + break; + } while (true); + return 0; +} + +/* + * Fast path for the easy case of mapping a 4k page to an already allocated + * table. This is a common workload. If it returns EAGAIN run the full algorithm + * instead. + */ +static __always_inline int __do_map_single_page(struct pt_range *range, + void *arg, unsigned int level, + struct pt_table_p *table, + pt_level_fn_t descend_fn) +{ + struct pt_state pts = pt_init(range, level, table); + struct pt_iommu_map_args *map = arg; + + pts.type = pt_load_single_entry(&pts); + if (level == 0) { + if (pts.type != PT_ENTRY_EMPTY) + return -EADDRINUSE; + pt_install_leaf_entry(&pts, map->oa, PAGE_SHIFT, + &map->attrs); + /* No flush, not used when incoherent */ + map->oa += PAGE_SIZE; + return 0; + } + if (pts.type == PT_ENTRY_TABLE) + return pt_descend(&pts, arg, descend_fn); + /* Something else, use the slow path */ + return -EAGAIN; +} +PT_MAKE_LEVELS(__map_single_page, __do_map_single_page); + +/* + * Add a table to the top, increasing the top level as much as necessary to + * encompass range. + */ +static int increase_top(struct pt_iommu *iommu_table, struct pt_range *range, + struct pt_iommu_map_args *map) +{ + struct iommu_pages_list free_list = IOMMU_PAGES_LIST_INIT(free_list); + struct pt_common *common = common_from_iommu(iommu_table); + uintptr_t top_of_table = READ_ONCE(common->top_of_table); + uintptr_t new_top_of_table = top_of_table; + struct pt_table_p *table_mem; + unsigned int new_level; + spinlock_t *domain_lock; + unsigned long flags; + int ret; + + while (true) { + struct pt_range top_range = + _pt_top_range(common, new_top_of_table); + struct pt_state pts = pt_init_top(&top_range); + + top_range.va = range->va; + top_range.last_va = range->last_va; + + if (!pt_check_range(&top_range) && + map->leaf_level <= pts.level) { + new_level = pts.level; + break; + } + + pts.level++; + if (pts.level > PT_MAX_TOP_LEVEL || + pt_table_item_lg2sz(&pts) >= common->max_vasz_lg2) { + ret = -ERANGE; + goto err_free; + } + + table_mem = + table_alloc_top(common, _pt_top_set(NULL, pts.level), + map->attrs.gfp, ALLOC_DEFER_COHERENT_FLUSH); + if (IS_ERR(table_mem)) { + ret = PTR_ERR(table_mem); + goto err_free; + } + iommu_pages_list_add(&free_list, table_mem); + + /* The new table links to the lower table always at index 0 */ + top_range.va = 0; + top_range.top_level = pts.level; + pts.table_lower = pts.table; + pts.table = table_mem; + pt_load_single_entry(&pts); + PT_WARN_ON(pts.index != 0); + pt_install_table(&pts, virt_to_phys(pts.table_lower), + &map->attrs); + new_top_of_table = _pt_top_set(pts.table, pts.level); + } + + /* + * Avoid double flushing, flush it once after all pt_install_table() + */ + if (pt_feature(common, PT_FEAT_DMA_INCOHERENT)) { + ret = iommu_pages_start_incoherent_list( + &free_list, iommu_table->iommu_device); + if (ret) + goto err_free; + } + + /* + * top_of_table is write locked by the spinlock, but readers can use + * READ_ONCE() to get the value. Since we encode both the level and the + * pointer in one quanta the lockless reader will always see something + * valid. The HW must be updated to the new level under the spinlock + * before top_of_table is updated so that concurrent readers don't map + * into the new level until it is fully functional. If another thread + * already updated it while we were working then throw everything away + * and try again. + */ + domain_lock = iommu_table->driver_ops->get_top_lock(iommu_table); + spin_lock_irqsave(domain_lock, flags); + if (common->top_of_table != top_of_table || + top_of_table == new_top_of_table) { + spin_unlock_irqrestore(domain_lock, flags); + ret = -EAGAIN; + goto err_free; + } + + /* + * We do not issue any flushes for change_top on the expectation that + * any walk cache will not become a problem by adding another layer to + * the tree. Misses will rewalk from the updated top pointer, hits + * continue to be correct. Negative caching is fine too since all the + * new IOVA added by the new top is non-present. + */ + iommu_table->driver_ops->change_top( + iommu_table, virt_to_phys(table_mem), new_level); + WRITE_ONCE(common->top_of_table, new_top_of_table); + spin_unlock_irqrestore(domain_lock, flags); + return 0; + +err_free: + if (pt_feature(common, PT_FEAT_DMA_INCOHERENT)) + iommu_pages_stop_incoherent_list(&free_list, + iommu_table->iommu_device); + iommu_put_pages_list(&free_list); + return ret; +} + +static int check_map_range(struct pt_iommu *iommu_table, struct pt_range *range, + struct pt_iommu_map_args *map) +{ + struct pt_common *common = common_from_iommu(iommu_table); + int ret; + + do { + ret = pt_check_range(range); + if (!pt_feature(common, PT_FEAT_DYNAMIC_TOP)) + return ret; + + if (!ret && map->leaf_level <= range->top_level) + break; + + ret = increase_top(iommu_table, range, map); + if (ret && ret != -EAGAIN) + return ret; + + /* Reload the new top */ + *range = pt_make_range(common, range->va, range->last_va); + } while (ret); + PT_WARN_ON(pt_check_range(range)); + return 0; +} + +static int do_map(struct pt_range *range, struct pt_common *common, + bool single_page, struct pt_iommu_map_args *map) +{ + /* + * The __map_single_page() fast path does not support DMA_INCOHERENT + * flushing to keep its .text small. + */ + if (single_page && !pt_feature(common, PT_FEAT_DMA_INCOHERENT)) { + int ret; + + ret = pt_walk_range(range, __map_single_page, map); + if (ret != -EAGAIN) + return ret; + /* EAGAIN falls through to the full path */ + } + + if (map->leaf_level == range->top_level) + return pt_walk_range(range, __map_range_leaf, map); + return pt_walk_range(range, __map_range, map); +} + +/** + * map_pages() - Install translation for an IOVA range + * @domain: Domain to manipulate + * @iova: IO virtual address to start + * @paddr: Physical/Output address to start + * @pgsize: Length of each page + * @pgcount: Length of the range in pgsize units starting from @iova + * @prot: A bitmap of IOMMU_READ/WRITE/CACHE/NOEXEC/MMIO + * @gfp: GFP flags for any memory allocations + * @mapped: Total bytes successfully mapped + * + * The range starting at IOVA will have paddr installed into it. The caller + * must specify a valid pgsize and pgcount to segment the range into compatible + * blocks. + * + * On error the caller will probably want to invoke unmap on the range from iova + * up to the amount indicated by @mapped to return the table back to an + * unchanged state. + * + * Context: The caller must hold a write range lock that includes the whole + * range. + * + * Returns: -ERRNO on failure, 0 on success. The number of bytes of VA that were + * mapped are added to @mapped, @mapped is not zerod first. + */ +int DOMAIN_NS(map_pages)(struct iommu_domain *domain, unsigned long iova, + phys_addr_t paddr, size_t pgsize, size_t pgcount, + int prot, gfp_t gfp, size_t *mapped) +{ + struct pt_iommu *iommu_table = + container_of(domain, struct pt_iommu, domain); + pt_vaddr_t pgsize_bitmap = iommu_table->domain.pgsize_bitmap; + struct pt_common *common = common_from_iommu(iommu_table); + struct iommu_iotlb_gather iotlb_gather; + pt_vaddr_t len = pgsize * pgcount; + struct pt_iommu_map_args map = { + .iotlb_gather = &iotlb_gather, + .oa = paddr, + .leaf_pgsize_lg2 = vaffs(pgsize), + }; + bool single_page = false; + struct pt_range range; + int ret; + + iommu_iotlb_gather_init(&iotlb_gather); + + if (WARN_ON(!(prot & (IOMMU_READ | IOMMU_WRITE)))) + return -EINVAL; + + /* Check the paddr doesn't exceed what the table can store */ + if ((sizeof(pt_oaddr_t) < sizeof(paddr) && + (pt_vaddr_t)paddr > PT_VADDR_MAX) || + (common->max_oasz_lg2 != PT_VADDR_MAX_LG2 && + oalog2_div(paddr, common->max_oasz_lg2))) + return -ERANGE; + + ret = pt_iommu_set_prot(common, &map.attrs, prot); + if (ret) + return ret; + map.attrs.gfp = gfp; + + ret = make_range_no_check(common, &range, iova, len); + if (ret) + return ret; + + /* Calculate target page size and level for the leaves */ + if (pt_has_system_page_size(common) && pgsize == PAGE_SIZE && + pgcount == 1) { + PT_WARN_ON(!(pgsize_bitmap & PAGE_SIZE)); + if (log2_mod(iova | paddr, PAGE_SHIFT)) + return -ENXIO; + map.leaf_pgsize_lg2 = PAGE_SHIFT; + map.leaf_level = 0; + single_page = true; + } else { + map.leaf_pgsize_lg2 = pt_compute_best_pgsize( + pgsize_bitmap, range.va, range.last_va, paddr); + if (!map.leaf_pgsize_lg2) + return -ENXIO; + map.leaf_level = + pt_pgsz_lg2_to_level(common, map.leaf_pgsize_lg2); + } + + ret = check_map_range(iommu_table, &range, &map); + if (ret) + return ret; + + PT_WARN_ON(map.leaf_level > range.top_level); + + ret = do_map(&range, common, single_page, &map); + + /* + * Table levels were freed and replaced with large items, flush any walk + * cache that may refer to the freed levels. + */ + if (!iommu_pages_list_empty(&iotlb_gather.freelist)) + iommu_iotlb_sync(&iommu_table->domain, &iotlb_gather); + + /* Bytes successfully mapped */ + PT_WARN_ON(!ret && map.oa - paddr != len); + *mapped += map.oa - paddr; + return ret; +} +EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(map_pages), "GENERIC_PT_IOMMU"); + +struct pt_unmap_args { + struct iommu_pages_list free_list; + pt_vaddr_t unmapped; +}; + +static __maybe_unused int __unmap_range(struct pt_range *range, void *arg, + unsigned int level, + struct pt_table_p *table) +{ + struct pt_state pts = pt_init(range, level, table); + struct pt_unmap_args *unmap = arg; + unsigned int num_oas = 0; + unsigned int start_index; + int ret = 0; + + _pt_iter_first(&pts); + start_index = pts.index; + pts.type = pt_load_entry_raw(&pts); + /* + * A starting index is in the middle of a contiguous entry + * + * The IOMMU API does not require drivers to support unmapping parts of + * large pages. Long ago VFIO would try to split maps but the current + * version never does. + * + * Instead when unmap reaches a partial unmap of the start of a large + * IOPTE it should remove the entire IOPTE and return that size to the + * caller. + */ + if (pts.type == PT_ENTRY_OA) { + if (log2_mod(range->va, pt_entry_oa_lg2sz(&pts))) + return -EINVAL; + /* Micro optimization */ + goto start_oa; + } + + do { + if (pts.type != PT_ENTRY_OA) { + bool fully_covered; + + if (pts.type != PT_ENTRY_TABLE) { + ret = -EINVAL; + break; + } + + if (pts.index != start_index) + pt_index_to_va(&pts); + pts.table_lower = pt_table_ptr(&pts); + + fully_covered = pt_entry_fully_covered( + &pts, pt_table_item_lg2sz(&pts)); + + ret = pt_descend(&pts, arg, __unmap_range); + if (ret) + break; + + /* + * If the unmapping range fully covers the table then we + * can free it as well. The clear is delayed until we + * succeed in clearing the lower table levels. + */ + if (fully_covered) { + iommu_pages_list_add(&unmap->free_list, + pts.table_lower); + pt_clear_entries(&pts, ilog2(1)); + } + pts.index++; + } else { + unsigned int num_contig_lg2; +start_oa: + /* + * If the caller requested an last that falls within a + * single entry then the entire entry is unmapped and + * the length returned will be larger than requested. + */ + num_contig_lg2 = pt_entry_num_contig_lg2(&pts); + pt_clear_entries(&pts, num_contig_lg2); + num_oas += log2_to_int(num_contig_lg2); + pts.index += log2_to_int(num_contig_lg2); + } + if (pts.index >= pts.end_index) + break; + pts.type = pt_load_entry_raw(&pts); + } while (true); + + unmap->unmapped += log2_mul(num_oas, pt_table_item_lg2sz(&pts)); + flush_writes_range(&pts, start_index, pts.index); + + return ret; +} + +/** + * unmap_pages() - Make a range of IOVA empty/not present + * @domain: Domain to manipulate + * @iova: IO virtual address to start + * @pgsize: Length of each page + * @pgcount: Length of the range in pgsize units starting from @iova + * @iotlb_gather: Gather struct that must be flushed on return + * + * unmap_pages() will remove a translation created by map_pages(). It cannot + * subdivide a mapping created by map_pages(), so it should be called with IOVA + * ranges that match those passed to map_pages(). The IOVA range can aggregate + * contiguous map_pages() calls so long as no individual range is split. + * + * Context: The caller must hold a write range lock that includes + * the whole range. + * + * Returns: Number of bytes of VA unmapped. iova + res will be the point + * unmapping stopped. + */ +size_t DOMAIN_NS(unmap_pages)(struct iommu_domain *domain, unsigned long iova, + size_t pgsize, size_t pgcount, + struct iommu_iotlb_gather *iotlb_gather) +{ + struct pt_iommu *iommu_table = + container_of(domain, struct pt_iommu, domain); + struct pt_unmap_args unmap = { .free_list = IOMMU_PAGES_LIST_INIT( + unmap.free_list) }; + pt_vaddr_t len = pgsize * pgcount; + struct pt_range range; + int ret; + + ret = make_range(common_from_iommu(iommu_table), &range, iova, len); + if (ret) + return 0; + + pt_walk_range(&range, __unmap_range, &unmap); + + gather_range_pages(iotlb_gather, iommu_table, iova, len, + &unmap.free_list); + + return unmap.unmapped; +} +EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(unmap_pages), "GENERIC_PT_IOMMU"); + +static void NS(get_info)(struct pt_iommu *iommu_table, + struct pt_iommu_info *info) +{ + struct pt_common *common = common_from_iommu(iommu_table); + struct pt_range range = pt_top_range(common); + struct pt_state pts = pt_init_top(&range); + pt_vaddr_t pgsize_bitmap = 0; + + if (pt_feature(common, PT_FEAT_DYNAMIC_TOP)) { + for (pts.level = 0; pts.level <= PT_MAX_TOP_LEVEL; + pts.level++) { + if (pt_table_item_lg2sz(&pts) >= common->max_vasz_lg2) + break; + pgsize_bitmap |= pt_possible_sizes(&pts); + } + } else { + for (pts.level = 0; pts.level <= range.top_level; pts.level++) + pgsize_bitmap |= pt_possible_sizes(&pts); + } + + /* Hide page sizes larger than the maximum OA */ + info->pgsize_bitmap = oalog2_mod(pgsize_bitmap, common->max_oasz_lg2); +} + +static void NS(deinit)(struct pt_iommu *iommu_table) +{ + struct pt_common *common = common_from_iommu(iommu_table); + struct pt_range range = pt_all_range(common); + struct pt_iommu_collect_args collect = { + .free_list = IOMMU_PAGES_LIST_INIT(collect.free_list), + }; + + iommu_pages_list_add(&collect.free_list, range.top_table); + pt_walk_range(&range, __collect_tables, &collect); + + /* + * The driver has to already have fenced the HW access to the page table + * and invalidated any caching referring to this memory. + */ + if (pt_feature(common, PT_FEAT_DMA_INCOHERENT)) + iommu_pages_stop_incoherent_list(&collect.free_list, + iommu_table->iommu_device); + iommu_put_pages_list(&collect.free_list); +} + +static const struct pt_iommu_ops NS(ops) = { +#if IS_ENABLED(CONFIG_IOMMUFD_DRIVER) && defined(pt_entry_is_write_dirty) && \ + IS_ENABLED(CONFIG_IOMMUFD_TEST) && defined(pt_entry_make_write_dirty) + .set_dirty = NS(set_dirty), +#endif + .get_info = NS(get_info), + .deinit = NS(deinit), +}; + +static int pt_init_common(struct pt_common *common) +{ + struct pt_range top_range = pt_top_range(common); + + if (PT_WARN_ON(top_range.top_level > PT_MAX_TOP_LEVEL)) + return -EINVAL; + + if (top_range.top_level == PT_MAX_TOP_LEVEL || + common->max_vasz_lg2 == top_range.max_vasz_lg2) + common->features &= ~BIT(PT_FEAT_DYNAMIC_TOP); + + if (top_range.max_vasz_lg2 == PT_VADDR_MAX_LG2) + common->features |= BIT(PT_FEAT_FULL_VA); + + /* Requested features must match features compiled into this format */ + if ((common->features & ~(unsigned int)PT_SUPPORTED_FEATURES) || + (!IS_ENABLED(CONFIG_DEBUG_GENERIC_PT) && + (common->features & PT_FORCE_ENABLED_FEATURES) != + PT_FORCE_ENABLED_FEATURES)) + return -EOPNOTSUPP; + + /* + * Check if the top level of the page table is too small to hold the + * specified maxvasz. + */ + if (!pt_feature(common, PT_FEAT_DYNAMIC_TOP) && + top_range.top_level != PT_MAX_TOP_LEVEL) { + struct pt_state pts = { .range = &top_range, + .level = top_range.top_level }; + + if (common->max_vasz_lg2 > + pt_num_items_lg2(&pts) + pt_table_item_lg2sz(&pts)) + return -EOPNOTSUPP; + } + + if (common->max_oasz_lg2 == 0) + common->max_oasz_lg2 = pt_max_oa_lg2(common); + else + common->max_oasz_lg2 = min(common->max_oasz_lg2, + pt_max_oa_lg2(common)); + return 0; +} + +static int pt_iommu_init_domain(struct pt_iommu *iommu_table, + struct iommu_domain *domain) +{ + struct pt_common *common = common_from_iommu(iommu_table); + struct pt_iommu_info info; + struct pt_range range; + + NS(get_info)(iommu_table, &info); + + domain->type = __IOMMU_DOMAIN_PAGING; + domain->pgsize_bitmap = info.pgsize_bitmap; + + if (pt_feature(common, PT_FEAT_DYNAMIC_TOP)) + range = _pt_top_range(common, + _pt_top_set(NULL, PT_MAX_TOP_LEVEL)); + else + range = pt_top_range(common); + + /* A 64-bit high address space table on a 32-bit system cannot work. */ + domain->geometry.aperture_start = (unsigned long)range.va; + if ((pt_vaddr_t)domain->geometry.aperture_start != range.va) + return -EOVERFLOW; + + /* + * The aperture is limited to what the API can do after considering all + * the different types dma_addr_t/unsigned long/pt_vaddr_t that are used + * to store a VA. Set the aperture to something that is valid for all + * cases. Saturate instead of truncate the end if the types are smaller + * than the top range. aperture_end should be called aperture_last. + */ + domain->geometry.aperture_end = (unsigned long)range.last_va; + if ((pt_vaddr_t)domain->geometry.aperture_end != range.last_va) { + domain->geometry.aperture_end = ULONG_MAX; + domain->pgsize_bitmap &= ULONG_MAX; + } + domain->geometry.force_aperture = true; + + return 0; +} + +static void pt_iommu_zero(struct pt_iommu_table *fmt_table) +{ + struct pt_iommu *iommu_table = &fmt_table->iommu; + struct pt_iommu cfg = *iommu_table; + + static_assert(offsetof(struct pt_iommu_table, iommu.domain) == 0); + memset_after(fmt_table, 0, iommu.domain); + + /* The caller can initialize some of these values */ + iommu_table->iommu_device = cfg.iommu_device; + iommu_table->driver_ops = cfg.driver_ops; + iommu_table->nid = cfg.nid; +} + +#define pt_iommu_table_cfg CONCATENATE(pt_iommu_table, _cfg) +#define pt_iommu_init CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), init) + +int pt_iommu_init(struct pt_iommu_table *fmt_table, + const struct pt_iommu_table_cfg *cfg, gfp_t gfp) +{ + struct pt_iommu *iommu_table = &fmt_table->iommu; + struct pt_common *common = common_from_iommu(iommu_table); + struct pt_table_p *table_mem; + int ret; + + if (cfg->common.hw_max_vasz_lg2 > PT_MAX_VA_ADDRESS_LG2 || + !cfg->common.hw_max_vasz_lg2 || !cfg->common.hw_max_oasz_lg2) + return -EINVAL; + + pt_iommu_zero(fmt_table); + common->features = cfg->common.features; + common->max_vasz_lg2 = cfg->common.hw_max_vasz_lg2; + common->max_oasz_lg2 = cfg->common.hw_max_oasz_lg2; + ret = pt_iommu_fmt_init(fmt_table, cfg); + if (ret) + return ret; + + if (cfg->common.hw_max_oasz_lg2 > pt_max_oa_lg2(common)) + return -EINVAL; + + ret = pt_init_common(common); + if (ret) + return ret; + + if (pt_feature(common, PT_FEAT_DYNAMIC_TOP) && + WARN_ON(!iommu_table->driver_ops || + !iommu_table->driver_ops->change_top || + !iommu_table->driver_ops->get_top_lock)) + return -EINVAL; + + if (pt_feature(common, PT_FEAT_SIGN_EXTEND) && + (pt_feature(common, PT_FEAT_FULL_VA) || + pt_feature(common, PT_FEAT_DYNAMIC_TOP))) + return -EINVAL; + + if (pt_feature(common, PT_FEAT_DMA_INCOHERENT) && + WARN_ON(!iommu_table->iommu_device)) + return -EINVAL; + + ret = pt_iommu_init_domain(iommu_table, &iommu_table->domain); + if (ret) + return ret; + + table_mem = table_alloc_top(common, common->top_of_table, gfp, + ALLOC_NORMAL); + if (IS_ERR(table_mem)) + return PTR_ERR(table_mem); + pt_top_set(common, table_mem, pt_top_get_level(common)); + + /* Must be last, see pt_iommu_deinit() */ + iommu_table->ops = &NS(ops); + return 0; +} +EXPORT_SYMBOL_NS_GPL(pt_iommu_init, "GENERIC_PT_IOMMU"); + +#ifdef pt_iommu_fmt_hw_info +#define pt_iommu_table_hw_info CONCATENATE(pt_iommu_table, _hw_info) +#define pt_iommu_hw_info CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), hw_info) +void pt_iommu_hw_info(struct pt_iommu_table *fmt_table, + struct pt_iommu_table_hw_info *info) +{ + struct pt_iommu *iommu_table = &fmt_table->iommu; + struct pt_common *common = common_from_iommu(iommu_table); + struct pt_range top_range = pt_top_range(common); + + pt_iommu_fmt_hw_info(fmt_table, &top_range, info); +} +EXPORT_SYMBOL_NS_GPL(pt_iommu_hw_info, "GENERIC_PT_IOMMU"); +#endif + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("IOMMU Page table implementation for " __stringify(PTPFX_RAW)); +MODULE_IMPORT_NS("GENERIC_PT"); +/* For iommu_dirty_bitmap_record() */ +MODULE_IMPORT_NS("IOMMUFD"); + +#endif /* __GENERIC_PT_IOMMU_PT_H */ diff --git a/drivers/iommu/generic_pt/kunit_generic_pt.h b/drivers/iommu/generic_pt/kunit_generic_pt.h new file mode 100644 index 000000000000..68278bf15cfe --- /dev/null +++ b/drivers/iommu/generic_pt/kunit_generic_pt.h @@ -0,0 +1,823 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + * Test the format API directly. + * + */ +#include "kunit_iommu.h" +#include "pt_iter.h" + +static void do_map(struct kunit *test, pt_vaddr_t va, pt_oaddr_t pa, + pt_vaddr_t len) +{ + struct kunit_iommu_priv *priv = test->priv; + int ret; + + KUNIT_ASSERT_EQ(test, len, (size_t)len); + + ret = iommu_map(&priv->domain, va, pa, len, IOMMU_READ | IOMMU_WRITE, + GFP_KERNEL); + KUNIT_ASSERT_NO_ERRNO_FN(test, "map_pages", ret); +} + +#define KUNIT_ASSERT_PT_LOAD(test, pts, entry) \ + ({ \ + pt_load_entry(pts); \ + KUNIT_ASSERT_EQ(test, (pts)->type, entry); \ + }) + +struct check_levels_arg { + struct kunit *test; + void *fn_arg; + void (*fn)(struct kunit *test, struct pt_state *pts, void *arg); +}; + +static int __check_all_levels(struct pt_range *range, void *arg, + unsigned int level, struct pt_table_p *table) +{ + struct pt_state pts = pt_init(range, level, table); + struct check_levels_arg *chk = arg; + struct kunit *test = chk->test; + int ret; + + _pt_iter_first(&pts); + + + /* + * If we were able to use the full VA space this should always be the + * last index in each table. + */ + if (!(IS_32BIT && range->max_vasz_lg2 > 32)) { + if (pt_feature(range->common, PT_FEAT_SIGN_EXTEND) && + pts.level == pts.range->top_level) + KUNIT_ASSERT_EQ(test, pts.index, + log2_to_int(range->max_vasz_lg2 - 1 - + pt_table_item_lg2sz(&pts)) - + 1); + else + KUNIT_ASSERT_EQ(test, pts.index, + log2_to_int(pt_table_oa_lg2sz(&pts) - + pt_table_item_lg2sz(&pts)) - + 1); + } + + if (pt_can_have_table(&pts)) { + pt_load_single_entry(&pts); + KUNIT_ASSERT_EQ(test, pts.type, PT_ENTRY_TABLE); + ret = pt_descend(&pts, arg, __check_all_levels); + KUNIT_ASSERT_EQ(test, ret, 0); + + /* Index 0 is used by the test */ + if (IS_32BIT && !pts.index) + return 0; + KUNIT_ASSERT_NE(chk->test, pts.index, 0); + } + + /* + * A format should not create a table with only one entry, at least this + * test approach won't work. + */ + KUNIT_ASSERT_GT(chk->test, pts.end_index, 1); + + /* + * For increase top we end up using index 0 for the original top's tree, + * so use index 1 for testing instead. + */ + pts.index = 0; + pt_index_to_va(&pts); + pt_load_single_entry(&pts); + if (pts.type == PT_ENTRY_TABLE && pts.end_index > 2) { + pts.index = 1; + pt_index_to_va(&pts); + } + (*chk->fn)(chk->test, &pts, chk->fn_arg); + return 0; +} + +/* + * Call fn for each level in the table with a pts setup to index 0 in a table + * for that level. This allows writing tests that run on every level. + * The test can use every index in the table except the last one. + */ +static void check_all_levels(struct kunit *test, + void (*fn)(struct kunit *test, + struct pt_state *pts, void *arg), + void *fn_arg) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_range range = pt_top_range(priv->common); + struct check_levels_arg chk = { + .test = test, + .fn = fn, + .fn_arg = fn_arg, + }; + int ret; + + if (pt_feature(priv->common, PT_FEAT_DYNAMIC_TOP) && + priv->common->max_vasz_lg2 > range.max_vasz_lg2) + range.last_va = fvalog2_set_mod_max(range.va, + priv->common->max_vasz_lg2); + + /* + * Map a page at the highest VA, this will populate all the levels so we + * can then iterate over them. Index 0 will be used for testing. + */ + if (IS_32BIT && range.max_vasz_lg2 > 32) + range.last_va = (u32)range.last_va; + range.va = range.last_va - (priv->smallest_pgsz - 1); + do_map(test, range.va, 0, priv->smallest_pgsz); + + range = pt_make_range(priv->common, range.va, range.last_va); + ret = pt_walk_range(&range, __check_all_levels, &chk); + KUNIT_ASSERT_EQ(test, ret, 0); +} + +static void test_init(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + + /* Fixture does the setup */ + KUNIT_ASSERT_NE(test, priv->info.pgsize_bitmap, 0); +} + +/* + * Basic check that the log2_* functions are working, especially at the integer + * limits. + */ +static void test_bitops(struct kunit *test) +{ + int i; + + KUNIT_ASSERT_EQ(test, fls_t(u32, 0), 0); + KUNIT_ASSERT_EQ(test, fls_t(u32, 1), 1); + KUNIT_ASSERT_EQ(test, fls_t(u32, BIT(2)), 3); + KUNIT_ASSERT_EQ(test, fls_t(u32, U32_MAX), 32); + + KUNIT_ASSERT_EQ(test, fls_t(u64, 0), 0); + KUNIT_ASSERT_EQ(test, fls_t(u64, 1), 1); + KUNIT_ASSERT_EQ(test, fls_t(u64, BIT(2)), 3); + KUNIT_ASSERT_EQ(test, fls_t(u64, U64_MAX), 64); + + KUNIT_ASSERT_EQ(test, ffs_t(u32, 1), 0); + KUNIT_ASSERT_EQ(test, ffs_t(u32, BIT(2)), 2); + KUNIT_ASSERT_EQ(test, ffs_t(u32, BIT(31)), 31); + + KUNIT_ASSERT_EQ(test, ffs_t(u64, 1), 0); + KUNIT_ASSERT_EQ(test, ffs_t(u64, BIT(2)), 2); + KUNIT_ASSERT_EQ(test, ffs_t(u64, BIT_ULL(63)), 63); + + for (i = 0; i != 31; i++) + KUNIT_ASSERT_EQ(test, ffz_t(u64, BIT_ULL(i) - 1), i); + + for (i = 0; i != 63; i++) + KUNIT_ASSERT_EQ(test, ffz_t(u64, BIT_ULL(i) - 1), i); + + for (i = 0; i != 32; i++) { + u64 val = get_random_u64(); + + KUNIT_ASSERT_EQ(test, log2_mod_t(u32, val, ffs_t(u32, val)), 0); + KUNIT_ASSERT_EQ(test, log2_mod_t(u64, val, ffs_t(u64, val)), 0); + + KUNIT_ASSERT_EQ(test, log2_mod_t(u32, val, ffz_t(u32, val)), + log2_to_max_int_t(u32, ffz_t(u32, val))); + KUNIT_ASSERT_EQ(test, log2_mod_t(u64, val, ffz_t(u64, val)), + log2_to_max_int_t(u64, ffz_t(u64, val))); + } +} + +static unsigned int ref_best_pgsize(pt_vaddr_t pgsz_bitmap, pt_vaddr_t va, + pt_vaddr_t last_va, pt_oaddr_t oa) +{ + pt_vaddr_t pgsz_lg2; + + /* Brute force the constraints described in pt_compute_best_pgsize() */ + for (pgsz_lg2 = PT_VADDR_MAX_LG2 - 1; pgsz_lg2 != 0; pgsz_lg2--) { + if ((pgsz_bitmap & log2_to_int(pgsz_lg2)) && + log2_mod(va, pgsz_lg2) == 0 && + oalog2_mod(oa, pgsz_lg2) == 0 && + va + log2_to_int(pgsz_lg2) - 1 <= last_va && + log2_div_eq(va, va + log2_to_int(pgsz_lg2) - 1, pgsz_lg2) && + oalog2_div_eq(oa, oa + log2_to_int(pgsz_lg2) - 1, pgsz_lg2)) + return pgsz_lg2; + } + return 0; +} + +/* Check that the bit logic in pt_compute_best_pgsize() works. */ +static void test_best_pgsize(struct kunit *test) +{ + unsigned int a_lg2; + unsigned int b_lg2; + unsigned int c_lg2; + + /* Try random prefixes with every suffix combination */ + for (a_lg2 = 1; a_lg2 != 10; a_lg2++) { + for (b_lg2 = 1; b_lg2 != 10; b_lg2++) { + for (c_lg2 = 1; c_lg2 != 10; c_lg2++) { + pt_vaddr_t pgsz_bitmap = get_random_u64(); + pt_vaddr_t va = get_random_u64() << a_lg2; + pt_oaddr_t oa = get_random_u64() << b_lg2; + pt_vaddr_t last_va = log2_set_mod_max( + get_random_u64(), c_lg2); + + if (va > last_va) + swap(va, last_va); + KUNIT_ASSERT_EQ( + test, + pt_compute_best_pgsize(pgsz_bitmap, va, + last_va, oa), + ref_best_pgsize(pgsz_bitmap, va, + last_va, oa)); + } + } + } + + /* 0 prefix, every suffix */ + for (c_lg2 = 1; c_lg2 != PT_VADDR_MAX_LG2 - 1; c_lg2++) { + pt_vaddr_t pgsz_bitmap = get_random_u64(); + pt_vaddr_t va = 0; + pt_oaddr_t oa = 0; + pt_vaddr_t last_va = log2_set_mod_max(0, c_lg2); + + KUNIT_ASSERT_EQ(test, + pt_compute_best_pgsize(pgsz_bitmap, va, last_va, + oa), + ref_best_pgsize(pgsz_bitmap, va, last_va, oa)); + } + + /* 1's prefix, every suffix */ + for (a_lg2 = 1; a_lg2 != 10; a_lg2++) { + for (b_lg2 = 1; b_lg2 != 10; b_lg2++) { + for (c_lg2 = 1; c_lg2 != 10; c_lg2++) { + pt_vaddr_t pgsz_bitmap = get_random_u64(); + pt_vaddr_t va = PT_VADDR_MAX << a_lg2; + pt_oaddr_t oa = PT_VADDR_MAX << b_lg2; + pt_vaddr_t last_va = PT_VADDR_MAX; + + KUNIT_ASSERT_EQ( + test, + pt_compute_best_pgsize(pgsz_bitmap, va, + last_va, oa), + ref_best_pgsize(pgsz_bitmap, va, + last_va, oa)); + } + } + } + + /* pgsize_bitmap is always 0 */ + for (a_lg2 = 1; a_lg2 != 10; a_lg2++) { + for (b_lg2 = 1; b_lg2 != 10; b_lg2++) { + for (c_lg2 = 1; c_lg2 != 10; c_lg2++) { + pt_vaddr_t pgsz_bitmap = 0; + pt_vaddr_t va = get_random_u64() << a_lg2; + pt_oaddr_t oa = get_random_u64() << b_lg2; + pt_vaddr_t last_va = log2_set_mod_max( + get_random_u64(), c_lg2); + + if (va > last_va) + swap(va, last_va); + KUNIT_ASSERT_EQ( + test, + pt_compute_best_pgsize(pgsz_bitmap, va, + last_va, oa), + 0); + } + } + } + + if (sizeof(pt_vaddr_t) <= 4) + return; + + /* over 32 bit page sizes */ + for (a_lg2 = 32; a_lg2 != 42; a_lg2++) { + for (b_lg2 = 32; b_lg2 != 42; b_lg2++) { + for (c_lg2 = 32; c_lg2 != 42; c_lg2++) { + pt_vaddr_t pgsz_bitmap = get_random_u64(); + pt_vaddr_t va = get_random_u64() << a_lg2; + pt_oaddr_t oa = get_random_u64() << b_lg2; + pt_vaddr_t last_va = log2_set_mod_max( + get_random_u64(), c_lg2); + + if (va > last_va) + swap(va, last_va); + KUNIT_ASSERT_EQ( + test, + pt_compute_best_pgsize(pgsz_bitmap, va, + last_va, oa), + ref_best_pgsize(pgsz_bitmap, va, + last_va, oa)); + } + } + } +} + +/* + * Check that pt_install_table() and pt_table_pa() match + */ +static void test_lvl_table_ptr(struct kunit *test, struct pt_state *pts, + void *arg) +{ + struct kunit_iommu_priv *priv = test->priv; + pt_oaddr_t paddr = + log2_set_mod(priv->test_oa, 0, priv->smallest_pgsz_lg2); + struct pt_write_attrs attrs = {}; + + if (!pt_can_have_table(pts)) + return; + + KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot", + pt_iommu_set_prot(pts->range->common, &attrs, + IOMMU_READ)); + + pt_load_single_entry(pts); + KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY); + + KUNIT_ASSERT_TRUE(test, pt_install_table(pts, paddr, &attrs)); + + /* A second install should pass because install updates pts->entry. */ + KUNIT_ASSERT_EQ(test, pt_install_table(pts, paddr, &attrs), true); + + KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_TABLE); + KUNIT_ASSERT_EQ(test, pt_table_pa(pts), paddr); + + pt_clear_entries(pts, ilog2(1)); + KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY); +} + +static void test_table_ptr(struct kunit *test) +{ + check_all_levels(test, test_lvl_table_ptr, NULL); +} + +struct lvl_radix_arg { + pt_vaddr_t vbits; +}; + +/* + * Check pt_table_oa_lg2sz() and pt_table_item_lg2sz() they need to decode a + * continuous list of VA across all the levels that covers the entire advertised + * VA space. + */ +static void test_lvl_radix(struct kunit *test, struct pt_state *pts, void *arg) +{ + unsigned int table_lg2sz = pt_table_oa_lg2sz(pts); + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + struct lvl_radix_arg *radix = arg; + + /* Every bit below us is decoded */ + KUNIT_ASSERT_EQ(test, log2_set_mod_max(0, isz_lg2), radix->vbits); + + /* We are not decoding bits someone else is */ + KUNIT_ASSERT_EQ(test, log2_div(radix->vbits, isz_lg2), 0); + + /* Can't decode past the pt_vaddr_t size */ + KUNIT_ASSERT_LE(test, table_lg2sz, PT_VADDR_MAX_LG2); + KUNIT_ASSERT_EQ(test, fvalog2_div(table_lg2sz, PT_MAX_VA_ADDRESS_LG2), + 0); + + radix->vbits = fvalog2_set_mod_max(0, table_lg2sz); +} + +static void test_max_va(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_range range = pt_top_range(priv->common); + + KUNIT_ASSERT_GE(test, priv->common->max_vasz_lg2, range.max_vasz_lg2); +} + +static void test_table_radix(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + struct lvl_radix_arg radix = { .vbits = priv->smallest_pgsz - 1 }; + struct pt_range range; + + check_all_levels(test, test_lvl_radix, &radix); + + range = pt_top_range(priv->common); + if (range.max_vasz_lg2 == PT_VADDR_MAX_LG2) { + KUNIT_ASSERT_EQ(test, radix.vbits, PT_VADDR_MAX); + } else { + if (!IS_32BIT) + KUNIT_ASSERT_EQ(test, + log2_set_mod_max(0, range.max_vasz_lg2), + radix.vbits); + KUNIT_ASSERT_EQ(test, log2_div(radix.vbits, range.max_vasz_lg2), + 0); + } +} + +static unsigned int safe_pt_num_items_lg2(const struct pt_state *pts) +{ + struct pt_range top_range = pt_top_range(pts->range->common); + struct pt_state top_pts = pt_init_top(&top_range); + + /* + * Avoid calling pt_num_items_lg2() on the top, instead we can derive + * the size of the top table from the top range. + */ + if (pts->level == top_range.top_level) + return ilog2(pt_range_to_end_index(&top_pts)); + return pt_num_items_lg2(pts); +} + +static void test_lvl_possible_sizes(struct kunit *test, struct pt_state *pts, + void *arg) +{ + unsigned int num_items_lg2 = safe_pt_num_items_lg2(pts); + pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts); + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + + if (!pt_can_have_leaf(pts)) { + KUNIT_ASSERT_EQ(test, pgsize_bitmap, 0); + return; + } + + /* No bits for sizes that would be outside this table */ + KUNIT_ASSERT_EQ(test, log2_mod(pgsize_bitmap, isz_lg2), 0); + KUNIT_ASSERT_EQ( + test, fvalog2_div(pgsize_bitmap, num_items_lg2 + isz_lg2), 0); + + /* + * Non contiguous must be supported. AMDv1 has a HW bug where it does + * not support it on one of the levels. + */ + if ((u64)pgsize_bitmap != 0xff0000000000ULL || + strcmp(__stringify(PTPFX_RAW), "amdv1") != 0) + KUNIT_ASSERT_TRUE(test, pgsize_bitmap & log2_to_int(isz_lg2)); + else + KUNIT_ASSERT_NE(test, pgsize_bitmap, 0); + + /* A contiguous entry should not span the whole table */ + if (num_items_lg2 + isz_lg2 != PT_VADDR_MAX_LG2) + KUNIT_ASSERT_FALSE( + test, + pgsize_bitmap & log2_to_int(num_items_lg2 + isz_lg2)); +} + +static void test_entry_possible_sizes(struct kunit *test) +{ + check_all_levels(test, test_lvl_possible_sizes, NULL); +} + +static void sweep_all_pgsizes(struct kunit *test, struct pt_state *pts, + struct pt_write_attrs *attrs, + pt_oaddr_t test_oaddr) +{ + pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts); + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + unsigned int len_lg2; + + if (pts->index != 0) + return; + + for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2 - 1; len_lg2++) { + struct pt_state sub_pts = *pts; + pt_oaddr_t oaddr; + + if (!(pgsize_bitmap & log2_to_int(len_lg2))) + continue; + + oaddr = log2_set_mod(test_oaddr, 0, len_lg2); + pt_install_leaf_entry(pts, oaddr, len_lg2, attrs); + /* Verify that every contiguous item translates correctly */ + for (sub_pts.index = 0; + sub_pts.index != log2_to_int(len_lg2 - isz_lg2); + sub_pts.index++) { + KUNIT_ASSERT_PT_LOAD(test, &sub_pts, PT_ENTRY_OA); + KUNIT_ASSERT_EQ(test, pt_item_oa(&sub_pts), + oaddr + sub_pts.index * + oalog2_mul(1, isz_lg2)); + KUNIT_ASSERT_EQ(test, pt_entry_oa(&sub_pts), oaddr); + KUNIT_ASSERT_EQ(test, pt_entry_num_contig_lg2(&sub_pts), + len_lg2 - isz_lg2); + } + + pt_clear_entries(pts, len_lg2 - isz_lg2); + KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY); + } +} + +/* + * Check that pt_install_leaf_entry() and pt_entry_oa() match. + * Check that pt_clear_entries() works. + */ +static void test_lvl_entry_oa(struct kunit *test, struct pt_state *pts, + void *arg) +{ + unsigned int max_oa_lg2 = pts->range->common->max_oasz_lg2; + struct kunit_iommu_priv *priv = test->priv; + struct pt_write_attrs attrs = {}; + + if (!pt_can_have_leaf(pts)) + return; + + KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot", + pt_iommu_set_prot(pts->range->common, &attrs, + IOMMU_READ)); + + sweep_all_pgsizes(test, pts, &attrs, priv->test_oa); + + /* Check that the table can store the boundary OAs */ + sweep_all_pgsizes(test, pts, &attrs, 0); + if (max_oa_lg2 == PT_OADDR_MAX_LG2) + sweep_all_pgsizes(test, pts, &attrs, PT_OADDR_MAX); + else + sweep_all_pgsizes(test, pts, &attrs, + oalog2_to_max_int(max_oa_lg2)); +} + +static void test_entry_oa(struct kunit *test) +{ + check_all_levels(test, test_lvl_entry_oa, NULL); +} + +/* Test pt_attr_from_entry() */ +static void test_lvl_attr_from_entry(struct kunit *test, struct pt_state *pts, + void *arg) +{ + pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts); + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + struct kunit_iommu_priv *priv = test->priv; + unsigned int len_lg2; + unsigned int prot; + + if (!pt_can_have_leaf(pts)) + return; + + for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2; len_lg2++) { + if (!(pgsize_bitmap & log2_to_int(len_lg2))) + continue; + for (prot = 0; prot <= (IOMMU_READ | IOMMU_WRITE | IOMMU_CACHE | + IOMMU_NOEXEC | IOMMU_MMIO); + prot++) { + pt_oaddr_t oaddr; + struct pt_write_attrs attrs = {}; + u64 good_entry; + + /* + * If the format doesn't support this combination of + * prot bits skip it + */ + if (pt_iommu_set_prot(pts->range->common, &attrs, + prot)) { + /* But RW has to be supported */ + KUNIT_ASSERT_NE(test, prot, + IOMMU_READ | IOMMU_WRITE); + continue; + } + + oaddr = log2_set_mod(priv->test_oa, 0, len_lg2); + pt_install_leaf_entry(pts, oaddr, len_lg2, &attrs); + KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_OA); + + good_entry = pts->entry; + + memset(&attrs, 0, sizeof(attrs)); + pt_attr_from_entry(pts, &attrs); + + pt_clear_entries(pts, len_lg2 - isz_lg2); + KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY); + + pt_install_leaf_entry(pts, oaddr, len_lg2, &attrs); + KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_OA); + + /* + * The descriptor produced by pt_attr_from_entry() + * produce an identical entry value when re-written + */ + KUNIT_ASSERT_EQ(test, good_entry, pts->entry); + + pt_clear_entries(pts, len_lg2 - isz_lg2); + } + } +} + +static void test_attr_from_entry(struct kunit *test) +{ + check_all_levels(test, test_lvl_attr_from_entry, NULL); +} + +static void test_lvl_dirty(struct kunit *test, struct pt_state *pts, void *arg) +{ + pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts); + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + struct kunit_iommu_priv *priv = test->priv; + unsigned int start_idx = pts->index; + struct pt_write_attrs attrs = {}; + unsigned int len_lg2; + + if (!pt_can_have_leaf(pts)) + return; + + KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot", + pt_iommu_set_prot(pts->range->common, &attrs, + IOMMU_READ | IOMMU_WRITE)); + + for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2; len_lg2++) { + pt_oaddr_t oaddr; + unsigned int i; + + if (!(pgsize_bitmap & log2_to_int(len_lg2))) + continue; + + oaddr = log2_set_mod(priv->test_oa, 0, len_lg2); + pt_install_leaf_entry(pts, oaddr, len_lg2, &attrs); + KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_OA); + + pt_load_entry(pts); + pt_entry_make_write_clean(pts); + pt_load_entry(pts); + KUNIT_ASSERT_FALSE(test, pt_entry_is_write_dirty(pts)); + + for (i = 0; i != log2_to_int(len_lg2 - isz_lg2); i++) { + /* dirty every contiguous entry */ + pts->index = start_idx + i; + pt_load_entry(pts); + KUNIT_ASSERT_TRUE(test, pt_entry_make_write_dirty(pts)); + pts->index = start_idx; + pt_load_entry(pts); + KUNIT_ASSERT_TRUE(test, pt_entry_is_write_dirty(pts)); + + pt_entry_make_write_clean(pts); + pt_load_entry(pts); + KUNIT_ASSERT_FALSE(test, pt_entry_is_write_dirty(pts)); + } + + pt_clear_entries(pts, len_lg2 - isz_lg2); + } +} + +static __maybe_unused void test_dirty(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + + if (!pt_dirty_supported(priv->common)) + kunit_skip(test, + "Page table features do not support dirty tracking"); + + check_all_levels(test, test_lvl_dirty, NULL); +} + +static void test_lvl_sw_bit_leaf(struct kunit *test, struct pt_state *pts, + void *arg) +{ + struct kunit_iommu_priv *priv = test->priv; + pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts); + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + struct pt_write_attrs attrs = {}; + unsigned int len_lg2; + + if (!pt_can_have_leaf(pts)) + return; + if (pts->index != 0) + return; + + KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot", + pt_iommu_set_prot(pts->range->common, &attrs, + IOMMU_READ)); + + for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2 - 1; len_lg2++) { + pt_oaddr_t paddr = log2_set_mod(priv->test_oa, 0, len_lg2); + struct pt_write_attrs new_attrs = {}; + unsigned int bitnr; + + if (!(pgsize_bitmap & log2_to_int(len_lg2))) + continue; + + pt_install_leaf_entry(pts, paddr, len_lg2, &attrs); + + for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); + bitnr++) + KUNIT_ASSERT_FALSE(test, + pt_test_sw_bit_acquire(pts, bitnr)); + + for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); + bitnr++) { + KUNIT_ASSERT_FALSE(test, + pt_test_sw_bit_acquire(pts, bitnr)); + pt_set_sw_bit_release(pts, bitnr); + KUNIT_ASSERT_TRUE(test, + pt_test_sw_bit_acquire(pts, bitnr)); + } + + for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); + bitnr++) + KUNIT_ASSERT_TRUE(test, + pt_test_sw_bit_acquire(pts, bitnr)); + + KUNIT_ASSERT_EQ(test, pt_item_oa(pts), paddr); + + /* SW bits didn't leak into the attrs */ + pt_attr_from_entry(pts, &new_attrs); + KUNIT_ASSERT_MEMEQ(test, &new_attrs, &attrs, sizeof(attrs)); + + pt_clear_entries(pts, len_lg2 - isz_lg2); + KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY); + } +} + +static __maybe_unused void test_sw_bit_leaf(struct kunit *test) +{ + check_all_levels(test, test_lvl_sw_bit_leaf, NULL); +} + +static void test_lvl_sw_bit_table(struct kunit *test, struct pt_state *pts, + void *arg) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_write_attrs attrs = {}; + pt_oaddr_t paddr = + log2_set_mod(priv->test_oa, 0, priv->smallest_pgsz_lg2); + unsigned int bitnr; + + if (!pt_can_have_leaf(pts)) + return; + if (pts->index != 0) + return; + + KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot", + pt_iommu_set_prot(pts->range->common, &attrs, + IOMMU_READ)); + + KUNIT_ASSERT_TRUE(test, pt_install_table(pts, paddr, &attrs)); + + for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); bitnr++) + KUNIT_ASSERT_FALSE(test, pt_test_sw_bit_acquire(pts, bitnr)); + + for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); bitnr++) { + KUNIT_ASSERT_FALSE(test, pt_test_sw_bit_acquire(pts, bitnr)); + pt_set_sw_bit_release(pts, bitnr); + KUNIT_ASSERT_TRUE(test, pt_test_sw_bit_acquire(pts, bitnr)); + } + + for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); bitnr++) + KUNIT_ASSERT_TRUE(test, pt_test_sw_bit_acquire(pts, bitnr)); + + KUNIT_ASSERT_EQ(test, pt_table_pa(pts), paddr); + + pt_clear_entries(pts, ilog2(1)); + KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY); +} + +static __maybe_unused void test_sw_bit_table(struct kunit *test) +{ + check_all_levels(test, test_lvl_sw_bit_table, NULL); +} + +static struct kunit_case generic_pt_test_cases[] = { + KUNIT_CASE_FMT(test_init), + KUNIT_CASE_FMT(test_bitops), + KUNIT_CASE_FMT(test_best_pgsize), + KUNIT_CASE_FMT(test_table_ptr), + KUNIT_CASE_FMT(test_max_va), + KUNIT_CASE_FMT(test_table_radix), + KUNIT_CASE_FMT(test_entry_possible_sizes), + KUNIT_CASE_FMT(test_entry_oa), + KUNIT_CASE_FMT(test_attr_from_entry), +#ifdef pt_entry_is_write_dirty + KUNIT_CASE_FMT(test_dirty), +#endif +#ifdef pt_sw_bit + KUNIT_CASE_FMT(test_sw_bit_leaf), + KUNIT_CASE_FMT(test_sw_bit_table), +#endif + {}, +}; + +static int pt_kunit_generic_pt_init(struct kunit *test) +{ + struct kunit_iommu_priv *priv; + int ret; + + priv = kunit_kzalloc(test, sizeof(*priv), GFP_KERNEL); + if (!priv) + return -ENOMEM; + ret = pt_kunit_priv_init(test, priv); + if (ret) { + kunit_kfree(test, priv); + return ret; + } + test->priv = priv; + return 0; +} + +static void pt_kunit_generic_pt_exit(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + + if (!test->priv) + return; + + pt_iommu_deinit(priv->iommu); + kunit_kfree(test, test->priv); +} + +static struct kunit_suite NS(generic_pt_suite) = { + .name = __stringify(NS(fmt_test)), + .init = pt_kunit_generic_pt_init, + .exit = pt_kunit_generic_pt_exit, + .test_cases = generic_pt_test_cases, +}; +kunit_test_suites(&NS(generic_pt_suite)); diff --git a/drivers/iommu/generic_pt/kunit_iommu.h b/drivers/iommu/generic_pt/kunit_iommu.h new file mode 100644 index 000000000000..22c9e4c4dd97 --- /dev/null +++ b/drivers/iommu/generic_pt/kunit_iommu.h @@ -0,0 +1,184 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + */ +#ifndef __GENERIC_PT_KUNIT_IOMMU_H +#define __GENERIC_PT_KUNIT_IOMMU_H + +#define GENERIC_PT_KUNIT 1 +#include <kunit/device.h> +#include <kunit/test.h> +#include "../iommu-pages.h" +#include "pt_iter.h" + +#define pt_iommu_table_cfg CONCATENATE(pt_iommu_table, _cfg) +#define pt_iommu_init CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), init) +int pt_iommu_init(struct pt_iommu_table *fmt_table, + const struct pt_iommu_table_cfg *cfg, gfp_t gfp); + +/* The format can provide a list of configurations it would like to test */ +#ifdef kunit_fmt_cfgs +static const void *kunit_pt_gen_params_cfg(struct kunit *test, const void *prev, + char *desc) +{ + uintptr_t cfg_id = (uintptr_t)prev; + + cfg_id++; + if (cfg_id >= ARRAY_SIZE(kunit_fmt_cfgs) + 1) + return NULL; + snprintf(desc, KUNIT_PARAM_DESC_SIZE, "%s_cfg_%u", + __stringify(PTPFX_RAW), (unsigned int)(cfg_id - 1)); + return (void *)cfg_id; +} +#define KUNIT_CASE_FMT(test_name) \ + KUNIT_CASE_PARAM(test_name, kunit_pt_gen_params_cfg) +#else +#define KUNIT_CASE_FMT(test_name) KUNIT_CASE(test_name) +#endif + +#define KUNIT_ASSERT_NO_ERRNO(test, ret) \ + KUNIT_ASSERT_EQ_MSG(test, ret, 0, KUNIT_SUBSUBTEST_INDENT "errno %pe", \ + ERR_PTR(ret)) + +#define KUNIT_ASSERT_NO_ERRNO_FN(test, fn, ret) \ + KUNIT_ASSERT_EQ_MSG(test, ret, 0, \ + KUNIT_SUBSUBTEST_INDENT "errno %pe from %s", \ + ERR_PTR(ret), fn) + +/* + * When the test is run on a 32 bit system unsigned long can be 32 bits. This + * cause the iommu op signatures to be restricted to 32 bits. Meaning the test + * has to be mindful not to create any VA's over the 32 bit limit. Reduce the + * scope of the testing as the main purpose of checking on full 32 bit is to + * look for 32bitism in the core code. Run the test on i386 with X86_PAE=y to + * get the full coverage when dma_addr_t & phys_addr_t are 8 bytes + */ +#define IS_32BIT (sizeof(unsigned long) == 4) + +struct kunit_iommu_priv { + union { + struct iommu_domain domain; + struct pt_iommu_table fmt_table; + }; + spinlock_t top_lock; + struct device *dummy_dev; + struct pt_iommu *iommu; + struct pt_common *common; + struct pt_iommu_table_cfg cfg; + struct pt_iommu_info info; + unsigned int smallest_pgsz_lg2; + pt_vaddr_t smallest_pgsz; + unsigned int largest_pgsz_lg2; + pt_oaddr_t test_oa; + pt_vaddr_t safe_pgsize_bitmap; + unsigned long orig_nr_secondary_pagetable; + +}; +PT_IOMMU_CHECK_DOMAIN(struct kunit_iommu_priv, fmt_table.iommu, domain); + +static void pt_kunit_iotlb_sync(struct iommu_domain *domain, + struct iommu_iotlb_gather *gather) +{ + iommu_put_pages_list(&gather->freelist); +} + +#define IOMMU_PT_DOMAIN_OPS1(x) IOMMU_PT_DOMAIN_OPS(x) +static const struct iommu_domain_ops kunit_pt_ops = { + IOMMU_PT_DOMAIN_OPS1(PTPFX_RAW), + .iotlb_sync = &pt_kunit_iotlb_sync, +}; + +static void pt_kunit_change_top(struct pt_iommu *iommu_table, + phys_addr_t top_paddr, unsigned int top_level) +{ +} + +static spinlock_t *pt_kunit_get_top_lock(struct pt_iommu *iommu_table) +{ + struct kunit_iommu_priv *priv = container_of( + iommu_table, struct kunit_iommu_priv, fmt_table.iommu); + + return &priv->top_lock; +} + +static const struct pt_iommu_driver_ops pt_kunit_driver_ops = { + .change_top = &pt_kunit_change_top, + .get_top_lock = &pt_kunit_get_top_lock, +}; + +static int pt_kunit_priv_init(struct kunit *test, struct kunit_iommu_priv *priv) +{ + unsigned int va_lg2sz; + int ret; + + /* Enough so the memory allocator works */ + priv->dummy_dev = kunit_device_register(test, "pt_kunit_dev"); + if (IS_ERR(priv->dummy_dev)) + return PTR_ERR(priv->dummy_dev); + set_dev_node(priv->dummy_dev, NUMA_NO_NODE); + + spin_lock_init(&priv->top_lock); + +#ifdef kunit_fmt_cfgs + priv->cfg = kunit_fmt_cfgs[((uintptr_t)test->param_value) - 1]; + /* + * The format can set a list of features that the kunit_fmt_cfgs + * controls, other features are default to on. + */ + priv->cfg.common.features |= PT_SUPPORTED_FEATURES & + (~KUNIT_FMT_FEATURES); +#else + priv->cfg.common.features = PT_SUPPORTED_FEATURES; +#endif + + /* Defaults, for the kunit */ + if (!priv->cfg.common.hw_max_vasz_lg2) + priv->cfg.common.hw_max_vasz_lg2 = PT_MAX_VA_ADDRESS_LG2; + if (!priv->cfg.common.hw_max_oasz_lg2) + priv->cfg.common.hw_max_oasz_lg2 = pt_max_oa_lg2(NULL); + + priv->fmt_table.iommu.nid = NUMA_NO_NODE; + priv->fmt_table.iommu.driver_ops = &pt_kunit_driver_ops; + priv->fmt_table.iommu.iommu_device = priv->dummy_dev; + priv->domain.ops = &kunit_pt_ops; + ret = pt_iommu_init(&priv->fmt_table, &priv->cfg, GFP_KERNEL); + if (ret) { + if (ret == -EOVERFLOW) + kunit_skip( + test, + "This configuration cannot be tested on 32 bit"); + return ret; + } + + priv->iommu = &priv->fmt_table.iommu; + priv->common = common_from_iommu(&priv->fmt_table.iommu); + priv->iommu->ops->get_info(priv->iommu, &priv->info); + + /* + * size_t is used to pass the mapping length, it can be 32 bit, truncate + * the pagesizes so we don't use large sizes. + */ + priv->info.pgsize_bitmap = (size_t)priv->info.pgsize_bitmap; + + priv->smallest_pgsz_lg2 = vaffs(priv->info.pgsize_bitmap); + priv->smallest_pgsz = log2_to_int(priv->smallest_pgsz_lg2); + priv->largest_pgsz_lg2 = + vafls((dma_addr_t)priv->info.pgsize_bitmap) - 1; + + priv->test_oa = + oalog2_mod(0x74a71445deadbeef, priv->common->max_oasz_lg2); + + /* + * We run out of VA space if the mappings get too big, make something + * smaller that can safely pass through dma_addr_t API. + */ + va_lg2sz = priv->common->max_vasz_lg2; + if (IS_32BIT && va_lg2sz > 32) + va_lg2sz = 32; + priv->safe_pgsize_bitmap = + log2_mod(priv->info.pgsize_bitmap, va_lg2sz - 1); + + return 0; +} + +#endif diff --git a/drivers/iommu/generic_pt/kunit_iommu_pt.h b/drivers/iommu/generic_pt/kunit_iommu_pt.h new file mode 100644 index 000000000000..e8a63c8ea850 --- /dev/null +++ b/drivers/iommu/generic_pt/kunit_iommu_pt.h @@ -0,0 +1,487 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES + */ +#include "kunit_iommu.h" +#include "pt_iter.h" +#include <linux/generic_pt/iommu.h> +#include <linux/iommu.h> + +static void do_map(struct kunit *test, pt_vaddr_t va, pt_oaddr_t pa, + pt_vaddr_t len); + +struct count_valids { + u64 per_size[PT_VADDR_MAX_LG2]; +}; + +static int __count_valids(struct pt_range *range, void *arg, unsigned int level, + struct pt_table_p *table) +{ + struct pt_state pts = pt_init(range, level, table); + struct count_valids *valids = arg; + + for_each_pt_level_entry(&pts) { + if (pts.type == PT_ENTRY_TABLE) { + pt_descend(&pts, arg, __count_valids); + continue; + } + if (pts.type == PT_ENTRY_OA) { + valids->per_size[pt_entry_oa_lg2sz(&pts)]++; + continue; + } + } + return 0; +} + +/* + * Number of valid table entries. This counts contiguous entries as a single + * valid. + */ +static unsigned int count_valids(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_range range = pt_top_range(priv->common); + struct count_valids valids = {}; + u64 total = 0; + unsigned int i; + + KUNIT_ASSERT_NO_ERRNO(test, + pt_walk_range(&range, __count_valids, &valids)); + + for (i = 0; i != ARRAY_SIZE(valids.per_size); i++) + total += valids.per_size[i]; + return total; +} + +/* Only a single page size is present, count the number of valid entries */ +static unsigned int count_valids_single(struct kunit *test, pt_vaddr_t pgsz) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_range range = pt_top_range(priv->common); + struct count_valids valids = {}; + u64 total = 0; + unsigned int i; + + KUNIT_ASSERT_NO_ERRNO(test, + pt_walk_range(&range, __count_valids, &valids)); + + for (i = 0; i != ARRAY_SIZE(valids.per_size); i++) { + if ((1ULL << i) == pgsz) + total = valids.per_size[i]; + else + KUNIT_ASSERT_EQ(test, valids.per_size[i], 0); + } + return total; +} + +static void do_unmap(struct kunit *test, pt_vaddr_t va, pt_vaddr_t len) +{ + struct kunit_iommu_priv *priv = test->priv; + size_t ret; + + ret = iommu_unmap(&priv->domain, va, len); + KUNIT_ASSERT_EQ(test, ret, len); +} + +static void check_iova(struct kunit *test, pt_vaddr_t va, pt_oaddr_t pa, + pt_vaddr_t len) +{ + struct kunit_iommu_priv *priv = test->priv; + pt_vaddr_t pfn = log2_div(va, priv->smallest_pgsz_lg2); + pt_vaddr_t end_pfn = pfn + log2_div(len, priv->smallest_pgsz_lg2); + + for (; pfn != end_pfn; pfn++) { + phys_addr_t res = iommu_iova_to_phys(&priv->domain, + pfn * priv->smallest_pgsz); + + KUNIT_ASSERT_EQ(test, res, (phys_addr_t)pa); + if (res != pa) + break; + pa += priv->smallest_pgsz; + } +} + +static void test_increase_level(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_common *common = priv->common; + + if (!pt_feature(common, PT_FEAT_DYNAMIC_TOP)) + kunit_skip(test, "PT_FEAT_DYNAMIC_TOP not set for this format"); + + if (IS_32BIT) + kunit_skip(test, "Unable to test on 32bit"); + + KUNIT_ASSERT_GT(test, common->max_vasz_lg2, + pt_top_range(common).max_vasz_lg2); + + /* Add every possible level to the max */ + while (common->max_vasz_lg2 != pt_top_range(common).max_vasz_lg2) { + struct pt_range top_range = pt_top_range(common); + + if (top_range.va == 0) + do_map(test, top_range.last_va + 1, 0, + priv->smallest_pgsz); + else + do_map(test, top_range.va - priv->smallest_pgsz, 0, + priv->smallest_pgsz); + + KUNIT_ASSERT_EQ(test, pt_top_range(common).top_level, + top_range.top_level + 1); + KUNIT_ASSERT_GE(test, common->max_vasz_lg2, + pt_top_range(common).max_vasz_lg2); + } +} + +static void test_map_simple(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_range range = pt_top_range(priv->common); + struct count_valids valids = {}; + pt_vaddr_t pgsize_bitmap = priv->safe_pgsize_bitmap; + unsigned int pgsz_lg2; + pt_vaddr_t cur_va; + + /* Map every reported page size */ + cur_va = range.va + priv->smallest_pgsz * 256; + for (pgsz_lg2 = 0; pgsz_lg2 != PT_VADDR_MAX_LG2; pgsz_lg2++) { + pt_oaddr_t paddr = log2_set_mod(priv->test_oa, 0, pgsz_lg2); + u64 len = log2_to_int(pgsz_lg2); + + if (!(pgsize_bitmap & len)) + continue; + + cur_va = ALIGN(cur_va, len); + do_map(test, cur_va, paddr, len); + if (len <= SZ_2G) + check_iova(test, cur_va, paddr, len); + cur_va += len; + } + + /* The read interface reports that every page size was created */ + range = pt_top_range(priv->common); + KUNIT_ASSERT_NO_ERRNO(test, + pt_walk_range(&range, __count_valids, &valids)); + for (pgsz_lg2 = 0; pgsz_lg2 != PT_VADDR_MAX_LG2; pgsz_lg2++) { + if (pgsize_bitmap & (1ULL << pgsz_lg2)) + KUNIT_ASSERT_EQ(test, valids.per_size[pgsz_lg2], 1); + else + KUNIT_ASSERT_EQ(test, valids.per_size[pgsz_lg2], 0); + } + + /* Unmap works */ + range = pt_top_range(priv->common); + cur_va = range.va + priv->smallest_pgsz * 256; + for (pgsz_lg2 = 0; pgsz_lg2 != PT_VADDR_MAX_LG2; pgsz_lg2++) { + u64 len = log2_to_int(pgsz_lg2); + + if (!(pgsize_bitmap & len)) + continue; + cur_va = ALIGN(cur_va, len); + do_unmap(test, cur_va, len); + cur_va += len; + } + KUNIT_ASSERT_EQ(test, count_valids(test), 0); +} + +/* + * Test to convert a table pointer into an OA by mapping something small, + * unmapping it so as to leave behind a table pointer, then mapping something + * larger that will convert the table into an OA. + */ +static void test_map_table_to_oa(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + pt_vaddr_t limited_pgbitmap = + priv->info.pgsize_bitmap % (IS_32BIT ? SZ_2G : SZ_16G); + struct pt_range range = pt_top_range(priv->common); + unsigned int pgsz_lg2; + pt_vaddr_t max_pgsize; + pt_vaddr_t cur_va; + + max_pgsize = 1ULL << (vafls(limited_pgbitmap) - 1); + KUNIT_ASSERT_TRUE(test, priv->info.pgsize_bitmap & max_pgsize); + + for (pgsz_lg2 = 0; pgsz_lg2 != PT_VADDR_MAX_LG2; pgsz_lg2++) { + pt_oaddr_t paddr = log2_set_mod(priv->test_oa, 0, pgsz_lg2); + u64 len = log2_to_int(pgsz_lg2); + pt_vaddr_t offset; + + if (!(priv->info.pgsize_bitmap & len)) + continue; + if (len > max_pgsize) + break; + + cur_va = ALIGN(range.va + priv->smallest_pgsz * 256, + max_pgsize); + for (offset = 0; offset != max_pgsize; offset += len) + do_map(test, cur_va + offset, paddr + offset, len); + check_iova(test, cur_va, paddr, max_pgsize); + KUNIT_ASSERT_EQ(test, count_valids_single(test, len), + log2_div(max_pgsize, pgsz_lg2)); + + if (len == max_pgsize) { + do_unmap(test, cur_va, max_pgsize); + } else { + do_unmap(test, cur_va, max_pgsize / 2); + for (offset = max_pgsize / 2; offset != max_pgsize; + offset += len) + do_unmap(test, cur_va + offset, len); + } + + KUNIT_ASSERT_EQ(test, count_valids(test), 0); + } +} + +/* + * Test unmapping a small page at the start of a large page. This always unmaps + * the large page. + */ +static void test_unmap_split(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_range top_range = pt_top_range(priv->common); + pt_vaddr_t pgsize_bitmap = priv->safe_pgsize_bitmap; + unsigned int pgsz_lg2; + unsigned int count = 0; + + for (pgsz_lg2 = 0; pgsz_lg2 != PT_VADDR_MAX_LG2; pgsz_lg2++) { + pt_vaddr_t base_len = log2_to_int(pgsz_lg2); + unsigned int next_pgsz_lg2; + + if (!(pgsize_bitmap & base_len)) + continue; + + for (next_pgsz_lg2 = pgsz_lg2 + 1; + next_pgsz_lg2 != PT_VADDR_MAX_LG2; next_pgsz_lg2++) { + pt_vaddr_t next_len = log2_to_int(next_pgsz_lg2); + pt_vaddr_t vaddr = top_range.va; + pt_oaddr_t paddr = 0; + size_t gnmapped; + + if (!(pgsize_bitmap & next_len)) + continue; + + do_map(test, vaddr, paddr, next_len); + gnmapped = iommu_unmap(&priv->domain, vaddr, base_len); + KUNIT_ASSERT_EQ(test, gnmapped, next_len); + + /* Make sure unmap doesn't keep going */ + do_map(test, vaddr, paddr, next_len); + do_map(test, vaddr + next_len, paddr, next_len); + gnmapped = iommu_unmap(&priv->domain, vaddr, base_len); + KUNIT_ASSERT_EQ(test, gnmapped, next_len); + gnmapped = iommu_unmap(&priv->domain, vaddr + next_len, + next_len); + KUNIT_ASSERT_EQ(test, gnmapped, next_len); + + count++; + } + } + + if (count == 0) + kunit_skip(test, "Test needs two page sizes"); +} + +static void unmap_collisions(struct kunit *test, struct maple_tree *mt, + pt_vaddr_t start, pt_vaddr_t last) +{ + struct kunit_iommu_priv *priv = test->priv; + MA_STATE(mas, mt, start, last); + void *entry; + + mtree_lock(mt); + mas_for_each(&mas, entry, last) { + pt_vaddr_t mas_start = mas.index; + pt_vaddr_t len = (mas.last - mas_start) + 1; + pt_oaddr_t paddr; + + mas_erase(&mas); + mas_pause(&mas); + mtree_unlock(mt); + + paddr = oalog2_mod(mas_start, priv->common->max_oasz_lg2); + check_iova(test, mas_start, paddr, len); + do_unmap(test, mas_start, len); + mtree_lock(mt); + } + mtree_unlock(mt); +} + +static void clamp_range(struct kunit *test, struct pt_range *range) +{ + struct kunit_iommu_priv *priv = test->priv; + + if (range->last_va - range->va > SZ_1G) + range->last_va = range->va + SZ_1G; + KUNIT_ASSERT_NE(test, range->last_va, PT_VADDR_MAX); + if (range->va <= MAPLE_RESERVED_RANGE) + range->va = + ALIGN(MAPLE_RESERVED_RANGE, priv->smallest_pgsz); +} + +/* + * Randomly map and unmap ranges that can large physical pages. If a random + * range overlaps with existing ranges then unmap them. This hits all the + * special cases. + */ +static void test_random_map(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_range upper_range = pt_upper_range(priv->common); + struct pt_range top_range = pt_top_range(priv->common); + struct maple_tree mt; + unsigned int iter; + + mt_init(&mt); + + /* + * Shrink the range so randomization is more likely to have + * intersections + */ + clamp_range(test, &top_range); + clamp_range(test, &upper_range); + + for (iter = 0; iter != 1000; iter++) { + struct pt_range *range = &top_range; + pt_oaddr_t paddr; + pt_vaddr_t start; + pt_vaddr_t end; + int ret; + + if (pt_feature(priv->common, PT_FEAT_SIGN_EXTEND) && + ULONG_MAX >= PT_VADDR_MAX && get_random_u32_inclusive(0, 1)) + range = &upper_range; + + start = get_random_u32_below( + min(U32_MAX, range->last_va - range->va)); + end = get_random_u32_below( + min(U32_MAX, range->last_va - start)); + + start = ALIGN_DOWN(start, priv->smallest_pgsz); + end = ALIGN(end, priv->smallest_pgsz); + start += range->va; + end += start; + if (start < range->va || end > range->last_va + 1 || + start >= end) + continue; + + /* Try overmapping to test the failure handling */ + paddr = oalog2_mod(start, priv->common->max_oasz_lg2); + ret = iommu_map(&priv->domain, start, paddr, end - start, + IOMMU_READ | IOMMU_WRITE, GFP_KERNEL); + if (ret) { + KUNIT_ASSERT_EQ(test, ret, -EADDRINUSE); + unmap_collisions(test, &mt, start, end - 1); + do_map(test, start, paddr, end - start); + } + + KUNIT_ASSERT_NO_ERRNO_FN(test, "mtree_insert_range", + mtree_insert_range(&mt, start, end - 1, + XA_ZERO_ENTRY, + GFP_KERNEL)); + + check_iova(test, start, paddr, end - start); + if (iter % 100) + cond_resched(); + } + + unmap_collisions(test, &mt, 0, PT_VADDR_MAX); + KUNIT_ASSERT_EQ(test, count_valids(test), 0); + + mtree_destroy(&mt); +} + +/* See https://lore.kernel.org/r/b9b18a03-63a2-4065-a27e-d92dd5c860bc@amd.com */ +static void test_pgsize_boundary(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_range top_range = pt_top_range(priv->common); + + if (top_range.va != 0 || top_range.last_va < 0xfef9ffff || + priv->smallest_pgsz != SZ_4K) + kunit_skip(test, "Format does not have the required range"); + + do_map(test, 0xfef80000, 0x208b95d000, 0xfef9ffff - 0xfef80000 + 1); +} + +/* See https://lore.kernel.org/r/20250826143816.38686-1-eugkoira@amazon.com */ +static void test_mixed(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + struct pt_range top_range = pt_top_range(priv->common); + u64 start = 0x3fe400ULL << 12; + u64 end = 0x4c0600ULL << 12; + pt_vaddr_t len = end - start; + pt_oaddr_t oa = start; + + if (top_range.last_va <= start || sizeof(unsigned long) == 4) + kunit_skip(test, "range is too small"); + if ((priv->safe_pgsize_bitmap & GENMASK(30, 21)) != (BIT(30) | BIT(21))) + kunit_skip(test, "incompatible psize"); + + do_map(test, start, oa, len); + /* 14 2M, 3 1G, 3 2M */ + KUNIT_ASSERT_EQ(test, count_valids(test), 20); + check_iova(test, start, oa, len); +} + +static struct kunit_case iommu_test_cases[] = { + KUNIT_CASE_FMT(test_increase_level), + KUNIT_CASE_FMT(test_map_simple), + KUNIT_CASE_FMT(test_map_table_to_oa), + KUNIT_CASE_FMT(test_unmap_split), + KUNIT_CASE_FMT(test_random_map), + KUNIT_CASE_FMT(test_pgsize_boundary), + KUNIT_CASE_FMT(test_mixed), + {}, +}; + +static int pt_kunit_iommu_init(struct kunit *test) +{ + struct kunit_iommu_priv *priv; + int ret; + + priv = kunit_kzalloc(test, sizeof(*priv), GFP_KERNEL); + if (!priv) + return -ENOMEM; + + priv->orig_nr_secondary_pagetable = + global_node_page_state(NR_SECONDARY_PAGETABLE); + ret = pt_kunit_priv_init(test, priv); + if (ret) { + kunit_kfree(test, priv); + return ret; + } + test->priv = priv; + return 0; +} + +static void pt_kunit_iommu_exit(struct kunit *test) +{ + struct kunit_iommu_priv *priv = test->priv; + + if (!test->priv) + return; + + pt_iommu_deinit(priv->iommu); + /* + * Look for memory leaks, assumes kunit is running isolated and nothing + * else is using secondary page tables. + */ + KUNIT_ASSERT_EQ(test, priv->orig_nr_secondary_pagetable, + global_node_page_state(NR_SECONDARY_PAGETABLE)); + kunit_kfree(test, test->priv); +} + +static struct kunit_suite NS(iommu_suite) = { + .name = __stringify(NS(iommu_test)), + .init = pt_kunit_iommu_init, + .exit = pt_kunit_iommu_exit, + .test_cases = iommu_test_cases, +}; +kunit_test_suites(&NS(iommu_suite)); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Kunit for generic page table"); +MODULE_IMPORT_NS("GENERIC_PT_IOMMU"); diff --git a/drivers/iommu/generic_pt/pt_common.h b/drivers/iommu/generic_pt/pt_common.h new file mode 100644 index 000000000000..e1123d35c907 --- /dev/null +++ b/drivers/iommu/generic_pt/pt_common.h @@ -0,0 +1,389 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + * This header is included after the format. It contains definitions + * that build on the format definitions to create the basic format API. + * + * The format API is listed here, with kdocs. The functions without bodies are + * implemented in the format using the pattern: + * static inline FMTpt_XXX(..) {..} + * #define pt_XXX FMTpt_XXX + * + * If the format doesn't implement a function then pt_fmt_defaults.h can provide + * a generic version. + * + * The routines marked "@pts: Entry to query" operate on the entire contiguous + * entry and can be called with a pts->index pointing to any sub item that makes + * up that entry. + * + * The header order is: + * pt_defs.h + * FMT.h + * pt_common.h + */ +#ifndef __GENERIC_PT_PT_COMMON_H +#define __GENERIC_PT_PT_COMMON_H + +#include "pt_defs.h" +#include "pt_fmt_defaults.h" + +/** + * pt_attr_from_entry() - Convert the permission bits back to attrs + * @pts: Entry to convert from + * @attrs: Resulting attrs + * + * Fill in the attrs with the permission bits encoded in the current leaf entry. + * The attrs should be usable with pt_install_leaf_entry() to reconstruct the + * same entry. + */ +static inline void pt_attr_from_entry(const struct pt_state *pts, + struct pt_write_attrs *attrs); + +/** + * pt_can_have_leaf() - True if the current level can have an OA entry + * @pts: The current level + * + * True if the current level can support pt_install_leaf_entry(). A leaf + * entry produce an OA. + */ +static inline bool pt_can_have_leaf(const struct pt_state *pts); + +/** + * pt_can_have_table() - True if the current level can have a lower table + * @pts: The current level + * + * Every level except 0 is allowed to have a lower table. + */ +static inline bool pt_can_have_table(const struct pt_state *pts) +{ + /* No further tables at level 0 */ + return pts->level > 0; +} + +/** + * pt_clear_entries() - Make entries empty (non-present) + * @pts: Starting table index + * @num_contig_lg2: Number of contiguous items to clear + * + * Clear a run of entries. A cleared entry will load back as PT_ENTRY_EMPTY + * and does not have any effect on table walking. The starting index must be + * aligned to num_contig_lg2. + */ +static inline void pt_clear_entries(struct pt_state *pts, + unsigned int num_contig_lg2); + +/** + * pt_entry_make_write_dirty() - Make an entry dirty + * @pts: Table entry to change + * + * Make pt_entry_is_write_dirty() return true for this entry. This can be called + * asynchronously with any other table manipulation under a RCU lock and must + * not corrupt the table. + */ +static inline bool pt_entry_make_write_dirty(struct pt_state *pts); + +/** + * pt_entry_make_write_clean() - Make the entry write clean + * @pts: Table entry to change + * + * Modify the entry so that pt_entry_is_write_dirty() == false. The HW will + * eventually be notified of this change via a TLB flush, which is the point + * that the HW must become synchronized. Any "write dirty" prior to the TLB + * flush can be lost, but once the TLB flush completes all writes must make + * their entries write dirty. + * + * The format should alter the entry in a way that is compatible with any + * concurrent update from HW. The entire contiguous entry is changed. + */ +static inline void pt_entry_make_write_clean(struct pt_state *pts); + +/** + * pt_entry_is_write_dirty() - True if the entry has been written to + * @pts: Entry to query + * + * "write dirty" means that the HW has written to the OA translated + * by this entry. If the entry is contiguous then the consolidated + * "write dirty" for all the items must be returned. + */ +static inline bool pt_entry_is_write_dirty(const struct pt_state *pts); + +/** + * pt_dirty_supported() - True if the page table supports dirty tracking + * @common: Page table to query + */ +static inline bool pt_dirty_supported(struct pt_common *common); + +/** + * pt_entry_num_contig_lg2() - Number of contiguous items for this leaf entry + * @pts: Entry to query + * + * Return the number of contiguous items this leaf entry spans. If the entry + * is single item it returns ilog2(1). + */ +static inline unsigned int pt_entry_num_contig_lg2(const struct pt_state *pts); + +/** + * pt_entry_oa() - Output Address for this leaf entry + * @pts: Entry to query + * + * Return the output address for the start of the entry. If the entry + * is contiguous this returns the same value for each sub-item. I.e.:: + * + * log2_mod(pt_entry_oa(), pt_entry_oa_lg2sz()) == 0 + * + * See pt_item_oa(). The format should implement one of these two functions + * depending on how it stores the OAs in the table. + */ +static inline pt_oaddr_t pt_entry_oa(const struct pt_state *pts); + +/** + * pt_entry_oa_lg2sz() - Return the size of an OA entry + * @pts: Entry to query + * + * If the entry is not contiguous this returns pt_table_item_lg2sz(), otherwise + * it returns the total VA/OA size of the entire contiguous entry. + */ +static inline unsigned int pt_entry_oa_lg2sz(const struct pt_state *pts) +{ + return pt_entry_num_contig_lg2(pts) + pt_table_item_lg2sz(pts); +} + +/** + * pt_entry_oa_exact() - Return the complete OA for an entry + * @pts: Entry to query + * + * During iteration the first entry could have a VA with an offset from the + * natural start of the entry. Return the exact OA including the pts's VA + * offset. + */ +static inline pt_oaddr_t pt_entry_oa_exact(const struct pt_state *pts) +{ + return _pt_entry_oa_fast(pts) | + log2_mod(pts->range->va, pt_entry_oa_lg2sz(pts)); +} + +/** + * pt_full_va_prefix() - The top bits of the VA + * @common: Page table to query + * + * This is usually 0, but some formats have their VA space going downward from + * PT_VADDR_MAX, and will return that instead. This value must always be + * adjusted by struct pt_common max_vasz_lg2. + */ +static inline pt_vaddr_t pt_full_va_prefix(const struct pt_common *common); + +/** + * pt_has_system_page_size() - True if level 0 can install a PAGE_SHIFT entry + * @common: Page table to query + * + * If true the caller can use, at level 0, pt_install_leaf_entry(PAGE_SHIFT). + * This is useful to create optimized paths for common cases of PAGE_SIZE + * mappings. + */ +static inline bool pt_has_system_page_size(const struct pt_common *common); + +/** + * pt_install_leaf_entry() - Write a leaf entry to the table + * @pts: Table index to change + * @oa: Output Address for this leaf + * @oasz_lg2: Size in VA/OA for this leaf + * @attrs: Attributes to modify the entry + * + * A leaf OA entry will return PT_ENTRY_OA from pt_load_entry(). It translates + * the VA indicated by pts to the given OA. + * + * For a single item non-contiguous entry oasz_lg2 is pt_table_item_lg2sz(). + * For contiguous it is pt_table_item_lg2sz() + num_contig_lg2. + * + * This must not be called if pt_can_have_leaf() == false. Contiguous sizes + * not indicated by pt_possible_sizes() must not be specified. + */ +static inline void pt_install_leaf_entry(struct pt_state *pts, pt_oaddr_t oa, + unsigned int oasz_lg2, + const struct pt_write_attrs *attrs); + +/** + * pt_install_table() - Write a table entry to the table + * @pts: Table index to change + * @table_pa: CPU physical address of the lower table's memory + * @attrs: Attributes to modify the table index + * + * A table entry will return PT_ENTRY_TABLE from pt_load_entry(). The table_pa + * is the table at pts->level - 1. This is done by cmpxchg so pts must have the + * current entry loaded. The pts is updated with the installed entry. + * + * This must not be called if pt_can_have_table() == false. + * + * Returns: true if the table was installed successfully. + */ +static inline bool pt_install_table(struct pt_state *pts, pt_oaddr_t table_pa, + const struct pt_write_attrs *attrs); + +/** + * pt_item_oa() - Output Address for this leaf item + * @pts: Item to query + * + * Return the output address for this item. If the item is part of a contiguous + * entry it returns the value of the OA for this individual sub item. + * + * See pt_entry_oa(). The format should implement one of these two functions + * depending on how it stores the OA's in the table. + */ +static inline pt_oaddr_t pt_item_oa(const struct pt_state *pts); + +/** + * pt_load_entry_raw() - Read from the location pts points at into the pts + * @pts: Table index to load + * + * Return the type of entry that was loaded. pts->entry will be filled in with + * the entry's content. See pt_load_entry() + */ +static inline enum pt_entry_type pt_load_entry_raw(struct pt_state *pts); + +/** + * pt_max_oa_lg2() - Return the maximum OA the table format can hold + * @common: Page table to query + * + * The value oalog2_to_max_int(pt_max_oa_lg2()) is the MAX for the + * OA. This is the absolute maximum address the table can hold. struct pt_common + * max_oasz_lg2 sets a lower dynamic maximum based on HW capability. + */ +static inline unsigned int +pt_max_oa_lg2(const struct pt_common *common); + +/** + * pt_num_items_lg2() - Return the number of items in this table level + * @pts: The current level + * + * The number of items in a table level defines the number of bits this level + * decodes from the VA. This function is not called for the top level, + * so it does not need to compute a special value for the top case. The + * result for the top is based on pt_common max_vasz_lg2. + * + * The value is used as part of determining the table indexes via the + * equation:: + * + * log2_mod(log2_div(VA, pt_table_item_lg2sz()), pt_num_items_lg2()) + */ +static inline unsigned int pt_num_items_lg2(const struct pt_state *pts); + +/** + * pt_pgsz_lg2_to_level - Return the level that maps the page size + * @common: Page table to query + * @pgsize_lg2: Log2 page size + * + * Returns the table level that will map the given page size. The page + * size must be part of the pt_possible_sizes() for some level. + */ +static inline unsigned int pt_pgsz_lg2_to_level(struct pt_common *common, + unsigned int pgsize_lg2); + +/** + * pt_possible_sizes() - Return a bitmap of possible output sizes at this level + * @pts: The current level + * + * Each level has a list of possible output sizes that can be installed as + * leaf entries. If pt_can_have_leaf() is false returns zero. + * + * Otherwise the bit in position pt_table_item_lg2sz() should be set indicating + * that a non-contiguous single item leaf entry is supported. The following + * pt_num_items_lg2() number of bits can be set indicating contiguous entries + * are supported. Bit pt_table_item_lg2sz() + pt_num_items_lg2() must not be + * set, contiguous entries cannot span the entire table. + * + * The OR of pt_possible_sizes() of all levels is the typical bitmask of all + * supported sizes in the entire table. + */ +static inline pt_vaddr_t pt_possible_sizes(const struct pt_state *pts); + +/** + * pt_table_item_lg2sz() - Size of a single item entry in this table level + * @pts: The current level + * + * The size of the item specifies how much VA and OA a single item occupies. + * + * See pt_entry_oa_lg2sz() for the same value including the effect of contiguous + * entries. + */ +static inline unsigned int pt_table_item_lg2sz(const struct pt_state *pts); + +/** + * pt_table_oa_lg2sz() - Return the VA/OA size of the entire table + * @pts: The current level + * + * Return the size of VA decoded by the entire table level. + */ +static inline unsigned int pt_table_oa_lg2sz(const struct pt_state *pts) +{ + if (pts->range->top_level == pts->level) + return pts->range->max_vasz_lg2; + return min_t(unsigned int, pts->range->common->max_vasz_lg2, + pt_num_items_lg2(pts) + pt_table_item_lg2sz(pts)); +} + +/** + * pt_table_pa() - Return the CPU physical address of the table entry + * @pts: Entry to query + * + * This is only ever called on PT_ENTRY_TABLE entries. Must return the same + * value passed to pt_install_table(). + */ +static inline pt_oaddr_t pt_table_pa(const struct pt_state *pts); + +/** + * pt_table_ptr() - Return a CPU pointer for a table item + * @pts: Entry to query + * + * Same as pt_table_pa() but returns a CPU pointer. + */ +static inline struct pt_table_p *pt_table_ptr(const struct pt_state *pts) +{ + return __va(pt_table_pa(pts)); +} + +/** + * pt_max_sw_bit() - Return the maximum software bit usable for any level and + * entry + * @common: Page table + * + * The swbit can be passed as bitnr to the other sw_bit functions. + */ +static inline unsigned int pt_max_sw_bit(struct pt_common *common); + +/** + * pt_test_sw_bit_acquire() - Read a software bit in an item + * @pts: Entry to read + * @bitnr: Bit to read + * + * Software bits are ignored by HW and can be used for any purpose by the + * software. This does a test bit and acquire operation. + */ +static inline bool pt_test_sw_bit_acquire(struct pt_state *pts, + unsigned int bitnr); + +/** + * pt_set_sw_bit_release() - Set a software bit in an item + * @pts: Entry to set + * @bitnr: Bit to set + * + * Software bits are ignored by HW and can be used for any purpose by the + * software. This does a set bit and release operation. + */ +static inline void pt_set_sw_bit_release(struct pt_state *pts, + unsigned int bitnr); + +/** + * pt_load_entry() - Read from the location pts points at into the pts + * @pts: Table index to load + * + * Set the type of entry that was loaded. pts->entry and pts->table_lower + * will be filled in with the entry's content. + */ +static inline void pt_load_entry(struct pt_state *pts) +{ + pts->type = pt_load_entry_raw(pts); + if (pts->type == PT_ENTRY_TABLE) + pts->table_lower = pt_table_ptr(pts); +} +#endif diff --git a/drivers/iommu/generic_pt/pt_defs.h b/drivers/iommu/generic_pt/pt_defs.h new file mode 100644 index 000000000000..c25544d72f97 --- /dev/null +++ b/drivers/iommu/generic_pt/pt_defs.h @@ -0,0 +1,332 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + * This header is included before the format. It contains definitions + * that are required to compile the format. The header order is: + * pt_defs.h + * fmt_XX.h + * pt_common.h + */ +#ifndef __GENERIC_PT_DEFS_H +#define __GENERIC_PT_DEFS_H + +#include <linux/generic_pt/common.h> + +#include <linux/types.h> +#include <linux/atomic.h> +#include <linux/bits.h> +#include <linux/limits.h> +#include <linux/bug.h> +#include <linux/kconfig.h> +#include "pt_log2.h" + +/* Header self-compile default defines */ +#ifndef pt_write_attrs +typedef u64 pt_vaddr_t; +typedef u64 pt_oaddr_t; +#endif + +struct pt_table_p; + +enum { + PT_VADDR_MAX = sizeof(pt_vaddr_t) == 8 ? U64_MAX : U32_MAX, + PT_VADDR_MAX_LG2 = sizeof(pt_vaddr_t) == 8 ? 64 : 32, + PT_OADDR_MAX = sizeof(pt_oaddr_t) == 8 ? U64_MAX : U32_MAX, + PT_OADDR_MAX_LG2 = sizeof(pt_oaddr_t) == 8 ? 64 : 32, +}; + +/* + * The format instantiation can have features wired off or on to optimize the + * code gen. Supported features are just a reflection of what the current set of + * kernel users want to use. + */ +#ifndef PT_SUPPORTED_FEATURES +#define PT_SUPPORTED_FEATURES 0 +#endif + +/* + * When in debug mode we compile all formats with all features. This allows the + * kunit to test the full matrix. SIGN_EXTEND can't co-exist with DYNAMIC_TOP or + * FULL_VA. DMA_INCOHERENT requires a SW bit that not all formats have + */ +#if IS_ENABLED(CONFIG_DEBUG_GENERIC_PT) +enum { + PT_ORIG_SUPPORTED_FEATURES = PT_SUPPORTED_FEATURES, + PT_DEBUG_SUPPORTED_FEATURES = + UINT_MAX & + ~((PT_ORIG_SUPPORTED_FEATURES & BIT(PT_FEAT_DMA_INCOHERENT) ? + 0 : + BIT(PT_FEAT_DMA_INCOHERENT))) & + ~((PT_ORIG_SUPPORTED_FEATURES & BIT(PT_FEAT_SIGN_EXTEND)) ? + BIT(PT_FEAT_DYNAMIC_TOP) | BIT(PT_FEAT_FULL_VA) : + BIT(PT_FEAT_SIGN_EXTEND)), +}; +#undef PT_SUPPORTED_FEATURES +#define PT_SUPPORTED_FEATURES PT_DEBUG_SUPPORTED_FEATURES +#endif + +#ifndef PT_FORCE_ENABLED_FEATURES +#define PT_FORCE_ENABLED_FEATURES 0 +#endif + +/** + * DOC: Generic Page Table Language + * + * Language used in Generic Page Table + * VA + * The input address to the page table, often the virtual address. + * OA + * The output address from the page table, often the physical address. + * leaf + * An entry that results in an output address. + * start/end + * An half-open range, e.g. [0,0) refers to no VA. + * start/last + * An inclusive closed range, e.g. [0,0] refers to the VA 0 + * common + * The generic page table container struct pt_common + * level + * Level 0 is always a table of only leaves with no futher table pointers. + * Increasing levels increase the size of the table items. The least + * significant VA bits used to index page tables are used to index the Level + * 0 table. The various labels for table levels used by HW descriptions are + * not used. + * top_level + * The inclusive highest level of the table. A two-level table + * has a top level of 1. + * table + * A linear array of translation items for that level. + * index + * The position in a table of an element: item = table[index] + * item + * A single index in a table + * entry + * A single logical element in a table. If contiguous pages are not + * supported then item and entry are the same thing, otherwise entry refers + * to all the items that comprise a single contiguous translation. + * item/entry_size + * The number of bytes of VA the table index translates for. + * If the item is a table entry then the next table covers + * this size. If the entry translates to an output address then the + * full OA is: OA | (VA % entry_size) + * contig_count + * The number of consecutive items fused into a single entry. + * item_size * contig_count is the size of that entry's translation. + * lg2 + * Indicates the value is encoded as log2, i.e. 1<<x is the actual value. + * Normally the compiler is fine to optimize divide and mod with log2 values + * automatically when inlining, however if the values are not constant + * expressions it can't. So we do it by hand; we want to avoid 64-bit + * divmod. + */ + +/* Returned by pt_load_entry() and for_each_pt_level_entry() */ +enum pt_entry_type { + PT_ENTRY_EMPTY, + /* Entry is valid and points to a lower table level */ + PT_ENTRY_TABLE, + /* Entry is valid and returns an output address */ + PT_ENTRY_OA, +}; + +struct pt_range { + struct pt_common *common; + struct pt_table_p *top_table; + pt_vaddr_t va; + pt_vaddr_t last_va; + u8 top_level; + u8 max_vasz_lg2; +}; + +/* + * Similar to xa_state, this records information about an in-progress parse at a + * single level. + */ +struct pt_state { + struct pt_range *range; + struct pt_table_p *table; + struct pt_table_p *table_lower; + u64 entry; + enum pt_entry_type type; + unsigned short index; + unsigned short end_index; + u8 level; +}; + +#define pt_cur_table(pts, type) ((type *)((pts)->table)) + +/* + * Try to install a new table pointer. The locking methodology requires this to + * be atomic (multiple threads can race to install a pointer). The losing + * threads will fail the atomic and return false. They should free any memory + * and reparse the table level again. + */ +#if !IS_ENABLED(CONFIG_GENERIC_ATOMIC64) +static inline bool pt_table_install64(struct pt_state *pts, u64 table_entry) +{ + u64 *entryp = pt_cur_table(pts, u64) + pts->index; + u64 old_entry = pts->entry; + bool ret; + + /* + * Ensure the zero'd table content itself is visible before its PTE can + * be. release is a NOP on !SMP, but the HW is still doing an acquire. + */ + if (!IS_ENABLED(CONFIG_SMP)) + dma_wmb(); + ret = try_cmpxchg64_release(entryp, &old_entry, table_entry); + if (ret) + pts->entry = table_entry; + return ret; +} +#endif + +static inline bool pt_table_install32(struct pt_state *pts, u32 table_entry) +{ + u32 *entryp = pt_cur_table(pts, u32) + pts->index; + u32 old_entry = pts->entry; + bool ret; + + /* + * Ensure the zero'd table content itself is visible before its PTE can + * be. release is a NOP on !SMP, but the HW is still doing an acquire. + */ + if (!IS_ENABLED(CONFIG_SMP)) + dma_wmb(); + ret = try_cmpxchg_release(entryp, &old_entry, table_entry); + if (ret) + pts->entry = table_entry; + return ret; +} + +#define PT_SUPPORTED_FEATURE(feature_nr) (PT_SUPPORTED_FEATURES & BIT(feature_nr)) + +static inline bool pt_feature(const struct pt_common *common, + unsigned int feature_nr) +{ + if (PT_FORCE_ENABLED_FEATURES & BIT(feature_nr)) + return true; + if (!PT_SUPPORTED_FEATURE(feature_nr)) + return false; + return common->features & BIT(feature_nr); +} + +static inline bool pts_feature(const struct pt_state *pts, + unsigned int feature_nr) +{ + return pt_feature(pts->range->common, feature_nr); +} + +/* + * PT_WARN_ON is used for invariants that the kunit should be checking can't + * happen. + */ +#if IS_ENABLED(CONFIG_DEBUG_GENERIC_PT) +#define PT_WARN_ON WARN_ON +#else +static inline bool PT_WARN_ON(bool condition) +{ + return false; +} +#endif + +/* These all work on the VA type */ +#define log2_to_int(a_lg2) log2_to_int_t(pt_vaddr_t, a_lg2) +#define log2_to_max_int(a_lg2) log2_to_max_int_t(pt_vaddr_t, a_lg2) +#define log2_div(a, b_lg2) log2_div_t(pt_vaddr_t, a, b_lg2) +#define log2_div_eq(a, b, c_lg2) log2_div_eq_t(pt_vaddr_t, a, b, c_lg2) +#define log2_mod(a, b_lg2) log2_mod_t(pt_vaddr_t, a, b_lg2) +#define log2_mod_eq_max(a, b_lg2) log2_mod_eq_max_t(pt_vaddr_t, a, b_lg2) +#define log2_set_mod(a, val, b_lg2) log2_set_mod_t(pt_vaddr_t, a, val, b_lg2) +#define log2_set_mod_max(a, b_lg2) log2_set_mod_max_t(pt_vaddr_t, a, b_lg2) +#define log2_mul(a, b_lg2) log2_mul_t(pt_vaddr_t, a, b_lg2) +#define vaffs(a) ffs_t(pt_vaddr_t, a) +#define vafls(a) fls_t(pt_vaddr_t, a) +#define vaffz(a) ffz_t(pt_vaddr_t, a) + +/* + * The full VA (fva) versions permit the lg2 value to be == PT_VADDR_MAX_LG2 and + * generate a useful defined result. The non-fva versions will malfunction at + * this extreme. + */ +static inline pt_vaddr_t fvalog2_div(pt_vaddr_t a, unsigned int b_lg2) +{ + if (PT_SUPPORTED_FEATURE(PT_FEAT_FULL_VA) && b_lg2 == PT_VADDR_MAX_LG2) + return 0; + return log2_div_t(pt_vaddr_t, a, b_lg2); +} + +static inline pt_vaddr_t fvalog2_mod(pt_vaddr_t a, unsigned int b_lg2) +{ + if (PT_SUPPORTED_FEATURE(PT_FEAT_FULL_VA) && b_lg2 == PT_VADDR_MAX_LG2) + return a; + return log2_mod_t(pt_vaddr_t, a, b_lg2); +} + +static inline bool fvalog2_div_eq(pt_vaddr_t a, pt_vaddr_t b, + unsigned int c_lg2) +{ + if (PT_SUPPORTED_FEATURE(PT_FEAT_FULL_VA) && c_lg2 == PT_VADDR_MAX_LG2) + return true; + return log2_div_eq_t(pt_vaddr_t, a, b, c_lg2); +} + +static inline pt_vaddr_t fvalog2_set_mod(pt_vaddr_t a, pt_vaddr_t val, + unsigned int b_lg2) +{ + if (PT_SUPPORTED_FEATURE(PT_FEAT_FULL_VA) && b_lg2 == PT_VADDR_MAX_LG2) + return val; + return log2_set_mod_t(pt_vaddr_t, a, val, b_lg2); +} + +static inline pt_vaddr_t fvalog2_set_mod_max(pt_vaddr_t a, unsigned int b_lg2) +{ + if (PT_SUPPORTED_FEATURE(PT_FEAT_FULL_VA) && b_lg2 == PT_VADDR_MAX_LG2) + return PT_VADDR_MAX; + return log2_set_mod_max_t(pt_vaddr_t, a, b_lg2); +} + +/* These all work on the OA type */ +#define oalog2_to_int(a_lg2) log2_to_int_t(pt_oaddr_t, a_lg2) +#define oalog2_to_max_int(a_lg2) log2_to_max_int_t(pt_oaddr_t, a_lg2) +#define oalog2_div(a, b_lg2) log2_div_t(pt_oaddr_t, a, b_lg2) +#define oalog2_div_eq(a, b, c_lg2) log2_div_eq_t(pt_oaddr_t, a, b, c_lg2) +#define oalog2_mod(a, b_lg2) log2_mod_t(pt_oaddr_t, a, b_lg2) +#define oalog2_mod_eq_max(a, b_lg2) log2_mod_eq_max_t(pt_oaddr_t, a, b_lg2) +#define oalog2_set_mod(a, val, b_lg2) log2_set_mod_t(pt_oaddr_t, a, val, b_lg2) +#define oalog2_set_mod_max(a, b_lg2) log2_set_mod_max_t(pt_oaddr_t, a, b_lg2) +#define oalog2_mul(a, b_lg2) log2_mul_t(pt_oaddr_t, a, b_lg2) +#define oaffs(a) ffs_t(pt_oaddr_t, a) +#define oafls(a) fls_t(pt_oaddr_t, a) +#define oaffz(a) ffz_t(pt_oaddr_t, a) + +static inline uintptr_t _pt_top_set(struct pt_table_p *table_mem, + unsigned int top_level) +{ + return top_level | (uintptr_t)table_mem; +} + +static inline void pt_top_set(struct pt_common *common, + struct pt_table_p *table_mem, + unsigned int top_level) +{ + WRITE_ONCE(common->top_of_table, _pt_top_set(table_mem, top_level)); +} + +static inline void pt_top_set_level(struct pt_common *common, + unsigned int top_level) +{ + pt_top_set(common, NULL, top_level); +} + +static inline unsigned int pt_top_get_level(const struct pt_common *common) +{ + return READ_ONCE(common->top_of_table) % (1 << PT_TOP_LEVEL_BITS); +} + +static inline bool pt_check_install_leaf_args(struct pt_state *pts, + pt_oaddr_t oa, + unsigned int oasz_lg2); + +#endif diff --git a/drivers/iommu/generic_pt/pt_fmt_defaults.h b/drivers/iommu/generic_pt/pt_fmt_defaults.h new file mode 100644 index 000000000000..69fb7c2314ca --- /dev/null +++ b/drivers/iommu/generic_pt/pt_fmt_defaults.h @@ -0,0 +1,295 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + * Default definitions for formats that don't define these functions. + */ +#ifndef __GENERIC_PT_PT_FMT_DEFAULTS_H +#define __GENERIC_PT_PT_FMT_DEFAULTS_H + +#include "pt_defs.h" +#include <linux/log2.h> + +/* Header self-compile default defines */ +#ifndef pt_load_entry_raw +#include "fmt/amdv1.h" +#endif + +/* + * The format must provide PT_GRANULE_LG2SZ, PT_TABLEMEM_LG2SZ, and + * PT_ITEM_WORD_SIZE. They must be the same at every level excluding the top. + */ +#ifndef pt_table_item_lg2sz +static inline unsigned int pt_table_item_lg2sz(const struct pt_state *pts) +{ + return PT_GRANULE_LG2SZ + + (PT_TABLEMEM_LG2SZ - ilog2(PT_ITEM_WORD_SIZE)) * pts->level; +} +#endif + +#ifndef pt_pgsz_lg2_to_level +static inline unsigned int pt_pgsz_lg2_to_level(struct pt_common *common, + unsigned int pgsize_lg2) +{ + return ((unsigned int)(pgsize_lg2 - PT_GRANULE_LG2SZ)) / + (PT_TABLEMEM_LG2SZ - ilog2(PT_ITEM_WORD_SIZE)); +} +#endif + +/* + * If not supplied by the format then contiguous pages are not supported. + * + * If contiguous pages are supported then the format must also provide + * pt_contig_count_lg2() if it supports a single contiguous size per level, + * or pt_possible_sizes() if it supports multiple sizes per level. + */ +#ifndef pt_entry_num_contig_lg2 +static inline unsigned int pt_entry_num_contig_lg2(const struct pt_state *pts) +{ + return ilog2(1); +} + +/* + * Return the number of contiguous OA items forming an entry at this table level + */ +static inline unsigned short pt_contig_count_lg2(const struct pt_state *pts) +{ + return ilog2(1); +} +#endif + +/* If not supplied by the format then dirty tracking is not supported */ +#ifndef pt_entry_is_write_dirty +static inline bool pt_entry_is_write_dirty(const struct pt_state *pts) +{ + return false; +} + +static inline void pt_entry_make_write_clean(struct pt_state *pts) +{ +} + +static inline bool pt_dirty_supported(struct pt_common *common) +{ + return false; +} +#else +/* If not supplied then dirty tracking is always enabled */ +#ifndef pt_dirty_supported +static inline bool pt_dirty_supported(struct pt_common *common) +{ + return true; +} +#endif +#endif + +#ifndef pt_entry_make_write_dirty +static inline bool pt_entry_make_write_dirty(struct pt_state *pts) +{ + return false; +} +#endif + +/* + * Format supplies either: + * pt_entry_oa - OA is at the start of a contiguous entry + * or + * pt_item_oa - OA is adjusted for every item in a contiguous entry + * + * Build the missing one + * + * The internal helper _pt_entry_oa_fast() allows generating + * an efficient pt_entry_oa_exact(), it doesn't care which + * option is selected. + */ +#ifdef pt_entry_oa +static inline pt_oaddr_t pt_item_oa(const struct pt_state *pts) +{ + return pt_entry_oa(pts) | + log2_mul(pts->index, pt_table_item_lg2sz(pts)); +} +#define _pt_entry_oa_fast pt_entry_oa +#endif + +#ifdef pt_item_oa +static inline pt_oaddr_t pt_entry_oa(const struct pt_state *pts) +{ + return log2_set_mod(pt_item_oa(pts), 0, + pt_entry_num_contig_lg2(pts) + + pt_table_item_lg2sz(pts)); +} +#define _pt_entry_oa_fast pt_item_oa +#endif + +/* + * If not supplied by the format then use the constant + * PT_MAX_OUTPUT_ADDRESS_LG2. + */ +#ifndef pt_max_oa_lg2 +static inline unsigned int +pt_max_oa_lg2(const struct pt_common *common) +{ + return PT_MAX_OUTPUT_ADDRESS_LG2; +} +#endif + +#ifndef pt_has_system_page_size +static inline bool pt_has_system_page_size(const struct pt_common *common) +{ + return PT_GRANULE_LG2SZ == PAGE_SHIFT; +} +#endif + +/* + * If not supplied by the format then assume only one contiguous size determined + * by pt_contig_count_lg2() + */ +#ifndef pt_possible_sizes +static inline unsigned short pt_contig_count_lg2(const struct pt_state *pts); + +/* Return a bitmap of possible leaf page sizes at this level */ +static inline pt_vaddr_t pt_possible_sizes(const struct pt_state *pts) +{ + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + + if (!pt_can_have_leaf(pts)) + return 0; + return log2_to_int(isz_lg2) | + log2_to_int(pt_contig_count_lg2(pts) + isz_lg2); +} +#endif + +/* If not supplied by the format then use 0. */ +#ifndef pt_full_va_prefix +static inline pt_vaddr_t pt_full_va_prefix(const struct pt_common *common) +{ + return 0; +} +#endif + +/* If not supplied by the format then zero fill using PT_ITEM_WORD_SIZE */ +#ifndef pt_clear_entries +static inline void pt_clear_entries64(struct pt_state *pts, + unsigned int num_contig_lg2) +{ + u64 *tablep = pt_cur_table(pts, u64) + pts->index; + u64 *end = tablep + log2_to_int(num_contig_lg2); + + PT_WARN_ON(log2_mod(pts->index, num_contig_lg2)); + for (; tablep != end; tablep++) + WRITE_ONCE(*tablep, 0); +} + +static inline void pt_clear_entries32(struct pt_state *pts, + unsigned int num_contig_lg2) +{ + u32 *tablep = pt_cur_table(pts, u32) + pts->index; + u32 *end = tablep + log2_to_int(num_contig_lg2); + + PT_WARN_ON(log2_mod(pts->index, num_contig_lg2)); + for (; tablep != end; tablep++) + WRITE_ONCE(*tablep, 0); +} + +static inline void pt_clear_entries(struct pt_state *pts, + unsigned int num_contig_lg2) +{ + if (PT_ITEM_WORD_SIZE == sizeof(u32)) + pt_clear_entries32(pts, num_contig_lg2); + else + pt_clear_entries64(pts, num_contig_lg2); +} +#define pt_clear_entries pt_clear_entries +#endif + +/* If not supplied then SW bits are not supported */ +#ifdef pt_sw_bit +static inline bool pt_test_sw_bit_acquire(struct pt_state *pts, + unsigned int bitnr) +{ + /* Acquire, pairs with pt_set_sw_bit_release() */ + smp_mb(); + /* For a contiguous entry the sw bit is only stored in the first item. */ + return pts->entry & pt_sw_bit(bitnr); +} +#define pt_test_sw_bit_acquire pt_test_sw_bit_acquire + +static inline void pt_set_sw_bit_release(struct pt_state *pts, + unsigned int bitnr) +{ +#if !IS_ENABLED(CONFIG_GENERIC_ATOMIC64) + if (PT_ITEM_WORD_SIZE == sizeof(u64)) { + u64 *entryp = pt_cur_table(pts, u64) + pts->index; + u64 old_entry = pts->entry; + u64 new_entry; + + do { + new_entry = old_entry | pt_sw_bit(bitnr); + } while (!try_cmpxchg64_release(entryp, &old_entry, new_entry)); + pts->entry = new_entry; + return; + } +#endif + if (PT_ITEM_WORD_SIZE == sizeof(u32)) { + u32 *entryp = pt_cur_table(pts, u32) + pts->index; + u32 old_entry = pts->entry; + u32 new_entry; + + do { + new_entry = old_entry | pt_sw_bit(bitnr); + } while (!try_cmpxchg_release(entryp, &old_entry, new_entry)); + pts->entry = new_entry; + } else + BUILD_BUG(); +} +#define pt_set_sw_bit_release pt_set_sw_bit_release +#else +static inline unsigned int pt_max_sw_bit(struct pt_common *common) +{ + return 0; +} + +extern void __pt_no_sw_bit(void); +static inline bool pt_test_sw_bit_acquire(struct pt_state *pts, + unsigned int bitnr) +{ + __pt_no_sw_bit(); + return false; +} + +static inline void pt_set_sw_bit_release(struct pt_state *pts, + unsigned int bitnr) +{ + __pt_no_sw_bit(); +} +#endif + +/* + * Format can call in the pt_install_leaf_entry() to check the arguments are all + * aligned correctly. + */ +static inline bool pt_check_install_leaf_args(struct pt_state *pts, + pt_oaddr_t oa, + unsigned int oasz_lg2) +{ + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + + if (PT_WARN_ON(oalog2_mod(oa, oasz_lg2))) + return false; + +#ifdef pt_possible_sizes + if (PT_WARN_ON(isz_lg2 > oasz_lg2 || + oasz_lg2 > isz_lg2 + pt_num_items_lg2(pts))) + return false; +#else + if (PT_WARN_ON(oasz_lg2 != isz_lg2 && + oasz_lg2 != isz_lg2 + pt_contig_count_lg2(pts))) + return false; +#endif + + if (PT_WARN_ON(oalog2_mod(pts->index, oasz_lg2 - isz_lg2))) + return false; + return true; +} + +#endif diff --git a/drivers/iommu/generic_pt/pt_iter.h b/drivers/iommu/generic_pt/pt_iter.h new file mode 100644 index 000000000000..c0d8617cce29 --- /dev/null +++ b/drivers/iommu/generic_pt/pt_iter.h @@ -0,0 +1,636 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + * Iterators for Generic Page Table + */ +#ifndef __GENERIC_PT_PT_ITER_H +#define __GENERIC_PT_PT_ITER_H + +#include "pt_common.h" + +#include <linux/errno.h> + +/* + * Use to mangle symbols so that backtraces and the symbol table are + * understandable. Any non-inlined function should get mangled like this. + */ +#define NS(fn) CONCATENATE(PTPFX, fn) + +/** + * pt_check_range() - Validate the range can be iterated + * @range: Range to validate + * + * Check that VA and last_va fall within the permitted range of VAs. If the + * format is using PT_FEAT_SIGN_EXTEND then this also checks the sign extension + * is correct. + */ +static inline int pt_check_range(struct pt_range *range) +{ + pt_vaddr_t prefix; + + PT_WARN_ON(!range->max_vasz_lg2); + + if (pt_feature(range->common, PT_FEAT_SIGN_EXTEND)) { + PT_WARN_ON(range->common->max_vasz_lg2 != range->max_vasz_lg2); + prefix = fvalog2_div(range->va, range->max_vasz_lg2 - 1) ? + PT_VADDR_MAX : + 0; + } else { + prefix = pt_full_va_prefix(range->common); + } + + if (!fvalog2_div_eq(range->va, prefix, range->max_vasz_lg2) || + !fvalog2_div_eq(range->last_va, prefix, range->max_vasz_lg2)) + return -ERANGE; + return 0; +} + +/** + * pt_index_to_va() - Update range->va to the current pts->index + * @pts: Iteration State + * + * Adjust range->va to match the current index. This is done in a lazy manner + * since computing the VA takes several instructions and is rarely required. + */ +static inline void pt_index_to_va(struct pt_state *pts) +{ + pt_vaddr_t lower_va; + + lower_va = log2_mul(pts->index, pt_table_item_lg2sz(pts)); + pts->range->va = fvalog2_set_mod(pts->range->va, lower_va, + pt_table_oa_lg2sz(pts)); +} + +/* + * Add index_count_lg2 number of entries to pts's VA and index. The VA will be + * adjusted to the end of the contiguous block if it is currently in the middle. + */ +static inline void _pt_advance(struct pt_state *pts, + unsigned int index_count_lg2) +{ + pts->index = log2_set_mod(pts->index + log2_to_int(index_count_lg2), 0, + index_count_lg2); +} + +/** + * pt_entry_fully_covered() - Check if the item or entry is entirely contained + * within pts->range + * @pts: Iteration State + * @oasz_lg2: The size of the item to check, pt_table_item_lg2sz() or + * pt_entry_oa_lg2sz() + * + * Returns: true if the item is fully enclosed by the pts->range. + */ +static inline bool pt_entry_fully_covered(const struct pt_state *pts, + unsigned int oasz_lg2) +{ + struct pt_range *range = pts->range; + + /* Range begins at the start of the entry */ + if (log2_mod(pts->range->va, oasz_lg2)) + return false; + + /* Range ends past the end of the entry */ + if (!log2_div_eq(range->va, range->last_va, oasz_lg2)) + return true; + + /* Range ends at the end of the entry */ + return log2_mod_eq_max(range->last_va, oasz_lg2); +} + +/** + * pt_range_to_index() - Starting index for an iteration + * @pts: Iteration State + * + * Return: the starting index for the iteration in pts. + */ +static inline unsigned int pt_range_to_index(const struct pt_state *pts) +{ + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + + PT_WARN_ON(pts->level > pts->range->top_level); + if (pts->range->top_level == pts->level) + return log2_div(fvalog2_mod(pts->range->va, + pts->range->max_vasz_lg2), + isz_lg2); + return log2_mod(log2_div(pts->range->va, isz_lg2), + pt_num_items_lg2(pts)); +} + +/** + * pt_range_to_end_index() - Ending index iteration + * @pts: Iteration State + * + * Return: the last index for the iteration in pts. + */ +static inline unsigned int pt_range_to_end_index(const struct pt_state *pts) +{ + unsigned int isz_lg2 = pt_table_item_lg2sz(pts); + struct pt_range *range = pts->range; + unsigned int num_entries_lg2; + + if (range->va == range->last_va) + return pts->index + 1; + + if (pts->range->top_level == pts->level) + return log2_div(fvalog2_mod(pts->range->last_va, + pts->range->max_vasz_lg2), + isz_lg2) + + 1; + + num_entries_lg2 = pt_num_items_lg2(pts); + + /* last_va falls within this table */ + if (log2_div_eq(range->va, range->last_va, num_entries_lg2 + isz_lg2)) + return log2_mod(log2_div(pts->range->last_va, isz_lg2), + num_entries_lg2) + + 1; + + return log2_to_int(num_entries_lg2); +} + +static inline void _pt_iter_first(struct pt_state *pts) +{ + pts->index = pt_range_to_index(pts); + pts->end_index = pt_range_to_end_index(pts); + PT_WARN_ON(pts->index > pts->end_index); +} + +static inline bool _pt_iter_load(struct pt_state *pts) +{ + if (pts->index >= pts->end_index) + return false; + pt_load_entry(pts); + return true; +} + +/** + * pt_next_entry() - Advance pts to the next entry + * @pts: Iteration State + * + * Update pts to go to the next index at this level. If pts is pointing at a + * contiguous entry then the index may advance my more than one. + */ +static inline void pt_next_entry(struct pt_state *pts) +{ + if (pts->type == PT_ENTRY_OA && + !__builtin_constant_p(pt_entry_num_contig_lg2(pts) == 0)) + _pt_advance(pts, pt_entry_num_contig_lg2(pts)); + else + pts->index++; + pt_index_to_va(pts); +} + +/** + * for_each_pt_level_entry() - For loop wrapper over entries in the range + * @pts: Iteration State + * + * This is the basic iteration primitive. It iterates over all the entries in + * pts->range that fall within the pts's current table level. Each step does + * pt_load_entry(pts). + */ +#define for_each_pt_level_entry(pts) \ + for (_pt_iter_first(pts); _pt_iter_load(pts); pt_next_entry(pts)) + +/** + * pt_load_single_entry() - Version of pt_load_entry() usable within a walker + * @pts: Iteration State + * + * Alternative to for_each_pt_level_entry() if the walker function uses only a + * single entry. + */ +static inline enum pt_entry_type pt_load_single_entry(struct pt_state *pts) +{ + pts->index = pt_range_to_index(pts); + pt_load_entry(pts); + return pts->type; +} + +static __always_inline struct pt_range _pt_top_range(struct pt_common *common, + uintptr_t top_of_table) +{ + struct pt_range range = { + .common = common, + .top_table = + (struct pt_table_p *)(top_of_table & + ~(uintptr_t)PT_TOP_LEVEL_MASK), + .top_level = top_of_table % (1 << PT_TOP_LEVEL_BITS), + }; + struct pt_state pts = { .range = &range, .level = range.top_level }; + unsigned int max_vasz_lg2; + + max_vasz_lg2 = common->max_vasz_lg2; + if (pt_feature(common, PT_FEAT_DYNAMIC_TOP) && + pts.level != PT_MAX_TOP_LEVEL) + max_vasz_lg2 = min_t(unsigned int, common->max_vasz_lg2, + pt_num_items_lg2(&pts) + + pt_table_item_lg2sz(&pts)); + + /* + * The top range will default to the lower region only with sign extend. + */ + range.max_vasz_lg2 = max_vasz_lg2; + if (pt_feature(common, PT_FEAT_SIGN_EXTEND)) + max_vasz_lg2--; + + range.va = fvalog2_set_mod(pt_full_va_prefix(common), 0, max_vasz_lg2); + range.last_va = + fvalog2_set_mod_max(pt_full_va_prefix(common), max_vasz_lg2); + return range; +} + +/** + * pt_top_range() - Return a range that spans part of the top level + * @common: Table + * + * For PT_FEAT_SIGN_EXTEND this will return the lower range, and cover half the + * total page table. Otherwise it returns the entire page table. + */ +static __always_inline struct pt_range pt_top_range(struct pt_common *common) +{ + /* + * The top pointer can change without locking. We capture the value and + * it's level here and are safe to walk it so long as both values are + * captured without tearing. + */ + return _pt_top_range(common, READ_ONCE(common->top_of_table)); +} + +/** + * pt_all_range() - Return a range that spans the entire page table + * @common: Table + * + * The returned range spans the whole page table. Due to how PT_FEAT_SIGN_EXTEND + * is supported range->va and range->last_va will be incorrect during the + * iteration and must not be accessed. + */ +static inline struct pt_range pt_all_range(struct pt_common *common) +{ + struct pt_range range = pt_top_range(common); + + if (!pt_feature(common, PT_FEAT_SIGN_EXTEND)) + return range; + + /* + * Pretend the table is linear from 0 without a sign extension. This + * generates the correct indexes for iteration. + */ + range.last_va = fvalog2_set_mod_max(0, range.max_vasz_lg2); + return range; +} + +/** + * pt_upper_range() - Return a range that spans part of the top level + * @common: Table + * + * For PT_FEAT_SIGN_EXTEND this will return the upper range, and cover half the + * total page table. Otherwise it returns the entire page table. + */ +static inline struct pt_range pt_upper_range(struct pt_common *common) +{ + struct pt_range range = pt_top_range(common); + + if (!pt_feature(common, PT_FEAT_SIGN_EXTEND)) + return range; + + range.va = fvalog2_set_mod(PT_VADDR_MAX, 0, range.max_vasz_lg2 - 1); + range.last_va = PT_VADDR_MAX; + return range; +} + +/** + * pt_make_range() - Return a range that spans part of the table + * @common: Table + * @va: Start address + * @last_va: Last address + * + * The caller must validate the range with pt_check_range() before using it. + */ +static __always_inline struct pt_range +pt_make_range(struct pt_common *common, pt_vaddr_t va, pt_vaddr_t last_va) +{ + struct pt_range range = + _pt_top_range(common, READ_ONCE(common->top_of_table)); + + range.va = va; + range.last_va = last_va; + + return range; +} + +/* + * Span a slice of the table starting at a lower table level from an active + * walk. + */ +static __always_inline struct pt_range +pt_make_child_range(const struct pt_range *parent, pt_vaddr_t va, + pt_vaddr_t last_va) +{ + struct pt_range range = *parent; + + range.va = va; + range.last_va = last_va; + + PT_WARN_ON(last_va < va); + PT_WARN_ON(pt_check_range(&range)); + + return range; +} + +/** + * pt_init() - Initialize a pt_state on the stack + * @range: Range pointer to embed in the state + * @level: Table level for the state + * @table: Pointer to the table memory at level + * + * Helper to initialize the on-stack pt_state from walker arguments. + */ +static __always_inline struct pt_state +pt_init(struct pt_range *range, unsigned int level, struct pt_table_p *table) +{ + struct pt_state pts = { + .range = range, + .table = table, + .level = level, + }; + return pts; +} + +/** + * pt_init_top() - Initialize a pt_state on the stack + * @range: Range pointer to embed in the state + * + * The pt_state points to the top most level. + */ +static __always_inline struct pt_state pt_init_top(struct pt_range *range) +{ + return pt_init(range, range->top_level, range->top_table); +} + +typedef int (*pt_level_fn_t)(struct pt_range *range, void *arg, + unsigned int level, struct pt_table_p *table); + +/** + * pt_descend() - Recursively invoke the walker for the lower level + * @pts: Iteration State + * @arg: Value to pass to the function + * @fn: Walker function to call + * + * pts must point to a table item. Invoke fn as a walker on the table + * pts points to. + */ +static __always_inline int pt_descend(struct pt_state *pts, void *arg, + pt_level_fn_t fn) +{ + int ret; + + if (PT_WARN_ON(!pts->table_lower)) + return -EINVAL; + + ret = (*fn)(pts->range, arg, pts->level - 1, pts->table_lower); + return ret; +} + +/** + * pt_walk_range() - Walk over a VA range + * @range: Range pointer + * @fn: Walker function to call + * @arg: Value to pass to the function + * + * Walk over a VA range. The caller should have done a validity check, at + * least calling pt_check_range(), when building range. The walk will + * start at the top most table. + */ +static __always_inline int pt_walk_range(struct pt_range *range, + pt_level_fn_t fn, void *arg) +{ + return fn(range, arg, range->top_level, range->top_table); +} + +/* + * pt_walk_descend() - Recursively invoke the walker for a slice of a lower + * level + * @pts: Iteration State + * @va: Start address + * @last_va: Last address + * @fn: Walker function to call + * @arg: Value to pass to the function + * + * With pts pointing at a table item this will descend and over a slice of the + * lower table. The caller must ensure that va/last_va are within the table + * item. This creates a new walk and does not alter pts or pts->range. + */ +static __always_inline int pt_walk_descend(const struct pt_state *pts, + pt_vaddr_t va, pt_vaddr_t last_va, + pt_level_fn_t fn, void *arg) +{ + struct pt_range range = pt_make_child_range(pts->range, va, last_va); + + if (PT_WARN_ON(!pt_can_have_table(pts)) || + PT_WARN_ON(!pts->table_lower)) + return -EINVAL; + + return fn(&range, arg, pts->level - 1, pts->table_lower); +} + +/* + * pt_walk_descend_all() - Recursively invoke the walker for a table item + * @parent_pts: Iteration State + * @fn: Walker function to call + * @arg: Value to pass to the function + * + * With pts pointing at a table item this will descend and over the entire lower + * table. This creates a new walk and does not alter pts or pts->range. + */ +static __always_inline int +pt_walk_descend_all(const struct pt_state *parent_pts, pt_level_fn_t fn, + void *arg) +{ + unsigned int isz_lg2 = pt_table_item_lg2sz(parent_pts); + + return pt_walk_descend(parent_pts, + log2_set_mod(parent_pts->range->va, 0, isz_lg2), + log2_set_mod_max(parent_pts->range->va, isz_lg2), + fn, arg); +} + +/** + * pt_range_slice() - Return a range that spans indexes + * @pts: Iteration State + * @start_index: Starting index within pts + * @end_index: Ending index within pts + * + * Create a range than spans an index range of the current table level + * pt_state points at. + */ +static inline struct pt_range pt_range_slice(const struct pt_state *pts, + unsigned int start_index, + unsigned int end_index) +{ + unsigned int table_lg2sz = pt_table_oa_lg2sz(pts); + pt_vaddr_t last_va; + pt_vaddr_t va; + + va = fvalog2_set_mod(pts->range->va, + log2_mul(start_index, pt_table_item_lg2sz(pts)), + table_lg2sz); + last_va = fvalog2_set_mod( + pts->range->va, + log2_mul(end_index, pt_table_item_lg2sz(pts)) - 1, table_lg2sz); + return pt_make_child_range(pts->range, va, last_va); +} + +/** + * pt_top_memsize_lg2() + * @common: Table + * @top_of_table: Top of table value from _pt_top_set() + * + * Compute the allocation size of the top table. For PT_FEAT_DYNAMIC_TOP this + * will compute the top size assuming the table will grow. + */ +static inline unsigned int pt_top_memsize_lg2(struct pt_common *common, + uintptr_t top_of_table) +{ + struct pt_range range = _pt_top_range(common, top_of_table); + struct pt_state pts = pt_init_top(&range); + unsigned int num_items_lg2; + + num_items_lg2 = common->max_vasz_lg2 - pt_table_item_lg2sz(&pts); + if (range.top_level != PT_MAX_TOP_LEVEL && + pt_feature(common, PT_FEAT_DYNAMIC_TOP)) + num_items_lg2 = min(num_items_lg2, pt_num_items_lg2(&pts)); + + /* Round up the allocation size to the minimum alignment */ + return max(ffs_t(u64, PT_TOP_PHYS_MASK), + num_items_lg2 + ilog2(PT_ITEM_WORD_SIZE)); +} + +/** + * pt_compute_best_pgsize() - Determine the best page size for leaf entries + * @pgsz_bitmap: Permitted page sizes + * @va: Starting virtual address for the leaf entry + * @last_va: Last virtual address for the leaf entry, sets the max page size + * @oa: Starting output address for the leaf entry + * + * Compute the largest page size for va, last_va, and oa together and return it + * in lg2. The largest page size depends on the format's supported page sizes at + * this level, and the relative alignment of the VA and OA addresses. 0 means + * the OA cannot be stored with the provided pgsz_bitmap. + */ +static inline unsigned int pt_compute_best_pgsize(pt_vaddr_t pgsz_bitmap, + pt_vaddr_t va, + pt_vaddr_t last_va, + pt_oaddr_t oa) +{ + unsigned int best_pgsz_lg2; + unsigned int pgsz_lg2; + pt_vaddr_t len = last_va - va + 1; + pt_vaddr_t mask; + + if (PT_WARN_ON(va >= last_va)) + return 0; + + /* + * Given a VA/OA pair the best page size is the largest page size + * where: + * + * 1) VA and OA start at the page. Bitwise this is the count of least + * significant 0 bits. + * This also implies that last_va/oa has the same prefix as va/oa. + */ + mask = va | oa; + + /* + * 2) The page size is not larger than the last_va (length). Since page + * sizes are always power of two this can't be larger than the + * largest power of two factor of the length. + */ + mask |= log2_to_int(vafls(len) - 1); + + best_pgsz_lg2 = vaffs(mask); + + /* Choose the highest bit <= best_pgsz_lg2 */ + if (best_pgsz_lg2 < PT_VADDR_MAX_LG2 - 1) + pgsz_bitmap = log2_mod(pgsz_bitmap, best_pgsz_lg2 + 1); + + pgsz_lg2 = vafls(pgsz_bitmap); + if (!pgsz_lg2) + return 0; + + pgsz_lg2--; + + PT_WARN_ON(log2_mod(va, pgsz_lg2) != 0); + PT_WARN_ON(oalog2_mod(oa, pgsz_lg2) != 0); + PT_WARN_ON(va + log2_to_int(pgsz_lg2) - 1 > last_va); + PT_WARN_ON(!log2_div_eq(va, va + log2_to_int(pgsz_lg2) - 1, pgsz_lg2)); + PT_WARN_ON( + !oalog2_div_eq(oa, oa + log2_to_int(pgsz_lg2) - 1, pgsz_lg2)); + return pgsz_lg2; +} + +#define _PT_MAKE_CALL_LEVEL(fn) \ + static __always_inline int fn(struct pt_range *range, void *arg, \ + unsigned int level, \ + struct pt_table_p *table) \ + { \ + static_assert(PT_MAX_TOP_LEVEL <= 5); \ + if (level == 0) \ + return CONCATENATE(fn, 0)(range, arg, 0, table); \ + if (level == 1 || PT_MAX_TOP_LEVEL == 1) \ + return CONCATENATE(fn, 1)(range, arg, 1, table); \ + if (level == 2 || PT_MAX_TOP_LEVEL == 2) \ + return CONCATENATE(fn, 2)(range, arg, 2, table); \ + if (level == 3 || PT_MAX_TOP_LEVEL == 3) \ + return CONCATENATE(fn, 3)(range, arg, 3, table); \ + if (level == 4 || PT_MAX_TOP_LEVEL == 4) \ + return CONCATENATE(fn, 4)(range, arg, 4, table); \ + return CONCATENATE(fn, 5)(range, arg, 5, table); \ + } + +static inline int __pt_make_level_fn_err(struct pt_range *range, void *arg, + unsigned int unused_level, + struct pt_table_p *table) +{ + static_assert(PT_MAX_TOP_LEVEL <= 5); + return -EPROTOTYPE; +} + +#define __PT_MAKE_LEVEL_FN(fn, level, descend_fn, do_fn) \ + static inline int fn(struct pt_range *range, void *arg, \ + unsigned int unused_level, \ + struct pt_table_p *table) \ + { \ + return do_fn(range, arg, level, table, descend_fn); \ + } + +/** + * PT_MAKE_LEVELS() - Build an unwound walker + * @fn: Name of the walker function + * @do_fn: Function to call at each level + * + * This builds a function call tree that can be fully inlined. + * The caller must provide a function body in an __always_inline function:: + * + * static __always_inline int do_fn(struct pt_range *range, void *arg, + * unsigned int level, struct pt_table_p *table, + * pt_level_fn_t descend_fn) + * + * An inline function will be created for each table level that calls do_fn with + * a compile time constant for level and a pointer to the next lower function. + * This generates an optimally inlined walk where each of the functions sees a + * constant level and can codegen the exact constants/etc for that level. + * + * Note this can produce a lot of code! + */ +#define PT_MAKE_LEVELS(fn, do_fn) \ + __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 0), 0, __pt_make_level_fn_err, \ + do_fn); \ + __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 1), 1, CONCATENATE(fn, 0), do_fn); \ + __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 2), 2, CONCATENATE(fn, 1), do_fn); \ + __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 3), 3, CONCATENATE(fn, 2), do_fn); \ + __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 4), 4, CONCATENATE(fn, 3), do_fn); \ + __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 5), 5, CONCATENATE(fn, 4), do_fn); \ + _PT_MAKE_CALL_LEVEL(fn) + +#endif diff --git a/drivers/iommu/generic_pt/pt_log2.h b/drivers/iommu/generic_pt/pt_log2.h new file mode 100644 index 000000000000..6dbbed119238 --- /dev/null +++ b/drivers/iommu/generic_pt/pt_log2.h @@ -0,0 +1,122 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES + * + * Helper macros for working with log2 values + * + */ +#ifndef __GENERIC_PT_LOG2_H +#define __GENERIC_PT_LOG2_H +#include <linux/bitops.h> +#include <linux/limits.h> + +/* Compute a */ +#define log2_to_int_t(type, a_lg2) ((type)(((type)1) << (a_lg2))) +static_assert(log2_to_int_t(unsigned int, 0) == 1); + +/* Compute a - 1 (aka all low bits set) */ +#define log2_to_max_int_t(type, a_lg2) ((type)(log2_to_int_t(type, a_lg2) - 1)) + +/* Compute a / b */ +#define log2_div_t(type, a, b_lg2) ((type)(((type)a) >> (b_lg2))) +static_assert(log2_div_t(unsigned int, 4, 2) == 1); + +/* + * Compute: + * a / c == b / c + * aka the high bits are equal + */ +#define log2_div_eq_t(type, a, b, c_lg2) \ + (log2_div_t(type, (a) ^ (b), c_lg2) == 0) +static_assert(log2_div_eq_t(unsigned int, 1, 1, 2)); + +/* Compute a % b */ +#define log2_mod_t(type, a, b_lg2) \ + ((type)(((type)a) & log2_to_max_int_t(type, b_lg2))) +static_assert(log2_mod_t(unsigned int, 1, 2) == 1); + +/* + * Compute: + * a % b == b - 1 + * aka the low bits are all 1s + */ +#define log2_mod_eq_max_t(type, a, b_lg2) \ + (log2_mod_t(type, a, b_lg2) == log2_to_max_int_t(type, b_lg2)) +static_assert(log2_mod_eq_max_t(unsigned int, 3, 2)); + +/* + * Return a value such that: + * a / b == ret / b + * ret % b == val + * aka set the low bits to val. val must be < b + */ +#define log2_set_mod_t(type, a, val, b_lg2) \ + ((((type)(a)) & (~log2_to_max_int_t(type, b_lg2))) | ((type)(val))) +static_assert(log2_set_mod_t(unsigned int, 3, 1, 2) == 1); + +/* Return a value such that: + * a / b == ret / b + * ret % b == b - 1 + * aka set the low bits to all 1s + */ +#define log2_set_mod_max_t(type, a, b_lg2) \ + (((type)(a)) | log2_to_max_int_t(type, b_lg2)) +static_assert(log2_set_mod_max_t(unsigned int, 2, 2) == 3); + +/* Compute a * b */ +#define log2_mul_t(type, a, b_lg2) ((type)(((type)a) << (b_lg2))) +static_assert(log2_mul_t(unsigned int, 2, 2) == 8); + +#define _dispatch_sz(type, fn, a) \ + (sizeof(type) == 4 ? fn##32((u32)a) : fn##64(a)) + +/* + * Return the highest value such that: + * fls_t(u32, 0) == 0 + * fls_t(u3, 1) == 1 + * a >= log2_to_int(ret - 1) + * aka find last set bit + */ +static inline unsigned int fls32(u32 a) +{ + return fls(a); +} +#define fls_t(type, a) _dispatch_sz(type, fls, a) + +/* + * Return the highest value such that: + * ffs_t(u32, 0) == UNDEFINED + * ffs_t(u32, 1) == 0 + * log_mod(a, ret) == 0 + * aka find first set bit + */ +static inline unsigned int __ffs32(u32 a) +{ + return __ffs(a); +} +#define ffs_t(type, a) _dispatch_sz(type, __ffs, a) + +/* + * Return the highest value such that: + * ffz_t(u32, U32_MAX) == UNDEFINED + * ffz_t(u32, 0) == 0 + * ffz_t(u32, 1) == 1 + * log_mod(a, ret) == log_to_max_int(ret) + * aka find first zero bit + */ +static inline unsigned int ffz32(u32 a) +{ + return ffz(a); +} +static inline unsigned int ffz64(u64 a) +{ + if (sizeof(u64) == sizeof(unsigned long)) + return ffz(a); + + if ((u32)a == U32_MAX) + return ffz32(a >> 32) + 32; + return ffz32(a); +} +#define ffz_t(type, a) _dispatch_sz(type, ffz, a) + +#endif |
