From 91b3e8241acc8c8a7b4cc4f67633486965852943 Mon Sep 17 00:00:00 2001 From: Jeff Garzik Date: Sun, 7 Dec 2025 19:51:05 -0500 Subject: [PATCH 1/8] [cc] backend regalloc refactor --- cc/arch/aarch64/codegen.rs | 953 +----------------------------------- cc/arch/aarch64/lir.rs | 2 +- cc/arch/aarch64/mod.rs | 1 + cc/arch/aarch64/regalloc.rs | 919 ++++++++++++++++++++++++++++++++++ cc/arch/x86_64/codegen.rs | 813 +----------------------------- cc/arch/x86_64/lir.rs | 2 +- cc/arch/x86_64/mod.rs | 1 + cc/arch/x86_64/regalloc.rs | 785 +++++++++++++++++++++++++++++ 8 files changed, 1710 insertions(+), 1766 deletions(-) create mode 100644 cc/arch/aarch64/regalloc.rs create mode 100644 cc/arch/x86_64/regalloc.rs diff --git a/cc/arch/aarch64/codegen.rs b/cc/arch/aarch64/codegen.rs index a9a8354f..1da18f10 100644 --- a/cc/arch/aarch64/codegen.rs +++ b/cc/arch/aarch64/codegen.rs @@ -18,6 +18,7 @@ // use crate::arch::aarch64::lir::{Aarch64Inst, CallTarget, Cond, GpOperand, MemAddr}; +use crate::arch::aarch64::regalloc::{Loc, Reg, RegAlloc, VReg}; use crate::arch::codegen::CodeGenerator; use crate::arch::lir::{Directive, FpSize, Label, OperandSize, Symbol}; use crate::arch::DEFAULT_LIR_BUFFER_CAPACITY; @@ -26,958 +27,6 @@ use crate::target::Target; use crate::types::{TypeId, TypeModifiers, TypeTable}; use std::collections::HashMap; -// ============================================================================ -// AArch64 Register Definitions -// ============================================================================ - -/// AArch64 physical registers -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum Reg { - // General purpose registers x0-x28 - X0, - X1, - X2, - X3, - X4, - X5, - X6, - X7, - X8, - X9, - X10, - X11, - X12, - X13, - X14, - X15, - X16, - X17, - X18, - X19, - X20, - X21, - X22, - X23, - X24, - X25, - X26, - X27, - X28, - // Frame pointer - X29, - // Link register - X30, - // Stack pointer (special, shares encoding with XZR in some contexts) - SP, -} - -impl Reg { - /// Get 64-bit register name - pub fn name64(&self) -> &'static str { - match self { - Reg::X0 => "x0", - Reg::X1 => "x1", - Reg::X2 => "x2", - Reg::X3 => "x3", - Reg::X4 => "x4", - Reg::X5 => "x5", - Reg::X6 => "x6", - Reg::X7 => "x7", - Reg::X8 => "x8", - Reg::X9 => "x9", - Reg::X10 => "x10", - Reg::X11 => "x11", - Reg::X12 => "x12", - Reg::X13 => "x13", - Reg::X14 => "x14", - Reg::X15 => "x15", - Reg::X16 => "x16", - Reg::X17 => "x17", - Reg::X18 => "x18", - Reg::X19 => "x19", - Reg::X20 => "x20", - Reg::X21 => "x21", - Reg::X22 => "x22", - Reg::X23 => "x23", - Reg::X24 => "x24", - Reg::X25 => "x25", - Reg::X26 => "x26", - Reg::X27 => "x27", - Reg::X28 => "x28", - Reg::X29 => "x29", - Reg::X30 => "x30", - Reg::SP => "sp", - } - } - - /// Get 32-bit register name - pub fn name32(&self) -> &'static str { - match self { - Reg::X0 => "w0", - Reg::X1 => "w1", - Reg::X2 => "w2", - Reg::X3 => "w3", - Reg::X4 => "w4", - Reg::X5 => "w5", - Reg::X6 => "w6", - Reg::X7 => "w7", - Reg::X8 => "w8", - Reg::X9 => "w9", - Reg::X10 => "w10", - Reg::X11 => "w11", - Reg::X12 => "w12", - Reg::X13 => "w13", - Reg::X14 => "w14", - Reg::X15 => "w15", - Reg::X16 => "w16", - Reg::X17 => "w17", - Reg::X18 => "w18", - Reg::X19 => "w19", - Reg::X20 => "w20", - Reg::X21 => "w21", - Reg::X22 => "w22", - Reg::X23 => "w23", - Reg::X24 => "w24", - Reg::X25 => "w25", - Reg::X26 => "w26", - Reg::X27 => "w27", - Reg::X28 => "w28", - Reg::X29 => "w29", - Reg::X30 => "w30", - Reg::SP => "sp", // SP doesn't have a 32-bit form in normal use - } - } - - /// Get register name for a given bit size - pub fn name_for_size(&self, bits: u32) -> &'static str { - match bits { - 8 | 16 | 32 => self.name32(), - _ => self.name64(), - } - } - - /// Is this a callee-saved register? - pub fn is_callee_saved(&self) -> bool { - matches!( - self, - Reg::X19 - | Reg::X20 - | Reg::X21 - | Reg::X22 - | Reg::X23 - | Reg::X24 - | Reg::X25 - | Reg::X26 - | Reg::X27 - | Reg::X28 - ) - } - - /// Argument registers in order (AAPCS64) - pub fn arg_regs() -> &'static [Reg] { - &[ - Reg::X0, - Reg::X1, - Reg::X2, - Reg::X3, - Reg::X4, - Reg::X5, - Reg::X6, - Reg::X7, - ] - } - - /// All allocatable registers - /// Excludes: x8 (indirect result), x16/x17 (linker scratch), - /// x18 (platform), x29 (fp), x30 (lr), sp - pub fn allocatable() -> &'static [Reg] { - &[ - Reg::X0, - Reg::X1, - Reg::X2, - Reg::X3, - Reg::X4, - Reg::X5, - Reg::X6, - Reg::X7, - // Skip x8 (indirect result register for large struct returns per AAPCS64) - Reg::X9, - Reg::X10, - Reg::X11, - Reg::X12, - Reg::X13, - Reg::X14, - Reg::X15, - // Skip x16, x17 (linker scratch) - // Skip x18 (platform reserved) - Reg::X19, - Reg::X20, - Reg::X21, - Reg::X22, - Reg::X23, - Reg::X24, - Reg::X25, - Reg::X26, - Reg::X27, - Reg::X28, - // Skip x29 (fp), x30 (lr) - ] - } - - /// Scratch registers for codegen (not allocatable, used for temporaries) - /// x16 (IP0) and x17 (IP1) are linker scratch registers - pub fn scratch_regs() -> (Reg, Reg) { - (Reg::X16, Reg::X17) - } - - /// Frame pointer register - pub fn fp() -> Reg { - Reg::X29 - } - - /// Link register - pub fn lr() -> Reg { - Reg::X30 - } - - /// Stack pointer register - pub fn sp() -> Reg { - Reg::SP - } - - /// Platform reserved register (x18) - /// Should not be used; this is only for documentation and completeness - pub fn platform_reserved() -> Reg { - Reg::X18 - } -} - -// ============================================================================ -// AArch64 Floating-Point Register Definitions -// ============================================================================ - -/// AArch64 SIMD/FP registers (V0-V31, accessed as D0-D31 for double, S0-S31 for float) -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum VReg { - V0, - V1, - V2, - V3, - V4, - V5, - V6, - V7, - V8, - V9, - V10, - V11, - V12, - V13, - V14, - V15, - V16, - V17, - V18, - V19, - V20, - V21, - V22, - V23, - V24, - V25, - V26, - V27, - V28, - V29, - V30, - V31, -} - -impl VReg { - /// Get 64-bit (double) register name - pub fn name_d(&self) -> &'static str { - match self { - VReg::V0 => "d0", - VReg::V1 => "d1", - VReg::V2 => "d2", - VReg::V3 => "d3", - VReg::V4 => "d4", - VReg::V5 => "d5", - VReg::V6 => "d6", - VReg::V7 => "d7", - VReg::V8 => "d8", - VReg::V9 => "d9", - VReg::V10 => "d10", - VReg::V11 => "d11", - VReg::V12 => "d12", - VReg::V13 => "d13", - VReg::V14 => "d14", - VReg::V15 => "d15", - VReg::V16 => "d16", - VReg::V17 => "d17", - VReg::V18 => "d18", - VReg::V19 => "d19", - VReg::V20 => "d20", - VReg::V21 => "d21", - VReg::V22 => "d22", - VReg::V23 => "d23", - VReg::V24 => "d24", - VReg::V25 => "d25", - VReg::V26 => "d26", - VReg::V27 => "d27", - VReg::V28 => "d28", - VReg::V29 => "d29", - VReg::V30 => "d30", - VReg::V31 => "d31", - } - } - - /// Get 32-bit (float) register name - pub fn name_s(&self) -> &'static str { - match self { - VReg::V0 => "s0", - VReg::V1 => "s1", - VReg::V2 => "s2", - VReg::V3 => "s3", - VReg::V4 => "s4", - VReg::V5 => "s5", - VReg::V6 => "s6", - VReg::V7 => "s7", - VReg::V8 => "s8", - VReg::V9 => "s9", - VReg::V10 => "s10", - VReg::V11 => "s11", - VReg::V12 => "s12", - VReg::V13 => "s13", - VReg::V14 => "s14", - VReg::V15 => "s15", - VReg::V16 => "s16", - VReg::V17 => "s17", - VReg::V18 => "s18", - VReg::V19 => "s19", - VReg::V20 => "s20", - VReg::V21 => "s21", - VReg::V22 => "s22", - VReg::V23 => "s23", - VReg::V24 => "s24", - VReg::V25 => "s25", - VReg::V26 => "s26", - VReg::V27 => "s27", - VReg::V28 => "s28", - VReg::V29 => "s29", - VReg::V30 => "s30", - VReg::V31 => "s31", - } - } - - /// Get register name for a given size (32 = float, 64 = double) - pub fn name_for_size(&self, bits: u32) -> &'static str { - if bits <= 32 { - self.name_s() - } else { - self.name_d() - } - } - - /// Is this a callee-saved FP register? - /// AAPCS64: v8-v15 (lower 64 bits) are callee-saved - pub fn is_callee_saved(&self) -> bool { - matches!( - self, - VReg::V8 - | VReg::V9 - | VReg::V10 - | VReg::V11 - | VReg::V12 - | VReg::V13 - | VReg::V14 - | VReg::V15 - ) - } - - /// FP argument registers in order (AAPCS64) - pub fn arg_regs() -> &'static [VReg] { - &[ - VReg::V0, - VReg::V1, - VReg::V2, - VReg::V3, - VReg::V4, - VReg::V5, - VReg::V6, - VReg::V7, - ] - } - - /// All allocatable FP registers - /// Note: V16, V17, V18 are reserved as scratch registers for codegen - pub fn allocatable() -> &'static [VReg] { - &[ - VReg::V0, - VReg::V1, - VReg::V2, - VReg::V3, - VReg::V4, - VReg::V5, - VReg::V6, - VReg::V7, - VReg::V8, - VReg::V9, - VReg::V10, - VReg::V11, - VReg::V12, - VReg::V13, - VReg::V14, - VReg::V15, - // V16, V17, V18 reserved for scratch - VReg::V19, - VReg::V20, - VReg::V21, - VReg::V22, - VReg::V23, - VReg::V24, - VReg::V25, - VReg::V26, - VReg::V27, - VReg::V28, - VReg::V29, - VReg::V30, - VReg::V31, - ] - } -} - -// ============================================================================ -// Operand - Location of a value (register or memory) -// ============================================================================ - -/// Location of a value -#[derive(Debug, Clone)] -pub enum Loc { - /// In a register - Reg(Reg), - /// In a floating-point register - VReg(VReg), - /// On the stack at [sp + offset] (positive offset from sp after allocation) - Stack(i32), - /// Immediate constant - Imm(i64), - /// Floating-point immediate constant - FImm(f64), - /// Global symbol - Global(String), -} - -// ============================================================================ -// Register Allocator (Linear Scan) -// ============================================================================ - -/// Live interval for a pseudo -#[derive(Debug, Clone)] -struct LiveInterval { - pseudo: PseudoId, - start: usize, - end: usize, -} - -/// Simple linear scan register allocator for AArch64 -pub struct RegAlloc { - /// Mapping from pseudo to location - locations: HashMap, - /// Free registers - free_regs: Vec, - /// Free FP registers - free_fp_regs: Vec, - /// Active intervals (sorted by end point) - active: Vec<(LiveInterval, Reg)>, - /// Active FP intervals (sorted by end point) - active_fp: Vec<(LiveInterval, VReg)>, - /// Next stack slot offset - stack_offset: i32, - /// Callee-saved registers that were used - used_callee_saved: Vec, - /// Callee-saved FP registers that were used - used_callee_saved_fp: Vec, - /// Set of pseudos that need FP registers (determined by analyzing the IR) - fp_pseudos: std::collections::HashSet, -} - -impl RegAlloc { - pub fn new() -> Self { - Self { - locations: HashMap::new(), - free_regs: Reg::allocatable().to_vec(), - free_fp_regs: VReg::allocatable().to_vec(), - active: Vec::new(), - active_fp: Vec::new(), - stack_offset: 0, - used_callee_saved: Vec::new(), - used_callee_saved_fp: Vec::new(), - fp_pseudos: std::collections::HashSet::new(), - } - } - - /// Perform register allocation for a function - pub fn allocate(&mut self, func: &Function, types: &TypeTable) -> HashMap { - // Reset state - self.locations.clear(); - self.free_regs = Reg::allocatable().to_vec(); - self.free_fp_regs = VReg::allocatable().to_vec(); - self.active.clear(); - self.active_fp.clear(); - self.stack_offset = 0; - self.used_callee_saved.clear(); - self.used_callee_saved_fp.clear(); - self.fp_pseudos.clear(); - - // Identify which pseudos need FP registers - self.identify_fp_pseudos(func, types); - - // Pre-allocate argument registers (integer and FP separately) - let int_arg_regs = Reg::arg_regs(); - let fp_arg_regs = VReg::arg_regs(); - let mut int_arg_idx = 0usize; - let mut fp_arg_idx = 0usize; - - // AAPCS64: Detect if there's a hidden return pointer for large struct returns. - // Unlike x86-64 where sret goes in RDI (first arg), AAPCS64 uses X8 (indirect - // result register) which does NOT shift other arguments. - let sret_pseudo = func - .pseudos - .iter() - .find(|p| matches!(p.kind, PseudoKind::Arg(0)) && p.name.as_deref() == Some("__sret")); - - // If there's a hidden return pointer, allocate X8 for it - if let Some(sret) = sret_pseudo { - self.locations.insert(sret.id, Loc::Reg(Reg::X8)); - // X8 is not in allocatable list, so no need to remove from free_regs - } - - // Offset for matching arg_idx: if sret exists, real params have arg_idx starting at 1 - let arg_idx_offset: u32 = if sret_pseudo.is_some() { 1 } else { 0 }; - - for (i, (_name, typ)) in func.params.iter().enumerate() { - // Find the pseudo for this argument - for pseudo in &func.pseudos { - if let PseudoKind::Arg(arg_idx) = pseudo.kind { - // Match by adjusted index (sret shifts arg_idx but not register allocation) - if arg_idx == (i as u32) + arg_idx_offset { - let is_fp = types.is_float(*typ); - if is_fp { - if fp_arg_idx < fp_arg_regs.len() { - self.locations - .insert(pseudo.id, Loc::VReg(fp_arg_regs[fp_arg_idx])); - self.free_fp_regs.retain(|&r| r != fp_arg_regs[fp_arg_idx]); - self.fp_pseudos.insert(pseudo.id); - } else { - // FP argument on stack - let offset = 16 + (fp_arg_idx - fp_arg_regs.len()) as i32 * 8; - self.locations.insert(pseudo.id, Loc::Stack(offset)); - self.fp_pseudos.insert(pseudo.id); - } - fp_arg_idx += 1; - } else { - if int_arg_idx < int_arg_regs.len() { - self.locations - .insert(pseudo.id, Loc::Reg(int_arg_regs[int_arg_idx])); - self.free_regs.retain(|&r| r != int_arg_regs[int_arg_idx]); - } else { - // Integer argument on stack - let offset = 16 + (int_arg_idx - int_arg_regs.len()) as i32 * 8; - self.locations.insert(pseudo.id, Loc::Stack(offset)); - } - int_arg_idx += 1; - } - break; - } - } - } - } - - // Compute live intervals - let intervals = self.compute_live_intervals(func); - - // Allocate registers using linear scan - for interval in intervals { - self.expire_old_intervals(interval.start); - - // Check if already allocated (e.g., arguments) - if self.locations.contains_key(&interval.pseudo) { - continue; - } - - // Check if this is a constant - constants don't need registers - if let Some(pseudo) = func.pseudos.iter().find(|p| p.id == interval.pseudo) { - match &pseudo.kind { - PseudoKind::Val(v) => { - self.locations.insert(interval.pseudo, Loc::Imm(*v)); - continue; - } - PseudoKind::FVal(v) => { - self.locations.insert(interval.pseudo, Loc::FImm(*v)); - self.fp_pseudos.insert(interval.pseudo); - continue; - } - PseudoKind::Sym(name) => { - // Check if this is a local variable or a global symbol - if let Some(local) = func.locals.get(name) { - // Local variable - allocate stack space based on actual type size - let size = (types.size_bits(local.typ) / 8).max(8) as i32; - // Align to 8 bytes - let aligned_size = (size + 7) & !7; - self.stack_offset += aligned_size; - self.locations - .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); - } else { - // Global symbol - self.locations - .insert(interval.pseudo, Loc::Global(name.clone())); - } - continue; - } - _ => {} - } - } - - // Check if this pseudo needs an FP register - let needs_fp = self.fp_pseudos.contains(&interval.pseudo); - - if needs_fp { - // Try to allocate an FP register - if let Some(reg) = self.free_fp_regs.pop() { - if reg.is_callee_saved() && !self.used_callee_saved_fp.contains(®) { - self.used_callee_saved_fp.push(reg); - } - self.locations.insert(interval.pseudo, Loc::VReg(reg)); - self.active_fp.push((interval.clone(), reg)); - self.active_fp.sort_by_key(|(i, _)| i.end); - } else { - // Spill to stack - self.stack_offset += 8; - self.locations - .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); - } - } else { - // Try to allocate an integer register - if let Some(reg) = self.free_regs.pop() { - if reg.is_callee_saved() && !self.used_callee_saved.contains(®) { - self.used_callee_saved.push(reg); - } - self.locations.insert(interval.pseudo, Loc::Reg(reg)); - self.active.push((interval.clone(), reg)); - self.active.sort_by_key(|(i, _)| i.end); - } else { - // Spill to stack - self.stack_offset += 8; - self.locations - .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); - } - } - } - - self.locations.clone() - } - - /// Identify which pseudos need FP registers by scanning the IR - fn identify_fp_pseudos(&mut self, func: &Function, types: &TypeTable) { - for block in &func.blocks { - for insn in &block.insns { - // Check if this instruction produces a floating-point result - let is_fp_op = matches!( - insn.op, - Opcode::FAdd - | Opcode::FSub - | Opcode::FMul - | Opcode::FDiv - | Opcode::FNeg - | Opcode::FCvtF - | Opcode::UCvtF - | Opcode::SCvtF - ); - - if is_fp_op { - if let Some(target) = insn.target { - self.fp_pseudos.insert(target); - } - } - - // Also check if the type is floating point - // (but exclude comparisons which always produce int regardless of operand type) - if let Some(typ) = insn.typ { - if types.is_float(typ) - && !matches!( - insn.op, - Opcode::FCmpOEq - | Opcode::FCmpONe - | Opcode::FCmpOLt - | Opcode::FCmpOLe - | Opcode::FCmpOGt - | Opcode::FCmpOGe - ) - { - if let Some(target) = insn.target { - self.fp_pseudos.insert(target); - } - } - } - } - } - - // Also mark pseudos that are FVal constants - for pseudo in &func.pseudos { - if matches!(pseudo.kind, PseudoKind::FVal(_)) { - self.fp_pseudos.insert(pseudo.id); - } - } - } - - fn expire_old_intervals(&mut self, point: usize) { - // Expire integer register intervals - let mut to_remove = Vec::new(); - for (i, (interval, reg)) in self.active.iter().enumerate() { - if interval.end < point { - self.free_regs.push(*reg); - to_remove.push(i); - } - } - // Remove in reverse order to preserve indices - for i in to_remove.into_iter().rev() { - self.active.remove(i); - } - - // Expire FP register intervals - let mut to_remove_fp = Vec::new(); - for (i, (interval, reg)) in self.active_fp.iter().enumerate() { - if interval.end < point { - self.free_fp_regs.push(*reg); - to_remove_fp.push(i); - } - } - for i in to_remove_fp.into_iter().rev() { - self.active_fp.remove(i); - } - } - - fn compute_live_intervals(&self, func: &Function) -> Vec { - use crate::ir::BasicBlockId; - - // Track earliest definition, latest definition, and latest use for each pseudo - struct IntervalInfo { - pseudo: PseudoId, - first_def: usize, - last_def: usize, - last_use: usize, - } - - let mut intervals: HashMap = HashMap::new(); - let mut pos = 0usize; - - // First pass: compute block start and end positions - let mut block_start_pos: HashMap = HashMap::new(); - let mut block_end_pos: HashMap = HashMap::new(); - let mut temp_pos = 0usize; - for block in &func.blocks { - block_start_pos.insert(block.id, temp_pos); - temp_pos += block.insns.len(); - block_end_pos.insert(block.id, temp_pos.saturating_sub(1)); - } - - // Collect phi sources with their source blocks for later processing - // Also track phi targets - they need intervals extended to cover all their copy definitions - let mut phi_sources: Vec<(BasicBlockId, PseudoId)> = Vec::new(); - let mut phi_targets: Vec<(BasicBlockId, PseudoId)> = Vec::new(); - - for block in &func.blocks { - for insn in &block.insns { - // Definition point - track both first and last definition - // This is important because phi elimination creates multiple definitions - // of the same pseudo (via Copy instructions in predecessor blocks) - if let Some(target) = insn.target { - intervals - .entry(target) - .and_modify(|info| { - info.first_def = info.first_def.min(pos); - info.last_def = info.last_def.max(pos); - }) - .or_insert(IntervalInfo { - pseudo: target, - first_def: pos, - last_def: pos, - last_use: pos, - }); - } - - // Use points - for &src in &insn.src { - if let Some(info) = intervals.get_mut(&src) { - info.last_use = info.last_use.max(pos); - } else { - intervals.insert( - src, - IntervalInfo { - pseudo: src, - first_def: pos, - last_def: pos, - last_use: pos, - }, - ); - } - } - - // Collect phi sources - they need to be live at the END of their source block - // Also collect phi targets - they have Copy definitions in each source block - for (src_bb, pseudo) in &insn.phi_list { - phi_sources.push((*src_bb, *pseudo)); - // The phi target (if present) is defined via Copy at the end of each source block - if let Some(target) = insn.target { - phi_targets.push((*src_bb, target)); - } - } - - pos += 1; - } - } - - // Process phi sources: extend their live interval to the end of their source block - // This is critical for loops where phi sources come from later blocks via back edges - for (src_bb, pseudo) in phi_sources { - if let Some(&end_pos) = block_end_pos.get(&src_bb) { - if let Some(info) = intervals.get_mut(&pseudo) { - info.last_use = info.last_use.max(end_pos); - info.last_def = info.last_def.max(end_pos); - } else { - intervals.insert( - pseudo, - IntervalInfo { - pseudo, - first_def: end_pos, - last_def: end_pos, - last_use: end_pos, - }, - ); - } - } - } - - // Process phi targets: extend their live interval to cover all Copy definitions - // The phi target is defined via Copy at the end of each source block - // For loops, this means the target must be live until the last Copy (at the loop back edge) - for (src_bb, target) in phi_targets { - if let Some(&end_pos) = block_end_pos.get(&src_bb) { - if let Some(info) = intervals.get_mut(&target) { - info.last_def = info.last_def.max(end_pos); - } else { - intervals.insert( - target, - IntervalInfo { - pseudo: target, - first_def: end_pos, - last_def: end_pos, - last_use: end_pos, - }, - ); - } - } - } - - // Detect loops via back edges and extend lifetimes of variables used in loop bodies - // A back edge is a branch from a block to an earlier block (lower start position) - // Any variable used in a block that's part of a loop must be live until the back edge - let mut loop_back_edges: Vec<(BasicBlockId, BasicBlockId, usize)> = Vec::new(); // (from, to, from_end_pos) - for block in &func.blocks { - // Check the last instruction for branch targets - if let Some(last_insn) = block.insns.last() { - // Collect branch targets (bb_true for unconditional/conditional, bb_false for conditional) - let mut targets = Vec::new(); - if let Some(target) = last_insn.bb_true { - targets.push(target); - } - if let Some(target) = last_insn.bb_false { - targets.push(target); - } - - // Check if any target is a back edge (to an earlier block) - let from_start = block_start_pos.get(&block.id).copied().unwrap_or(0); - for target_bb in targets { - let target_start = block_start_pos.get(&target_bb).copied().unwrap_or(0); - if target_start < from_start { - // This is a back edge - the loop spans from target_bb to block - let from_end = block_end_pos.get(&block.id).copied().unwrap_or(0); - loop_back_edges.push((block.id, target_bb, from_end)); - } - } - } - } - - // For each loop, extend lifetimes of variables used within the loop - for (_from_bb, to_bb, back_edge_pos) in &loop_back_edges { - let loop_start = block_start_pos.get(to_bb).copied().unwrap_or(0); - - // Extend any variable that: - // 1. Is defined before the loop (first_def < loop_start) - // 2. Is used inside the loop (last_use >= loop_start && last_use <= back_edge_pos) - for info in intervals.values_mut() { - if info.first_def < loop_start - && info.last_use >= loop_start - && info.last_use <= *back_edge_pos - { - // This variable is used inside the loop but defined outside - // Extend its lifetime to the back edge - info.last_use = info.last_use.max(*back_edge_pos); - } - } - } - - // Find the maximum position in the function - let max_pos = pos.saturating_sub(1); - - // Convert to LiveInterval - // The interval spans from first_def to max(last_def, last_use) - // This ensures that pseudos with multiple definitions (from phi copies) - // stay live across all their definitions - // - // IMPORTANT: For loop variables, if last_def > last_use (definition comes after use - // due to back edge), the variable must stay live until the end of the function - // because it will be used again in the next loop iteration. - let mut result: Vec<_> = intervals - .into_values() - .map(|info| { - let end = if info.last_def > info.last_use { - // Loop variable: definition after use means we wrap around - // Extend to end of function to ensure it stays live - max_pos - } else { - info.last_def.max(info.last_use) - }; - LiveInterval { - pseudo: info.pseudo, - start: info.first_def, - end, - } - }) - .collect(); - result.sort_by_key(|i| i.start); - result - } - - /// Get stack size needed (aligned to 16 bytes) - pub fn stack_size(&self) -> i32 { - // Align to 16 bytes - (self.stack_offset + 15) & !15 - } - - /// Get callee-saved registers that need to be preserved - pub fn callee_saved_used(&self) -> &[Reg] { - &self.used_callee_saved - } -} - -impl Default for RegAlloc { - fn default() -> Self { - Self::new() - } -} - // ============================================================================ // AArch64 Code Generator // ============================================================================ diff --git a/cc/arch/aarch64/lir.rs b/cc/arch/aarch64/lir.rs index a696522f..4d8f3310 100644 --- a/cc/arch/aarch64/lir.rs +++ b/cc/arch/aarch64/lir.rs @@ -13,7 +13,7 @@ // enabling peephole optimizations before final assembly emission. // -use super::codegen::{Reg, VReg}; +use super::regalloc::{Reg, VReg}; use crate::arch::lir::{Directive, EmitAsm, FpSize, Label, OperandSize, Symbol}; use crate::target::{Os, Target}; use std::fmt::{self, Write}; diff --git a/cc/arch/aarch64/mod.rs b/cc/arch/aarch64/mod.rs index 342feccc..fde8937f 100644 --- a/cc/arch/aarch64/mod.rs +++ b/cc/arch/aarch64/mod.rs @@ -12,5 +12,6 @@ pub mod codegen; pub mod lir; pub mod macros; +pub mod regalloc; pub use macros::get_macros; diff --git a/cc/arch/aarch64/regalloc.rs b/cc/arch/aarch64/regalloc.rs new file mode 100644 index 00000000..4083c51f --- /dev/null +++ b/cc/arch/aarch64/regalloc.rs @@ -0,0 +1,919 @@ +// +// Copyright (c) 2024 Jeff Garzik +// +// This file is part of the posixutils-rs project covered under +// the MIT License. For the full license text, please see the LICENSE +// file in the root directory of this project. +// SPDX-License-Identifier: MIT +// +// AArch64 Register Allocator +// Linear scan register allocation for AArch64 +// + +use crate::ir::{BasicBlockId, Function, Opcode, PseudoId, PseudoKind}; +use crate::types::TypeTable; +use std::collections::{HashMap, HashSet}; + +// ============================================================================ +// AArch64 Register Definitions +// ============================================================================ + +/// AArch64 physical registers +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Reg { + // General purpose registers x0-x28 + X0, + X1, + X2, + X3, + X4, + X5, + X6, + X7, + X8, + X9, + X10, + X11, + X12, + X13, + X14, + X15, + X16, + X17, + X18, + X19, + X20, + X21, + X22, + X23, + X24, + X25, + X26, + X27, + X28, + // Frame pointer + X29, + // Link register + X30, + // Stack pointer (special, shares encoding with XZR in some contexts) + SP, +} + +impl Reg { + /// Get 64-bit register name + pub fn name64(&self) -> &'static str { + match self { + Reg::X0 => "x0", + Reg::X1 => "x1", + Reg::X2 => "x2", + Reg::X3 => "x3", + Reg::X4 => "x4", + Reg::X5 => "x5", + Reg::X6 => "x6", + Reg::X7 => "x7", + Reg::X8 => "x8", + Reg::X9 => "x9", + Reg::X10 => "x10", + Reg::X11 => "x11", + Reg::X12 => "x12", + Reg::X13 => "x13", + Reg::X14 => "x14", + Reg::X15 => "x15", + Reg::X16 => "x16", + Reg::X17 => "x17", + Reg::X18 => "x18", + Reg::X19 => "x19", + Reg::X20 => "x20", + Reg::X21 => "x21", + Reg::X22 => "x22", + Reg::X23 => "x23", + Reg::X24 => "x24", + Reg::X25 => "x25", + Reg::X26 => "x26", + Reg::X27 => "x27", + Reg::X28 => "x28", + Reg::X29 => "x29", + Reg::X30 => "x30", + Reg::SP => "sp", + } + } + + /// Get 32-bit register name + pub fn name32(&self) -> &'static str { + match self { + Reg::X0 => "w0", + Reg::X1 => "w1", + Reg::X2 => "w2", + Reg::X3 => "w3", + Reg::X4 => "w4", + Reg::X5 => "w5", + Reg::X6 => "w6", + Reg::X7 => "w7", + Reg::X8 => "w8", + Reg::X9 => "w9", + Reg::X10 => "w10", + Reg::X11 => "w11", + Reg::X12 => "w12", + Reg::X13 => "w13", + Reg::X14 => "w14", + Reg::X15 => "w15", + Reg::X16 => "w16", + Reg::X17 => "w17", + Reg::X18 => "w18", + Reg::X19 => "w19", + Reg::X20 => "w20", + Reg::X21 => "w21", + Reg::X22 => "w22", + Reg::X23 => "w23", + Reg::X24 => "w24", + Reg::X25 => "w25", + Reg::X26 => "w26", + Reg::X27 => "w27", + Reg::X28 => "w28", + Reg::X29 => "w29", + Reg::X30 => "w30", + Reg::SP => "sp", // SP doesn't have a 32-bit form in normal use + } + } + + /// Get register name for a given bit size + pub fn name_for_size(&self, bits: u32) -> &'static str { + match bits { + 8 | 16 | 32 => self.name32(), + _ => self.name64(), + } + } + + /// Is this a callee-saved register? + pub fn is_callee_saved(&self) -> bool { + matches!( + self, + Reg::X19 + | Reg::X20 + | Reg::X21 + | Reg::X22 + | Reg::X23 + | Reg::X24 + | Reg::X25 + | Reg::X26 + | Reg::X27 + | Reg::X28 + ) + } + + /// Argument registers in order (AAPCS64) + pub fn arg_regs() -> &'static [Reg] { + &[ + Reg::X0, + Reg::X1, + Reg::X2, + Reg::X3, + Reg::X4, + Reg::X5, + Reg::X6, + Reg::X7, + ] + } + + /// All allocatable registers + /// Excludes: x8 (indirect result), x16/x17 (linker scratch), + /// x18 (platform), x29 (fp), x30 (lr), sp + pub fn allocatable() -> &'static [Reg] { + &[ + Reg::X0, + Reg::X1, + Reg::X2, + Reg::X3, + Reg::X4, + Reg::X5, + Reg::X6, + Reg::X7, + // Skip x8 (indirect result register for large struct returns per AAPCS64) + Reg::X9, + Reg::X10, + Reg::X11, + Reg::X12, + Reg::X13, + Reg::X14, + Reg::X15, + // Skip x16, x17 (linker scratch) + // Skip x18 (platform reserved) + Reg::X19, + Reg::X20, + Reg::X21, + Reg::X22, + Reg::X23, + Reg::X24, + Reg::X25, + Reg::X26, + Reg::X27, + Reg::X28, + // Skip x29 (fp), x30 (lr) + ] + } + + /// Scratch registers for codegen (not allocatable, used for temporaries) + /// x16 (IP0) and x17 (IP1) are linker scratch registers + pub fn scratch_regs() -> (Reg, Reg) { + (Reg::X16, Reg::X17) + } + + /// Frame pointer register + pub fn fp() -> Reg { + Reg::X29 + } + + /// Link register + pub fn lr() -> Reg { + Reg::X30 + } + + /// Stack pointer register + pub fn sp() -> Reg { + Reg::SP + } + + /// Platform reserved register (x18) + /// Should not be used; this is only for documentation and completeness + pub fn platform_reserved() -> Reg { + Reg::X18 + } +} + +// ============================================================================ +// AArch64 Floating-Point Register Definitions +// ============================================================================ + +/// AArch64 SIMD/FP registers (V0-V31, accessed as D0-D31 for double, S0-S31 for float) +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum VReg { + V0, + V1, + V2, + V3, + V4, + V5, + V6, + V7, + V8, + V9, + V10, + V11, + V12, + V13, + V14, + V15, + V16, + V17, + V18, + V19, + V20, + V21, + V22, + V23, + V24, + V25, + V26, + V27, + V28, + V29, + V30, + V31, +} + +impl VReg { + /// Get 64-bit (double) register name + pub fn name_d(&self) -> &'static str { + match self { + VReg::V0 => "d0", + VReg::V1 => "d1", + VReg::V2 => "d2", + VReg::V3 => "d3", + VReg::V4 => "d4", + VReg::V5 => "d5", + VReg::V6 => "d6", + VReg::V7 => "d7", + VReg::V8 => "d8", + VReg::V9 => "d9", + VReg::V10 => "d10", + VReg::V11 => "d11", + VReg::V12 => "d12", + VReg::V13 => "d13", + VReg::V14 => "d14", + VReg::V15 => "d15", + VReg::V16 => "d16", + VReg::V17 => "d17", + VReg::V18 => "d18", + VReg::V19 => "d19", + VReg::V20 => "d20", + VReg::V21 => "d21", + VReg::V22 => "d22", + VReg::V23 => "d23", + VReg::V24 => "d24", + VReg::V25 => "d25", + VReg::V26 => "d26", + VReg::V27 => "d27", + VReg::V28 => "d28", + VReg::V29 => "d29", + VReg::V30 => "d30", + VReg::V31 => "d31", + } + } + + /// Get 32-bit (float) register name + pub fn name_s(&self) -> &'static str { + match self { + VReg::V0 => "s0", + VReg::V1 => "s1", + VReg::V2 => "s2", + VReg::V3 => "s3", + VReg::V4 => "s4", + VReg::V5 => "s5", + VReg::V6 => "s6", + VReg::V7 => "s7", + VReg::V8 => "s8", + VReg::V9 => "s9", + VReg::V10 => "s10", + VReg::V11 => "s11", + VReg::V12 => "s12", + VReg::V13 => "s13", + VReg::V14 => "s14", + VReg::V15 => "s15", + VReg::V16 => "s16", + VReg::V17 => "s17", + VReg::V18 => "s18", + VReg::V19 => "s19", + VReg::V20 => "s20", + VReg::V21 => "s21", + VReg::V22 => "s22", + VReg::V23 => "s23", + VReg::V24 => "s24", + VReg::V25 => "s25", + VReg::V26 => "s26", + VReg::V27 => "s27", + VReg::V28 => "s28", + VReg::V29 => "s29", + VReg::V30 => "s30", + VReg::V31 => "s31", + } + } + + /// Get register name for a given size (32 = float, 64 = double) + pub fn name_for_size(&self, bits: u32) -> &'static str { + if bits <= 32 { + self.name_s() + } else { + self.name_d() + } + } + + /// Is this a callee-saved FP register? + /// AAPCS64: v8-v15 (lower 64 bits) are callee-saved + pub fn is_callee_saved(&self) -> bool { + matches!( + self, + VReg::V8 + | VReg::V9 + | VReg::V10 + | VReg::V11 + | VReg::V12 + | VReg::V13 + | VReg::V14 + | VReg::V15 + ) + } + + /// FP argument registers in order (AAPCS64) + pub fn arg_regs() -> &'static [VReg] { + &[ + VReg::V0, + VReg::V1, + VReg::V2, + VReg::V3, + VReg::V4, + VReg::V5, + VReg::V6, + VReg::V7, + ] + } + + /// All allocatable FP registers + /// Note: V16, V17, V18 are reserved as scratch registers for codegen + pub fn allocatable() -> &'static [VReg] { + &[ + VReg::V0, + VReg::V1, + VReg::V2, + VReg::V3, + VReg::V4, + VReg::V5, + VReg::V6, + VReg::V7, + VReg::V8, + VReg::V9, + VReg::V10, + VReg::V11, + VReg::V12, + VReg::V13, + VReg::V14, + VReg::V15, + // V16, V17, V18 reserved for scratch + VReg::V19, + VReg::V20, + VReg::V21, + VReg::V22, + VReg::V23, + VReg::V24, + VReg::V25, + VReg::V26, + VReg::V27, + VReg::V28, + VReg::V29, + VReg::V30, + VReg::V31, + ] + } +} + +// ============================================================================ +// Operand - Location of a value (register or memory) +// ============================================================================ + +/// Location of a value +#[derive(Debug, Clone)] +pub enum Loc { + /// In a register + Reg(Reg), + /// In a floating-point register + VReg(VReg), + /// On the stack at [sp + offset] (positive offset from sp after allocation) + Stack(i32), + /// Immediate constant + Imm(i64), + /// Floating-point immediate constant + FImm(f64), + /// Global symbol + Global(String), +} + +// ============================================================================ +// Register Allocator (Linear Scan) +// ============================================================================ + +/// Live interval for a pseudo +#[derive(Debug, Clone)] +struct LiveInterval { + pseudo: PseudoId, + start: usize, + end: usize, +} + +/// Simple linear scan register allocator for AArch64 +pub struct RegAlloc { + /// Mapping from pseudo to location + locations: HashMap, + /// Free registers + free_regs: Vec, + /// Free FP registers + free_fp_regs: Vec, + /// Active intervals (sorted by end point) + active: Vec<(LiveInterval, Reg)>, + /// Active FP intervals (sorted by end point) + active_fp: Vec<(LiveInterval, VReg)>, + /// Next stack slot offset + stack_offset: i32, + /// Callee-saved registers that were used + used_callee_saved: Vec, + /// Callee-saved FP registers that were used + used_callee_saved_fp: Vec, + /// Set of pseudos that need FP registers (determined by analyzing the IR) + fp_pseudos: HashSet, +} + +impl RegAlloc { + pub fn new() -> Self { + Self { + locations: HashMap::new(), + free_regs: Reg::allocatable().to_vec(), + free_fp_regs: VReg::allocatable().to_vec(), + active: Vec::new(), + active_fp: Vec::new(), + stack_offset: 0, + used_callee_saved: Vec::new(), + used_callee_saved_fp: Vec::new(), + fp_pseudos: HashSet::new(), + } + } + + /// Perform register allocation for a function + pub fn allocate(&mut self, func: &Function, types: &TypeTable) -> HashMap { + self.reset_state(); + self.identify_fp_pseudos(func, types); + self.allocate_arguments(func, types); + + let intervals = self.compute_live_intervals(func); + self.run_linear_scan(func, types, intervals); + + self.locations.clone() + } + + /// Reset allocator state for a new function + fn reset_state(&mut self) { + self.locations.clear(); + self.free_regs = Reg::allocatable().to_vec(); + self.free_fp_regs = VReg::allocatable().to_vec(); + self.active.clear(); + self.active_fp.clear(); + self.stack_offset = 0; + self.used_callee_saved.clear(); + self.used_callee_saved_fp.clear(); + self.fp_pseudos.clear(); + } + + /// Pre-allocate argument registers per AAPCS64 + fn allocate_arguments(&mut self, func: &Function, types: &TypeTable) { + let int_arg_regs = Reg::arg_regs(); + let fp_arg_regs = VReg::arg_regs(); + let mut int_arg_idx = 0usize; + let mut fp_arg_idx = 0usize; + + // Detect hidden return pointer for large struct returns + let sret_pseudo = func + .pseudos + .iter() + .find(|p| matches!(p.kind, PseudoKind::Arg(0)) && p.name.as_deref() == Some("__sret")); + + // Allocate X8 for hidden return pointer if present + if let Some(sret) = sret_pseudo { + self.locations.insert(sret.id, Loc::Reg(Reg::X8)); + } + + let arg_idx_offset: u32 = if sret_pseudo.is_some() { 1 } else { 0 }; + + for (i, (_name, typ)) in func.params.iter().enumerate() { + for pseudo in &func.pseudos { + if let PseudoKind::Arg(arg_idx) = pseudo.kind { + if arg_idx == (i as u32) + arg_idx_offset { + let is_fp = types.is_float(*typ); + if is_fp { + if fp_arg_idx < fp_arg_regs.len() { + self.locations + .insert(pseudo.id, Loc::VReg(fp_arg_regs[fp_arg_idx])); + self.free_fp_regs.retain(|&r| r != fp_arg_regs[fp_arg_idx]); + self.fp_pseudos.insert(pseudo.id); + } else { + let offset = 16 + (fp_arg_idx - fp_arg_regs.len()) as i32 * 8; + self.locations.insert(pseudo.id, Loc::Stack(offset)); + self.fp_pseudos.insert(pseudo.id); + } + fp_arg_idx += 1; + } else { + if int_arg_idx < int_arg_regs.len() { + self.locations + .insert(pseudo.id, Loc::Reg(int_arg_regs[int_arg_idx])); + self.free_regs.retain(|&r| r != int_arg_regs[int_arg_idx]); + } else { + let offset = 16 + (int_arg_idx - int_arg_regs.len()) as i32 * 8; + self.locations.insert(pseudo.id, Loc::Stack(offset)); + } + int_arg_idx += 1; + } + break; + } + } + } + } + } + + /// Run the linear scan allocation algorithm + fn run_linear_scan( + &mut self, + func: &Function, + types: &TypeTable, + intervals: Vec, + ) { + for interval in intervals { + self.expire_old_intervals(interval.start); + + if self.locations.contains_key(&interval.pseudo) { + continue; + } + + // Handle constants and symbols + if let Some(pseudo) = func.pseudos.iter().find(|p| p.id == interval.pseudo) { + match &pseudo.kind { + PseudoKind::Val(v) => { + self.locations.insert(interval.pseudo, Loc::Imm(*v)); + continue; + } + PseudoKind::FVal(v) => { + self.locations.insert(interval.pseudo, Loc::FImm(*v)); + self.fp_pseudos.insert(interval.pseudo); + continue; + } + PseudoKind::Sym(name) => { + if let Some(local) = func.locals.get(name) { + let size = (types.size_bits(local.typ) / 8) as i32; + let size = size.max(8); + let aligned_size = (size + 7) & !7; + self.stack_offset += aligned_size; + self.locations + .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); + if types.is_float(local.typ) { + self.fp_pseudos.insert(interval.pseudo); + } + } else { + self.locations + .insert(interval.pseudo, Loc::Global(name.clone())); + } + continue; + } + _ => {} + } + } + + // Allocate register based on type + let needs_fp = self.fp_pseudos.contains(&interval.pseudo); + + if needs_fp { + if let Some(reg) = self.free_fp_regs.pop() { + if reg.is_callee_saved() && !self.used_callee_saved_fp.contains(®) { + self.used_callee_saved_fp.push(reg); + } + self.locations.insert(interval.pseudo, Loc::VReg(reg)); + self.active_fp.push((interval.clone(), reg)); + self.active_fp.sort_by_key(|(i, _)| i.end); + } else { + self.stack_offset += 8; + self.locations + .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); + } + } else if let Some(reg) = self.free_regs.pop() { + if reg.is_callee_saved() && !self.used_callee_saved.contains(®) { + self.used_callee_saved.push(reg); + } + self.locations.insert(interval.pseudo, Loc::Reg(reg)); + self.active.push((interval.clone(), reg)); + self.active.sort_by_key(|(i, _)| i.end); + } else { + self.stack_offset += 8; + self.locations + .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); + } + } + } + + /// Identify which pseudos need FP registers by scanning the IR + fn identify_fp_pseudos(&mut self, func: &Function, types: &TypeTable) { + for block in &func.blocks { + for insn in &block.insns { + let is_fp_op = matches!( + insn.op, + Opcode::FAdd + | Opcode::FSub + | Opcode::FMul + | Opcode::FDiv + | Opcode::FNeg + | Opcode::FCvtF + | Opcode::UCvtF + | Opcode::SCvtF + ); + + if is_fp_op { + if let Some(target) = insn.target { + self.fp_pseudos.insert(target); + } + } + + if let Some(typ) = insn.typ { + if types.is_float(typ) + && !matches!( + insn.op, + Opcode::FCmpOEq + | Opcode::FCmpONe + | Opcode::FCmpOLt + | Opcode::FCmpOLe + | Opcode::FCmpOGt + | Opcode::FCmpOGe + ) + { + if let Some(target) = insn.target { + self.fp_pseudos.insert(target); + } + } + } + } + } + + // Also mark pseudos that are FVal constants + for pseudo in &func.pseudos { + if matches!(pseudo.kind, PseudoKind::FVal(_)) { + self.fp_pseudos.insert(pseudo.id); + } + } + } + + fn expire_old_intervals(&mut self, point: usize) { + let mut to_remove = Vec::new(); + for (i, (interval, reg)) in self.active.iter().enumerate() { + if interval.end < point { + self.free_regs.push(*reg); + to_remove.push(i); + } + } + for i in to_remove.into_iter().rev() { + self.active.remove(i); + } + + let mut to_remove_fp = Vec::new(); + for (i, (interval, reg)) in self.active_fp.iter().enumerate() { + if interval.end < point { + self.free_fp_regs.push(*reg); + to_remove_fp.push(i); + } + } + for i in to_remove_fp.into_iter().rev() { + self.active_fp.remove(i); + } + } + + fn compute_live_intervals(&self, func: &Function) -> Vec { + struct IntervalInfo { + pseudo: PseudoId, + first_def: usize, + last_def: usize, + last_use: usize, + } + + let mut intervals: HashMap = HashMap::new(); + let mut pos = 0usize; + + let mut block_start_pos: HashMap = HashMap::new(); + let mut block_end_pos: HashMap = HashMap::new(); + let mut temp_pos = 0usize; + for block in &func.blocks { + block_start_pos.insert(block.id, temp_pos); + temp_pos += block.insns.len(); + block_end_pos.insert(block.id, temp_pos.saturating_sub(1)); + } + + let mut phi_sources: Vec<(BasicBlockId, PseudoId)> = Vec::new(); + let mut phi_targets: Vec<(BasicBlockId, PseudoId)> = Vec::new(); + + for block in &func.blocks { + for insn in &block.insns { + if let Some(target) = insn.target { + intervals + .entry(target) + .and_modify(|info| { + info.first_def = info.first_def.min(pos); + info.last_def = info.last_def.max(pos); + }) + .or_insert(IntervalInfo { + pseudo: target, + first_def: pos, + last_def: pos, + last_use: pos, + }); + } + + for &src in &insn.src { + if let Some(info) = intervals.get_mut(&src) { + info.last_use = info.last_use.max(pos); + } else { + intervals.insert( + src, + IntervalInfo { + pseudo: src, + first_def: pos, + last_def: pos, + last_use: pos, + }, + ); + } + } + + for (src_bb, pseudo) in &insn.phi_list { + phi_sources.push((*src_bb, *pseudo)); + if let Some(target) = insn.target { + phi_targets.push((*src_bb, target)); + } + } + + pos += 1; + } + } + + // Extend phi source intervals + for (src_bb, pseudo) in phi_sources { + if let Some(&end_pos) = block_end_pos.get(&src_bb) { + if let Some(info) = intervals.get_mut(&pseudo) { + info.last_use = info.last_use.max(end_pos); + info.last_def = info.last_def.max(end_pos); + } else { + intervals.insert( + pseudo, + IntervalInfo { + pseudo, + first_def: end_pos, + last_def: end_pos, + last_use: end_pos, + }, + ); + } + } + } + + // Extend phi target intervals + for (src_bb, target) in phi_targets { + if let Some(&end_pos) = block_end_pos.get(&src_bb) { + if let Some(info) = intervals.get_mut(&target) { + info.last_def = info.last_def.max(end_pos); + } else { + intervals.insert( + target, + IntervalInfo { + pseudo: target, + first_def: end_pos, + last_def: end_pos, + last_use: end_pos, + }, + ); + } + } + } + + // Handle loop back edges + let mut loop_back_edges: Vec<(BasicBlockId, BasicBlockId, usize)> = Vec::new(); + for block in &func.blocks { + if let Some(last_insn) = block.insns.last() { + let mut targets = Vec::new(); + if let Some(target) = last_insn.bb_true { + targets.push(target); + } + if let Some(target) = last_insn.bb_false { + targets.push(target); + } + + let from_start = block_start_pos.get(&block.id).copied().unwrap_or(0); + for target_bb in targets { + let target_start = block_start_pos.get(&target_bb).copied().unwrap_or(0); + if target_start < from_start { + let from_end = block_end_pos.get(&block.id).copied().unwrap_or(0); + loop_back_edges.push((block.id, target_bb, from_end)); + } + } + } + } + + // Extend lifetimes for loop variables + for (_from_bb, to_bb, back_edge_pos) in &loop_back_edges { + let loop_start = block_start_pos.get(to_bb).copied().unwrap_or(0); + + for info in intervals.values_mut() { + if info.first_def < loop_start + && info.last_use >= loop_start + && info.last_use <= *back_edge_pos + { + info.last_use = info.last_use.max(*back_edge_pos); + } + } + } + + let max_pos = pos.saturating_sub(1); + + let mut result: Vec<_> = intervals + .into_values() + .map(|info| { + let end = if info.last_def > info.last_use { + max_pos + } else { + info.last_def.max(info.last_use) + }; + LiveInterval { + pseudo: info.pseudo, + start: info.first_def, + end, + } + }) + .collect(); + result.sort_by_key(|i| i.start); + result + } + + /// Get stack size needed (aligned to 16 bytes) + pub fn stack_size(&self) -> i32 { + (self.stack_offset + 15) & !15 + } + + /// Get callee-saved registers that need to be preserved + pub fn callee_saved_used(&self) -> &[Reg] { + &self.used_callee_saved + } +} + +impl Default for RegAlloc { + fn default() -> Self { + Self::new() + } +} diff --git a/cc/arch/x86_64/codegen.rs b/cc/arch/x86_64/codegen.rs index e5a35bcc..1fd36a7c 100644 --- a/cc/arch/x86_64/codegen.rs +++ b/cc/arch/x86_64/codegen.rs @@ -19,824 +19,13 @@ use crate::arch::lir::{Directive, FpSize, Label, OperandSize, Symbol}; use crate::arch::x86_64::lir::{ CallTarget, GpOperand, IntCC, MemAddr, ShiftCount, X86Inst, XmmOperand, }; +use crate::arch::x86_64::regalloc::{Loc, Reg, RegAlloc, XmmReg}; use crate::arch::DEFAULT_LIR_BUFFER_CAPACITY; use crate::ir::{Function, Initializer, Instruction, Module, Opcode, Pseudo, PseudoId, PseudoKind}; use crate::target::Target; use crate::types::{TypeId, TypeModifiers, TypeTable}; use std::collections::HashMap; -// ============================================================================ -// x86-64 Register Definitions -// ============================================================================ - -/// x86-64 physical registers -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum Reg { - // 64-bit general purpose registers - Rax, - Rbx, - Rcx, - Rdx, - Rsi, - Rdi, - Rbp, - Rsp, - R8, - R9, - R10, - R11, - R12, - R13, - R14, - R15, -} - -impl Reg { - /// Get AT&T syntax name for 64-bit register - pub fn name64(&self) -> &'static str { - match self { - Reg::Rax => "%rax", - Reg::Rbx => "%rbx", - Reg::Rcx => "%rcx", - Reg::Rdx => "%rdx", - Reg::Rsi => "%rsi", - Reg::Rdi => "%rdi", - Reg::Rbp => "%rbp", - Reg::Rsp => "%rsp", - Reg::R8 => "%r8", - Reg::R9 => "%r9", - Reg::R10 => "%r10", - Reg::R11 => "%r11", - Reg::R12 => "%r12", - Reg::R13 => "%r13", - Reg::R14 => "%r14", - Reg::R15 => "%r15", - } - } - - /// Get AT&T syntax name for 32-bit register - pub fn name32(&self) -> &'static str { - match self { - Reg::Rax => "%eax", - Reg::Rbx => "%ebx", - Reg::Rcx => "%ecx", - Reg::Rdx => "%edx", - Reg::Rsi => "%esi", - Reg::Rdi => "%edi", - Reg::Rbp => "%ebp", - Reg::Rsp => "%esp", - Reg::R8 => "%r8d", - Reg::R9 => "%r9d", - Reg::R10 => "%r10d", - Reg::R11 => "%r11d", - Reg::R12 => "%r12d", - Reg::R13 => "%r13d", - Reg::R14 => "%r14d", - Reg::R15 => "%r15d", - } - } - - /// Get AT&T syntax name for 16-bit register - pub fn name16(&self) -> &'static str { - match self { - Reg::Rax => "%ax", - Reg::Rbx => "%bx", - Reg::Rcx => "%cx", - Reg::Rdx => "%dx", - Reg::Rsi => "%si", - Reg::Rdi => "%di", - Reg::Rbp => "%bp", - Reg::Rsp => "%sp", - Reg::R8 => "%r8w", - Reg::R9 => "%r9w", - Reg::R10 => "%r10w", - Reg::R11 => "%r11w", - Reg::R12 => "%r12w", - Reg::R13 => "%r13w", - Reg::R14 => "%r14w", - Reg::R15 => "%r15w", - } - } - - /// Get AT&T syntax name for 8-bit register (low byte) - pub fn name8(&self) -> &'static str { - match self { - Reg::Rax => "%al", - Reg::Rbx => "%bl", - Reg::Rcx => "%cl", - Reg::Rdx => "%dl", - Reg::Rsi => "%sil", - Reg::Rdi => "%dil", - Reg::Rbp => "%bpl", - Reg::Rsp => "%spl", - Reg::R8 => "%r8b", - Reg::R9 => "%r9b", - Reg::R10 => "%r10b", - Reg::R11 => "%r11b", - Reg::R12 => "%r12b", - Reg::R13 => "%r13b", - Reg::R14 => "%r14b", - Reg::R15 => "%r15b", - } - } - - /// Get register name for a given bit size - pub fn name_for_size(&self, bits: u32) -> &'static str { - match bits { - 8 => self.name8(), - 16 => self.name16(), - 32 => self.name32(), - _ => self.name64(), - } - } - - /// Is this a callee-saved register? - pub fn is_callee_saved(&self) -> bool { - matches!( - self, - Reg::Rbx | Reg::Rbp | Reg::R12 | Reg::R13 | Reg::R14 | Reg::R15 - ) - } - - /// Argument registers in order (System V AMD64 ABI) - pub fn arg_regs() -> &'static [Reg] { - &[Reg::Rdi, Reg::Rsi, Reg::Rdx, Reg::Rcx, Reg::R8, Reg::R9] - } - - /// All allocatable registers (excluding RSP, RBP, and R10) - /// R10 is reserved as scratch for division instructions - pub fn allocatable() -> &'static [Reg] { - &[ - Reg::Rax, - Reg::Rbx, - Reg::Rcx, - Reg::Rdx, - Reg::Rsi, - Reg::Rdi, - Reg::R8, - Reg::R9, - // R10 is reserved for division scratch - Reg::R11, - Reg::R12, - Reg::R13, - Reg::R14, - Reg::R15, - ] - } - - /// Stack pointer register - pub fn sp() -> Reg { - Reg::Rsp - } - - /// Base/frame pointer register - pub fn bp() -> Reg { - Reg::Rbp - } -} - -// ============================================================================ -// XMM Register Definitions (SSE/FP) -// ============================================================================ - -/// x86-64 XMM registers for floating-point operations -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum XmmReg { - Xmm0, - Xmm1, - Xmm2, - Xmm3, - Xmm4, - Xmm5, - Xmm6, - Xmm7, - Xmm8, - Xmm9, - Xmm10, - Xmm11, - Xmm12, - Xmm13, - Xmm14, - Xmm15, -} - -impl XmmReg { - /// Get AT&T syntax name for XMM register - pub fn name(&self) -> &'static str { - match self { - XmmReg::Xmm0 => "%xmm0", - XmmReg::Xmm1 => "%xmm1", - XmmReg::Xmm2 => "%xmm2", - XmmReg::Xmm3 => "%xmm3", - XmmReg::Xmm4 => "%xmm4", - XmmReg::Xmm5 => "%xmm5", - XmmReg::Xmm6 => "%xmm6", - XmmReg::Xmm7 => "%xmm7", - XmmReg::Xmm8 => "%xmm8", - XmmReg::Xmm9 => "%xmm9", - XmmReg::Xmm10 => "%xmm10", - XmmReg::Xmm11 => "%xmm11", - XmmReg::Xmm12 => "%xmm12", - XmmReg::Xmm13 => "%xmm13", - XmmReg::Xmm14 => "%xmm14", - XmmReg::Xmm15 => "%xmm15", - } - } - - /// Floating-point argument registers (System V AMD64 ABI) - pub fn arg_regs() -> &'static [XmmReg] { - &[ - XmmReg::Xmm0, - XmmReg::Xmm1, - XmmReg::Xmm2, - XmmReg::Xmm3, - XmmReg::Xmm4, - XmmReg::Xmm5, - XmmReg::Xmm6, - XmmReg::Xmm7, - ] - } - - /// All allocatable XMM registers - /// XMM0-XMM7 are caller-saved (scratch), XMM8-XMM15 are callee-saved on Windows but not on System V - /// XMM14 and XMM15 are reserved as scratch registers for codegen operations - pub fn allocatable() -> &'static [XmmReg] { - &[ - XmmReg::Xmm0, - XmmReg::Xmm1, - XmmReg::Xmm2, - XmmReg::Xmm3, - XmmReg::Xmm4, - XmmReg::Xmm5, - XmmReg::Xmm6, - XmmReg::Xmm7, - XmmReg::Xmm8, - XmmReg::Xmm9, - XmmReg::Xmm10, - XmmReg::Xmm11, - XmmReg::Xmm12, - XmmReg::Xmm13, - // XMM14 and XMM15 are reserved for scratch use in codegen - ] - } -} - -// ============================================================================ -// Operand - Location of a value (register or memory) -// ============================================================================ - -/// Location of a value -#[derive(Debug, Clone)] -pub enum Loc { - /// In a general-purpose register - Reg(Reg), - /// In an XMM register (floating-point) - Xmm(XmmReg), - /// On the stack at [rbp - offset] - Stack(i32), - /// Immediate integer constant - Imm(i64), - /// Immediate float constant (value, size in bits) - FImm(f64, u32), - /// Global symbol - Global(String), -} - -// ============================================================================ -// Register Allocator (Linear Scan) -// ============================================================================ - -/// Live interval for a pseudo -#[derive(Debug, Clone)] -struct LiveInterval { - pseudo: PseudoId, - start: usize, - end: usize, -} - -/// Simple linear scan register allocator for x86-64 -pub struct RegAlloc { - /// Mapping from pseudo to location - locations: HashMap, - /// Free general-purpose registers - free_regs: Vec, - /// Free XMM registers (for floating-point) - free_xmm_regs: Vec, - /// Active integer register intervals (sorted by end point) - active: Vec<(LiveInterval, Reg)>, - /// Active XMM register intervals (sorted by end point) - active_xmm: Vec<(LiveInterval, XmmReg)>, - /// Next stack slot offset - stack_offset: i32, - /// Callee-saved registers that were used - used_callee_saved: Vec, - /// Track which pseudos need FP registers (based on type) - fp_pseudos: std::collections::HashSet, -} - -impl RegAlloc { - pub fn new() -> Self { - Self { - locations: HashMap::new(), - free_regs: Reg::allocatable().to_vec(), - free_xmm_regs: XmmReg::allocatable().to_vec(), - active: Vec::new(), - active_xmm: Vec::new(), - stack_offset: 0, - used_callee_saved: Vec::new(), - fp_pseudos: std::collections::HashSet::new(), - } - } - - /// Perform register allocation for a function - pub fn allocate(&mut self, func: &Function, types: &TypeTable) -> HashMap { - // Reset state - self.locations.clear(); - self.free_regs = Reg::allocatable().to_vec(); - self.free_xmm_regs = XmmReg::allocatable().to_vec(); - self.active.clear(); - self.active_xmm.clear(); - self.stack_offset = 0; - self.used_callee_saved.clear(); - self.fp_pseudos.clear(); - - // Scan instructions to identify which pseudos need FP registers - self.identify_fp_pseudos(func, types); - - // Pre-allocate argument registers - // System V AMD64 ABI: integer args in RDI, RSI, RDX, RCX, R8, R9 - // FP args in XMM0-XMM7 - let int_arg_regs = Reg::arg_regs(); - let fp_arg_regs = XmmReg::arg_regs(); - let mut int_arg_idx = 0; - let mut fp_arg_idx = 0; - - // Detect if there's a hidden return pointer (for functions returning large structs) - // The __sret pseudo has arg_idx=0 and shifts all other arg indices by 1 - let sret_pseudo = func - .pseudos - .iter() - .find(|p| matches!(p.kind, PseudoKind::Arg(0)) && p.name.as_deref() == Some("__sret")); - let arg_idx_offset: u32 = if sret_pseudo.is_some() { 1 } else { 0 }; - - // If there's a hidden return pointer, allocate RDI for it - if let Some(sret) = sret_pseudo { - self.locations.insert(sret.id, Loc::Reg(int_arg_regs[0])); - self.free_regs.retain(|&r| r != int_arg_regs[0]); - int_arg_idx += 1; - } - - for (i, (_name, typ)) in func.params.iter().enumerate() { - // Find the pseudo for this argument - for pseudo in &func.pseudos { - if let PseudoKind::Arg(arg_idx) = pseudo.kind { - if arg_idx == (i as u32) + arg_idx_offset { - let is_fp = types.is_float(*typ); - if is_fp { - // FP argument - if fp_arg_idx < fp_arg_regs.len() { - self.locations - .insert(pseudo.id, Loc::Xmm(fp_arg_regs[fp_arg_idx])); - self.free_xmm_regs.retain(|&r| r != fp_arg_regs[fp_arg_idx]); - self.fp_pseudos.insert(pseudo.id); - } else { - // FP arg on stack - let offset = - 16 + (i - int_arg_regs.len() - fp_arg_regs.len()) as i32 * 8; - self.locations.insert(pseudo.id, Loc::Stack(-offset)); - } - fp_arg_idx += 1; - } else { - // Integer argument - if int_arg_idx < int_arg_regs.len() { - self.locations - .insert(pseudo.id, Loc::Reg(int_arg_regs[int_arg_idx])); - self.free_regs.retain(|&r| r != int_arg_regs[int_arg_idx]); - } else { - // Integer arg on stack - let offset = 16 + (i - int_arg_regs.len()) as i32 * 8; - self.locations.insert(pseudo.id, Loc::Stack(-offset)); - } - int_arg_idx += 1; - } - break; - } - } - } - } - - // Compute live intervals (simplified: just walk instructions) - let intervals = self.compute_live_intervals(func); - - // Find all call positions - arguments in caller-saved registers that live across - // calls need to be spilled to stack - let mut call_positions: Vec = Vec::new(); - let mut pos = 0usize; - for block in &func.blocks { - for insn in &block.insns { - if insn.op == Opcode::Call { - call_positions.push(pos); - } - pos += 1; - } - } - - // Check arguments in caller-saved registers - if their interval crosses a call, - // spill them to the stack instead - let int_arg_regs_set: Vec = Reg::arg_regs().to_vec(); - for interval in &intervals { - if let Some(Loc::Reg(reg)) = self.locations.get(&interval.pseudo) { - // Check if this is an argument register (caller-saved) - if int_arg_regs_set.contains(reg) { - // Check if interval crosses any call - let crosses_call = call_positions - .iter() - .any(|&call_pos| interval.start <= call_pos && call_pos < interval.end); - if crosses_call { - // Spill to stack - remove from register, allocate stack slot - let reg_to_restore = *reg; - self.stack_offset += 8; - self.locations - .insert(interval.pseudo, Loc::Stack(self.stack_offset)); - // Restore the register to free list for other uses - self.free_regs.push(reg_to_restore); - } - } - } - } - - // Force alloca results to the stack. Alloca results typically have long live - // ranges (used across multiple loop iterations) and putting them in registers - // leads to clobbering issues when the codegen needs temp registers. - for block in &func.blocks { - for insn in &block.insns { - if insn.op == Opcode::Alloca { - if let Some(target) = insn.target { - self.stack_offset += 8; - self.locations.insert(target, Loc::Stack(self.stack_offset)); - } - } - } - } - - // Allocate registers using linear scan - for interval in intervals { - self.expire_old_intervals(interval.start); - - // Check if already allocated (e.g., arguments) - if self.locations.contains_key(&interval.pseudo) { - continue; - } - - // Check if this is a constant - constants don't need registers - if let Some(pseudo) = func.pseudos.iter().find(|p| p.id == interval.pseudo) { - match &pseudo.kind { - PseudoKind::Val(v) => { - self.locations.insert(interval.pseudo, Loc::Imm(*v)); - continue; - } - PseudoKind::FVal(v) => { - // Float constants are stored as FImm - // Find the size from the SetVal instruction that defines this constant - let size = func - .blocks - .iter() - .flat_map(|b| &b.insns) - .find(|insn| { - insn.op == crate::ir::Opcode::SetVal - && insn.target == Some(interval.pseudo) - }) - .map(|insn| insn.size) - .unwrap_or(64); // Default to double if not found - self.locations.insert(interval.pseudo, Loc::FImm(*v, size)); - self.fp_pseudos.insert(interval.pseudo); - continue; - } - PseudoKind::Sym(name) => { - // Check if this is a local variable or a global symbol - if let Some(local_var) = func.locals.get(name) { - // Local variable - allocate stack space based on type size - let size = (types.size_bits(local_var.typ) / 8) as i32; - let size = std::cmp::max(size, 8); // Minimum 8 bytes - // Align to 8 bytes - let aligned_size = (size + 7) & !7; - self.stack_offset += aligned_size; - self.locations - .insert(interval.pseudo, Loc::Stack(self.stack_offset)); - // Mark as FP if the type is float - if types.is_float(local_var.typ) { - self.fp_pseudos.insert(interval.pseudo); - } - } else { - // Global symbol - self.locations - .insert(interval.pseudo, Loc::Global(name.clone())); - } - continue; - } - _ => {} - } - } - - // Determine if this pseudo needs an FP register - let needs_fp = self.fp_pseudos.contains(&interval.pseudo); - - if needs_fp { - // Try to allocate an XMM register - if let Some(xmm) = self.free_xmm_regs.pop() { - self.locations.insert(interval.pseudo, Loc::Xmm(xmm)); - self.active_xmm.push((interval.clone(), xmm)); - self.active_xmm.sort_by_key(|(i, _)| i.end); - } else { - // Spill to stack - self.stack_offset += 8; - self.locations - .insert(interval.pseudo, Loc::Stack(self.stack_offset)); - } - } else { - // Try to allocate a general-purpose register - if let Some(reg) = self.free_regs.pop() { - if reg.is_callee_saved() && !self.used_callee_saved.contains(®) { - self.used_callee_saved.push(reg); - } - self.locations.insert(interval.pseudo, Loc::Reg(reg)); - self.active.push((interval.clone(), reg)); - // Keep active sorted by end point - self.active.sort_by_key(|(i, _)| i.end); - } else { - // Spill to stack - self.stack_offset += 8; - self.locations - .insert(interval.pseudo, Loc::Stack(self.stack_offset)); - } - } - } - - self.locations.clone() - } - - /// Scan function to identify which pseudos need FP registers - fn identify_fp_pseudos(&mut self, func: &Function, types: &TypeTable) { - for block in &func.blocks { - for insn in &block.insns { - // Check instruction type - FP instructions produce FP results - let is_fp_op = matches!( - insn.op, - Opcode::FAdd - | Opcode::FSub - | Opcode::FMul - | Opcode::FDiv - | Opcode::FNeg - | Opcode::FCmpOEq - | Opcode::FCmpONe - | Opcode::FCmpOLt - | Opcode::FCmpOLe - | Opcode::FCmpOGt - | Opcode::FCmpOGe - | Opcode::UCvtF - | Opcode::SCvtF - | Opcode::FCvtF - ); - - // Mark target pseudo as FP if this is an FP operation - // (except comparisons which produce int) - if is_fp_op - && !matches!( - insn.op, - Opcode::FCmpOEq - | Opcode::FCmpONe - | Opcode::FCmpOLt - | Opcode::FCmpOLe - | Opcode::FCmpOGt - | Opcode::FCmpOGe - ) - { - if let Some(target) = insn.target { - self.fp_pseudos.insert(target); - } - } - - // Also check the type if available - // (but exclude comparisons which always produce int regardless of operand type) - if let Some(typ) = insn.typ { - if types.is_float(typ) - && !matches!( - insn.op, - Opcode::FCmpOEq - | Opcode::FCmpONe - | Opcode::FCmpOLt - | Opcode::FCmpOLe - | Opcode::FCmpOGt - | Opcode::FCmpOGe - ) - { - if let Some(target) = insn.target { - self.fp_pseudos.insert(target); - } - } - } - } - } - } - - fn expire_old_intervals(&mut self, point: usize) { - // Expire integer register intervals - let mut to_remove = Vec::new(); - for (i, (interval, reg)) in self.active.iter().enumerate() { - if interval.end < point { - self.free_regs.push(*reg); - to_remove.push(i); - } - } - // Remove in reverse order to preserve indices - for i in to_remove.into_iter().rev() { - self.active.remove(i); - } - - // Expire XMM register intervals - let mut to_remove_xmm = Vec::new(); - for (i, (interval, xmm)) in self.active_xmm.iter().enumerate() { - if interval.end < point { - self.free_xmm_regs.push(*xmm); - to_remove_xmm.push(i); - } - } - for i in to_remove_xmm.into_iter().rev() { - self.active_xmm.remove(i); - } - } - - fn compute_live_intervals(&self, func: &Function) -> Vec { - use crate::ir::BasicBlockId; - - // Track earliest definition, latest definition, and latest use for each pseudo - struct IntervalInfo { - pseudo: PseudoId, - first_def: usize, - last_def: usize, - last_use: usize, - } - - let mut intervals: HashMap = HashMap::new(); - let mut pos = 0usize; - - // First pass: compute block end positions - let mut block_end_pos: HashMap = HashMap::new(); - let mut temp_pos = 0usize; - for block in &func.blocks { - temp_pos += block.insns.len(); - block_end_pos.insert(block.id, temp_pos.saturating_sub(1)); - } - - // Collect phi sources with their source blocks for later processing - // Also track phi targets - they need intervals extended to cover all their copy definitions - let mut phi_sources: Vec<(BasicBlockId, PseudoId)> = Vec::new(); - let mut phi_targets: Vec<(BasicBlockId, PseudoId)> = Vec::new(); - - for block in &func.blocks { - for insn in &block.insns { - // Definition point - track both first and last definition - // This is important because phi elimination creates multiple definitions - // of the same pseudo (via Copy instructions in predecessor blocks) - if let Some(target) = insn.target { - intervals - .entry(target) - .and_modify(|info| { - info.first_def = info.first_def.min(pos); - info.last_def = info.last_def.max(pos); - }) - .or_insert(IntervalInfo { - pseudo: target, - first_def: pos, - last_def: pos, - last_use: pos, - }); - } - - // Use points - for &src in &insn.src { - if let Some(info) = intervals.get_mut(&src) { - info.last_use = info.last_use.max(pos); - } else { - intervals.insert( - src, - IntervalInfo { - pseudo: src, - first_def: pos, - last_def: pos, - last_use: pos, - }, - ); - } - } - - // Collect phi sources - they need to be live at the END of their source block - // Also collect phi targets - they have Copy definitions in each source block - for (src_bb, pseudo) in &insn.phi_list { - phi_sources.push((*src_bb, *pseudo)); - // The phi target (if present) is defined via Copy at the end of each source block - if let Some(target) = insn.target { - phi_targets.push((*src_bb, target)); - } - } - - pos += 1; - } - } - - // Process phi sources: extend their live interval to the end of their source block - // This is critical for loops where phi sources come from later blocks via back edges - // NOTE: We only extend last_use, NOT last_def, because phi sources are USED (not defined) - // at the copy point. Setting last_def incorrectly can confuse the interval calculation. - for (src_bb, pseudo) in phi_sources { - if let Some(&end_pos) = block_end_pos.get(&src_bb) { - if let Some(info) = intervals.get_mut(&pseudo) { - info.last_use = info.last_use.max(end_pos); - } else { - intervals.insert( - pseudo, - IntervalInfo { - pseudo, - first_def: end_pos, - last_def: end_pos, - last_use: end_pos, - }, - ); - } - } - } - - // Process phi targets: extend their live interval to cover all Copy definitions - // The phi target is defined via Copy at the end of each source block - // For loops, this means the target must be live until the last Copy (at the loop back edge) - for (src_bb, target) in phi_targets { - if let Some(&end_pos) = block_end_pos.get(&src_bb) { - if let Some(info) = intervals.get_mut(&target) { - info.last_def = info.last_def.max(end_pos); - } else { - intervals.insert( - target, - IntervalInfo { - pseudo: target, - first_def: end_pos, - last_def: end_pos, - last_use: end_pos, - }, - ); - } - } - } - - // Find the maximum position in the function - let max_pos = pos.saturating_sub(1); - - // Convert to LiveInterval - // The interval spans from first_def to max(last_def, last_use) - // This ensures that pseudos with multiple definitions (from phi copies) - // stay live across all their definitions - // - // IMPORTANT: For loop variables, if last_def > last_use (definition comes after use - // due to back edge), the variable must stay live until the end of the function - // because it will be used again in the next loop iteration. - let mut result: Vec<_> = intervals - .into_values() - .map(|info| { - let end = if info.last_def > info.last_use { - // Loop variable: definition after use means we wrap around - // Extend to end of function to ensure it stays live - max_pos - } else { - info.last_def.max(info.last_use) - }; - LiveInterval { - pseudo: info.pseudo, - start: info.first_def, - end, - } - }) - .collect(); - result.sort_by_key(|i| i.start); - result - } - - /// Get stack size needed (aligned to 16 bytes) - pub fn stack_size(&self) -> i32 { - // Align to 16 bytes - (self.stack_offset + 15) & !15 - } - - /// Get callee-saved registers that need to be preserved - pub fn callee_saved_used(&self) -> &[Reg] { - &self.used_callee_saved - } -} - -impl Default for RegAlloc { - fn default() -> Self { - Self::new() - } -} - // ============================================================================ // x86-64 Code Generator // ============================================================================ diff --git a/cc/arch/x86_64/lir.rs b/cc/arch/x86_64/lir.rs index 98afd458..d09ef7cd 100644 --- a/cc/arch/x86_64/lir.rs +++ b/cc/arch/x86_64/lir.rs @@ -13,7 +13,7 @@ // enabling peephole optimizations before final assembly emission. // -use super::codegen::{Reg, XmmReg}; +use super::regalloc::{Reg, XmmReg}; use crate::arch::lir::{Directive, EmitAsm, FpSize, Label, OperandSize, Symbol}; use crate::target::{Os, Target}; use std::fmt::{self, Write}; diff --git a/cc/arch/x86_64/mod.rs b/cc/arch/x86_64/mod.rs index feb19551..2a13e3d9 100644 --- a/cc/arch/x86_64/mod.rs +++ b/cc/arch/x86_64/mod.rs @@ -12,5 +12,6 @@ pub mod codegen; pub mod lir; pub mod macros; +pub mod regalloc; pub use macros::get_macros; diff --git a/cc/arch/x86_64/regalloc.rs b/cc/arch/x86_64/regalloc.rs new file mode 100644 index 00000000..1a3797a6 --- /dev/null +++ b/cc/arch/x86_64/regalloc.rs @@ -0,0 +1,785 @@ +// +// Copyright (c) 2024 Jeff Garzik +// +// This file is part of the posixutils-rs project covered under +// the MIT License. For the full license text, please see the LICENSE +// file in the root directory of this project. +// SPDX-License-Identifier: MIT +// +// x86-64 Register Allocator +// Linear scan register allocation for x86-64 +// + +use crate::ir::{Function, Opcode, PseudoId, PseudoKind}; +use crate::types::TypeTable; +use std::collections::{HashMap, HashSet}; + +// ============================================================================ +// x86-64 Register Definitions +// ============================================================================ + +/// x86-64 physical registers +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Reg { + // 64-bit general purpose registers + Rax, + Rbx, + Rcx, + Rdx, + Rsi, + Rdi, + Rbp, + Rsp, + R8, + R9, + R10, + R11, + R12, + R13, + R14, + R15, +} + +impl Reg { + /// Get AT&T syntax name for 64-bit register + pub fn name64(&self) -> &'static str { + match self { + Reg::Rax => "%rax", + Reg::Rbx => "%rbx", + Reg::Rcx => "%rcx", + Reg::Rdx => "%rdx", + Reg::Rsi => "%rsi", + Reg::Rdi => "%rdi", + Reg::Rbp => "%rbp", + Reg::Rsp => "%rsp", + Reg::R8 => "%r8", + Reg::R9 => "%r9", + Reg::R10 => "%r10", + Reg::R11 => "%r11", + Reg::R12 => "%r12", + Reg::R13 => "%r13", + Reg::R14 => "%r14", + Reg::R15 => "%r15", + } + } + + /// Get AT&T syntax name for 32-bit register + pub fn name32(&self) -> &'static str { + match self { + Reg::Rax => "%eax", + Reg::Rbx => "%ebx", + Reg::Rcx => "%ecx", + Reg::Rdx => "%edx", + Reg::Rsi => "%esi", + Reg::Rdi => "%edi", + Reg::Rbp => "%ebp", + Reg::Rsp => "%esp", + Reg::R8 => "%r8d", + Reg::R9 => "%r9d", + Reg::R10 => "%r10d", + Reg::R11 => "%r11d", + Reg::R12 => "%r12d", + Reg::R13 => "%r13d", + Reg::R14 => "%r14d", + Reg::R15 => "%r15d", + } + } + + /// Get AT&T syntax name for 16-bit register + pub fn name16(&self) -> &'static str { + match self { + Reg::Rax => "%ax", + Reg::Rbx => "%bx", + Reg::Rcx => "%cx", + Reg::Rdx => "%dx", + Reg::Rsi => "%si", + Reg::Rdi => "%di", + Reg::Rbp => "%bp", + Reg::Rsp => "%sp", + Reg::R8 => "%r8w", + Reg::R9 => "%r9w", + Reg::R10 => "%r10w", + Reg::R11 => "%r11w", + Reg::R12 => "%r12w", + Reg::R13 => "%r13w", + Reg::R14 => "%r14w", + Reg::R15 => "%r15w", + } + } + + /// Get AT&T syntax name for 8-bit register (low byte) + pub fn name8(&self) -> &'static str { + match self { + Reg::Rax => "%al", + Reg::Rbx => "%bl", + Reg::Rcx => "%cl", + Reg::Rdx => "%dl", + Reg::Rsi => "%sil", + Reg::Rdi => "%dil", + Reg::Rbp => "%bpl", + Reg::Rsp => "%spl", + Reg::R8 => "%r8b", + Reg::R9 => "%r9b", + Reg::R10 => "%r10b", + Reg::R11 => "%r11b", + Reg::R12 => "%r12b", + Reg::R13 => "%r13b", + Reg::R14 => "%r14b", + Reg::R15 => "%r15b", + } + } + + /// Get register name for a given bit size + pub fn name_for_size(&self, bits: u32) -> &'static str { + match bits { + 8 => self.name8(), + 16 => self.name16(), + 32 => self.name32(), + _ => self.name64(), + } + } + + /// Is this a callee-saved register? + pub fn is_callee_saved(&self) -> bool { + matches!( + self, + Reg::Rbx | Reg::Rbp | Reg::R12 | Reg::R13 | Reg::R14 | Reg::R15 + ) + } + + /// Argument registers in order (System V AMD64 ABI) + pub fn arg_regs() -> &'static [Reg] { + &[Reg::Rdi, Reg::Rsi, Reg::Rdx, Reg::Rcx, Reg::R8, Reg::R9] + } + + /// All allocatable registers (excluding RSP, RBP, and R10) + /// R10 is reserved as scratch for division instructions + pub fn allocatable() -> &'static [Reg] { + &[ + Reg::Rax, + Reg::Rbx, + Reg::Rcx, + Reg::Rdx, + Reg::Rsi, + Reg::Rdi, + Reg::R8, + Reg::R9, + // R10 is reserved for division scratch + Reg::R11, + Reg::R12, + Reg::R13, + Reg::R14, + Reg::R15, + ] + } + + /// Stack pointer register + pub fn sp() -> Reg { + Reg::Rsp + } + + /// Base/frame pointer register + pub fn bp() -> Reg { + Reg::Rbp + } +} + +// ============================================================================ +// XMM Register Definitions (SSE/FP) +// ============================================================================ + +/// x86-64 XMM registers for floating-point operations +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum XmmReg { + Xmm0, + Xmm1, + Xmm2, + Xmm3, + Xmm4, + Xmm5, + Xmm6, + Xmm7, + Xmm8, + Xmm9, + Xmm10, + Xmm11, + Xmm12, + Xmm13, + Xmm14, + Xmm15, +} + +impl XmmReg { + /// Get AT&T syntax name for XMM register + pub fn name(&self) -> &'static str { + match self { + XmmReg::Xmm0 => "%xmm0", + XmmReg::Xmm1 => "%xmm1", + XmmReg::Xmm2 => "%xmm2", + XmmReg::Xmm3 => "%xmm3", + XmmReg::Xmm4 => "%xmm4", + XmmReg::Xmm5 => "%xmm5", + XmmReg::Xmm6 => "%xmm6", + XmmReg::Xmm7 => "%xmm7", + XmmReg::Xmm8 => "%xmm8", + XmmReg::Xmm9 => "%xmm9", + XmmReg::Xmm10 => "%xmm10", + XmmReg::Xmm11 => "%xmm11", + XmmReg::Xmm12 => "%xmm12", + XmmReg::Xmm13 => "%xmm13", + XmmReg::Xmm14 => "%xmm14", + XmmReg::Xmm15 => "%xmm15", + } + } + + /// Floating-point argument registers (System V AMD64 ABI) + pub fn arg_regs() -> &'static [XmmReg] { + &[ + XmmReg::Xmm0, + XmmReg::Xmm1, + XmmReg::Xmm2, + XmmReg::Xmm3, + XmmReg::Xmm4, + XmmReg::Xmm5, + XmmReg::Xmm6, + XmmReg::Xmm7, + ] + } + + /// All allocatable XMM registers + /// XMM0-XMM7 are caller-saved (scratch), XMM8-XMM15 are callee-saved on Windows but not on System V + /// XMM14 and XMM15 are reserved as scratch registers for codegen operations + pub fn allocatable() -> &'static [XmmReg] { + &[ + XmmReg::Xmm0, + XmmReg::Xmm1, + XmmReg::Xmm2, + XmmReg::Xmm3, + XmmReg::Xmm4, + XmmReg::Xmm5, + XmmReg::Xmm6, + XmmReg::Xmm7, + XmmReg::Xmm8, + XmmReg::Xmm9, + XmmReg::Xmm10, + XmmReg::Xmm11, + XmmReg::Xmm12, + XmmReg::Xmm13, + // XMM14 and XMM15 are reserved for scratch use in codegen + ] + } +} + +// ============================================================================ +// Operand - Location of a value (register or memory) +// ============================================================================ + +/// Location of a value +#[derive(Debug, Clone)] +pub enum Loc { + /// In a general-purpose register + Reg(Reg), + /// In an XMM register (floating-point) + Xmm(XmmReg), + /// On the stack at [rbp - offset] + Stack(i32), + /// Immediate integer constant + Imm(i64), + /// Immediate float constant (value, size in bits) + FImm(f64, u32), + /// Global symbol + Global(String), +} + +// ============================================================================ +// Register Allocator (Linear Scan) +// ============================================================================ + +/// Live interval for a pseudo +#[derive(Debug, Clone)] +struct LiveInterval { + pseudo: PseudoId, + start: usize, + end: usize, +} + +/// Simple linear scan register allocator for x86-64 +pub struct RegAlloc { + /// Mapping from pseudo to location + locations: HashMap, + /// Free general-purpose registers + free_regs: Vec, + /// Free XMM registers (for floating-point) + free_xmm_regs: Vec, + /// Active integer register intervals (sorted by end point) + active: Vec<(LiveInterval, Reg)>, + /// Active XMM register intervals (sorted by end point) + active_xmm: Vec<(LiveInterval, XmmReg)>, + /// Next stack slot offset + stack_offset: i32, + /// Callee-saved registers that were used + used_callee_saved: Vec, + /// Track which pseudos need FP registers (based on type) + fp_pseudos: HashSet, +} + +impl RegAlloc { + pub fn new() -> Self { + Self { + locations: HashMap::new(), + free_regs: Reg::allocatable().to_vec(), + free_xmm_regs: XmmReg::allocatable().to_vec(), + active: Vec::new(), + active_xmm: Vec::new(), + stack_offset: 0, + used_callee_saved: Vec::new(), + fp_pseudos: HashSet::new(), + } + } + + /// Perform register allocation for a function + pub fn allocate(&mut self, func: &Function, types: &TypeTable) -> HashMap { + self.reset_state(); + self.identify_fp_pseudos(func, types); + self.allocate_arguments(func, types); + + let intervals = self.compute_live_intervals(func); + + self.spill_args_across_calls(func, &intervals); + self.allocate_alloca_to_stack(func); + self.run_linear_scan(func, types, intervals); + + self.locations.clone() + } + + /// Reset allocator state for a new function + fn reset_state(&mut self) { + self.locations.clear(); + self.free_regs = Reg::allocatable().to_vec(); + self.free_xmm_regs = XmmReg::allocatable().to_vec(); + self.active.clear(); + self.active_xmm.clear(); + self.stack_offset = 0; + self.used_callee_saved.clear(); + self.fp_pseudos.clear(); + } + + /// Pre-allocate argument registers per System V AMD64 ABI + fn allocate_arguments(&mut self, func: &Function, types: &TypeTable) { + let int_arg_regs = Reg::arg_regs(); + let fp_arg_regs = XmmReg::arg_regs(); + let mut int_arg_idx = 0; + let mut fp_arg_idx = 0; + + // Detect hidden return pointer for large struct returns + let sret_pseudo = func + .pseudos + .iter() + .find(|p| matches!(p.kind, PseudoKind::Arg(0)) && p.name.as_deref() == Some("__sret")); + let arg_idx_offset: u32 = if sret_pseudo.is_some() { 1 } else { 0 }; + + // Allocate RDI for hidden return pointer if present + if let Some(sret) = sret_pseudo { + self.locations.insert(sret.id, Loc::Reg(int_arg_regs[0])); + self.free_regs.retain(|&r| r != int_arg_regs[0]); + int_arg_idx += 1; + } + + for (i, (_name, typ)) in func.params.iter().enumerate() { + for pseudo in &func.pseudos { + if let PseudoKind::Arg(arg_idx) = pseudo.kind { + if arg_idx == (i as u32) + arg_idx_offset { + let is_fp = types.is_float(*typ); + if is_fp { + if fp_arg_idx < fp_arg_regs.len() { + self.locations + .insert(pseudo.id, Loc::Xmm(fp_arg_regs[fp_arg_idx])); + self.free_xmm_regs.retain(|&r| r != fp_arg_regs[fp_arg_idx]); + self.fp_pseudos.insert(pseudo.id); + } else { + let offset = + 16 + (i - int_arg_regs.len() - fp_arg_regs.len()) as i32 * 8; + self.locations.insert(pseudo.id, Loc::Stack(-offset)); + } + fp_arg_idx += 1; + } else { + if int_arg_idx < int_arg_regs.len() { + self.locations + .insert(pseudo.id, Loc::Reg(int_arg_regs[int_arg_idx])); + self.free_regs.retain(|&r| r != int_arg_regs[int_arg_idx]); + } else { + let offset = 16 + (i - int_arg_regs.len()) as i32 * 8; + self.locations.insert(pseudo.id, Loc::Stack(-offset)); + } + int_arg_idx += 1; + } + break; + } + } + } + } + } + + /// Spill arguments in caller-saved registers if their interval crosses a call + fn spill_args_across_calls(&mut self, func: &Function, intervals: &[LiveInterval]) { + // Find all call positions + let mut call_positions: Vec = Vec::new(); + let mut pos = 0usize; + for block in &func.blocks { + for insn in &block.insns { + if insn.op == Opcode::Call { + call_positions.push(pos); + } + pos += 1; + } + } + + // Check arguments in caller-saved registers + let int_arg_regs_set: Vec = Reg::arg_regs().to_vec(); + for interval in intervals { + if let Some(Loc::Reg(reg)) = self.locations.get(&interval.pseudo) { + if int_arg_regs_set.contains(reg) { + let crosses_call = call_positions + .iter() + .any(|&call_pos| interval.start <= call_pos && call_pos < interval.end); + if crosses_call { + let reg_to_restore = *reg; + self.stack_offset += 8; + self.locations + .insert(interval.pseudo, Loc::Stack(self.stack_offset)); + self.free_regs.push(reg_to_restore); + } + } + } + } + } + + /// Force alloca results to stack to avoid clobbering issues + fn allocate_alloca_to_stack(&mut self, func: &Function) { + for block in &func.blocks { + for insn in &block.insns { + if insn.op == Opcode::Alloca { + if let Some(target) = insn.target { + self.stack_offset += 8; + self.locations.insert(target, Loc::Stack(self.stack_offset)); + } + } + } + } + } + + /// Run the linear scan allocation algorithm + fn run_linear_scan( + &mut self, + func: &Function, + types: &TypeTable, + intervals: Vec, + ) { + for interval in intervals { + self.expire_old_intervals(interval.start); + + if self.locations.contains_key(&interval.pseudo) { + continue; + } + + // Handle constants and symbols + if let Some(pseudo) = func.pseudos.iter().find(|p| p.id == interval.pseudo) { + match &pseudo.kind { + PseudoKind::Val(v) => { + self.locations.insert(interval.pseudo, Loc::Imm(*v)); + continue; + } + PseudoKind::FVal(v) => { + let size = func + .blocks + .iter() + .flat_map(|b| &b.insns) + .find(|insn| { + insn.op == Opcode::SetVal && insn.target == Some(interval.pseudo) + }) + .map(|insn| insn.size) + .unwrap_or(64); + self.locations.insert(interval.pseudo, Loc::FImm(*v, size)); + self.fp_pseudos.insert(interval.pseudo); + continue; + } + PseudoKind::Sym(name) => { + if let Some(local_var) = func.locals.get(name) { + let size = (types.size_bits(local_var.typ) / 8) as i32; + let size = size.max(8); + let aligned_size = (size + 7) & !7; + self.stack_offset += aligned_size; + self.locations + .insert(interval.pseudo, Loc::Stack(self.stack_offset)); + if types.is_float(local_var.typ) { + self.fp_pseudos.insert(interval.pseudo); + } + } else { + self.locations + .insert(interval.pseudo, Loc::Global(name.clone())); + } + continue; + } + _ => {} + } + } + + // Allocate register based on type + let needs_fp = self.fp_pseudos.contains(&interval.pseudo); + + if needs_fp { + if let Some(xmm) = self.free_xmm_regs.pop() { + self.locations.insert(interval.pseudo, Loc::Xmm(xmm)); + self.active_xmm.push((interval.clone(), xmm)); + self.active_xmm.sort_by_key(|(i, _)| i.end); + } else { + self.stack_offset += 8; + self.locations + .insert(interval.pseudo, Loc::Stack(self.stack_offset)); + } + } else if let Some(reg) = self.free_regs.pop() { + if reg.is_callee_saved() && !self.used_callee_saved.contains(®) { + self.used_callee_saved.push(reg); + } + self.locations.insert(interval.pseudo, Loc::Reg(reg)); + self.active.push((interval.clone(), reg)); + self.active.sort_by_key(|(i, _)| i.end); + } else { + self.stack_offset += 8; + self.locations + .insert(interval.pseudo, Loc::Stack(self.stack_offset)); + } + } + } + + /// Scan function to identify which pseudos need FP registers + fn identify_fp_pseudos(&mut self, func: &Function, types: &TypeTable) { + for block in &func.blocks { + for insn in &block.insns { + let is_fp_op = matches!( + insn.op, + Opcode::FAdd + | Opcode::FSub + | Opcode::FMul + | Opcode::FDiv + | Opcode::FNeg + | Opcode::FCmpOEq + | Opcode::FCmpONe + | Opcode::FCmpOLt + | Opcode::FCmpOLe + | Opcode::FCmpOGt + | Opcode::FCmpOGe + | Opcode::UCvtF + | Opcode::SCvtF + | Opcode::FCvtF + ); + + // Mark target as FP if this is an FP operation (except comparisons) + if is_fp_op + && !matches!( + insn.op, + Opcode::FCmpOEq + | Opcode::FCmpONe + | Opcode::FCmpOLt + | Opcode::FCmpOLe + | Opcode::FCmpOGt + | Opcode::FCmpOGe + ) + { + if let Some(target) = insn.target { + self.fp_pseudos.insert(target); + } + } + + // Also check the type if available + if let Some(typ) = insn.typ { + if types.is_float(typ) + && !matches!( + insn.op, + Opcode::FCmpOEq + | Opcode::FCmpONe + | Opcode::FCmpOLt + | Opcode::FCmpOLe + | Opcode::FCmpOGt + | Opcode::FCmpOGe + ) + { + if let Some(target) = insn.target { + self.fp_pseudos.insert(target); + } + } + } + } + } + } + + fn expire_old_intervals(&mut self, point: usize) { + // Expire integer register intervals + let mut to_remove = Vec::new(); + for (i, (interval, reg)) in self.active.iter().enumerate() { + if interval.end < point { + self.free_regs.push(*reg); + to_remove.push(i); + } + } + for i in to_remove.into_iter().rev() { + self.active.remove(i); + } + + // Expire XMM register intervals + let mut to_remove_xmm = Vec::new(); + for (i, (interval, xmm)) in self.active_xmm.iter().enumerate() { + if interval.end < point { + self.free_xmm_regs.push(*xmm); + to_remove_xmm.push(i); + } + } + for i in to_remove_xmm.into_iter().rev() { + self.active_xmm.remove(i); + } + } + + fn compute_live_intervals(&self, func: &Function) -> Vec { + use crate::ir::BasicBlockId; + + struct IntervalInfo { + pseudo: PseudoId, + first_def: usize, + last_def: usize, + last_use: usize, + } + + let mut intervals: HashMap = HashMap::new(); + let mut pos = 0usize; + + // First pass: compute block end positions + let mut block_end_pos: HashMap = HashMap::new(); + let mut temp_pos = 0usize; + for block in &func.blocks { + temp_pos += block.insns.len(); + block_end_pos.insert(block.id, temp_pos.saturating_sub(1)); + } + + // Collect phi sources and targets + let mut phi_sources: Vec<(BasicBlockId, PseudoId)> = Vec::new(); + let mut phi_targets: Vec<(BasicBlockId, PseudoId)> = Vec::new(); + + for block in &func.blocks { + for insn in &block.insns { + if let Some(target) = insn.target { + intervals + .entry(target) + .and_modify(|info| { + info.first_def = info.first_def.min(pos); + info.last_def = info.last_def.max(pos); + }) + .or_insert(IntervalInfo { + pseudo: target, + first_def: pos, + last_def: pos, + last_use: pos, + }); + } + + for &src in &insn.src { + if let Some(info) = intervals.get_mut(&src) { + info.last_use = info.last_use.max(pos); + } else { + intervals.insert( + src, + IntervalInfo { + pseudo: src, + first_def: pos, + last_def: pos, + last_use: pos, + }, + ); + } + } + + for (src_bb, pseudo) in &insn.phi_list { + phi_sources.push((*src_bb, *pseudo)); + if let Some(target) = insn.target { + phi_targets.push((*src_bb, target)); + } + } + + pos += 1; + } + } + + // Extend phi source intervals to end of their source block + for (src_bb, pseudo) in phi_sources { + if let Some(&end_pos) = block_end_pos.get(&src_bb) { + if let Some(info) = intervals.get_mut(&pseudo) { + info.last_use = info.last_use.max(end_pos); + } else { + intervals.insert( + pseudo, + IntervalInfo { + pseudo, + first_def: end_pos, + last_def: end_pos, + last_use: end_pos, + }, + ); + } + } + } + + // Extend phi target intervals + for (src_bb, target) in phi_targets { + if let Some(&end_pos) = block_end_pos.get(&src_bb) { + if let Some(info) = intervals.get_mut(&target) { + info.last_def = info.last_def.max(end_pos); + } else { + intervals.insert( + target, + IntervalInfo { + pseudo: target, + first_def: end_pos, + last_def: end_pos, + last_use: end_pos, + }, + ); + } + } + } + + let max_pos = pos.saturating_sub(1); + + let mut result: Vec<_> = intervals + .into_values() + .map(|info| { + let end = if info.last_def > info.last_use { + max_pos + } else { + info.last_def.max(info.last_use) + }; + LiveInterval { + pseudo: info.pseudo, + start: info.first_def, + end, + } + }) + .collect(); + result.sort_by_key(|i| i.start); + result + } + + /// Get stack size needed (aligned to 16 bytes) + pub fn stack_size(&self) -> i32 { + (self.stack_offset + 15) & !15 + } + + /// Get callee-saved registers that need to be preserved + pub fn callee_saved_used(&self) -> &[Reg] { + &self.used_callee_saved + } +} + +impl Default for RegAlloc { + fn default() -> Self { + Self::new() + } +} From 62a3f7ddcb49f71eaea3ca5436f1dca1ec732e7f Mon Sep 17 00:00:00 2001 From: Jeff Garzik Date: Sun, 7 Dec 2025 20:18:31 -0500 Subject: [PATCH 2/8] [cc] aarch64 callee saved FP registers 1. cc/arch/aarch64/regalloc.rs - Added callee_saved_fp_used() getter (line 915) 2. cc/arch/aarch64/lir.rs - Added new LIR instructions: - StpFp - Store pair of FP registers (line 587) - LdpFp - Load pair of FP registers (line 595) - Emit implementations for both (lines 1310-1334) 3. cc/arch/aarch64/codegen.rs - Full integration: - Get FP callee-saved registers from allocator (line 280) - Frame size calculation includes FP callee-saved space (lines 295-297) - frame_info type extended to (i32, Vec, Vec) (line 604) - Prologue saves FP callee-saved registers using stp (lines 435-481) - Epilogue restores FP callee-saved registers using ldp (lines 702-728) Key AAPCS64 compliance: Per the ABI, only the lower 64 bits of V8-V15 need preservation, so we save/restore as d8-d15 (using FpSize::Double). --- cc/arch/aarch64/codegen.rs | 104 ++++++++++++++++++++++++++++++++---- cc/arch/aarch64/lir.rs | 42 +++++++++++++++ cc/arch/aarch64/regalloc.rs | 5 ++ 3 files changed, 142 insertions(+), 9 deletions(-) diff --git a/cc/arch/aarch64/codegen.rs b/cc/arch/aarch64/codegen.rs index 1da18f10..1269947e 100644 --- a/cc/arch/aarch64/codegen.rs +++ b/cc/arch/aarch64/codegen.rs @@ -277,6 +277,7 @@ impl Aarch64CodeGen { let stack_size = alloc.stack_size(); let callee_saved = alloc.callee_saved_used().to_vec(); + let callee_saved_fp = alloc.callee_saved_fp_used().to_vec(); // For variadic functions on Linux/FreeBSD, we need extra space for the register save area // AAPCS64: 8 GP regs (x0-x7) * 8 bytes = 64 bytes @@ -286,16 +287,20 @@ impl Aarch64CodeGen { let reg_save_area_size: i32 = if is_variadic && !is_darwin { 64 } else { 0 }; // Calculate total frame size - // Need space for: fp/lr (16 bytes) + callee-saved regs + local vars + reg save area - // Round up callee-saved count to even for 16-byte alignment - let callee_saved_pairs = (callee_saved.len() + 1) / 2; - let callee_saved_size = callee_saved_pairs as i32 * 16; + // Need space for: fp/lr (16 bytes) + GP callee-saved + FP callee-saved + local vars + reg save area + // Round up callee-saved counts to even for 16-byte alignment + // Note: AAPCS64 only requires the lower 64 bits of V8-V15 to be preserved (d8-d15) + let callee_saved_gp_pairs = (callee_saved.len() + 1) / 2; + let callee_saved_gp_size = callee_saved_gp_pairs as i32 * 16; + let callee_saved_fp_pairs = (callee_saved_fp.len() + 1) / 2; + let callee_saved_fp_size = callee_saved_fp_pairs as i32 * 16; // 8 bytes per d-reg, 16 per pair + let callee_saved_size = callee_saved_gp_size + callee_saved_fp_size; let total_frame = 16 + callee_saved_size + stack_size + reg_save_area_size; // Ensure 16-byte alignment let total_frame = (total_frame + 15) & !15; // Track register save area offset for va_start (offset from FP) - // Layout: [fp/lr][callee-saved][locals][reg_save_area] + // Layout: [fp/lr][GP callee-saved][FP callee-saved][locals][reg_save_area] // The save area is at FP + 16 + callee_saved_size + stack_size self.reg_save_area_offset = if is_variadic { 16 + callee_saved_size + stack_size @@ -426,6 +431,54 @@ impl Aarch64CodeGen { } offset += 16; } + + // Save FP callee-saved registers (d8-d15) in pairs + // AAPCS64 only requires preserving the lower 64 bits + let mut i = 0; + while i < callee_saved_fp.len() { + if i + 1 < callee_saved_fp.len() { + self.push_lir(Aarch64Inst::StpFp { + size: FpSize::Double, + src1: callee_saved_fp[i], + src2: callee_saved_fp[i + 1], + addr: MemAddr::BaseOffset { + base: Reg::X29, // fp + offset, + }, + }); + if self.emit_debug { + let cfi_offset1 = -(total_frame - offset); + let cfi_offset2 = -(total_frame - offset - 8); + self.push_lir(Aarch64Inst::Directive(Directive::cfi_offset( + callee_saved_fp[i].name_d(), + cfi_offset1, + ))); + self.push_lir(Aarch64Inst::Directive(Directive::cfi_offset( + callee_saved_fp[i + 1].name_d(), + cfi_offset2, + ))); + } + i += 2; + } else { + self.push_lir(Aarch64Inst::StrFp { + size: FpSize::Double, + src: callee_saved_fp[i], + addr: MemAddr::BaseOffset { + base: Reg::X29, // fp + offset, + }, + }); + if self.emit_debug { + let cfi_offset = -(total_frame - offset); + self.push_lir(Aarch64Inst::Directive(Directive::cfi_offset( + callee_saved_fp[i].name_d(), + cfi_offset, + ))); + } + i += 1; + } + offset += 16; + } } else { // Minimal frame: stp x29, x30, [sp, #-16]! self.push_lir(Aarch64Inst::Stp { @@ -548,7 +601,7 @@ impl Aarch64CodeGen { } // Store frame size for epilogue - let frame_info = (total_frame, callee_saved.clone()); + let frame_info = (total_frame, callee_saved.clone(), callee_saved_fp.clone()); // Emit basic blocks for block in &func.blocks { @@ -564,7 +617,7 @@ impl Aarch64CodeGen { fn emit_block( &mut self, block: &crate::ir::BasicBlock, - frame_info: &(i32, Vec), + frame_info: &(i32, Vec, Vec), types: &TypeTable, ) { // Always emit block ID label for consistency with jumps @@ -580,11 +633,16 @@ impl Aarch64CodeGen { } } - fn emit_insn(&mut self, insn: &Instruction, frame_info: &(i32, Vec), types: &TypeTable) { + fn emit_insn( + &mut self, + insn: &Instruction, + frame_info: &(i32, Vec, Vec), + types: &TypeTable, + ) { // Emit .loc directive for debug info self.emit_loc(insn); - let (total_frame, callee_saved) = frame_info; + let (total_frame, callee_saved, callee_saved_fp) = frame_info; match insn.op { Opcode::Entry => { @@ -645,6 +703,34 @@ impl Aarch64CodeGen { } offset += 16; } + + // Restore FP callee-saved registers (d8-d15) + let mut i = 0; + while i < callee_saved_fp.len() { + if i + 1 < callee_saved_fp.len() { + self.push_lir(Aarch64Inst::LdpFp { + size: FpSize::Double, + addr: MemAddr::BaseOffset { + base: Reg::sp(), + offset, + }, + dst1: callee_saved_fp[i], + dst2: callee_saved_fp[i + 1], + }); + i += 2; + } else { + self.push_lir(Aarch64Inst::LdrFp { + size: FpSize::Double, + addr: MemAddr::BaseOffset { + base: Reg::sp(), + offset, + }, + dst: callee_saved_fp[i], + }); + i += 1; + } + offset += 16; + } } // Restore fp/lr and deallocate stack diff --git a/cc/arch/aarch64/lir.rs b/cc/arch/aarch64/lir.rs index 4d8f3310..e65a99e2 100644 --- a/cc/arch/aarch64/lir.rs +++ b/cc/arch/aarch64/lir.rs @@ -584,6 +584,22 @@ pub enum Aarch64Inst { addr: MemAddr, }, + /// STP (FP) - Store pair of FP registers (for callee-saved register saves) + StpFp { + size: FpSize, + src1: VReg, + src2: VReg, + addr: MemAddr, + }, + + /// LDP (FP) - Load pair of FP registers (for callee-saved register restores) + LdpFp { + size: FpSize, + addr: MemAddr, + dst1: VReg, + dst2: VReg, + }, + /// FADD - FP add Fadd { size: FpSize, @@ -1291,6 +1307,32 @@ impl EmitAsm for Aarch64Inst { let _ = writeln!(out, " str {}, {}", name, addr.format()); } + Aarch64Inst::StpFp { + size, + src1, + src2, + addr, + } => { + let (name1, name2) = match size { + FpSize::Single => (src1.name_s(), src2.name_s()), + FpSize::Double => (src1.name_d(), src2.name_d()), + }; + let _ = writeln!(out, " stp {}, {}, {}", name1, name2, addr.format()); + } + + Aarch64Inst::LdpFp { + size, + addr, + dst1, + dst2, + } => { + let (name1, name2) = match size { + FpSize::Single => (dst1.name_s(), dst2.name_s()), + FpSize::Double => (dst1.name_d(), dst2.name_d()), + }; + let _ = writeln!(out, " ldp {}, {}, {}", name1, name2, addr.format()); + } + Aarch64Inst::LdrFpSymOffset { size, sym, diff --git a/cc/arch/aarch64/regalloc.rs b/cc/arch/aarch64/regalloc.rs index 4083c51f..903d8da0 100644 --- a/cc/arch/aarch64/regalloc.rs +++ b/cc/arch/aarch64/regalloc.rs @@ -910,6 +910,11 @@ impl RegAlloc { pub fn callee_saved_used(&self) -> &[Reg] { &self.used_callee_saved } + + /// Get callee-saved floating-point registers that need to be preserved + pub fn callee_saved_fp_used(&self) -> &[VReg] { + &self.used_callee_saved_fp + } } impl Default for RegAlloc { From 08481b6256ffd948702d1a1e24ec04247e625826 Mon Sep 17 00:00:00 2001 From: Jeff Garzik Date: Sun, 7 Dec 2025 20:39:13 -0500 Subject: [PATCH 3/8] [cc] aarch64 phi src fix --- cc/arch/aarch64/regalloc.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cc/arch/aarch64/regalloc.rs b/cc/arch/aarch64/regalloc.rs index 903d8da0..5882e788 100644 --- a/cc/arch/aarch64/regalloc.rs +++ b/cc/arch/aarch64/regalloc.rs @@ -804,12 +804,11 @@ impl RegAlloc { } } - // Extend phi source intervals + // Extend phi source intervals to end of their source block for (src_bb, pseudo) in phi_sources { if let Some(&end_pos) = block_end_pos.get(&src_bb) { if let Some(info) = intervals.get_mut(&pseudo) { info.last_use = info.last_use.max(end_pos); - info.last_def = info.last_def.max(end_pos); } else { intervals.insert( pseudo, From ab83e9d04de5a05ef97dea13158443af95f1bb98 Mon Sep 17 00:00:00 2001 From: Jeff Garzik Date: Sun, 7 Dec 2025 20:45:04 -0500 Subject: [PATCH 4/8] [cc] aarch64 fix: spill args across calls --- cc/arch/aarch64/regalloc.rs | 55 +++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/cc/arch/aarch64/regalloc.rs b/cc/arch/aarch64/regalloc.rs index 5882e788..efb93fb5 100644 --- a/cc/arch/aarch64/regalloc.rs +++ b/cc/arch/aarch64/regalloc.rs @@ -512,6 +512,8 @@ impl RegAlloc { self.allocate_arguments(func, types); let intervals = self.compute_live_intervals(func); + + self.spill_args_across_calls(func, &intervals); self.run_linear_scan(func, types, intervals); self.locations.clone() @@ -585,6 +587,59 @@ impl RegAlloc { } } + /// Spill arguments in caller-saved registers if their interval crosses a call + fn spill_args_across_calls(&mut self, func: &Function, intervals: &[LiveInterval]) { + // Find all call positions + let mut call_positions: Vec = Vec::new(); + let mut pos = 0usize; + for block in &func.blocks { + for insn in &block.insns { + if insn.op == Opcode::Call { + call_positions.push(pos); + } + pos += 1; + } + } + + // Check GP arguments in caller-saved registers (x0-x7) + let int_arg_regs_set: Vec = Reg::arg_regs().to_vec(); + for interval in intervals { + if let Some(Loc::Reg(reg)) = self.locations.get(&interval.pseudo) { + if int_arg_regs_set.contains(reg) { + let crosses_call = call_positions + .iter() + .any(|&call_pos| interval.start <= call_pos && call_pos < interval.end); + if crosses_call { + let reg_to_restore = *reg; + self.stack_offset += 8; + self.locations + .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); + self.free_regs.push(reg_to_restore); + } + } + } + } + + // Check FP arguments in caller-saved registers (v0-v7) + let fp_arg_regs_set: Vec = VReg::arg_regs().to_vec(); + for interval in intervals { + if let Some(Loc::VReg(reg)) = self.locations.get(&interval.pseudo) { + if fp_arg_regs_set.contains(reg) { + let crosses_call = call_positions + .iter() + .any(|&call_pos| interval.start <= call_pos && call_pos < interval.end); + if crosses_call { + let reg_to_restore = *reg; + self.stack_offset += 8; + self.locations + .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); + self.free_fp_regs.push(reg_to_restore); + } + } + } + } + } + /// Run the linear scan allocation algorithm fn run_linear_scan( &mut self, From 622948d811b90f93a87653afde9d9b1bd91fc3f0 Mon Sep 17 00:00:00 2001 From: Jeff Garzik Date: Sun, 7 Dec 2025 20:56:18 -0500 Subject: [PATCH 5/8] [cc] aarch64 alloca and FImm fixes --- cc/arch/aarch64/codegen.rs | 28 +++++++++++++++------------- cc/arch/aarch64/regalloc.rs | 31 ++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/cc/arch/aarch64/codegen.rs b/cc/arch/aarch64/codegen.rs index 1269947e..d8b43f8f 100644 --- a/cc/arch/aarch64/codegen.rs +++ b/cc/arch/aarch64/codegen.rs @@ -653,7 +653,7 @@ impl Aarch64CodeGen { // Move return value to x0 (integer) or v0 (float) if present if let Some(&src) = insn.src.first() { let src_loc = self.get_location(src); - let is_fp = matches!(src_loc, Loc::VReg(_) | Loc::FImm(_)); + let is_fp = matches!(src_loc, Loc::VReg(_) | Loc::FImm(..)); if is_fp { // FP return value goes in V0 @@ -826,7 +826,7 @@ impl Aarch64CodeGen { src: *v, }); } - Loc::FImm(f) => { + Loc::FImm(f, _) => { // FP immediate as condition - branch based on non-zero if *f != 0.0 { if let Some(target) = insn.bb_true { @@ -906,7 +906,7 @@ impl Aarch64CodeGen { Loc::Global(name) => { self.emit_load_global(name, scratch0, op_size); } - Loc::VReg(_) | Loc::FImm(_) => { + Loc::VReg(_) | Loc::FImm(..) => { // FP values shouldn't be used in switch statements // This is unreachable in valid C code } @@ -1322,14 +1322,15 @@ impl Aarch64CodeGen { dst, }); } - Loc::FImm(f) => { - // Load FP immediate as integer bits - let bits = if size <= 32 { + Loc::FImm(f, imm_size) => { + // Use the size from the FImm, not the passed-in size + // This ensures float constants are loaded as float, not double + let bits = if imm_size <= 32 { (f as f32).to_bits() as i64 } else { f.to_bits() as i64 }; - self.emit_mov_imm(dst, bits, size); + self.emit_mov_imm(dst, bits, imm_size); } } } @@ -2154,7 +2155,7 @@ impl Aarch64CodeGen { types.is_float(typ) } else { let arg_loc = self.get_location(arg); - matches!(arg_loc, Loc::VReg(_) | Loc::FImm(_)) + matches!(arg_loc, Loc::VReg(_) | Loc::FImm(..)) }; let arg_size = if let Some(typ) = arg_type { @@ -2240,7 +2241,7 @@ impl Aarch64CodeGen { } else { // Fall back to location-based detection for backwards compatibility let arg_loc = self.get_location(arg); - matches!(arg_loc, Loc::VReg(_) | Loc::FImm(_)) + matches!(arg_loc, Loc::VReg(_) | Loc::FImm(..)) }; // Get argument size from type, with minimum 32-bit for register ops @@ -2321,7 +2322,7 @@ impl Aarch64CodeGen { let is_fp_result = if let Some(typ) = insn.typ { types.is_float(typ) } else { - matches!(dst_loc, Loc::VReg(_) | Loc::FImm(_)) + matches!(dst_loc, Loc::VReg(_) | Loc::FImm(..)) }; // Get return value size from type @@ -2498,7 +2499,7 @@ impl Aarch64CodeGen { // Check if this is a FP copy (source or dest is in VReg or is FImm) let is_fp_copy = - matches!(&src_loc, Loc::VReg(_) | Loc::FImm(_)) || matches!(&dst_loc, Loc::VReg(_)); + matches!(&src_loc, Loc::VReg(_) | Loc::FImm(..)) || matches!(&dst_loc, Loc::VReg(_)); // Determine if the type is unsigned (for proper sign/zero extension) // For plain char, use target.char_signed to determine signedness @@ -2598,10 +2599,11 @@ impl Aarch64CodeGen { dst, }); } - Loc::FImm(f) => { + Loc::FImm(f, imm_size) => { // Load FP constant using integer register + // Use the size from the FImm for correct constant representation let (scratch0, _) = Reg::scratch_regs(); - let bits = if size <= 32 { + let bits = if imm_size <= 32 { (f as f32).to_bits() as i64 } else { f.to_bits() as i64 diff --git a/cc/arch/aarch64/regalloc.rs b/cc/arch/aarch64/regalloc.rs index efb93fb5..42d0a322 100644 --- a/cc/arch/aarch64/regalloc.rs +++ b/cc/arch/aarch64/regalloc.rs @@ -450,8 +450,8 @@ pub enum Loc { Stack(i32), /// Immediate constant Imm(i64), - /// Floating-point immediate constant - FImm(f64), + /// Floating-point immediate constant (value, size in bits) + FImm(f64, u32), /// Global symbol Global(String), } @@ -514,6 +514,7 @@ impl RegAlloc { let intervals = self.compute_live_intervals(func); self.spill_args_across_calls(func, &intervals); + self.allocate_alloca_to_stack(func); self.run_linear_scan(func, types, intervals); self.locations.clone() @@ -587,6 +588,21 @@ impl RegAlloc { } } + /// Force alloca results to stack to avoid clobbering issues + fn allocate_alloca_to_stack(&mut self, func: &Function) { + for block in &func.blocks { + for insn in &block.insns { + if insn.op == Opcode::Alloca { + if let Some(target) = insn.target { + self.stack_offset += 8; + self.locations + .insert(target, Loc::Stack(-self.stack_offset)); + } + } + } + } + } + /// Spill arguments in caller-saved registers if their interval crosses a call fn spill_args_across_calls(&mut self, func: &Function, intervals: &[LiveInterval]) { // Find all call positions @@ -662,7 +678,16 @@ impl RegAlloc { continue; } PseudoKind::FVal(v) => { - self.locations.insert(interval.pseudo, Loc::FImm(*v)); + let size = func + .blocks + .iter() + .flat_map(|b| &b.insns) + .find(|insn| { + insn.op == Opcode::SetVal && insn.target == Some(interval.pseudo) + }) + .map(|insn| insn.size) + .unwrap_or(64); + self.locations.insert(interval.pseudo, Loc::FImm(*v, size)); self.fp_pseudos.insert(interval.pseudo); continue; } From 98b8ac03e47bc4ba342c391497ebf06bfea145e4 Mon Sep 17 00:00:00 2001 From: Jeff Garzik Date: Sun, 7 Dec 2025 21:25:34 -0500 Subject: [PATCH 6/8] loop back-edge detection to x86_64 register allocator. Changes to cc/arch/x86_64/regalloc.rs: 1. Added block_start_pos tracking (line 656) - needed to detect back edges 2. Added loop back-edge detection (lines 751-772) - identifies branches to earlier blocks 3. Added lifetime extension for loop variables (lines 774-786) - extends last_use for variables that live across loop iterations The logic is identical to AArch64: if a variable is defined before a loop, used inside the loop, and there's a back edge, its live interval is extended to the back edge position. This prevents the register allocator from prematurely freeing registers that are still needed in subsequent loop iterations. --- cc/arch/x86_64/regalloc.rs | 41 +++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/cc/arch/x86_64/regalloc.rs b/cc/arch/x86_64/regalloc.rs index 1a3797a6..6f698816 100644 --- a/cc/arch/x86_64/regalloc.rs +++ b/cc/arch/x86_64/regalloc.rs @@ -652,10 +652,12 @@ impl RegAlloc { let mut intervals: HashMap = HashMap::new(); let mut pos = 0usize; - // First pass: compute block end positions + // First pass: compute block start and end positions + let mut block_start_pos: HashMap = HashMap::new(); let mut block_end_pos: HashMap = HashMap::new(); let mut temp_pos = 0usize; for block in &func.blocks { + block_start_pos.insert(block.id, temp_pos); temp_pos += block.insns.len(); block_end_pos.insert(block.id, temp_pos.saturating_sub(1)); } @@ -746,6 +748,43 @@ impl RegAlloc { } } + // Handle loop back edges + let mut loop_back_edges: Vec<(BasicBlockId, BasicBlockId, usize)> = Vec::new(); + for block in &func.blocks { + if let Some(last_insn) = block.insns.last() { + let mut targets = Vec::new(); + if let Some(target) = last_insn.bb_true { + targets.push(target); + } + if let Some(target) = last_insn.bb_false { + targets.push(target); + } + + let from_start = block_start_pos.get(&block.id).copied().unwrap_or(0); + for target_bb in targets { + let target_start = block_start_pos.get(&target_bb).copied().unwrap_or(0); + if target_start < from_start { + let from_end = block_end_pos.get(&block.id).copied().unwrap_or(0); + loop_back_edges.push((block.id, target_bb, from_end)); + } + } + } + } + + // Extend lifetimes for loop variables + for (_from_bb, to_bb, back_edge_pos) in &loop_back_edges { + let loop_start = block_start_pos.get(to_bb).copied().unwrap_or(0); + + for info in intervals.values_mut() { + if info.first_def < loop_start + && info.last_use >= loop_start + && info.last_use <= *back_edge_pos + { + info.last_use = info.last_use.max(*back_edge_pos); + } + } + } + let max_pos = pos.saturating_sub(1); let mut result: Vec<_> = intervals From 1f2eff684312a27aba979bc906feb144d86029e3 Mon Sep 17 00:00:00 2001 From: Jeff Garzik Date: Sun, 7 Dec 2025 22:01:16 -0500 Subject: [PATCH 7/8] [cc] common regalloc patterns --- cc/arch/aarch64/regalloc.rs | 85 ++++++++++--------------------------- cc/arch/mod.rs | 1 + cc/arch/regalloc.rs | 70 ++++++++++++++++++++++++++++++ cc/arch/x86_64/regalloc.rs | 65 ++++++---------------------- 4 files changed, 108 insertions(+), 113 deletions(-) create mode 100644 cc/arch/regalloc.rs diff --git a/cc/arch/aarch64/regalloc.rs b/cc/arch/aarch64/regalloc.rs index 42d0a322..f8882617 100644 --- a/cc/arch/aarch64/regalloc.rs +++ b/cc/arch/aarch64/regalloc.rs @@ -10,6 +10,9 @@ // Linear scan register allocation for AArch64 // +use crate::arch::regalloc::{ + expire_intervals, find_call_positions, interval_crosses_call, LiveInterval, +}; use crate::ir::{BasicBlockId, Function, Opcode, PseudoId, PseudoKind}; use crate::types::TypeTable; use std::collections::{HashMap, HashSet}; @@ -460,14 +463,6 @@ pub enum Loc { // Register Allocator (Linear Scan) // ============================================================================ -/// Live interval for a pseudo -#[derive(Debug, Clone)] -struct LiveInterval { - pseudo: PseudoId, - start: usize, - end: usize, -} - /// Simple linear scan register allocator for AArch64 pub struct RegAlloc { /// Mapping from pseudo to location @@ -605,33 +600,20 @@ impl RegAlloc { /// Spill arguments in caller-saved registers if their interval crosses a call fn spill_args_across_calls(&mut self, func: &Function, intervals: &[LiveInterval]) { - // Find all call positions - let mut call_positions: Vec = Vec::new(); - let mut pos = 0usize; - for block in &func.blocks { - for insn in &block.insns { - if insn.op == Opcode::Call { - call_positions.push(pos); - } - pos += 1; - } - } + let call_positions = find_call_positions(func); // Check GP arguments in caller-saved registers (x0-x7) let int_arg_regs_set: Vec = Reg::arg_regs().to_vec(); for interval in intervals { if let Some(Loc::Reg(reg)) = self.locations.get(&interval.pseudo) { - if int_arg_regs_set.contains(reg) { - let crosses_call = call_positions - .iter() - .any(|&call_pos| interval.start <= call_pos && call_pos < interval.end); - if crosses_call { - let reg_to_restore = *reg; - self.stack_offset += 8; - self.locations - .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); - self.free_regs.push(reg_to_restore); - } + if int_arg_regs_set.contains(reg) + && interval_crosses_call(interval, &call_positions) + { + let reg_to_restore = *reg; + self.stack_offset += 8; + self.locations + .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); + self.free_regs.push(reg_to_restore); } } } @@ -640,17 +622,13 @@ impl RegAlloc { let fp_arg_regs_set: Vec = VReg::arg_regs().to_vec(); for interval in intervals { if let Some(Loc::VReg(reg)) = self.locations.get(&interval.pseudo) { - if fp_arg_regs_set.contains(reg) { - let crosses_call = call_positions - .iter() - .any(|&call_pos| interval.start <= call_pos && call_pos < interval.end); - if crosses_call { - let reg_to_restore = *reg; - self.stack_offset += 8; - self.locations - .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); - self.free_fp_regs.push(reg_to_restore); - } + if fp_arg_regs_set.contains(reg) && interval_crosses_call(interval, &call_positions) + { + let reg_to_restore = *reg; + self.stack_offset += 8; + self.locations + .insert(interval.pseudo, Loc::Stack(-self.stack_offset)); + self.free_fp_regs.push(reg_to_restore); } } } @@ -794,27 +772,10 @@ impl RegAlloc { } fn expire_old_intervals(&mut self, point: usize) { - let mut to_remove = Vec::new(); - for (i, (interval, reg)) in self.active.iter().enumerate() { - if interval.end < point { - self.free_regs.push(*reg); - to_remove.push(i); - } - } - for i in to_remove.into_iter().rev() { - self.active.remove(i); - } - - let mut to_remove_fp = Vec::new(); - for (i, (interval, reg)) in self.active_fp.iter().enumerate() { - if interval.end < point { - self.free_fp_regs.push(*reg); - to_remove_fp.push(i); - } - } - for i in to_remove_fp.into_iter().rev() { - self.active_fp.remove(i); - } + // Expire GP register intervals + expire_intervals(&mut self.active, &mut self.free_regs, point); + // Expire FP register intervals + expire_intervals(&mut self.active_fp, &mut self.free_fp_regs, point); } fn compute_live_intervals(&self, func: &Function) -> Vec { diff --git a/cc/arch/mod.rs b/cc/arch/mod.rs index 58992add..b2ba319b 100644 --- a/cc/arch/mod.rs +++ b/cc/arch/mod.rs @@ -15,6 +15,7 @@ pub const DEFAULT_LIR_BUFFER_CAPACITY: usize = 5000; pub mod aarch64; pub mod codegen; pub mod lir; +pub mod regalloc; pub mod x86_64; use crate::target::{Arch, Target}; diff --git a/cc/arch/regalloc.rs b/cc/arch/regalloc.rs new file mode 100644 index 00000000..e94fa994 --- /dev/null +++ b/cc/arch/regalloc.rs @@ -0,0 +1,70 @@ +// +// Copyright (c) 2024 Jeff Garzik +// +// This file is part of the posixutils-rs project covered under +// the MIT License. For the full license text, please see the LICENSE +// file in the root directory of this project. +// SPDX-License-Identifier: MIT +// +// Common register allocator utilities shared between architectures +// + +use crate::ir::{Function, Opcode, PseudoId}; + +// ============================================================================ +// Common Types +// ============================================================================ + +/// Live interval for a pseudo-register +#[derive(Debug, Clone)] +pub struct LiveInterval { + pub pseudo: PseudoId, + pub start: usize, + pub end: usize, +} + +// ============================================================================ +// Common Functions +// ============================================================================ + +/// Expire old intervals from the active list, returning freed registers to the free list. +/// Generic over register type R (works with both GP and FP register types). +pub fn expire_intervals( + active: &mut Vec<(LiveInterval, R)>, + free_regs: &mut Vec, + point: usize, +) { + let mut to_remove = Vec::new(); + for (i, (interval, reg)) in active.iter().enumerate() { + if interval.end < point { + free_regs.push(*reg); + to_remove.push(i); + } + } + for i in to_remove.into_iter().rev() { + active.remove(i); + } +} + +/// Find all positions of call instructions in a function. +/// Used by spill_args_across_calls to identify where arguments may be clobbered. +pub fn find_call_positions(func: &Function) -> Vec { + let mut call_positions = Vec::new(); + let mut pos = 0usize; + for block in &func.blocks { + for insn in &block.insns { + if insn.op == Opcode::Call { + call_positions.push(pos); + } + pos += 1; + } + } + call_positions +} + +/// Check if a live interval crosses any call position. +pub fn interval_crosses_call(interval: &LiveInterval, call_positions: &[usize]) -> bool { + call_positions + .iter() + .any(|&call_pos| interval.start <= call_pos && call_pos < interval.end) +} diff --git a/cc/arch/x86_64/regalloc.rs b/cc/arch/x86_64/regalloc.rs index 6f698816..e69879d5 100644 --- a/cc/arch/x86_64/regalloc.rs +++ b/cc/arch/x86_64/regalloc.rs @@ -10,6 +10,9 @@ // Linear scan register allocation for x86-64 // +use crate::arch::regalloc::{ + expire_intervals, find_call_positions, interval_crosses_call, LiveInterval, +}; use crate::ir::{Function, Opcode, PseudoId, PseudoKind}; use crate::types::TypeTable; use std::collections::{HashMap, HashSet}; @@ -295,14 +298,6 @@ pub enum Loc { // Register Allocator (Linear Scan) // ============================================================================ -/// Live interval for a pseudo -#[derive(Debug, Clone)] -struct LiveInterval { - pseudo: PseudoId, - start: usize, - end: usize, -} - /// Simple linear scan register allocator for x86-64 pub struct RegAlloc { /// Mapping from pseudo to location @@ -422,33 +417,20 @@ impl RegAlloc { /// Spill arguments in caller-saved registers if their interval crosses a call fn spill_args_across_calls(&mut self, func: &Function, intervals: &[LiveInterval]) { - // Find all call positions - let mut call_positions: Vec = Vec::new(); - let mut pos = 0usize; - for block in &func.blocks { - for insn in &block.insns { - if insn.op == Opcode::Call { - call_positions.push(pos); - } - pos += 1; - } - } + let call_positions = find_call_positions(func); // Check arguments in caller-saved registers let int_arg_regs_set: Vec = Reg::arg_regs().to_vec(); for interval in intervals { if let Some(Loc::Reg(reg)) = self.locations.get(&interval.pseudo) { - if int_arg_regs_set.contains(reg) { - let crosses_call = call_positions - .iter() - .any(|&call_pos| interval.start <= call_pos && call_pos < interval.end); - if crosses_call { - let reg_to_restore = *reg; - self.stack_offset += 8; - self.locations - .insert(interval.pseudo, Loc::Stack(self.stack_offset)); - self.free_regs.push(reg_to_restore); - } + if int_arg_regs_set.contains(reg) + && interval_crosses_call(interval, &call_positions) + { + let reg_to_restore = *reg; + self.stack_offset += 8; + self.locations + .insert(interval.pseudo, Loc::Stack(self.stack_offset)); + self.free_regs.push(reg_to_restore); } } } @@ -615,28 +597,9 @@ impl RegAlloc { fn expire_old_intervals(&mut self, point: usize) { // Expire integer register intervals - let mut to_remove = Vec::new(); - for (i, (interval, reg)) in self.active.iter().enumerate() { - if interval.end < point { - self.free_regs.push(*reg); - to_remove.push(i); - } - } - for i in to_remove.into_iter().rev() { - self.active.remove(i); - } - + expire_intervals(&mut self.active, &mut self.free_regs, point); // Expire XMM register intervals - let mut to_remove_xmm = Vec::new(); - for (i, (interval, xmm)) in self.active_xmm.iter().enumerate() { - if interval.end < point { - self.free_xmm_regs.push(*xmm); - to_remove_xmm.push(i); - } - } - for i in to_remove_xmm.into_iter().rev() { - self.active_xmm.remove(i); - } + expire_intervals(&mut self.active_xmm, &mut self.free_xmm_regs, point); } fn compute_live_intervals(&self, func: &Function) -> Vec { From 00e6ff54d9bf9d5e2f07b9176640ae898cb8009a Mon Sep 17 00:00:00 2001 From: Jeff Garzik Date: Sun, 7 Dec 2025 22:55:16 -0500 Subject: [PATCH 8/8] [cc/codegen] cleanup: extract common code --- cc/arch/aarch64/codegen.rs | 43 +++++--------------------------------- cc/arch/codegen.rs | 37 +++++++++++++++++++++++++++++++- cc/arch/x86_64/codegen.rs | 39 +++------------------------------- 3 files changed, 44 insertions(+), 75 deletions(-) diff --git a/cc/arch/aarch64/codegen.rs b/cc/arch/aarch64/codegen.rs index d8b43f8f..eeee6546 100644 --- a/cc/arch/aarch64/codegen.rs +++ b/cc/arch/aarch64/codegen.rs @@ -19,7 +19,7 @@ use crate::arch::aarch64::lir::{Aarch64Inst, CallTarget, Cond, GpOperand, MemAddr}; use crate::arch::aarch64::regalloc::{Loc, Reg, RegAlloc, VReg}; -use crate::arch::codegen::CodeGenerator; +use crate::arch::codegen::{escape_string, is_variadic_function, CodeGenerator}; use crate::arch::lir::{Directive, FpSize, Label, OperandSize, Symbol}; use crate::arch::DEFAULT_LIR_BUFFER_CAPACITY; use crate::ir::{Function, Initializer, Instruction, Module, Opcode, Pseudo, PseudoId, PseudoKind}; @@ -87,18 +87,6 @@ impl Aarch64CodeGen { } } - /// Check if a function contains va_start (indicating it's variadic) - fn is_variadic_function(func: &Function) -> bool { - for block in &func.blocks { - for insn in &block.insns { - if matches!(insn.op, crate::ir::Opcode::VaStart) { - return true; - } - } - } - false - } - /// Compute the actual FP-relative offset for a stack location. /// For local variables (negative offsets), this accounts for the /// register save area in varargs functions which is placed at the @@ -236,39 +224,18 @@ impl Aarch64CodeGen { for (label, content) in strings { // Local label for string literal self.push_lir(Aarch64Inst::Directive(Directive::local_label(label))); - self.push_lir(Aarch64Inst::Directive(Directive::Asciz( - Self::escape_string(content), - ))); + self.push_lir(Aarch64Inst::Directive(Directive::Asciz(escape_string( + content, + )))); } // Switch back to text section for functions self.push_lir(Aarch64Inst::Directive(Directive::Text)); } - fn escape_string(s: &str) -> String { - let mut result = String::new(); - for c in s.chars() { - match c { - '\n' => result.push_str("\\n"), - '\r' => result.push_str("\\r"), - '\t' => result.push_str("\\t"), - '\\' => result.push_str("\\\\"), - '"' => result.push_str("\\\""), - c if c.is_ascii_graphic() || c == ' ' => result.push(c), - c => { - // Escape non-printable as octal - for byte in c.to_string().as_bytes() { - result.push_str(&format!("\\{:03o}", byte)); - } - } - } - } - result - } - fn emit_function(&mut self, func: &Function, types: &TypeTable) { // Check if this function uses varargs - let is_variadic = Self::is_variadic_function(func); + let is_variadic = is_variadic_function(func); // Register allocation let mut alloc = RegAlloc::new(); diff --git a/cc/arch/codegen.rs b/cc/arch/codegen.rs index 53bfb432..ff4ea353 100644 --- a/cc/arch/codegen.rs +++ b/cc/arch/codegen.rs @@ -9,7 +9,7 @@ // Architecture-independent code generation interface // -use crate::ir::Module; +use crate::ir::{Function, Module, Opcode}; use crate::target::Target; use crate::types::TypeTable; @@ -47,6 +47,41 @@ pub fn generate_header_comments(target: &Target) -> Vec { comments } +/// Check if a function uses variadic arguments (contains VaStart opcode) +pub fn is_variadic_function(func: &Function) -> bool { + for block in &func.blocks { + for insn in &block.insns { + if matches!(insn.op, Opcode::VaStart) { + return true; + } + } + } + false +} + +/// Escape a string for assembly output (.ascii/.asciz directives) +/// Non-printable and non-ASCII characters are escaped as octal byte sequences. +pub fn escape_string(s: &str) -> String { + let mut result = String::new(); + for c in s.chars() { + match c { + '\n' => result.push_str("\\n"), + '\r' => result.push_str("\\r"), + '\t' => result.push_str("\\t"), + '\\' => result.push_str("\\\\"), + '"' => result.push_str("\\\""), + c if c.is_ascii_graphic() || c == ' ' => result.push(c), + c => { + // Escape non-printable as octal bytes (handles UTF-8 correctly) + for byte in c.to_string().as_bytes() { + result.push_str(&format!("\\{:03o}", byte)); + } + } + } + } + result +} + /// Trait for architecture-specific code generators pub trait CodeGenerator { /// Generate assembly code for the given IR module diff --git a/cc/arch/x86_64/codegen.rs b/cc/arch/x86_64/codegen.rs index 1fd36a7c..039a6790 100644 --- a/cc/arch/x86_64/codegen.rs +++ b/cc/arch/x86_64/codegen.rs @@ -12,7 +12,7 @@ // Uses linear scan register allocation and System V AMD64 ABI. // -use crate::arch::codegen::CodeGenerator; +use crate::arch::codegen::{escape_string, is_variadic_function, CodeGenerator}; #[allow(unused_imports)] use crate::arch::lir::{Directive, FpSize, Label, OperandSize, Symbol}; #[allow(unused_imports)] @@ -211,24 +211,12 @@ impl X86_64CodeGen { } } - /// Check if a function contains va_start (indicating it's variadic) - fn is_variadic_function(func: &Function) -> bool { - for block in &func.blocks { - for insn in &block.insns { - if matches!(insn.op, crate::ir::Opcode::VaStart) { - return true; - } - } - } - false - } - fn emit_function(&mut self, func: &Function, types: &TypeTable) { // Save current function name for unique label generation self.current_fn = func.name.clone(); // Check if this function uses varargs - let is_variadic = Self::is_variadic_function(func); + let is_variadic = is_variadic_function(func); // Register allocation let mut alloc = RegAlloc::new(); @@ -4302,31 +4290,10 @@ impl X86_64CodeGen { // String labels are local (start with .) self.push_lir(X86Inst::Directive(Directive::local_label(label.clone()))); // Emit string with proper escaping - self.push_lir(X86Inst::Directive(Directive::Asciz(Self::escape_string( - content, - )))); + self.push_lir(X86Inst::Directive(Directive::Asciz(escape_string(content)))); } // Switch back to text section for functions self.push_lir(X86Inst::Directive(Directive::Text)); } - - fn escape_string(s: &str) -> String { - let mut result = String::new(); - for c in s.chars() { - match c { - '\n' => result.push_str("\\n"), - '\r' => result.push_str("\\r"), - '\t' => result.push_str("\\t"), - '\\' => result.push_str("\\\\"), - '"' => result.push_str("\\\""), - c if c.is_ascii_graphic() || c == ' ' => result.push(c), - c => { - // Emit as octal escape - result.push_str(&format!("\\{:03o}", c as u32)); - } - } - } - result - } }