Commit 923510c8 authored by Peter Zijlstra's avatar Peter Zijlstra Committed by Ingo Molnar

x86/static_call: Add support for Jcc tail-calls

Clang likes to create conditional tail calls like:

  0000000000000350 <amd_pmu_add_event>:
  350:       0f 1f 44 00 00          nopl   0x0(%rax,%rax,1) 351: R_X86_64_NONE      __fentry__-0x4
  355:       48 83 bf 20 01 00 00 00         cmpq   $0x0,0x120(%rdi)
  35d:       0f 85 00 00 00 00       jne    363 <amd_pmu_add_event+0x13>     35f: R_X86_64_PLT32     __SCT__amd_pmu_branch_add-0x4
  363:       e9 00 00 00 00          jmp    368 <amd_pmu_add_event+0x18>     364: R_X86_64_PLT32     __x86_return_thunk-0x4

Where 0x35d is a static call site that's turned into a conditional
tail-call using the Jcc class of instructions.

Teach the in-line static call text patching about this.

Notably, since there is no conditional-ret, in that case patch the Jcc
to point at an empty stub function that does the ret -- or the return
thunk when needed.
Reported-by: default avatar"Erhard F." <erhard_f@mailbox.org>
Signed-off-by: default avatarPeter Zijlstra (Intel) <peterz@infradead.org>
Signed-off-by: default avatarIngo Molnar <mingo@kernel.org>
Reviewed-by: default avatarMasami Hiramatsu (Google) <mhiramat@kernel.org>
Link: https://lore.kernel.org/r/Y9Kdg9QjHkr9G5b5@hirez.programming.kicks-ass.net
parent ac0ee0a9
...@@ -9,6 +9,7 @@ enum insn_type { ...@@ -9,6 +9,7 @@ enum insn_type {
NOP = 1, /* site cond-call */ NOP = 1, /* site cond-call */
JMP = 2, /* tramp / site tail-call */ JMP = 2, /* tramp / site tail-call */
RET = 3, /* tramp / site cond-tail-call */ RET = 3, /* tramp / site cond-tail-call */
JCC = 4,
}; };
/* /*
...@@ -25,12 +26,40 @@ static const u8 xor5rax[] = { 0x2e, 0x2e, 0x2e, 0x31, 0xc0 }; ...@@ -25,12 +26,40 @@ static const u8 xor5rax[] = { 0x2e, 0x2e, 0x2e, 0x31, 0xc0 };
static const u8 retinsn[] = { RET_INSN_OPCODE, 0xcc, 0xcc, 0xcc, 0xcc }; static const u8 retinsn[] = { RET_INSN_OPCODE, 0xcc, 0xcc, 0xcc, 0xcc };
static u8 __is_Jcc(u8 *insn) /* Jcc.d32 */
{
u8 ret = 0;
if (insn[0] == 0x0f) {
u8 tmp = insn[1];
if ((tmp & 0xf0) == 0x80)
ret = tmp;
}
return ret;
}
extern void __static_call_return(void);
asm (".global __static_call_return\n\t"
".type __static_call_return, @function\n\t"
ASM_FUNC_ALIGN "\n\t"
"__static_call_return:\n\t"
ANNOTATE_NOENDBR
ANNOTATE_RETPOLINE_SAFE
"ret; int3\n\t"
".size __static_call_return, . - __static_call_return \n\t");
static void __ref __static_call_transform(void *insn, enum insn_type type, static void __ref __static_call_transform(void *insn, enum insn_type type,
void *func, bool modinit) void *func, bool modinit)
{ {
const void *emulate = NULL; const void *emulate = NULL;
int size = CALL_INSN_SIZE; int size = CALL_INSN_SIZE;
const void *code; const void *code;
u8 op, buf[6];
if ((type == JMP || type == RET) && (op = __is_Jcc(insn)))
type = JCC;
switch (type) { switch (type) {
case CALL: case CALL:
...@@ -57,6 +86,20 @@ static void __ref __static_call_transform(void *insn, enum insn_type type, ...@@ -57,6 +86,20 @@ static void __ref __static_call_transform(void *insn, enum insn_type type,
else else
code = &retinsn; code = &retinsn;
break; break;
case JCC:
if (!func) {
func = __static_call_return;
if (cpu_feature_enabled(X86_FEATURE_RETHUNK))
func = x86_return_thunk;
}
buf[0] = 0x0f;
__text_gen_insn(buf+1, op, insn+1, func, 5);
code = buf;
size = 6;
break;
} }
if (memcmp(insn, code, size) == 0) if (memcmp(insn, code, size) == 0)
...@@ -68,9 +111,9 @@ static void __ref __static_call_transform(void *insn, enum insn_type type, ...@@ -68,9 +111,9 @@ static void __ref __static_call_transform(void *insn, enum insn_type type,
text_poke_bp(insn, code, size, emulate); text_poke_bp(insn, code, size, emulate);
} }
static void __static_call_validate(void *insn, bool tail, bool tramp) static void __static_call_validate(u8 *insn, bool tail, bool tramp)
{ {
u8 opcode = *(u8 *)insn; u8 opcode = insn[0];
if (tramp && memcmp(insn+5, tramp_ud, 3)) { if (tramp && memcmp(insn+5, tramp_ud, 3)) {
pr_err("trampoline signature fail"); pr_err("trampoline signature fail");
...@@ -79,7 +122,8 @@ static void __static_call_validate(void *insn, bool tail, bool tramp) ...@@ -79,7 +122,8 @@ static void __static_call_validate(void *insn, bool tail, bool tramp)
if (tail) { if (tail) {
if (opcode == JMP32_INSN_OPCODE || if (opcode == JMP32_INSN_OPCODE ||
opcode == RET_INSN_OPCODE) opcode == RET_INSN_OPCODE ||
__is_Jcc(insn))
return; return;
} else { } else {
if (opcode == CALL_INSN_OPCODE || if (opcode == CALL_INSN_OPCODE ||
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment