diff options
Diffstat (limited to 'arch/riscv/net/bpf_jit_comp64.c')
-rw-r--r-- | arch/riscv/net/bpf_jit_comp64.c | 573 |
1 files changed, 200 insertions, 373 deletions
diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c index 9883a55d61b5..45cbc7c6fe49 100644 --- a/arch/riscv/net/bpf_jit_comp64.c +++ b/arch/riscv/net/bpf_jit_comp64.c @@ -18,7 +18,7 @@ #define RV_MAX_REG_ARGS 8 #define RV_FENTRY_NINSNS 2 #define RV_FENTRY_NBYTES (RV_FENTRY_NINSNS * 4) -#define RV_KCFI_NINSNS (IS_ENABLED(CONFIG_CFI_CLANG) ? 1 : 0) +#define RV_KCFI_NINSNS (IS_ENABLED(CONFIG_CFI) ? 1 : 0) /* imm that allows emit_imm to emit max count insns */ #define RV_MAX_COUNT_IMM 0x7FFF7FF7FF7FF7FF @@ -469,142 +469,96 @@ static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx) static inline void emit_kcfi(u32 hash, struct rv_jit_context *ctx) { - if (IS_ENABLED(CONFIG_CFI_CLANG)) + if (IS_ENABLED(CONFIG_CFI)) emit(hash, ctx); } -static int emit_load_8(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) +static void emit_ldx_insn(u8 rd, s16 off, u8 rs, u8 size, bool sign_ext, + struct rv_jit_context *ctx) { - int insns_start; - - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lb(rd, off, rs), ctx); - else - emit(rv_lbu(rd, off, rs), ctx); - return ctx->ninsns - insns_start; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rs, ctx); - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lb(rd, 0, RV_REG_T1), ctx); - else - emit(rv_lbu(rd, 0, RV_REG_T1), ctx); - return ctx->ninsns - insns_start; -} - -static int emit_load_16(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) -{ - int insns_start; - - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lh(rd, off, rs), ctx); - else - emit(rv_lhu(rd, off, rs), ctx); - return ctx->ninsns - insns_start; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rs, ctx); - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lh(rd, 0, RV_REG_T1), ctx); - else - emit(rv_lhu(rd, 0, RV_REG_T1), ctx); - return ctx->ninsns - insns_start; -} - -static int emit_load_32(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) -{ - int insns_start; - - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lw(rd, off, rs), ctx); - else - emit(rv_lwu(rd, off, rs), ctx); - return ctx->ninsns - insns_start; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rs, ctx); - insns_start = ctx->ninsns; - if (sign_ext) - emit(rv_lw(rd, 0, RV_REG_T1), ctx); - else - emit(rv_lwu(rd, 0, RV_REG_T1), ctx); - return ctx->ninsns - insns_start; -} - -static int emit_load_64(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) -{ - int insns_start; - - if (is_12b_int(off)) { - insns_start = ctx->ninsns; + switch (size) { + case BPF_B: + emit(sign_ext ? rv_lb(rd, off, rs) : rv_lbu(rd, off, rs), ctx); + break; + case BPF_H: + emit(sign_ext ? rv_lh(rd, off, rs) : rv_lhu(rd, off, rs), ctx); + break; + case BPF_W: + emit(sign_ext ? rv_lw(rd, off, rs) : rv_lwu(rd, off, rs), ctx); + break; + case BPF_DW: emit_ld(rd, off, rs, ctx); - return ctx->ninsns - insns_start; + break; } - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rs, ctx); - insns_start = ctx->ninsns; - emit_ld(rd, 0, RV_REG_T1, ctx); - return ctx->ninsns - insns_start; } -static void emit_store_8(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) +static void emit_stx_insn(u8 rd, s16 off, u8 rs, u8 size, struct rv_jit_context *ctx) { - if (is_12b_int(off)) { + switch (size) { + case BPF_B: emit(rv_sb(rd, off, rs), ctx); - return; + break; + case BPF_H: + emit(rv_sh(rd, off, rs), ctx); + break; + case BPF_W: + emit_sw(rd, off, rs, ctx); + break; + case BPF_DW: + emit_sd(rd, off, rs, ctx); + break; } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - emit(rv_sb(RV_REG_T1, 0, rs), ctx); } -static void emit_store_16(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) +static void emit_ldx(u8 rd, s16 off, u8 rs, u8 size, bool sign_ext, + struct rv_jit_context *ctx) { if (is_12b_int(off)) { - emit(rv_sh(rd, off, rs), ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_ldx_insn(rd, off, rs, size, sign_ext, ctx); + ctx->ex_jmp_off = ctx->ninsns; return; } emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - emit(rv_sh(RV_REG_T1, 0, rs), ctx); + emit_add(RV_REG_T1, RV_REG_T1, rs, ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_ldx_insn(rd, 0, RV_REG_T1, size, sign_ext, ctx); + ctx->ex_jmp_off = ctx->ninsns; } -static void emit_store_32(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) +static void emit_st(u8 rd, s16 off, s32 imm, u8 size, struct rv_jit_context *ctx) { + emit_imm(RV_REG_T1, imm, ctx); if (is_12b_int(off)) { - emit_sw(rd, off, rs, ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_stx_insn(rd, off, RV_REG_T1, size, ctx); + ctx->ex_jmp_off = ctx->ninsns; return; } - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - emit_sw(RV_REG_T1, 0, rs, ctx); + emit_imm(RV_REG_T2, off, ctx); + emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_stx_insn(RV_REG_T2, 0, RV_REG_T1, size, ctx); + ctx->ex_jmp_off = ctx->ninsns; } -static void emit_store_64(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx) +static void emit_stx(u8 rd, s16 off, u8 rs, u8 size, struct rv_jit_context *ctx) { if (is_12b_int(off)) { - emit_sd(rd, off, rs, ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_stx_insn(rd, off, rs, size, ctx); + ctx->ex_jmp_off = ctx->ninsns; return; } emit_imm(RV_REG_T1, off, ctx); emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - emit_sd(RV_REG_T1, 0, rs, ctx); + ctx->ex_insn_off = ctx->ninsns; + emit_stx_insn(RV_REG_T1, 0, rs, size, ctx); + ctx->ex_jmp_off = ctx->ninsns; } static int emit_atomic_ld_st(u8 rd, u8 rs, const struct bpf_insn *insn, @@ -617,20 +571,12 @@ static int emit_atomic_ld_st(u8 rd, u8 rs, const struct bpf_insn *insn, switch (imm) { /* dst_reg = load_acquire(src_reg + off16) */ case BPF_LOAD_ACQ: - switch (BPF_SIZE(code)) { - case BPF_B: - emit_load_8(false, rd, off, rs, ctx); - break; - case BPF_H: - emit_load_16(false, rd, off, rs, ctx); - break; - case BPF_W: - emit_load_32(false, rd, off, rs, ctx); - break; - case BPF_DW: - emit_load_64(false, rd, off, rs, ctx); - break; + if (BPF_MODE(code) == BPF_PROBE_ATOMIC) { + emit_add(RV_REG_T2, rs, RV_REG_ARENA, ctx); + rs = RV_REG_T2; } + + emit_ldx(rd, off, rs, BPF_SIZE(code), false, ctx); emit_fence_r_rw(ctx); /* If our next insn is a redundant zext, return 1 to tell @@ -641,21 +587,13 @@ static int emit_atomic_ld_st(u8 rd, u8 rs, const struct bpf_insn *insn, break; /* store_release(dst_reg + off16, src_reg) */ case BPF_STORE_REL: - emit_fence_rw_w(ctx); - switch (BPF_SIZE(code)) { - case BPF_B: - emit_store_8(rd, off, rs, ctx); - break; - case BPF_H: - emit_store_16(rd, off, rs, ctx); - break; - case BPF_W: - emit_store_32(rd, off, rs, ctx); - break; - case BPF_DW: - emit_store_64(rd, off, rs, ctx); - break; + if (BPF_MODE(code) == BPF_PROBE_ATOMIC) { + emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx); + rd = RV_REG_T2; } + + emit_fence_rw_w(ctx); + emit_stx(rd, off, rs, BPF_SIZE(code), ctx); break; default: pr_err_once("bpf-jit: invalid atomic load/store opcode %02x\n", imm); @@ -668,17 +606,15 @@ static int emit_atomic_ld_st(u8 rd, u8 rs, const struct bpf_insn *insn, static int emit_atomic_rmw(u8 rd, u8 rs, const struct bpf_insn *insn, struct rv_jit_context *ctx) { - u8 r0, code = insn->code; + u8 code = insn->code; s16 off = insn->off; s32 imm = insn->imm; - int jmp_offset; - bool is64; + bool is64 = BPF_SIZE(code) == BPF_DW; if (BPF_SIZE(code) != BPF_W && BPF_SIZE(code) != BPF_DW) { pr_err_once("bpf-jit: 1- and 2-byte RMW atomics are not supported\n"); return -EINVAL; } - is64 = BPF_SIZE(code) == BPF_DW; if (off) { if (is_12b_int(off)) { @@ -690,72 +626,82 @@ static int emit_atomic_rmw(u8 rd, u8 rs, const struct bpf_insn *insn, rd = RV_REG_T1; } + if (BPF_MODE(code) == BPF_PROBE_ATOMIC) { + emit_add(RV_REG_T1, rd, RV_REG_ARENA, ctx); + rd = RV_REG_T1; + } + switch (imm) { /* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */ case BPF_ADD: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) : rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx); + ctx->ex_jmp_off = ctx->ninsns; break; case BPF_AND: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) : rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx); + ctx->ex_jmp_off = ctx->ninsns; break; case BPF_OR: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) : rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx); + ctx->ex_jmp_off = ctx->ninsns; break; case BPF_XOR: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) : rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx); + ctx->ex_jmp_off = ctx->ninsns; break; /* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */ case BPF_ADD | BPF_FETCH: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoadd_d(rs, rs, rd, 1, 1) : rv_amoadd_w(rs, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; if (!is64) emit_zextw(rs, rs, ctx); break; case BPF_AND | BPF_FETCH: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoand_d(rs, rs, rd, 1, 1) : rv_amoand_w(rs, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; if (!is64) emit_zextw(rs, rs, ctx); break; case BPF_OR | BPF_FETCH: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoor_d(rs, rs, rd, 1, 1) : rv_amoor_w(rs, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; if (!is64) emit_zextw(rs, rs, ctx); break; case BPF_XOR | BPF_FETCH: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoxor_d(rs, rs, rd, 1, 1) : rv_amoxor_w(rs, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; if (!is64) emit_zextw(rs, rs, ctx); break; /* src_reg = atomic_xchg(dst_reg + off16, src_reg); */ case BPF_XCHG: + ctx->ex_insn_off = ctx->ninsns; emit(is64 ? rv_amoswap_d(rs, rs, rd, 1, 1) : rv_amoswap_w(rs, rs, rd, 1, 1), ctx); + ctx->ex_jmp_off = ctx->ninsns; if (!is64) emit_zextw(rs, rs, ctx); break; /* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */ case BPF_CMPXCHG: - r0 = bpf_to_rv_reg(BPF_REG_0, ctx); - if (is64) - emit_mv(RV_REG_T2, r0, ctx); - else - emit_addiw(RV_REG_T2, r0, 0, ctx); - emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) : - rv_lr_w(r0, 0, rd, 0, 0), ctx); - jmp_offset = ninsns_rvoff(8); - emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx); - emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 1) : - rv_sc_w(RV_REG_T3, rs, rd, 0, 1), ctx); - jmp_offset = ninsns_rvoff(-6); - emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx); - emit_fence_rw_rw(ctx); + emit_cmpxchg(rd, rs, regmap[BPF_REG_0], is64, ctx); break; default: pr_err_once("bpf-jit: invalid atomic RMW opcode %02x\n", imm); @@ -765,6 +711,39 @@ static int emit_atomic_rmw(u8 rd, u8 rs, const struct bpf_insn *insn, return 0; } +/* + * Sign-extend the register if necessary + */ +static int sign_extend(u8 rd, u8 rs, u8 sz, bool sign, struct rv_jit_context *ctx) +{ + if (!sign && (sz == 1 || sz == 2)) { + if (rd != rs) + emit_mv(rd, rs, ctx); + return 0; + } + + switch (sz) { + case 1: + emit_sextb(rd, rs, ctx); + break; + case 2: + emit_sexth(rd, rs, ctx); + break; + case 4: + emit_sextw(rd, rs, ctx); + break; + case 8: + if (rd != rs) + emit_mv(rd, rs, ctx); + break; + default: + pr_err("bpf-jit: invalid size %d for sign_extend\n", sz); + return -EINVAL; + } + + return 0; +} + #define BPF_FIXUP_OFFSET_MASK GENMASK(26, 0) #define BPF_FIXUP_REG_MASK GENMASK(31, 27) #define REG_DONT_CLEAR_MARKER 0 /* RV_REG_ZERO unused in pt_regmap */ @@ -783,9 +762,8 @@ bool ex_handler_bpf(const struct exception_table_entry *ex, } /* For accesses to BTF pointers, add an entry to the exception table */ -static int add_exception_handler(const struct bpf_insn *insn, - struct rv_jit_context *ctx, - int dst_reg, int insn_len) +static int add_exception_handler(const struct bpf_insn *insn, int dst_reg, + struct rv_jit_context *ctx) { struct exception_table_entry *ex; unsigned long pc; @@ -793,21 +771,23 @@ static int add_exception_handler(const struct bpf_insn *insn, off_t fixup_offset; if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable || - (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX && - BPF_MODE(insn->code) != BPF_PROBE_MEM32)) + ctx->ex_insn_off <= 0 || ctx->ex_jmp_off <= 0) return 0; - if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries)) - return -EINVAL; + if (BPF_MODE(insn->code) != BPF_PROBE_MEM && + BPF_MODE(insn->code) != BPF_PROBE_MEMSX && + BPF_MODE(insn->code) != BPF_PROBE_MEM32 && + BPF_MODE(insn->code) != BPF_PROBE_ATOMIC) + return 0; - if (WARN_ON_ONCE(insn_len > ctx->ninsns)) + if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries)) return -EINVAL; - if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1)) + if (WARN_ON_ONCE(ctx->ex_insn_off > ctx->ninsns || ctx->ex_jmp_off > ctx->ninsns)) return -EINVAL; ex = &ctx->prog->aux->extable[ctx->nexentries]; - pc = (unsigned long)&ctx->ro_insns[ctx->ninsns - insn_len]; + pc = (unsigned long)&ctx->ro_insns[ctx->ex_insn_off]; /* * This is the relative offset of the instruction that may fault from @@ -831,7 +811,7 @@ static int add_exception_handler(const struct bpf_insn *insn, * that may fault. The execution will jump to this after handling the * fault. */ - fixup_offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16)); + fixup_offset = (long)&ex->fixup - (long)&ctx->ro_insns[ctx->ex_jmp_off]; if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset)) return -ERANGE; @@ -848,6 +828,8 @@ static int add_exception_handler(const struct bpf_insn *insn, FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg); ex->type = EX_TYPE_BPF; + ctx->ex_insn_off = 0; + ctx->ex_jmp_off = 0; ctx->nexentries++; return 0; } @@ -1079,10 +1061,9 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, stack_size += 16; save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET); - if (save_ret) { + if (save_ret) stack_size += 16; /* Save both A5 (BPF R0) and A0 */ - retval_off = stack_size; - } + retval_off = stack_size; stack_size += nr_arg_slots * 8; args_off = stack_size; @@ -1226,8 +1207,15 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, restore_args(min_t(int, nr_arg_slots, RV_MAX_REG_ARGS), args_off, ctx); if (save_ret) { - emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx); emit_ld(regmap[BPF_REG_0], -(retval_off - 8), RV_REG_FP, ctx); + if (is_struct_ops) { + ret = sign_extend(RV_REG_A0, regmap[BPF_REG_0], m->ret_size, + m->ret_flags & BTF_FMODEL_SIGNED_ARG, ctx); + if (ret) + goto out; + } else { + emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx); + } } emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx); @@ -1320,7 +1308,6 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *ro_image, goto out; } - bpf_flush_icache(ro_image, ro_image_end); out: kvfree(image); return ret < 0 ? ret : size; @@ -1857,7 +1844,6 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW: { bool sign_ext; - int insn_len; sign_ext = BPF_MODE(insn->code) == BPF_MEMSX || BPF_MODE(insn->code) == BPF_PROBE_MEMSX; @@ -1867,22 +1853,9 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, rs = RV_REG_T2; } - switch (BPF_SIZE(code)) { - case BPF_B: - insn_len = emit_load_8(sign_ext, rd, off, rs, ctx); - break; - case BPF_H: - insn_len = emit_load_16(sign_ext, rd, off, rs, ctx); - break; - case BPF_W: - insn_len = emit_load_32(sign_ext, rd, off, rs, ctx); - break; - case BPF_DW: - insn_len = emit_load_64(sign_ext, rd, off, rs, ctx); - break; - } + emit_ldx(rd, off, rs, BPF_SIZE(code), sign_ext, ctx); - ret = add_exception_handler(insn, ctx, rd, insn_len); + ret = add_exception_handler(insn, rd, ctx); if (ret) return ret; @@ -1890,238 +1863,73 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, return 1; break; } + /* speculation barrier */ case BPF_ST | BPF_NOSPEC: break; /* ST: *(size *)(dst + off) = imm */ case BPF_ST | BPF_MEM | BPF_B: - emit_imm(RV_REG_T1, imm, ctx); - if (is_12b_int(off)) { - emit(rv_sb(rd, off, RV_REG_T1), ctx); - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx); - break; - case BPF_ST | BPF_MEM | BPF_H: - emit_imm(RV_REG_T1, imm, ctx); - if (is_12b_int(off)) { - emit(rv_sh(rd, off, RV_REG_T1), ctx); - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx); - break; case BPF_ST | BPF_MEM | BPF_W: - emit_imm(RV_REG_T1, imm, ctx); - if (is_12b_int(off)) { - emit_sw(rd, off, RV_REG_T1, ctx); - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx); - break; case BPF_ST | BPF_MEM | BPF_DW: - emit_imm(RV_REG_T1, imm, ctx); - if (is_12b_int(off)) { - emit_sd(rd, off, RV_REG_T1, ctx); - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx); - break; - + /* ST | PROBE_MEM32: *(size *)(dst + RV_REG_ARENA + off) = imm */ case BPF_ST | BPF_PROBE_MEM32 | BPF_B: case BPF_ST | BPF_PROBE_MEM32 | BPF_H: case BPF_ST | BPF_PROBE_MEM32 | BPF_W: case BPF_ST | BPF_PROBE_MEM32 | BPF_DW: - { - int insn_len, insns_start; - - emit_add(RV_REG_T3, rd, RV_REG_ARENA, ctx); - rd = RV_REG_T3; - - /* Load imm to a register then store it */ - emit_imm(RV_REG_T1, imm, ctx); - - switch (BPF_SIZE(code)) { - case BPF_B: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit(rv_sb(rd, off, RV_REG_T1), ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - insns_start = ctx->ninsns; - emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_H: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit(rv_sh(rd, off, RV_REG_T1), ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - insns_start = ctx->ninsns; - emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_W: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit_sw(rd, off, RV_REG_T1, ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - insns_start = ctx->ninsns; - emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_DW: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit_sd(rd, off, RV_REG_T1, ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T2, off, ctx); - emit_add(RV_REG_T2, RV_REG_T2, rd, ctx); - insns_start = ctx->ninsns; - emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx); - insn_len = ctx->ninsns - insns_start; - break; + if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) { + emit_add(RV_REG_T3, rd, RV_REG_ARENA, ctx); + rd = RV_REG_T3; } - ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER, - insn_len); + emit_st(rd, off, imm, BPF_SIZE(code), ctx); + + ret = add_exception_handler(insn, REG_DONT_CLEAR_MARKER, ctx); if (ret) return ret; - break; - } /* STX: *(size *)(dst + off) = src */ case BPF_STX | BPF_MEM | BPF_B: - emit_store_8(rd, off, rs, ctx); - break; case BPF_STX | BPF_MEM | BPF_H: - emit_store_16(rd, off, rs, ctx); - break; case BPF_STX | BPF_MEM | BPF_W: - emit_store_32(rd, off, rs, ctx); - break; case BPF_STX | BPF_MEM | BPF_DW: - emit_store_64(rd, off, rs, ctx); + /* STX | PROBE_MEM32: *(size *)(dst + RV_REG_ARENA + off) = src */ + case BPF_STX | BPF_PROBE_MEM32 | BPF_B: + case BPF_STX | BPF_PROBE_MEM32 | BPF_H: + case BPF_STX | BPF_PROBE_MEM32 | BPF_W: + case BPF_STX | BPF_PROBE_MEM32 | BPF_DW: + if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) { + emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx); + rd = RV_REG_T2; + } + + emit_stx(rd, off, rs, BPF_SIZE(code), ctx); + + ret = add_exception_handler(insn, REG_DONT_CLEAR_MARKER, ctx); + if (ret) + return ret; break; + + /* Atomics */ case BPF_STX | BPF_ATOMIC | BPF_B: case BPF_STX | BPF_ATOMIC | BPF_H: case BPF_STX | BPF_ATOMIC | BPF_W: case BPF_STX | BPF_ATOMIC | BPF_DW: + case BPF_STX | BPF_PROBE_ATOMIC | BPF_B: + case BPF_STX | BPF_PROBE_ATOMIC | BPF_H: + case BPF_STX | BPF_PROBE_ATOMIC | BPF_W: + case BPF_STX | BPF_PROBE_ATOMIC | BPF_DW: if (bpf_atomic_is_load_store(insn)) ret = emit_atomic_ld_st(rd, rs, insn, ctx); else ret = emit_atomic_rmw(rd, rs, insn, ctx); - if (ret) - return ret; - break; - case BPF_STX | BPF_PROBE_MEM32 | BPF_B: - case BPF_STX | BPF_PROBE_MEM32 | BPF_H: - case BPF_STX | BPF_PROBE_MEM32 | BPF_W: - case BPF_STX | BPF_PROBE_MEM32 | BPF_DW: - { - int insn_len, insns_start; - - emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx); - rd = RV_REG_T2; - - switch (BPF_SIZE(code)) { - case BPF_B: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit(rv_sb(rd, off, rs), ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - insns_start = ctx->ninsns; - emit(rv_sb(RV_REG_T1, 0, rs), ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_H: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit(rv_sh(rd, off, rs), ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - insns_start = ctx->ninsns; - emit(rv_sh(RV_REG_T1, 0, rs), ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_W: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit_sw(rd, off, rs, ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - insns_start = ctx->ninsns; - emit_sw(RV_REG_T1, 0, rs, ctx); - insn_len = ctx->ninsns - insns_start; - break; - case BPF_DW: - if (is_12b_int(off)) { - insns_start = ctx->ninsns; - emit_sd(rd, off, rs, ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - emit_imm(RV_REG_T1, off, ctx); - emit_add(RV_REG_T1, RV_REG_T1, rd, ctx); - insns_start = ctx->ninsns; - emit_sd(RV_REG_T1, 0, rs, ctx); - insn_len = ctx->ninsns - insns_start; - break; - } - - ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER, - insn_len); + ret = ret ?: add_exception_handler(insn, REG_DONT_CLEAR_MARKER, ctx); if (ret) return ret; - break; - } default: pr_err("bpf-jit: unknown opcode %02x\n", code); @@ -2249,6 +2057,25 @@ bool bpf_jit_supports_arena(void) return true; } +bool bpf_jit_supports_insn(struct bpf_insn *insn, bool in_arena) +{ + if (in_arena) { + switch (insn->code) { + case BPF_STX | BPF_ATOMIC | BPF_W: + case BPF_STX | BPF_ATOMIC | BPF_DW: + if (insn->imm == BPF_CMPXCHG) + return rv_ext_enabled(ZACAS); + break; + case BPF_LDX | BPF_MEMSX | BPF_B: + case BPF_LDX | BPF_MEMSX | BPF_H: + case BPF_LDX | BPF_MEMSX | BPF_W: + return false; + } + } + + return true; +} + bool bpf_jit_supports_percpu_insn(void) { return true; |