From e16149fddca8af49aca6d3be0fd4de5a1561f2ae Mon Sep 17 00:00:00 2001 From: Rusty Wagner Date: Tue, 6 Jan 2026 20:10:04 -0700 Subject: [PATCH] Add APIs to allow calling conventions to provide requirements for being considered during heuristic calling convention detection --- arch/x86/arch_x86.cpp | 5 ++ binaryninjaapi.h | 19 ++++++ binaryninjacore.h | 4 ++ callingconvention.cpp | 62 +++++++++++++++++++ python/callingconvention.py | 54 +++++++++++++++++ rust/src/calling_convention.rs | 106 +++++++++++++++++++++++++++++++++ 6 files changed, 250 insertions(+) diff --git a/arch/x86/arch_x86.cpp b/arch/x86/arch_x86.cpp index 708135a7da..52ce0e5fc0 100644 --- a/arch/x86/arch_x86.cpp +++ b/arch/x86/arch_x86.cpp @@ -3814,6 +3814,11 @@ class X86ThiscallCallingConvention: public X86BaseCallingConvention return vector{ XED_REG_ECX }; } + virtual vector GetRequiredArgumentRegisters() override + { + return vector{ XED_REG_ECX }; + } + virtual bool IsStackAdjustedOnReturn() override { return true; diff --git a/binaryninjaapi.h b/binaryninjaapi.h index db54a9a8be..f37b32931e 100644 --- a/binaryninjaapi.h +++ b/binaryninjaapi.h @@ -17223,6 +17223,8 @@ namespace BinaryNinja { static uint32_t* GetCalleeSavedRegistersCallback(void* ctxt, size_t* count); static uint32_t* GetIntegerArgumentRegistersCallback(void* ctxt, size_t* count); static uint32_t* GetFloatArgumentRegistersCallback(void* ctxt, size_t* count); + static uint32_t* GetRequiredArgumentRegistersCallback(void* ctxt, size_t* count); + static uint32_t* GetRequiredClobberedRegistersCallback(void* ctxt, size_t* count); static void FreeRegisterListCallback(void* ctxt, uint32_t* regs, size_t len); static bool AreArgumentRegistersSharedIndexCallback(void* ctxt); @@ -17255,6 +17257,21 @@ namespace BinaryNinja { virtual std::vector GetIntegerArgumentRegisters(); virtual std::vector GetFloatArgumentRegisters(); + + /*! Gets the set of registers that must be arguments for heuristic calling convention + detection to consider this calling convention as a valid option. + + \return The set of registers that must be arguments + */ + virtual std::vector GetRequiredArgumentRegisters(); + + /*! Gets the set of registers that must be clobbered for heuristic calling convention + detection to consider this calling convention as a valid option. + + \return The set of registers that must be clobbered + */ + virtual std::vector GetRequiredClobberedRegisters(); + virtual bool AreArgumentRegistersSharedIndex(); virtual bool AreArgumentRegistersUsedForVarArgs(); virtual bool IsStackReservedForArgumentRegisters(); @@ -17287,6 +17304,8 @@ namespace BinaryNinja { virtual std::vector GetIntegerArgumentRegisters() override; virtual std::vector GetFloatArgumentRegisters() override; + virtual std::vector GetRequiredArgumentRegisters() override; + virtual std::vector GetRequiredClobberedRegisters() override; virtual bool AreArgumentRegistersSharedIndex() override; virtual bool AreArgumentRegistersUsedForVarArgs() override; virtual bool IsStackReservedForArgumentRegisters() override; diff --git a/binaryninjacore.h b/binaryninjacore.h index 8144ed7e2e..daace59993 100644 --- a/binaryninjacore.h +++ b/binaryninjacore.h @@ -2770,6 +2770,8 @@ extern "C" uint32_t* (*getCalleeSavedRegisters)(void* ctxt, size_t* count); uint32_t* (*getIntegerArgumentRegisters)(void* ctxt, size_t* count); uint32_t* (*getFloatArgumentRegisters)(void* ctxt, size_t* count); + uint32_t* (*getRequiredArgumentRegisters)(void* ctxt, size_t* count); + uint32_t* (*getRequiredClobberedRegisters)(void* ctxt, size_t* count); void (*freeRegisterList)(void* ctxt, uint32_t* regs, size_t len); bool (*areArgumentRegistersSharedIndex)(void* ctxt); @@ -7412,6 +7414,8 @@ extern "C" BINARYNINJACOREAPI uint32_t* BNGetIntegerArgumentRegisters(BNCallingConvention* cc, size_t* count); BINARYNINJACOREAPI uint32_t* BNGetFloatArgumentRegisters(BNCallingConvention* cc, size_t* count); + BINARYNINJACOREAPI uint32_t* BNGetRequiredArgumentRegisters(BNCallingConvention* cc, size_t* count); + BINARYNINJACOREAPI uint32_t* BNGetRequiredClobberedRegisters(BNCallingConvention* cc, size_t* count); BINARYNINJACOREAPI bool BNAreArgumentRegistersSharedIndex(BNCallingConvention* cc); BINARYNINJACOREAPI bool BNAreArgumentRegistersUsedForVarArgs(BNCallingConvention* cc); BINARYNINJACOREAPI bool BNIsStackReservedForArgumentRegisters(BNCallingConvention* cc); diff --git a/callingconvention.cpp b/callingconvention.cpp index abc9305358..b7ee5bbe7d 100644 --- a/callingconvention.cpp +++ b/callingconvention.cpp @@ -39,6 +39,8 @@ CallingConvention::CallingConvention(Architecture* arch, const string& name) cc.getCalleeSavedRegisters = GetCalleeSavedRegistersCallback; cc.getIntegerArgumentRegisters = GetIntegerArgumentRegistersCallback; cc.getFloatArgumentRegisters = GetFloatArgumentRegistersCallback; + cc.getRequiredArgumentRegisters = GetRequiredArgumentRegistersCallback; + cc.getRequiredClobberedRegisters = GetRequiredClobberedRegistersCallback; cc.freeRegisterList = FreeRegisterListCallback; cc.areArgumentRegistersSharedIndex = AreArgumentRegistersSharedIndexCallback; cc.areArgumentRegistersUsedForVarArgs = AreArgumentRegistersUsedForVarArgsCallback; @@ -119,6 +121,32 @@ uint32_t* CallingConvention::GetFloatArgumentRegistersCallback(void* ctxt, size_ } +uint32_t* CallingConvention::GetRequiredArgumentRegistersCallback(void* ctxt, size_t* count) +{ + CallbackRef cc(ctxt); + vector regs = cc->GetRequiredArgumentRegisters(); + *count = regs.size(); + + uint32_t* result = new uint32_t[regs.size()]; + for (size_t i = 0; i < regs.size(); i++) + result[i] = regs[i]; + return result; +} + + +uint32_t* CallingConvention::GetRequiredClobberedRegistersCallback(void* ctxt, size_t* count) +{ + CallbackRef cc(ctxt); + vector regs = cc->GetRequiredClobberedRegisters(); + *count = regs.size(); + + uint32_t* result = new uint32_t[regs.size()]; + for (size_t i = 0; i < regs.size(); i++) + result[i] = regs[i]; + return result; +} + + void CallingConvention::FreeRegisterListCallback(void*, uint32_t* regs, size_t) { delete[] regs; @@ -284,6 +312,18 @@ vector CallingConvention::GetFloatArgumentRegisters() } +vector CallingConvention::GetRequiredArgumentRegisters() +{ + return vector(); +} + + +vector CallingConvention::GetRequiredClobberedRegisters() +{ + return vector(); +} + + bool CallingConvention::AreArgumentRegistersSharedIndex() { return false; @@ -417,6 +457,28 @@ vector CoreCallingConvention::GetFloatArgumentRegisters() } +vector CoreCallingConvention::GetRequiredArgumentRegisters() +{ + size_t count; + uint32_t* regs = BNGetRequiredArgumentRegisters(m_object, &count); + vector result; + result.insert(result.end(), regs, ®s[count]); + BNFreeRegisterList(regs); + return result; +} + + +vector CoreCallingConvention::GetRequiredClobberedRegisters() +{ + size_t count; + uint32_t* regs = BNGetRequiredClobberedRegisters(m_object, &count); + vector result; + result.insert(result.end(), regs, ®s[count]); + BNFreeRegisterList(regs); + return result; +} + + bool CoreCallingConvention::AreArgumentRegistersSharedIndex() { return BNAreArgumentRegistersSharedIndex(m_object); diff --git a/python/callingconvention.py b/python/callingconvention.py index 2023e37f6f..32fa6d8177 100644 --- a/python/callingconvention.py +++ b/python/callingconvention.py @@ -40,6 +40,8 @@ class CallingConvention: callee_saved_regs = [] int_arg_regs = [] float_arg_regs = [] + required_arg_regs = [] + required_clobbered_regs = [] arg_regs_share_index = False arg_regs_for_varargs = True stack_reserved_for_arg_regs = False @@ -70,6 +72,8 @@ def __init__( self._get_int_arg_regs ) self._cb.getFloatArgumentRegisters = self._cb.getFloatArgumentRegisters.__class__(self._get_float_arg_regs) + self._cb.getRequiredArgumentRegisters = self._cb.getRequiredArgumentRegisters.__class__(self._get_required_arg_regs) + self._cb.getRequiredClobberedRegisters = self._cb.getRequiredClobberedRegisters.__class__(self._get_required_clobbered_regs) self._cb.freeRegisterList = self._cb.freeRegisterList.__class__(self._free_register_list) self._cb.areArgumentRegistersSharedIndex = self._cb.areArgumentRegistersSharedIndex.__class__( self._arg_regs_share_index @@ -161,6 +165,26 @@ def __init__( core.BNFreeRegisterList(regs) self.__dict__["float_arg_regs"] = result + count = ctypes.c_ulonglong() + regs = core.BNGetRequiredArgumentRegisters(_handle, count) + assert regs is not None, "core.BNGetRequiredArgumentRegisters returned None" + result = [] + arch = self.arch + for i in range(0, count.value): + result.append(arch.get_reg_name(regs[i])) + core.BNFreeRegisterList(regs) + self.__dict__["required_arg_regs"] = result + + count = ctypes.c_ulonglong() + regs = core.BNGetRequiredClobberedRegisters(_handle, count) + assert regs is not None, "core.BNGetRequiredClobberedRegisters returned None" + result = [] + arch = self.arch + for i in range(0, count.value): + result.append(arch.get_reg_name(regs[i])) + core.BNFreeRegisterList(regs) + self.__dict__["required_clobbered_regs"] = result + reg = core.BNGetIntegerReturnValueRegister(_handle) if reg == 0xffffffff: self.__dict__["int_return_reg"] = None @@ -281,6 +305,36 @@ def _get_float_arg_regs(self, ctxt, count): count[0] = 0 return None + def _get_required_arg_regs(self, ctxt, count): + try: + regs = self.__class__.required_arg_regs + count[0] = len(regs) + reg_buf = (ctypes.c_uint * len(regs))() + for i in range(0, len(regs)): + reg_buf[i] = self.arch.regs[regs[i]].index + result = ctypes.cast(reg_buf, ctypes.c_void_p) + self._pending_reg_lists[result.value] = (result, reg_buf) + return result.value + except: + log_error_for_exception("Unhandled Python exception in CallingConvention._get_required_arg_regs") + count[0] = 0 + return None + + def _get_required_clobbered_regs(self, ctxt, count): + try: + regs = self.__class__.required_clobbered_regs + count[0] = len(regs) + reg_buf = (ctypes.c_uint * len(regs))() + for i in range(0, len(regs)): + reg_buf[i] = self.arch.regs[regs[i]].index + result = ctypes.cast(reg_buf, ctypes.c_void_p) + self._pending_reg_lists[result.value] = (result, reg_buf) + return result.value + except: + log_error_for_exception("Unhandled Python exception in CallingConvention._get_required_clobbered_regs") + count[0] = 0 + return None + def _free_register_list(self, ctxt, regs, count): try: buf = ctypes.cast(regs, ctypes.c_void_p) diff --git a/rust/src/calling_convention.rs b/rust/src/calling_convention.rs index 4dc54a5b7a..a2c4fd9418 100644 --- a/rust/src/calling_convention.rs +++ b/rust/src/calling_convention.rs @@ -39,6 +39,12 @@ pub trait CallingConvention: Sync { fn callee_saved_registers(&self) -> Vec; fn int_arg_registers(&self) -> Vec; fn float_arg_registers(&self) -> Vec; + fn required_argument_registers(&self) -> Vec { + Vec::new() + } + fn required_clobbered_registers(&self) -> Vec { + Vec::new() + } fn arg_registers_shared_index(&self) -> bool; fn reserved_stack_space_for_arg_registers(&self) -> bool; @@ -163,6 +169,54 @@ where }) } + extern "C" fn cb_required_argument_registers( + ctxt: *mut c_void, + count: *mut usize, + ) -> *mut u32 + where + C: CallingConvention, + { + ffi_wrap!("CallingConvention::required_argument_registers", unsafe { + let ctxt = &*(ctxt as *mut CustomCallingConventionContext); + let mut regs: Vec<_> = ctxt + .cc + .required_argument_registers() + .iter() + .map(|r| r.0) + .collect(); + + // SAFETY: `count` is an out parameter + *count = regs.len(); + let regs_ptr = regs.as_mut_ptr(); + std::mem::forget(regs); + regs_ptr + }) + } + + extern "C" fn cb_required_clobbered_registers( + ctxt: *mut c_void, + count: *mut usize, + ) -> *mut u32 + where + C: CallingConvention, + { + ffi_wrap!("CallingConvention::required_clobbered_registers", unsafe { + let ctxt = &*(ctxt as *mut CustomCallingConventionContext); + let mut regs: Vec<_> = ctxt + .cc + .required_clobbered_registers() + .iter() + .map(|r| r.0) + .collect(); + + // SAFETY: `count` is an out parameter + *count = regs.len(); + let regs_ptr = regs.as_mut_ptr(); + std::mem::forget(regs); + regs_ptr + }) + } + extern "C" fn cb_arg_shared_index(ctxt: *mut c_void) -> bool where C: CallingConvention, @@ -390,6 +444,8 @@ where getCalleeSavedRegisters: Some(cb_callee_saved::), getIntegerArgumentRegisters: Some(cb_int_args::), getFloatArgumentRegisters: Some(cb_float_args::), + getRequiredArgumentRegisters: Some(cb_required_argument_registers::), + getRequiredClobberedRegisters: Some(cb_required_clobbered_registers::), freeRegisterList: Some(cb_free_register_list), areArgumentRegistersSharedIndex: Some(cb_arg_shared_index::), @@ -519,6 +575,14 @@ impl Debug for CoreCallingConvention { .field("callee_saved_registers", &self.callee_saved_registers()) .field("int_arg_registers", &self.int_arg_registers()) .field("float_arg_registers", &self.float_arg_registers()) + .field( + "required_argument_registers", + &self.required_argument_registers(), + ) + .field( + "required_clobbered_registers", + &self.required_clobbered_registers(), + ) .field( "arg_registers_shared_index", &self.arg_registers_shared_index(), @@ -611,6 +675,34 @@ impl CallingConvention for CoreCallingConvention { } } + fn required_argument_registers(&self) -> Vec { + unsafe { + let mut count = 0; + let regs_ptr = BNGetRequiredArgumentRegisters(self.handle, &mut count); + let regs: Vec = std::slice::from_raw_parts(regs_ptr, count) + .iter() + .copied() + .map(RegisterId::from) + .collect(); + BNFreeRegisterList(regs_ptr); + regs + } + } + + fn required_clobbered_registers(&self) -> Vec { + unsafe { + let mut count = 0; + let regs_ptr = BNGetRequiredClobberedRegisters(self.handle, &mut count); + let regs: Vec = std::slice::from_raw_parts(regs_ptr, count) + .iter() + .copied() + .map(RegisterId::from) + .collect(); + BNFreeRegisterList(regs_ptr); + regs + } + } + fn arg_registers_shared_index(&self) -> bool { unsafe { BNAreArgumentRegistersSharedIndex(self.handle) } } @@ -738,6 +830,8 @@ pub struct ConventionBuilder { callee_saved_registers: Vec, int_arg_registers: Vec, float_arg_registers: Vec, + required_argument_registers: Vec, + required_clobbered_registers: Vec, arg_registers_shared_index: bool, reserved_stack_space_for_arg_registers: bool, @@ -807,6 +901,8 @@ impl ConventionBuilder { callee_saved_registers: Vec::new(), int_arg_registers: Vec::new(), float_arg_registers: Vec::new(), + required_argument_registers: Vec::new(), + required_clobbered_registers: Vec::new(), arg_registers_shared_index: false, reserved_stack_space_for_arg_registers: false, @@ -832,6 +928,8 @@ impl ConventionBuilder { reg_list!(callee_saved_registers); reg_list!(int_arg_registers); reg_list!(float_arg_registers); + reg_list!(required_argument_registers); + reg_list!(required_clobbered_registers); bool_arg!(arg_registers_shared_index); bool_arg!(reserved_stack_space_for_arg_registers); @@ -871,6 +969,14 @@ impl CallingConvention for ConventionBuilder { self.float_arg_registers.clone() } + fn required_argument_registers(&self) -> Vec { + self.required_argument_registers.clone() + } + + fn required_clobbered_registers(&self) -> Vec { + self.required_clobbered_registers.clone() + } + fn arg_registers_shared_index(&self) -> bool { self.arg_registers_shared_index }