diff --git a/cc/README.md b/cc/README.md index 086f5cc1..3b3277c6 100644 --- a/cc/README.md +++ b/cc/README.md @@ -94,7 +94,9 @@ Supported: - Bitfields (named, unnamed, zero-width for alignment) Not yet implemented: +- goto, longjmp, setjmp - `inline` and inlining support +- multi-register returns (for structs larger than 8 bytes) - -fverbose-asm - Complex initializers - VLAs (variable-length arrays) @@ -130,4 +132,8 @@ Please run `cargo fmt` before committing code, and `cargo clippy` regularly whil cargo fmt && cargo clippy -p posixutils-cc ``` +DO NOT `allow(dead_code)` to fix warnings. Instead, remove dead code; do +not leave it around as a maintenance burden (and LLM token +tax). + Read CONTRIBUTING.md in the root of the repository for more details. diff --git a/cc/arch/aarch64/codegen.rs b/cc/arch/aarch64/codegen.rs index 03d4b623..9210a62c 100644 --- a/cc/arch/aarch64/codegen.rs +++ b/cc/arch/aarch64/codegen.rs @@ -20,10 +20,10 @@ use crate::arch::aarch64::lir::{Aarch64Inst, CallTarget, Cond, GpOperand, MemAddr}; use crate::arch::codegen::CodeGenerator; use crate::arch::lir::{Directive, FpSize, Label, OperandSize, Symbol}; +use crate::arch::DEFAULT_LIR_BUFFER_CAPACITY; use crate::ir::{Function, Initializer, Instruction, Module, Opcode, Pseudo, PseudoId, PseudoKind}; -use crate::linearize::MAX_REGISTER_AGGREGATE_BITS; use crate::target::Target; -use crate::types::{Type, TypeModifiers}; +use crate::types::{TypeId, TypeModifiers, TypeTable}; use std::collections::HashMap; // ============================================================================ @@ -518,7 +518,7 @@ impl RegAlloc { } /// Perform register allocation for a function - pub fn allocate(&mut self, func: &Function) -> HashMap { + pub fn allocate(&mut self, func: &Function, types: &TypeTable) -> HashMap { // Reset state self.locations.clear(); self.free_regs = Reg::allocatable().to_vec(); @@ -531,7 +531,7 @@ impl RegAlloc { self.fp_pseudos.clear(); // Identify which pseudos need FP registers - self.identify_fp_pseudos(func); + self.identify_fp_pseudos(func, types); // Pre-allocate argument registers (integer and FP separately) let int_arg_regs = Reg::arg_regs(); @@ -562,7 +562,7 @@ impl RegAlloc { if let PseudoKind::Arg(arg_idx) = pseudo.kind { // Match by adjusted index (sret shifts arg_idx but not register allocation) if arg_idx == (i as u32) + arg_idx_offset { - let is_fp = typ.is_float(); + let is_fp = types.is_float(*typ); if is_fp { if fp_arg_idx < fp_arg_regs.len() { self.locations @@ -622,7 +622,7 @@ impl RegAlloc { // Check if this is a local variable or a global symbol if let Some(local) = func.locals.get(name) { // Local variable - allocate stack space based on actual type size - let size = local.typ.size_bytes().max(8) as i32; + let size = (types.size_bits(local.typ) / 8).max(8) as i32; // Align to 8 bytes let aligned_size = (size + 7) & !7; self.stack_offset += aligned_size; @@ -679,7 +679,7 @@ impl RegAlloc { } /// Identify which pseudos need FP registers by scanning the IR - fn identify_fp_pseudos(&mut self, func: &Function) { + fn identify_fp_pseudos(&mut self, func: &Function, types: &TypeTable) { for block in &func.blocks { for insn in &block.insns { // Check if this instruction produces a floating-point result @@ -703,8 +703,8 @@ impl RegAlloc { // Also check if the type is floating point // (but exclude comparisons which always produce int regardless of operand type) - if let Some(ref typ) = insn.typ { - if typ.is_float() + if let Some(typ) = insn.typ { + if types.is_float(typ) && !matches!( insn.op, Opcode::FCmpOEq @@ -1022,7 +1022,7 @@ impl Aarch64CodeGen { Self { target, output: String::new(), - lir_buffer: Vec::new(), + lir_buffer: Vec::with_capacity(DEFAULT_LIR_BUFFER_CAPACITY), locations: HashMap::new(), pseudos: Vec::new(), current_fn: String::new(), @@ -1109,15 +1109,15 @@ impl Aarch64CodeGen { self.push_lir(Aarch64Inst::Directive(Directive::Text)); } - fn emit_global(&mut self, name: &str, typ: &Type, init: &Initializer) { - let size = typ.size_bits() / 8; + fn emit_global(&mut self, name: &str, typ: &TypeId, init: &Initializer, types: &TypeTable) { + let size = types.size_bits(*typ) / 8; let size = if size == 0 { 8 } else { size }; // Default to 8 bytes // Check storage class - skip .globl for static - let is_static = typ.modifiers.contains(TypeModifiers::STATIC); + let is_static = types.get(*typ).modifiers.contains(TypeModifiers::STATIC); // Get alignment from type info - let align = typ.alignment() as u32; + let align = types.alignment(*typ) as u32; // Use .comm for uninitialized external (non-static) globals let use_bss = matches!(init, Initializer::None) && !is_static; @@ -1217,13 +1217,13 @@ impl Aarch64CodeGen { result } - fn emit_function(&mut self, func: &Function) { + fn emit_function(&mut self, func: &Function, types: &TypeTable) { // Check if this function uses varargs let is_variadic = Self::is_variadic_function(func); // Register allocation let mut alloc = RegAlloc::new(); - self.locations = alloc.allocate(func); + self.locations = alloc.allocate(func, types); self.pseudos = func.pseudos.clone(); let stack_size = alloc.stack_size(); @@ -1463,7 +1463,7 @@ impl Aarch64CodeGen { self.num_fixed_gp_params = func .params .iter() - .filter(|(_, typ)| !typ.is_float()) + .filter(|(_, typ)| !types.is_float(*typ)) .count(); } @@ -1503,7 +1503,7 @@ impl Aarch64CodeGen { // Emit basic blocks for block in &func.blocks { - self.emit_block(block, &frame_info); + self.emit_block(block, &frame_info, types); } // CFI: End procedure @@ -1512,7 +1512,12 @@ impl Aarch64CodeGen { } } - fn emit_block(&mut self, block: &crate::ir::BasicBlock, frame_info: &(i32, Vec)) { + fn emit_block( + &mut self, + block: &crate::ir::BasicBlock, + frame_info: &(i32, Vec), + types: &TypeTable, + ) { // Emit block label using LIR (include function name for uniqueness) if let Some(label) = &block.label { // LIR: named block label (using Raw since format differs from standard) @@ -1530,11 +1535,11 @@ impl Aarch64CodeGen { // Emit instructions for insn in &block.insns { - self.emit_insn(insn, frame_info); + self.emit_insn(insn, frame_info, types); } } - fn emit_insn(&mut self, insn: &Instruction, frame_info: &(i32, Vec)) { + fn emit_insn(&mut self, insn: &Instruction, frame_info: &(i32, Vec), types: &TypeTable) { // Emit .loc directive for debug info self.emit_loc(insn); @@ -1862,7 +1867,7 @@ impl Aarch64CodeGen { } Opcode::Load => { - self.emit_load(insn, *total_frame); + self.emit_load(insn, *total_frame, types); } Opcode::Store => { @@ -1870,7 +1875,7 @@ impl Aarch64CodeGen { } Opcode::Call => { - self.emit_call(insn, *total_frame); + self.emit_call(insn, *total_frame, types); } Opcode::SetVal => { @@ -1914,13 +1919,7 @@ impl Aarch64CodeGen { Opcode::Copy => { if let (Some(target), Some(&src)) = (insn.target, insn.src.first()) { // Pass the type for proper sign/zero extension - self.emit_copy_with_type( - src, - target, - insn.size, - insn.typ.as_ref(), - *total_frame, - ); + self.emit_copy_with_type(src, target, insn.size, insn.typ, *total_frame, types); } } @@ -2006,7 +2005,7 @@ impl Aarch64CodeGen { } Opcode::VaArg => { - self.emit_va_arg(insn, *total_frame); + self.emit_va_arg(insn, *total_frame, types); } Opcode::VaEnd => { @@ -2532,7 +2531,7 @@ impl Aarch64CodeGen { } } - fn emit_load(&mut self, insn: &Instruction, frame_size: i32) { + fn emit_load(&mut self, insn: &Instruction, frame_size: i32, types: &TypeTable) { let mem_size = insn.size; let reg_size = insn.size.max(32); let addr = match insn.src.first() { @@ -2546,8 +2545,7 @@ impl Aarch64CodeGen { let dst_loc = self.get_location(target); // Check if this is an FP load - let is_fp = - insn.typ.as_ref().is_some_and(|t| t.is_float()) || matches!(dst_loc, Loc::VReg(_)); + let is_fp = insn.typ.is_some_and(|t| types.is_float(t)) || matches!(dst_loc, Loc::VReg(_)); if is_fp { self.emit_fp_load(insn, frame_size); @@ -2561,10 +2559,10 @@ impl Aarch64CodeGen { // Determine if we need sign or zero extension for small types // For plain char, use target.char_signed to determine signedness - let is_unsigned = insn.typ.as_ref().is_some_and(|t| { - if t.is_unsigned() { + let is_unsigned = insn.typ.is_some_and(|t| { + if types.is_unsigned(t) { true - } else if t.is_plain_char() { + } else if types.is_plain_char(t) { // Plain char: unsigned if target says char is not signed !self.target.char_signed } else { @@ -2977,7 +2975,7 @@ impl Aarch64CodeGen { } } - fn emit_call(&mut self, insn: &Instruction, frame_size: i32) { + fn emit_call(&mut self, insn: &Instruction, frame_size: i32, types: &TypeTable) { let func_name = match &insn.func_name { Some(n) => n.clone(), None => return, @@ -3004,26 +3002,9 @@ impl Aarch64CodeGen { let is_darwin_variadic = self.target.os == crate::target::Os::MacOS && insn.variadic_arg_start.is_some(); - // Check if this call returns a large struct - // If so, the first argument is the sret pointer and goes in X8 (not X0) - let returns_large_struct = insn.typ.as_ref().is_some_and(|t| { - (t.kind == crate::types::TypeKind::Struct || t.kind == crate::types::TypeKind::Union) - && t.size_bits() > MAX_REGISTER_AGGREGATE_BITS - }); - - // Also check if return type is a pointer to a large struct (linearizer wraps it) - let returns_large_struct = returns_large_struct - || insn.typ.as_ref().is_some_and(|t| { - if let Some(pointee) = t.get_base() { - (pointee.kind == crate::types::TypeKind::Struct - || pointee.kind == crate::types::TypeKind::Union) - && pointee.size_bits() > MAX_REGISTER_AGGREGATE_BITS - } else { - false - } - }); - - let args_start = if returns_large_struct && !insn.src.is_empty() { + // Check if this call returns a large struct via sret (hidden pointer argument). + // The linearizer sets is_sret_call=true and puts the sret pointer as the first arg. + let args_start = if insn.is_sret_call && !insn.src.is_empty() { // First argument is sret pointer - move to X8 self.emit_move(insn.src[0], Reg::X8, 64, frame_size); 1 // Skip first arg in main loop @@ -3041,16 +3022,16 @@ impl Aarch64CodeGen { // Process all arguments for (i, &arg) in insn.src.iter().enumerate().skip(args_start) { - let arg_type = insn.arg_types.get(i); + let arg_type = insn.arg_types.get(i).copied(); let is_fp = if let Some(typ) = arg_type { - typ.is_float() + types.is_float(typ) } else { let arg_loc = self.get_location(arg); matches!(arg_loc, Loc::VReg(_) | Loc::FImm(_)) }; let arg_size = if let Some(typ) = arg_type { - typ.size_bits().max(32) + types.size_bits(typ).max(32) } else { 64 }; @@ -3062,7 +3043,7 @@ impl Aarch64CodeGen { // Fixed arg - use registers as normal if is_fp { let fp_size = if let Some(typ) = arg_type { - typ.size_bits() + types.size_bits(typ) } else { 64 }; @@ -3126,9 +3107,9 @@ impl Aarch64CodeGen { // Move arguments to registers for (i, &arg) in insn.src.iter().enumerate().skip(args_start) { // Get argument type if available, otherwise fall back to location-based detection - let arg_type = insn.arg_types.get(i); + let arg_type = insn.arg_types.get(i).copied(); let is_fp = if let Some(typ) = arg_type { - typ.is_float() + types.is_float(typ) } else { // Fall back to location-based detection for backwards compatibility let arg_loc = self.get_location(arg); @@ -3137,7 +3118,7 @@ impl Aarch64CodeGen { // Get argument size from type, with minimum 32-bit for register ops let arg_size = if let Some(typ) = arg_type { - typ.size_bits().max(32) + types.size_bits(typ).max(32) } else { 64 // Default for backwards compatibility }; @@ -3145,7 +3126,7 @@ impl Aarch64CodeGen { if is_fp { // FP size from type (32 for float, 64 for double) let fp_size = if let Some(typ) = arg_type { - typ.size_bits() + types.size_bits(typ) } else { 64 }; @@ -3210,8 +3191,8 @@ impl Aarch64CodeGen { if let Some(target) = insn.target { let dst_loc = self.get_location(target); // Check if return value is floating-point based on type or location - let is_fp_result = if let Some(ref typ) = insn.typ { - typ.is_float() + let is_fp_result = if let Some(typ) = insn.typ { + types.is_float(typ) } else { matches!(dst_loc, Loc::VReg(_) | Loc::FImm(_)) }; @@ -3378,8 +3359,9 @@ impl Aarch64CodeGen { src: PseudoId, dst: PseudoId, size: u32, - typ: Option<&Type>, + typ: Option, frame_size: i32, + types: &TypeTable, ) { // Keep actual size for handling narrow types let actual_size = size; @@ -3394,9 +3376,9 @@ impl Aarch64CodeGen { // Determine if the type is unsigned (for proper sign/zero extension) // For plain char, use target.char_signed to determine signedness let is_unsigned = typ.is_some_and(|t| { - if t.is_unsigned() { + if types.is_unsigned(t) { true - } else if t.is_plain_char() { + } else if types.is_plain_char(t) { // Plain char: unsigned if target says char is not signed !self.target.char_signed } else { @@ -3955,7 +3937,7 @@ impl Aarch64CodeGen { /// Emit va_arg: Get the next variadic argument of the specified type /// Note: ap_addr is the ADDRESS of the va_list variable (from symaddr), not the va_list itself - fn emit_va_arg(&mut self, insn: &Instruction, frame_size: i32) { + fn emit_va_arg(&mut self, insn: &Instruction, frame_size: i32, types: &TypeTable) { let ap_addr = match insn.src.first() { Some(&s) => s, None => return, @@ -3965,9 +3947,8 @@ impl Aarch64CodeGen { None => return, }; - let default_type = Type::basic(crate::types::TypeKind::Int); - let arg_type = insn.typ.as_ref().unwrap_or(&default_type); - let arg_size = arg_type.size_bits().max(32); + let arg_type = insn.typ.unwrap_or(types.int_id); + let arg_size = types.size_bits(arg_type).max(32); let arg_bytes = (arg_size / 8).max(8) as i64; // Minimum 8 bytes per slot on ARM64 let ap_loc = self.get_location(ap_addr); @@ -4012,9 +3993,9 @@ impl Aarch64CodeGen { }); // Load the argument from *ap - if arg_type.is_float() { + if types.is_float(arg_type) { // Load floating point value - let fp_size = arg_type.size_bits(); + let fp_size = types.size_bits(arg_type); let fp_size_enum = FpSize::from_bits(fp_size); self.push_lir(Aarch64Inst::LdrFp { size: fp_size_enum, @@ -4494,7 +4475,7 @@ impl Aarch64CodeGen { // ============================================================================ impl CodeGenerator for Aarch64CodeGen { - fn generate(&mut self, module: &Module) -> String { + fn generate(&mut self, module: &Module, types: &TypeTable) -> String { self.output.clear(); self.lir_buffer.clear(); self.last_debug_line = 0; @@ -4515,7 +4496,7 @@ impl CodeGenerator for Aarch64CodeGen { // Emit globals for (name, typ, init) in &module.globals { - self.emit_global(name, typ, init); + self.emit_global(name, typ, init, types); } // Emit string literals @@ -4523,7 +4504,7 @@ impl CodeGenerator for Aarch64CodeGen { // Emit functions for func in &module.functions { - self.emit_function(func); + self.emit_function(func, types); } // Flush all buffered LIR instructions to output diff --git a/cc/arch/codegen.rs b/cc/arch/codegen.rs index fc0610af..53bfb432 100644 --- a/cc/arch/codegen.rs +++ b/cc/arch/codegen.rs @@ -11,6 +11,7 @@ use crate::ir::Module; use crate::target::Target; +use crate::types::TypeTable; /// pcc version string for assembly header pub const PCC_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -49,7 +50,7 @@ pub fn generate_header_comments(target: &Target) -> Vec { /// Trait for architecture-specific code generators pub trait CodeGenerator { /// Generate assembly code for the given IR module - fn generate(&mut self, module: &Module) -> String; + fn generate(&mut self, module: &Module, types: &TypeTable) -> String; /// Set whether to emit basic unwind tables (cfi_startproc/cfi_endproc) fn set_emit_unwind_tables(&mut self, emit: bool); diff --git a/cc/arch/mod.rs b/cc/arch/mod.rs index 36282129..09bafb26 100644 --- a/cc/arch/mod.rs +++ b/cc/arch/mod.rs @@ -9,6 +9,9 @@ // Architecture-specific predefined macros and code generators // +/// Default capacity for LIR instruction buffers (reduces reallocation overhead) +pub const DEFAULT_LIR_BUFFER_CAPACITY: usize = 5000; + pub mod aarch64; pub mod codegen; pub mod lir; diff --git a/cc/arch/x86_64/codegen.rs b/cc/arch/x86_64/codegen.rs index cf52771d..844c7408 100644 --- a/cc/arch/x86_64/codegen.rs +++ b/cc/arch/x86_64/codegen.rs @@ -19,9 +19,10 @@ use crate::arch::lir::{Directive, FpSize, Label, OperandSize, Symbol}; use crate::arch::x86_64::lir::{ CallTarget, GpOperand, IntCC, MemAddr, ShiftCount, X86Inst, XmmOperand, }; +use crate::arch::DEFAULT_LIR_BUFFER_CAPACITY; use crate::ir::{Function, Initializer, Instruction, Module, Opcode, Pseudo, PseudoId, PseudoKind}; use crate::target::Target; -use crate::types::{Type, TypeModifiers}; +use crate::types::{TypeId, TypeModifiers, TypeTable}; use std::collections::HashMap; // ============================================================================ @@ -348,7 +349,7 @@ impl RegAlloc { } /// Perform register allocation for a function - pub fn allocate(&mut self, func: &Function) -> HashMap { + pub fn allocate(&mut self, func: &Function, types: &TypeTable) -> HashMap { // Reset state self.locations.clear(); self.free_regs = Reg::allocatable().to_vec(); @@ -360,7 +361,7 @@ impl RegAlloc { self.fp_pseudos.clear(); // Scan instructions to identify which pseudos need FP registers - self.identify_fp_pseudos(func); + self.identify_fp_pseudos(func, types); // Pre-allocate argument registers // System V AMD64 ABI: integer args in RDI, RSI, RDX, RCX, R8, R9 @@ -390,7 +391,7 @@ impl RegAlloc { for pseudo in &func.pseudos { if let PseudoKind::Arg(arg_idx) = pseudo.kind { if arg_idx == (i as u32) + arg_idx_offset { - let is_fp = typ.is_float(); + let is_fp = types.is_float(*typ); if is_fp { // FP argument if fp_arg_idx < fp_arg_regs.len() { @@ -515,7 +516,7 @@ impl RegAlloc { // Check if this is a local variable or a global symbol if let Some(local_var) = func.locals.get(name) { // Local variable - allocate stack space based on type size - let size = (local_var.typ.size_bits() / 8) as i32; + let size = (types.size_bits(local_var.typ) / 8) as i32; let size = std::cmp::max(size, 8); // Minimum 8 bytes // Align to 8 bytes let aligned_size = (size + 7) & !7; @@ -523,7 +524,7 @@ impl RegAlloc { self.locations .insert(interval.pseudo, Loc::Stack(self.stack_offset)); // Mark as FP if the type is float - if local_var.typ.is_float() { + if types.is_float(local_var.typ) { self.fp_pseudos.insert(interval.pseudo); } } else { @@ -575,7 +576,7 @@ impl RegAlloc { } /// Scan function to identify which pseudos need FP registers - fn identify_fp_pseudos(&mut self, func: &Function) { + fn identify_fp_pseudos(&mut self, func: &Function, types: &TypeTable) { for block in &func.blocks { for insn in &block.insns { // Check instruction type - FP instructions produce FP results @@ -617,8 +618,8 @@ impl RegAlloc { // Also check the type if available // (but exclude comparisons which always produce int regardless of operand type) - if let Some(ref typ) = insn.typ { - if typ.is_float() + if let Some(typ) = insn.typ { + if types.is_float(typ) && !matches!( insn.op, Opcode::FCmpOEq @@ -879,7 +880,7 @@ impl X86_64CodeGen { Self { target, output: String::new(), - lir_buffer: Vec::new(), + lir_buffer: Vec::with_capacity(DEFAULT_LIR_BUFFER_CAPACITY), locations: HashMap::new(), pseudos: Vec::new(), current_fn: String::new(), @@ -956,15 +957,15 @@ impl X86_64CodeGen { self.push_lir(X86Inst::Directive(Directive::Text)); } - fn emit_global(&mut self, name: &str, typ: &Type, init: &Initializer) { - let size = typ.size_bits() / 8; + fn emit_global(&mut self, name: &str, typ: &TypeId, init: &Initializer, types: &TypeTable) { + let size = types.size_bits(*typ) / 8; let size = if size == 0 { 8 } else { size }; // Default to 8 bytes // Check storage class - skip .globl for static - let is_static = typ.modifiers.contains(TypeModifiers::STATIC); + let is_static = types.get(*typ).modifiers.contains(TypeModifiers::STATIC); // Get alignment from type info - let align = typ.alignment() as u32; + let align = types.alignment(*typ) as u32; // Use .comm for uninitialized external (non-static) globals let use_bss = matches!(init, Initializer::None) && !is_static; @@ -1033,7 +1034,7 @@ impl X86_64CodeGen { false } - fn emit_function(&mut self, func: &Function) { + fn emit_function(&mut self, func: &Function, types: &TypeTable) { // Save current function name for unique label generation self.current_fn = func.name.clone(); @@ -1042,7 +1043,7 @@ impl X86_64CodeGen { // Register allocation let mut alloc = RegAlloc::new(); - self.locations = alloc.allocate(func); + self.locations = alloc.allocate(func, types); self.pseudos = func.pseudos.clone(); let stack_size = alloc.stack_size(); @@ -1178,14 +1179,14 @@ impl X86_64CodeGen { for pseudo in &func.pseudos { if let PseudoKind::Arg(arg_idx) = pseudo.kind { if arg_idx == (i as u32) + arg_idx_offset { - let is_fp = typ.is_float(); + let is_fp = types.is_float(*typ); if is_fp { // FP argument if fp_arg_idx < fp_arg_regs.len() { if let Some(Loc::Stack(offset)) = self.locations.get(&pseudo.id) { // Move from FP arg register to stack let adjusted = offset + self.callee_saved_offset; - let fp_size = if typ.size_bits() == 32 { + let fp_size = if types.size_bits(*typ) == 32 { FpSize::Single } else { FpSize::Double @@ -1231,7 +1232,7 @@ impl X86_64CodeGen { self.num_fixed_gp_params = func .params .iter() - .filter(|(_, typ)| !typ.is_float()) + .filter(|(_, typ)| !types.is_float(*typ)) .count(); if has_sret { self.num_fixed_gp_params += 1; // Account for hidden sret pointer @@ -1240,7 +1241,7 @@ impl X86_64CodeGen { // Emit basic blocks for block in &func.blocks { - self.emit_block(block); + self.emit_block(block, types); } // CFI: End procedure @@ -1249,7 +1250,7 @@ impl X86_64CodeGen { } } - fn emit_block(&mut self, block: &crate::ir::BasicBlock) { + fn emit_block(&mut self, block: &crate::ir::BasicBlock, types: &TypeTable) { // Emit block label (include function name for uniqueness) if let Some(label) = &block.label { // LIR: named block label (using Raw since format differs from standard) @@ -1267,11 +1268,11 @@ impl X86_64CodeGen { // Emit instructions for insn in &block.insns { - self.emit_insn(insn); + self.emit_insn(insn, types); } } - fn emit_insn(&mut self, insn: &Instruction) { + fn emit_insn(&mut self, insn: &Instruction, types: &TypeTable) { // Emit .loc directive for debug info self.emit_loc(insn); @@ -1286,7 +1287,7 @@ impl X86_64CodeGen { if let Some(src) = insn.src.first() { let src_loc = self.get_location(*src); let is_fp = matches!(src_loc, Loc::Xmm(_) | Loc::FImm(..)) - || insn.typ.as_ref().is_some_and(|t| t.is_float()); + || insn.typ.is_some_and(|t| types.is_float(t)); if is_fp { // Float return value goes in XMM0 self.emit_fp_move(*src, XmmReg::Xmm0, insn.size); @@ -1550,15 +1551,15 @@ impl X86_64CodeGen { } Opcode::Load => { - self.emit_load(insn); + self.emit_load(insn, types); } Opcode::Store => { - self.emit_store(insn); + self.emit_store(insn, types); } Opcode::Call => { - self.emit_call(insn); + self.emit_call(insn, types); } Opcode::SetVal => { @@ -1593,7 +1594,7 @@ impl X86_64CodeGen { Opcode::Copy => { if let (Some(target), Some(&src)) = (insn.target, insn.src.first()) { // Pass the type to emit_copy for proper sign/zero extension - self.emit_copy_with_type(src, target, insn.size, insn.typ.as_ref()); + self.emit_copy_with_type(src, target, insn.size, insn.typ, types); } } @@ -1653,7 +1654,7 @@ impl X86_64CodeGen { } Opcode::VaArg => { - self.emit_va_arg(insn); + self.emit_va_arg(insn, types); } Opcode::VaEnd => { @@ -2873,7 +2874,7 @@ impl X86_64CodeGen { } } - fn emit_load(&mut self, insn: &Instruction) { + fn emit_load(&mut self, insn: &Instruction, types: &TypeTable) { let mem_size = insn.size; let reg_size = insn.size.max(32); let addr = match insn.src.first() { @@ -2887,8 +2888,7 @@ impl X86_64CodeGen { let dst_loc = self.get_location(target); // Check if this is an FP load - let is_fp = - insn.typ.as_ref().is_some_and(|t| t.is_float()) || matches!(dst_loc, Loc::Xmm(_)); + let is_fp = insn.typ.is_some_and(|t| types.is_float(t)) || matches!(dst_loc, Loc::Xmm(_)); if is_fp { self.emit_fp_load(insn); @@ -2903,10 +2903,10 @@ impl X86_64CodeGen { // Determine if we need sign or zero extension for small types // is_unsigned() returns true for explicitly unsigned types // For plain char, use target.char_signed to determine signedness - let is_unsigned = insn.typ.as_ref().is_some_and(|t| { - if t.is_unsigned() { + let is_unsigned = insn.typ.is_some_and(|t| { + if types.is_unsigned(t) { true - } else if t.is_plain_char() { + } else if types.is_plain_char(t) { // Plain char: unsigned if target says char is not signed !self.target.char_signed } else { @@ -3127,7 +3127,7 @@ impl X86_64CodeGen { } } - fn emit_store(&mut self, insn: &Instruction) { + fn emit_store(&mut self, insn: &Instruction, types: &TypeTable) { // Use actual size for memory stores (8, 16, 32, 64 bits) // This is critical for char/short types that need byte/word stores let mem_size = insn.size; @@ -3141,7 +3141,7 @@ impl X86_64CodeGen { // Check if this is an FP store let value_loc = self.get_location(value); - let is_fp = insn.typ.as_ref().is_some_and(|t| t.is_float()) + let is_fp = insn.typ.is_some_and(|t| types.is_float(t)) || matches!(value_loc, Loc::Xmm(_) | Loc::FImm(..)); if is_fp { @@ -3350,7 +3350,7 @@ impl X86_64CodeGen { } } - fn emit_call(&mut self, insn: &Instruction) { + fn emit_call(&mut self, insn: &Instruction, types: &TypeTable) { let func_name = match &insn.func_name { Some(n) => n.clone(), None => return, @@ -3379,9 +3379,9 @@ impl X86_64CodeGen { let mut temp_fp_idx = 0; for i in 0..insn.src.len() { - let arg_type = insn.arg_types.get(i); + let arg_type = insn.arg_types.get(i).copied(); let is_fp = if let Some(typ) = arg_type { - typ.is_float() + types.is_float(typ) } else { let arg_loc = self.get_location(insn.src[i]); matches!(arg_loc, Loc::Xmm(_) | Loc::FImm(..)) @@ -3421,9 +3421,9 @@ impl X86_64CodeGen { // Push stack arguments in reverse order for &i in stack_arg_indices.iter().rev() { let arg = insn.src[i]; - let arg_type = insn.arg_types.get(i); + let arg_type = insn.arg_types.get(i).copied(); let is_fp = if let Some(typ) = arg_type { - typ.is_float() + types.is_float(typ) } else { let arg_loc = self.get_location(arg); matches!(arg_loc, Loc::Xmm(_) | Loc::FImm(..)) @@ -3431,7 +3431,7 @@ impl X86_64CodeGen { if is_fp { let fp_size = if let Some(typ) = arg_type { - typ.size_bits() + types.size_bits(typ) } else { 64 }; @@ -3452,7 +3452,7 @@ impl X86_64CodeGen { }); } else { let arg_size = if let Some(typ) = arg_type { - typ.size_bits().max(32) + types.size_bits(typ).max(32) } else { 64 }; @@ -3472,9 +3472,9 @@ impl X86_64CodeGen { } let arg = insn.src[i]; // Get argument type if available, otherwise fall back to location-based detection - let arg_type = insn.arg_types.get(i); + let arg_type = insn.arg_types.get(i).copied(); let is_fp = if let Some(typ) = arg_type { - typ.is_float() + types.is_float(typ) } else { // Fall back to location-based detection for backwards compatibility let arg_loc = self.get_location(arg); @@ -3483,7 +3483,7 @@ impl X86_64CodeGen { // Get argument size from type, with minimum 32-bit for register ops let arg_size = if let Some(typ) = arg_type { - typ.size_bits().max(32) + types.size_bits(typ).max(32) } else { 64 // Default for backwards compatibility }; @@ -3491,7 +3491,7 @@ impl X86_64CodeGen { if is_fp { // FP size from type (32 for float, 64 for double) let fp_size = if let Some(typ) = arg_type { - typ.size_bits() + types.size_bits(typ) } else { 64 }; @@ -3534,8 +3534,8 @@ impl X86_64CodeGen { if let Some(target) = insn.target { let dst_loc = self.get_location(target); // Check if return value is floating-point based on its location or type - let is_fp_result = if let Some(ref typ) = insn.typ { - typ.is_float() + let is_fp_result = if let Some(typ) = insn.typ { + types.is_float(typ) } else { matches!(dst_loc, Loc::Xmm(_) | Loc::FImm(..)) }; @@ -3710,7 +3710,14 @@ impl X86_64CodeGen { } } - fn emit_copy_with_type(&mut self, src: PseudoId, dst: PseudoId, size: u32, typ: Option<&Type>) { + fn emit_copy_with_type( + &mut self, + src: PseudoId, + dst: PseudoId, + size: u32, + typ: Option, + types: &TypeTable, + ) { // Keep actual size for handling narrow types let actual_size = size; let reg_size = size.max(32); @@ -3724,9 +3731,9 @@ impl X86_64CodeGen { // Determine if the type is unsigned (for proper sign/zero extension) // For plain char, use target.char_signed to determine signedness let is_unsigned = typ.is_some_and(|t| { - if t.is_unsigned() { + if types.is_unsigned(t) { true - } else if t.is_plain_char() { + } else if types.is_plain_char(t) { // Plain char: unsigned if target says char is not signed !self.target.char_signed } else { @@ -3953,7 +3960,7 @@ impl X86_64CodeGen { } /// Emit va_arg: Get the next variadic argument - fn emit_va_arg(&mut self, insn: &Instruction) { + fn emit_va_arg(&mut self, insn: &Instruction, types: &TypeTable) { let ap_addr = match insn.src.first() { Some(&s) => s, None => return, @@ -3963,9 +3970,8 @@ impl X86_64CodeGen { None => return, }; - let default_type = Type::basic(crate::types::TypeKind::Int); - let arg_type = insn.typ.as_ref().unwrap_or(&default_type); - let arg_size = arg_type.size_bits().max(32); + let arg_type = insn.typ.unwrap_or(types.int_id); + let arg_size = types.size_bits(arg_type).max(32); let arg_bytes = (arg_size / 8).max(8) as i32; // Minimum 8 bytes per slot let ap_loc = self.get_location(ap_addr); @@ -3977,7 +3983,7 @@ impl X86_64CodeGen { match &ap_loc { Loc::Stack(ap_offset) => { - if arg_type.is_float() { + if types.is_float(arg_type) { // For float args, use overflow_arg_area (FP register save not implemented) // LIR: load overflow_arg_area pointer self.push_lir(X86Inst::Mov { @@ -3989,7 +3995,7 @@ impl X86_64CodeGen { dst: GpOperand::Reg(Reg::Rax), }); - let fp_size = arg_type.size_bits(); + let fp_size = types.size_bits(arg_type); let lir_fp_size = if fp_size <= 32 { FpSize::Single } else { @@ -4223,7 +4229,7 @@ impl X86_64CodeGen { } Loc::Reg(ap_reg) => { // Similar logic for when va_list is in a register - if arg_type.is_float() { + if types.is_float(arg_type) { // LIR: load overflow_arg_area pointer (at offset 8 from va_list base) self.push_lir(X86Inst::Mov { size: OperandSize::B64, @@ -4234,7 +4240,7 @@ impl X86_64CodeGen { dst: GpOperand::Reg(Reg::Rax), }); - let fp_size = arg_type.size_bits(); + let fp_size = types.size_bits(arg_type); let lir_fp_size = if fp_size <= 32 { FpSize::Single } else { @@ -5062,7 +5068,7 @@ impl X86_64CodeGen { // ============================================================================ impl CodeGenerator for X86_64CodeGen { - fn generate(&mut self, module: &Module) -> String { + fn generate(&mut self, module: &Module, types: &TypeTable) -> String { self.output.clear(); self.last_debug_line = 0; self.last_debug_file = 0; @@ -5082,7 +5088,7 @@ impl CodeGenerator for X86_64CodeGen { // Emit globals for (name, typ, init) in &module.globals { - self.emit_global(name, typ, init); + self.emit_global(name, typ, init, types); } // Emit string literals @@ -5092,7 +5098,7 @@ impl CodeGenerator for X86_64CodeGen { // Emit functions for func in &module.functions { - self.emit_function(func); + self.emit_function(func, types); } // Emit all buffered LIR instructions to output string diff --git a/cc/dominate.rs b/cc/dominate.rs index fd15b681..93d5e3c8 100644 --- a/cc/dominate.rs +++ b/cc/dominate.rs @@ -499,7 +499,7 @@ fn visit_domtree( mod tests { use super::*; use crate::ir::{BasicBlock, Instruction, Opcode}; - use crate::types::{Type, TypeKind}; + use crate::types::TypeTable; fn make_test_cfg() -> Function { // Create a simple CFG: @@ -514,7 +514,8 @@ mod tests { // v // exit(4) - let mut func = Function::new("test", Type::basic(TypeKind::Void)); + let types = TypeTable::new(); + let mut func = Function::new("test", types.void_id); let mut entry = BasicBlock::new(BasicBlockId(0)); entry.children = vec![BasicBlockId(1), BasicBlockId(2)]; diff --git a/cc/ir.rs b/cc/ir.rs index 405c7659..8ab66aaf 100644 --- a/cc/ir.rs +++ b/cc/ir.rs @@ -15,7 +15,7 @@ // use crate::diag::Position; -use crate::types::Type; +use crate::types::TypeId; use std::collections::HashMap; use std::fmt; @@ -419,8 +419,8 @@ pub struct Instruction { pub target: Option, /// Source operands pub src: Vec, - /// Type of the result - pub typ: Option, + /// Type of the result (interned TypeId) + pub typ: Option, /// For branches: true target pub bb_true: Option, /// For conditional branches: false target @@ -439,11 +439,14 @@ pub struct Instruction { pub switch_cases: Vec<(i64, BasicBlockId)>, /// For switch: default block (if no case matches) pub switch_default: Option, - /// For calls: argument types (parallel to src for Call instructions) - pub arg_types: Vec, + /// For calls: argument types (parallel to src for Call instructions, interned TypeIds) + pub arg_types: Vec, /// For variadic calls: index where variadic arguments start (0-based) /// All arguments at this index and beyond are variadic (should be passed on stack) pub variadic_arg_start: Option, + /// For calls: true if this call returns a large struct via sret (hidden pointer arg). + /// The first element of `src` is the sret pointer when this is true. + pub is_sret_call: bool, /// Source position for debug info pub pos: Option, } @@ -466,6 +469,7 @@ impl Default for Instruction { switch_default: None, arg_types: Vec::new(), variadic_arg_start: None, + is_sret_call: false, pos: None, } } @@ -504,13 +508,19 @@ impl Instruction { self } - /// Set the type - pub fn with_type(mut self, typ: Type) -> Self { - self.size = typ.size_bits(); + /// Set the type (caller should also call with_size if needed) + pub fn with_type(mut self, typ: TypeId) -> Self { self.typ = Some(typ); self } + /// Set type and size together (convenience for callers with TypeTable access) + pub fn with_type_and_size(mut self, typ: TypeId, size: u32) -> Self { + self.typ = Some(typ); + self.size = size; + self + } + /// Set the true branch target pub fn with_bb_true(mut self, bb: BasicBlockId) -> Self { self.bb_true = Some(bb); @@ -557,10 +567,8 @@ impl Instruction { } /// Create a return instruction with type - pub fn ret_typed(src: Option, typ: Type) -> Self { - let mut insn = Self::ret(src).with_type(typ.clone()); - insn.size = typ.size_bits(); - insn + pub fn ret_typed(src: Option, typ: TypeId, size: u32) -> Self { + Self::ret(src).with_type_and_size(typ, size) } /// Create an unconditional branch @@ -594,37 +602,44 @@ impl Instruction { } /// Create a binary operation - pub fn binop(op: Opcode, target: PseudoId, src1: PseudoId, src2: PseudoId, typ: Type) -> Self { + pub fn binop( + op: Opcode, + target: PseudoId, + src1: PseudoId, + src2: PseudoId, + typ: TypeId, + size: u32, + ) -> Self { Self::new(op) .with_target(target) .with_src2(src1, src2) - .with_type(typ) + .with_type_and_size(typ, size) } /// Create a unary operation - pub fn unop(op: Opcode, target: PseudoId, src: PseudoId, typ: Type) -> Self { + pub fn unop(op: Opcode, target: PseudoId, src: PseudoId, typ: TypeId, size: u32) -> Self { Self::new(op) .with_target(target) .with_src(src) - .with_type(typ) + .with_type_and_size(typ, size) } /// Create a load instruction - pub fn load(target: PseudoId, addr: PseudoId, offset: i64, typ: Type) -> Self { + pub fn load(target: PseudoId, addr: PseudoId, offset: i64, typ: TypeId, size: u32) -> Self { Self::new(Opcode::Load) .with_target(target) .with_src(addr) .with_offset(offset) - .with_type(typ) + .with_type_and_size(typ, size) } /// Create a store instruction - pub fn store(value: PseudoId, addr: PseudoId, offset: i64, typ: Type) -> Self { + pub fn store(value: PseudoId, addr: PseudoId, offset: i64, typ: TypeId, size: u32) -> Self { Self::new(Opcode::Store) .with_src(addr) .with_src(value) .with_offset(offset) - .with_type(typ) + .with_type_and_size(typ, size) } /// Create a call instruction @@ -632,10 +647,13 @@ impl Instruction { target: Option, func: &str, args: Vec, - arg_types: Vec, - ret_type: Type, + arg_types: Vec, + ret_type: TypeId, + ret_size: u32, ) -> Self { - let mut insn = Self::new(Opcode::Call).with_func(func).with_type(ret_type); + let mut insn = Self::new(Opcode::Call) + .with_func(func) + .with_type_and_size(ret_type, ret_size); if let Some(t) = target { insn.target = Some(t); } @@ -645,16 +663,18 @@ impl Instruction { } /// Create a symbol address instruction (get address of a symbol like string literals) - pub fn sym_addr(target: PseudoId, sym: PseudoId, typ: Type) -> Self { + pub fn sym_addr(target: PseudoId, sym: PseudoId, typ: TypeId) -> Self { Self::new(Opcode::SymAddr) .with_target(target) .with_src(sym) - .with_type(typ) + .with_type_and_size(typ, 64) // Pointers are always 64-bit } /// Create a phi node - pub fn phi(target: PseudoId, typ: Type) -> Self { - Self::new(Opcode::Phi).with_target(target).with_type(typ) + pub fn phi(target: PseudoId, typ: TypeId, size: u32) -> Self { + Self::new(Opcode::Phi) + .with_target(target) + .with_type_and_size(typ, size) } /// Create a select (ternary) instruction @@ -663,12 +683,13 @@ impl Instruction { cond: PseudoId, if_true: PseudoId, if_false: PseudoId, - typ: Type, + typ: TypeId, + size: u32, ) -> Self { Self::new(Opcode::Select) .with_target(target) .with_src3(cond, if_true, if_false) - .with_type(typ) + .with_type_and_size(typ, size) } } @@ -883,8 +904,8 @@ impl fmt::Display for BasicBlock { pub struct LocalVar { /// Symbol pseudo for this variable (address) pub sym: PseudoId, - /// Type of the variable - pub typ: Type, + /// Type of the variable (interned TypeId) + pub typ: TypeId, /// Is this variable volatile? pub is_volatile: bool, /// Block where this variable was declared (for scope-aware phi placement) @@ -897,10 +918,10 @@ pub struct LocalVar { pub struct Function { /// Function name pub name: String, - /// Return type - pub return_type: Type, - /// Parameter names and types - pub params: Vec<(String, Type)>, + /// Return type (interned TypeId) + pub return_type: TypeId, + /// Parameter names and types (interned TypeIds) + pub params: Vec<(String, TypeId)>, /// All basic blocks pub blocks: Vec, /// Entry block ID @@ -919,7 +940,7 @@ impl Default for Function { fn default() -> Self { Self { name: String::new(), - return_type: Type::default(), + return_type: TypeId::INVALID, params: Vec::new(), blocks: Vec::new(), entry: BasicBlockId(0), @@ -933,7 +954,7 @@ impl Default for Function { impl Function { /// Create a new function - pub fn new(name: impl Into, return_type: Type) -> Self { + pub fn new(name: impl Into, return_type: TypeId) -> Self { Self { name: name.into(), return_type, @@ -942,7 +963,7 @@ impl Function { } /// Add a parameter - pub fn add_param(&mut self, name: impl Into, typ: Type) { + pub fn add_param(&mut self, name: impl Into, typ: TypeId) { self.params.push((name.into(), typ)); } @@ -971,7 +992,7 @@ impl Function { &mut self, name: impl Into, sym: PseudoId, - typ: Type, + typ: TypeId, is_volatile: bool, decl_block: Option, ) { @@ -1015,12 +1036,12 @@ impl Function { impl fmt::Display for Function { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // Function header - write!(f, "define {} {}(", self.return_type, self.name)?; + write!(f, "define type#{} {}(", self.return_type.0, self.name)?; for (i, (name, typ)) in self.params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } - write!(f, "{} %{}", typ, name)?; + write!(f, "type#{} %{}", typ.0, name)?; } writeln!(f, ") {{")?; @@ -1069,7 +1090,7 @@ pub struct Module { /// Functions pub functions: Vec, /// Global variables (name, type, initializer) - pub globals: Vec<(String, Type, Initializer)>, + pub globals: Vec<(String, TypeId, Initializer)>, /// String literals (label, content) pub strings: Vec<(String, String)>, /// Generate debug info @@ -1096,7 +1117,7 @@ impl Module { } /// Add a global variable - pub fn add_global(&mut self, name: impl Into, typ: Type, init: Initializer) { + pub fn add_global(&mut self, name: impl Into, typ: TypeId, init: Initializer) { self.globals.push((name.into(), typ, init)); } @@ -1119,8 +1140,8 @@ impl fmt::Display for Module { // Globals for (name, typ, init) in &self.globals { match init { - Initializer::None => writeln!(f, "@{}: {}", name, typ)?, - _ => writeln!(f, "@{}: {} = {}", name, typ, init)?, + Initializer::None => writeln!(f, "@{}: type#{}", name, typ.0)?, + _ => writeln!(f, "@{}: type#{} = {}", name, typ.0, init)?, } } @@ -1144,7 +1165,7 @@ impl fmt::Display for Module { #[cfg(test)] mod tests { use super::*; - use crate::types::TypeKind; + use crate::types::{Type, TypeTable}; #[test] fn test_opcode_is_terminator() { @@ -1175,8 +1196,15 @@ mod tests { #[test] fn test_instruction_binop() { - let typ = Type::basic(TypeKind::Int); - let insn = Instruction::binop(Opcode::Add, PseudoId(3), PseudoId(1), PseudoId(2), typ); + let types = TypeTable::new(); + let insn = Instruction::binop( + Opcode::Add, + PseudoId(3), + PseudoId(1), + PseudoId(2), + types.int_id, + 32, + ); assert_eq!(insn.op, Opcode::Add); assert_eq!(insn.target, Some(PseudoId(3))); assert_eq!(insn.src.len(), 2); @@ -1184,8 +1212,15 @@ mod tests { #[test] fn test_instruction_display() { - let typ = Type::basic(TypeKind::Int); - let insn = Instruction::binop(Opcode::Add, PseudoId(3), PseudoId(1), PseudoId(2), typ); + let types = TypeTable::new(); + let insn = Instruction::binop( + Opcode::Add, + PseudoId(3), + PseudoId(1), + PseudoId(2), + types.int_id, + 32, + ); let s = format!("{}", insn); assert!(s.contains("add")); assert!(s.contains("%3")); @@ -1207,8 +1242,9 @@ mod tests { #[test] fn test_function_display() { - let mut func = Function::new("main", Type::basic(TypeKind::Int)); - func.add_param("argc", Type::basic(TypeKind::Int)); + let types = TypeTable::new(); + let mut func = Function::new("main", types.int_id); + func.add_param("argc", types.int_id); let mut entry = BasicBlock::new(BasicBlockId(0)); entry.add_insn(Instruction::ret(Some(PseudoId(0)))); @@ -1236,17 +1272,16 @@ mod tests { #[test] fn test_call_instruction() { - let typ = Type::basic(TypeKind::Int); - let arg_types = vec![ - Type::pointer(Type::basic(TypeKind::Char)), - Type::basic(TypeKind::Int), - ]; + let mut types = TypeTable::new(); + let char_ptr = types.intern(Type::pointer(types.char_id)); + let arg_types = vec![char_ptr, types.int_id]; let call = Instruction::call( Some(PseudoId(1)), "printf", vec![PseudoId(2), PseudoId(3)], arg_types.clone(), - typ, + types.int_id, + 32, ); assert_eq!(call.op, Opcode::Call); assert_eq!(call.func_name, Some("printf".to_string())); @@ -1256,24 +1291,25 @@ mod tests { #[test] fn test_load_store() { - let typ = Type::basic(TypeKind::Int); + let types = TypeTable::new(); - let load = Instruction::load(PseudoId(1), PseudoId(2), 8, typ.clone()); + let load = Instruction::load(PseudoId(1), PseudoId(2), 8, types.int_id, 32); assert_eq!(load.op, Opcode::Load); assert_eq!(load.offset, 8); - let store = Instruction::store(PseudoId(1), PseudoId(2), 0, typ); + let store = Instruction::store(PseudoId(1), PseudoId(2), 0, types.int_id, 32); assert_eq!(store.op, Opcode::Store); assert_eq!(store.src.len(), 2); } #[test] fn test_module() { + let types = TypeTable::new(); let mut module = Module::new(); - module.add_global("counter", Type::basic(TypeKind::Int), Initializer::Int(0)); + module.add_global("counter", types.int_id, Initializer::Int(0)); - let func = Function::new("main", Type::basic(TypeKind::Int)); + let func = Function::new("main", types.int_id); module.add_function(func); assert_eq!(module.globals.len(), 1); diff --git a/cc/linearize.rs b/cc/linearize.rs index 9628d59a..36103ad4 100644 --- a/cc/linearize.rs +++ b/cc/linearize.rs @@ -20,21 +20,17 @@ use crate::parse::ast::{ }; use crate::ssa::ssa_convert; use crate::symbol::SymbolTable; -use crate::types::{MemberInfo, Type, TypeKind, TypeModifiers}; +use crate::target::Target; +use crate::types::{MemberInfo, TypeId, TypeKind, TypeModifiers, TypeTable}; use std::collections::HashMap; -/// Maximum size (in bits) for aggregate types (struct/union) to be passed or -/// returned by value in registers. Aggregates larger than this require -/// indirect passing (pointer) or sret (struct return pointer). -pub const MAX_REGISTER_AGGREGATE_BITS: u32 = 64; - /// Information about a local variable #[derive(Clone)] struct LocalVarInfo { /// Symbol pseudo (address of the variable) sym: PseudoId, /// Type of the variable - typ: Type, + typ: TypeId, } /// Information about a static local variable @@ -43,7 +39,7 @@ struct StaticLocalInfo { /// Global symbol name (unique across translation unit) global_name: String, /// Type of the variable - typ: Type, + typ: TypeId, } // ============================================================================ @@ -76,6 +72,8 @@ pub struct Linearizer<'a> { run_ssa: bool, /// Symbol table for looking up enum constants, etc. symbols: &'a SymbolTable, + /// Type table for type information + types: &'a TypeTable, /// Hidden struct return pointer (for functions returning large structs) struct_return_ptr: Option, /// Size of struct being returned (for functions returning large structs) @@ -89,11 +87,13 @@ pub struct Linearizer<'a> { static_locals: HashMap, /// Current source position for debug info current_pos: Option, + /// Target configuration (architecture, ABI details) + target: &'a Target, } impl<'a> Linearizer<'a> { /// Create a new linearizer - pub fn new(symbols: &'a SymbolTable) -> Self { + pub fn new(symbols: &'a SymbolTable, types: &'a TypeTable, target: &'a Target) -> Self { Self { module: Module::new(), current_func: None, @@ -107,21 +107,23 @@ impl<'a> Linearizer<'a> { continue_targets: Vec::new(), run_ssa: true, // Enable SSA conversion by default symbols, + types, struct_return_ptr: None, struct_return_size: 0, current_func_name: String::new(), static_local_counter: 0, static_locals: HashMap::new(), current_pos: None, + target, } } /// Create a linearizer with SSA conversion disabled (for testing) #[cfg(test)] - pub fn new_no_ssa(symbols: &'a SymbolTable) -> Self { + pub fn new_no_ssa(symbols: &'a SymbolTable, types: &'a TypeTable, target: &'a Target) -> Self { Self { run_ssa: false, - ..Self::new(symbols) + ..Self::new(symbols, types, target) } } @@ -179,10 +181,10 @@ impl<'a> Linearizer<'a> { /// Apply C99 integer promotions (6.3.1.1) /// Types smaller than int are promoted to int (or unsigned int if int can't hold all values) - fn integer_promote(typ: &Type) -> Type { + fn integer_promote(&self, typ_id: TypeId) -> TypeId { // Integer promotions apply to _Bool, char, short (and their unsigned variants) // They are promoted to int if int can represent all values, otherwise unsigned int - match typ.kind { + match self.types.kind(typ_id) { TypeKind::Bool | TypeKind::Char | TypeKind::Short => { // int (32-bit signed) can represent all values of: // - _Bool (0-1) @@ -191,15 +193,15 @@ impl<'a> Linearizer<'a> { // - short/signed short (-32768 to 32767) // - unsigned short (0 to 65535) // So always promote to int - Type::basic(TypeKind::Int) + self.types.int_id } - _ => typ.clone(), + _ => typ_id, } } /// Compute the common type for usual arithmetic conversions (C99 6.3.1.8) /// Returns the wider type that both operands should be converted to - fn common_type(left: &Type, right: &Type) -> Type { + fn common_type(&self, left: TypeId, right: TypeId) -> TypeId { // C99 6.3.1.8 usual arithmetic conversions: // 1. If either is long double, convert to long double // 2. Else if either is double, convert to double @@ -212,36 +214,38 @@ impl<'a> Linearizer<'a> { // e. Otherwise convert both to unsigned version of signed type // Check for floating point types - let left_float = left.is_float(); - let right_float = right.is_float(); + let left_float = self.types.is_float(left); + let right_float = self.types.is_float(right); + + let left_kind = self.types.kind(left); + let right_kind = self.types.kind(right); if left_float || right_float { // At least one operand is floating point // Use the wider floating point type - if left.kind == TypeKind::LongDouble || right.kind == TypeKind::LongDouble { - return Type::basic(TypeKind::LongDouble); + if left_kind == TypeKind::LongDouble || right_kind == TypeKind::LongDouble { + return self.types.longdouble_id; } - if left.kind == TypeKind::Double || right.kind == TypeKind::Double { - return Type::basic(TypeKind::Double); + if left_kind == TypeKind::Double || right_kind == TypeKind::Double { + return self.types.double_id; } // Both are float or one is float and one is integer - return Type::basic(TypeKind::Float); + return self.types.float_id; } // Apply integer promotions first (C99 6.3.1.1) - let left_promoted = Self::integer_promote(left); - let right_promoted = Self::integer_promote(right); + let left_promoted = self.integer_promote(left); + let right_promoted = self.integer_promote(right); - let left_size = left_promoted.size_bits(); - let right_size = right_promoted.size_bits(); - let left_unsigned = left_promoted.is_unsigned(); - let right_unsigned = right_promoted.is_unsigned(); + let left_size = self.types.size_bits(left_promoted); + let right_size = self.types.size_bits(right_promoted); + let left_unsigned = self.types.is_unsigned(left_promoted); + let right_unsigned = self.types.is_unsigned(right_promoted); + let left_kind = self.types.kind(left_promoted); + let right_kind = self.types.kind(right_promoted); // If both have same type after promotion, use that type - if left_promoted.kind == right_promoted.kind - && left_unsigned == right_unsigned - && left_size == right_size - { + if left_kind == right_kind && left_unsigned == right_unsigned && left_size == right_size { return left_promoted; } @@ -255,53 +259,55 @@ impl<'a> Linearizer<'a> { } // Mixed signedness case - let (signed_typ, unsigned_typ) = if left_unsigned { - (&right_promoted, &left_promoted) + let (signed_id, unsigned_id) = if left_unsigned { + (right_promoted, left_promoted) } else { - (&left_promoted, &right_promoted) + (left_promoted, right_promoted) }; - let signed_size = signed_typ.size_bits(); - let unsigned_size = unsigned_typ.size_bits(); + let signed_size = self.types.size_bits(signed_id); + let unsigned_size = self.types.size_bits(unsigned_id); // If unsigned has rank >= signed, convert to unsigned if unsigned_size >= signed_size { - return unsigned_typ.clone(); + return unsigned_id; } // If signed type can represent all values of unsigned type, use signed // (This is true when signed_size > unsigned_size on our platforms) if signed_size > unsigned_size { - return signed_typ.clone(); + return signed_id; } // Otherwise convert both to unsigned version of signed type // (This case shouldn't happen on LP64 since we already handled size comparisons) - Type::with_modifiers(signed_typ.kind, TypeModifiers::UNSIGNED) + self.types.unsigned_version(signed_id) } /// Emit a type conversion if needed /// Returns the (possibly converted) pseudo ID - fn emit_convert(&mut self, val: PseudoId, from_typ: &Type, to_typ: &Type) -> PseudoId { - let from_size = from_typ.size_bits(); - let to_size = to_typ.size_bits(); - let from_float = from_typ.is_float(); - let to_float = to_typ.is_float(); + fn emit_convert(&mut self, val: PseudoId, from_typ: TypeId, to_typ: TypeId) -> PseudoId { + let from_size = self.types.size_bits(from_typ); + let to_size = self.types.size_bits(to_typ); + let from_float = self.types.is_float(from_typ); + let to_float = self.types.is_float(to_typ); + let from_kind = self.types.kind(from_typ); + let to_kind = self.types.kind(to_typ); // Same type and size - no conversion needed - if from_typ.kind == to_typ.kind && from_size == to_size { + if from_kind == to_kind && from_size == to_size { return val; } // Array to pointer conversion (decay) - no actual conversion needed // The array value is already the address of the first element (64-bit) - if from_typ.kind == TypeKind::Array && to_typ.kind == TypeKind::Pointer { + if from_kind == TypeKind::Array && to_kind == TypeKind::Pointer { return val; } // Pointer to pointer conversion - no actual conversion needed // All pointers are the same size (64-bit) - if from_typ.kind == TypeKind::Pointer && to_typ.kind == TypeKind::Pointer { + if from_kind == TypeKind::Pointer && to_kind == TypeKind::Pointer { return val; } @@ -309,7 +315,7 @@ impl<'a> Linearizer<'a> { // When any scalar value is converted to _Bool: // - Result is 0 if the value compares equal to 0 // - Result is 1 otherwise - if to_typ.kind == TypeKind::Bool && from_typ.kind != TypeKind::Bool { + if to_kind == TypeKind::Bool && from_kind != TypeKind::Bool { let result = self.alloc_pseudo(); let pseudo = Pseudo::reg(result, result.0); if let Some(func) = &mut self.current_func { @@ -317,7 +323,7 @@ impl<'a> Linearizer<'a> { } // Create a zero constant for comparison - let zero = self.emit_const(0, from_typ.clone()); + let zero = self.emit_const(0, from_typ); // Compare val != 0 let opcode = if from_float { @@ -326,8 +332,7 @@ impl<'a> Linearizer<'a> { Opcode::SetNe }; - let mut insn = Instruction::binop(opcode, result, val, zero, to_typ.clone()); - insn.size = to_size; + let mut insn = Instruction::binop(opcode, result, val, zero, to_typ, to_size); insn.src_size = from_size; self.emit(insn); @@ -347,22 +352,21 @@ impl<'a> Linearizer<'a> { Opcode::FCvtF } else if from_float { // Float to integer - if to_typ.is_unsigned() { + if self.types.is_unsigned(to_typ) { Opcode::FCvtU } else { Opcode::FCvtS } } else { // Integer to float - if from_typ.is_unsigned() { + if self.types.is_unsigned(from_typ) { Opcode::UCvtF } else { Opcode::SCvtF } }; - let mut insn = Instruction::unop(opcode, result, val, to_typ.clone()); - insn.size = to_size; + let mut insn = Instruction::unop(opcode, result, val, to_typ, to_size); insn.src_size = from_size; self.emit(insn); return result; @@ -382,19 +386,17 @@ impl<'a> Linearizer<'a> { if to_size > from_size { // Extending - use sign or zero extension based on source type - let opcode = if from_typ.is_unsigned() { + let opcode = if self.types.is_unsigned(from_typ) { Opcode::Zext } else { Opcode::Sext }; - let mut insn = Instruction::unop(opcode, result, val, to_typ.clone()); - insn.size = to_size; + let mut insn = Instruction::unop(opcode, result, val, to_typ, to_size); insn.src_size = from_size; self.emit(insn); } else { // Truncating - let mut insn = Instruction::unop(Opcode::Trunc, result, val, to_typ.clone()); - insn.size = to_size; + let mut insn = Instruction::unop(Opcode::Trunc, result, val, to_typ, to_size); insn.src_size = from_size; self.emit(insn); } @@ -446,12 +448,16 @@ impl<'a> Linearizer<'a> { fn linearize_global_decl(&mut self, decl: &Declaration) { for declarator in &decl.declarators { // Skip function declarations (they're just forward declarations for external functions) - if declarator.typ.kind == TypeKind::Function { + if self.types.kind(declarator.typ) == TypeKind::Function { continue; } // Skip extern declarations - they don't define storage - if declarator.typ.modifiers.contains(TypeModifiers::EXTERN) { + if self + .types + .modifiers(declarator.typ) + .contains(TypeModifiers::EXTERN) + { continue; } @@ -465,7 +471,7 @@ impl<'a> Linearizer<'a> { }); self.module - .add_global(&declarator.name, declarator.typ.clone(), init); + .add_global(&declarator.name, declarator.typ, init); } } @@ -491,16 +497,19 @@ impl<'a> Linearizer<'a> { // Note: static_locals is NOT cleared - it persists across functions // Create function - let is_static = func.return_type.modifiers.contains(TypeModifiers::STATIC); - let mut ir_func = Function::new(&func.name, func.return_type.clone()); + let is_static = self + .types + .modifiers(func.return_type) + .contains(TypeModifiers::STATIC); + let mut ir_func = Function::new(&func.name, func.return_type); ir_func.is_static = is_static; + let ret_kind = self.types.kind(func.return_type); // Check if function returns a large struct // Large structs are returned via a hidden first parameter (sret) // that points to caller-allocated space - let returns_large_struct = (func.return_type.kind == TypeKind::Struct - || func.return_type.kind == TypeKind::Union) - && func.return_type.size_bits() > MAX_REGISTER_AGGREGATE_BITS; + let returns_large_struct = (ret_kind == TypeKind::Struct || ret_kind == TypeKind::Union) + && self.types.size_bits(func.return_type) > self.target.max_aggregate_register_bits; // Argument index offset: if returning large struct, first arg is hidden return pointer let arg_offset: u32 = if returns_large_struct { 1 } else { 0 }; @@ -511,17 +520,17 @@ impl<'a> Linearizer<'a> { let sret_pseudo = Pseudo::arg(sret_id, 0).with_name("__sret"); ir_func.add_pseudo(sret_pseudo); self.struct_return_ptr = Some(sret_id); - self.struct_return_size = func.return_type.size_bits(); + self.struct_return_size = self.types.size_bits(func.return_type); } // Add parameters // For struct/union parameters, we need to copy them to local storage // so member access works properly - let mut struct_params: Vec<(String, Type, PseudoId)> = Vec::new(); + let mut struct_params: Vec<(String, TypeId, PseudoId)> = Vec::new(); for (i, param) in func.params.iter().enumerate() { let name = param.name.clone().unwrap_or_else(|| format!("arg{}", i)); - ir_func.add_param(&name, param.typ.clone()); + ir_func.add_param(&name, param.typ); // Create argument pseudo (offset by 1 if there's a hidden return pointer) let pseudo_id = self.alloc_pseudo(); @@ -529,8 +538,9 @@ impl<'a> Linearizer<'a> { ir_func.add_pseudo(pseudo); // For struct/union types, we'll copy to a local later - if param.typ.kind == TypeKind::Struct || param.typ.kind == TypeKind::Union { - struct_params.push((name, param.typ.clone(), pseudo_id)); + let param_kind = self.types.kind(param.typ); + if param_kind == TypeKind::Struct || param_kind == TypeKind::Union { + struct_params.push((name, param.typ, pseudo_id)); } else { self.var_map.insert(name, pseudo_id); } @@ -552,15 +562,16 @@ impl<'a> Linearizer<'a> { let sym = Pseudo::sym(local_sym, name.clone()); if let Some(func) = &mut self.current_func { func.add_pseudo(sym); - let is_volatile = typ.modifiers.contains(TypeModifiers::VOLATILE); - func.add_local(&name, local_sym, typ.clone(), is_volatile, None); + let is_volatile = self.types.modifiers(typ).contains(TypeModifiers::VOLATILE); + func.add_local(&name, local_sym, typ, is_volatile, None); } + let typ_size = self.types.size_bits(typ); // For large structs, arg_pseudo is a pointer to the struct // We need to copy the data from that pointer to local storage - if typ.size_bits() > MAX_REGISTER_AGGREGATE_BITS { + if typ_size > self.target.max_aggregate_register_bits { // arg_pseudo is a pointer - copy each 8-byte chunk - let struct_size = typ.size_bits() / 8; + let struct_size = typ_size / 8; let mut offset = 0i64; while offset < struct_size as i64 { // Load 8 bytes from arg_pseudo + offset @@ -573,20 +584,22 @@ impl<'a> Linearizer<'a> { temp, arg_pseudo, offset, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); // Store to local_sym + offset self.emit(Instruction::store( temp, local_sym, offset, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); offset += 8; } } else { // Small struct: arg_pseudo contains the value directly - self.emit(Instruction::store(arg_pseudo, local_sym, 0, typ.clone())); + self.emit(Instruction::store(arg_pseudo, local_sym, 0, typ, typ_size)); } // Register as a local variable @@ -604,11 +617,11 @@ impl<'a> Linearizer<'a> { // Ensure function ends with a return if !self.is_terminated() { - if func.return_type.kind == TypeKind::Void { + if ret_kind == TypeKind::Void { self.emit(Instruction::ret(None)); } else { // Return 0 as default - let zero = self.emit_const(0, Type::basic(TypeKind::Int)); + let zero = self.emit_const(0, self.types.int_id); self.emit(Instruction::ret(Some(zero))); } } @@ -616,7 +629,7 @@ impl<'a> Linearizer<'a> { // Run SSA conversion if enabled if self.run_ssa { if let Some(ref mut ir_func) = self.current_func { - ssa_convert(ir_func); + ssa_convert(ir_func, self.types); } } @@ -697,23 +710,30 @@ impl<'a> Linearizer<'a> { temp, src_addr, byte_offset, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); // Store to destination self.emit(Instruction::store( temp, sret_ptr, byte_offset, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); byte_offset += 8; } // Return the hidden pointer (ABI requirement) - self.emit(Instruction::ret_typed(Some(sret_ptr), Type::pointer(typ))); + self.emit(Instruction::ret_typed( + Some(sret_ptr), + self.types.void_ptr_id, + 64, + )); } else { let val = self.linearize_expr(e); - self.emit(Instruction::ret_typed(Some(val), typ)); + let typ_size = self.types.size_bits(typ); + self.emit(Instruction::ret_typed(Some(val), typ, typ_size)); } } else { self.emit(Instruction::ret(None)); @@ -771,10 +791,10 @@ impl<'a> Linearizer<'a> { fn linearize_local_decl(&mut self, decl: &Declaration) { for declarator in &decl.declarators { - let typ = declarator.typ.clone(); + let typ = declarator.typ; // Check if this is a static local variable - if typ.modifiers.contains(TypeModifiers::STATIC) { + if self.types.modifiers(typ).contains(TypeModifiers::STATIC) { // Static local: create a global with unique name self.linearize_static_local(declarator); continue; @@ -789,37 +809,27 @@ impl<'a> Linearizer<'a> { func.add_pseudo(sym); // Register with function's local variable tracking for SSA // Pass the current basic block as the declaration block for scope-aware phi placement - let is_volatile = typ.modifiers.contains(TypeModifiers::VOLATILE); - func.add_local( - &unique_name, - sym_id, - typ.clone(), - is_volatile, - self.current_bb, - ); + let is_volatile = self.types.modifiers(typ).contains(TypeModifiers::VOLATILE); + func.add_local(&unique_name, sym_id, typ, is_volatile, self.current_bb); } // Track in linearizer's locals map - self.locals.insert( - declarator.name.clone(), - LocalVarInfo { - sym: sym_id, - typ: typ.clone(), - }, - ); + self.locals + .insert(declarator.name.clone(), LocalVarInfo { sym: sym_id, typ }); // If there's an initializer, emit Store(s) if let Some(init) = &declarator.init { if let ExprKind::InitList { elements } = &init.kind { // Handle initializer list for arrays and structs - self.linearize_init_list(sym_id, &typ, elements); + self.linearize_init_list(sym_id, typ, elements); } else { // Simple scalar initializer let val = self.linearize_expr(init); // Convert the value to the target type (important for _Bool normalization) let init_type = self.expr_type(init); - let converted = self.emit_convert(val, &init_type, &typ); - self.emit(Instruction::store(converted, sym_id, 0, typ)); + let converted = self.emit_convert(val, init_type, typ); + let size = self.types.size_bits(typ); + self.emit(Instruction::store(converted, sym_id, 0, typ, size)); } } } @@ -845,7 +855,7 @@ impl<'a> Linearizer<'a> { key, StaticLocalInfo { global_name: global_name.clone(), - typ: declarator.typ.clone(), + typ: declarator.typ, }, ); @@ -856,7 +866,7 @@ impl<'a> Linearizer<'a> { LocalVarInfo { // Use a sentinel value - we'll handle static locals specially sym: PseudoId(u32::MAX), - typ: declarator.typ.clone(), + typ: declarator.typ, }, ); @@ -883,12 +893,11 @@ impl<'a> Linearizer<'a> { }); // Add as a global (type already has STATIC modifier which codegen uses) - self.module - .add_global(&global_name, declarator.typ.clone(), init); + self.module.add_global(&global_name, declarator.typ, init); } /// Linearize an initializer list for arrays or structs - fn linearize_init_list(&mut self, base_sym: PseudoId, typ: &Type, elements: &[InitElement]) { + fn linearize_init_list(&mut self, base_sym: PseudoId, typ: TypeId, elements: &[InitElement]) { self.linearize_init_list_at_offset(base_sym, 0, typ, elements); } @@ -897,17 +906,13 @@ impl<'a> Linearizer<'a> { &mut self, base_sym: PseudoId, base_offset: i64, - typ: &Type, + typ: TypeId, elements: &[InitElement], ) { - match typ.kind { + match self.types.kind(typ) { TypeKind::Array => { - let elem_type = typ - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Int)); - let elem_size = elem_type.size_bits() / 8; + let elem_type = self.types.base_type(typ).unwrap_or(self.types.int_id); + let elem_size = self.types.size_bits(elem_type) / 8; for (idx, element) in elements.iter().enumerate() { // Calculate the actual index considering designators @@ -932,27 +937,26 @@ impl<'a> Linearizer<'a> { self.linearize_init_list_at_offset( base_sym, offset, - &elem_type, + elem_type, nested_elems, ); } else { // Scalar value let val = self.linearize_expr(&element.value); let val_type = self.expr_type(&element.value); - let converted = self.emit_convert(val, &val_type, &elem_type); + let converted = self.emit_convert(val, val_type, elem_type); + let elem_size = self.types.size_bits(elem_type); self.emit(Instruction::store( - converted, - base_sym, - offset, - elem_type.clone(), + converted, base_sym, offset, elem_type, elem_size, )); } } } TypeKind::Struct | TypeKind::Union => { // Get struct fields from the type's composite data - if let Some(composite) = &typ.composite { - let members = &composite.members; + if let Some(composite) = self.types.get(typ).composite.as_ref() { + // Clone members to avoid borrow issues + let members: Vec<_> = composite.members.clone(); for (idx, element) in elements.iter().enumerate() { // Find the field (by designator or position) @@ -972,7 +976,7 @@ impl<'a> Linearizer<'a> { }; let offset = base_offset + member.offset as i64; - let field_type = member.typ.clone(); + let field_type = member.typ; // Handle nested initializer lists or scalar values if let ExprKind::InitList { @@ -983,15 +987,18 @@ impl<'a> Linearizer<'a> { self.linearize_init_list_at_offset( base_sym, offset, - &field_type, + field_type, nested_elems, ); } else { // Scalar value let val = self.linearize_expr(&element.value); let val_type = self.expr_type(&element.value); - let converted = self.emit_convert(val, &val_type, &field_type); - self.emit(Instruction::store(converted, base_sym, offset, field_type)); + let converted = self.emit_convert(val, val_type, field_type); + let field_size = self.types.size_bits(field_type); + self.emit(Instruction::store( + converted, base_sym, offset, field_type, field_size, + )); } } } @@ -1001,12 +1008,14 @@ impl<'a> Linearizer<'a> { if let Some(element) = elements.first() { let val = self.linearize_expr(&element.value); let val_type = self.expr_type(&element.value); - let converted = self.emit_convert(val, &val_type, typ); + let converted = self.emit_convert(val, val_type, typ); + let typ_size = self.types.size_bits(typ); self.emit(Instruction::store( converted, base_sym, base_offset, - typ.clone(), + typ, + typ_size, )); } } @@ -1215,7 +1224,7 @@ impl<'a> Linearizer<'a> { // Linearize the switch expression let switch_val = self.linearize_expr(expr); let expr_type = self.expr_type(expr); - let size = expr_type.size_bits(); + let size = self.types.size_bits(expr_type); let exit_bb = self.alloc_bb(); @@ -1453,8 +1462,8 @@ impl<'a> Linearizer<'a> { /// Get the type of an expression. /// PANICS if expression has no type - type evaluation pass must run first. /// Following sparse's design: the IR should ALWAYS receive fully typed input. - fn expr_type(&self, expr: &Expr) -> Type { - expr.typ.clone().expect( + fn expr_type(&self, expr: &Expr) -> TypeId { + expr.typ.expect( "BUG: expression has no type. Type evaluation pass must run before linearization.", ) } @@ -1479,7 +1488,7 @@ impl<'a> Linearizer<'a> { self.emit(Instruction::sym_addr( result, sym_id, - Type::pointer(static_info.typ), + self.types.pointer_to(static_info.typ), )); return result; } @@ -1488,7 +1497,7 @@ impl<'a> Linearizer<'a> { self.emit(Instruction::sym_addr( result, local.sym, - Type::pointer(local.typ), + self.types.pointer_to(local.typ), )); result } else { @@ -1500,7 +1509,11 @@ impl<'a> Linearizer<'a> { } let result = self.alloc_pseudo(); let typ = self.expr_type(expr); - self.emit(Instruction::sym_addr(result, sym_id, Type::pointer(typ))); + self.emit(Instruction::sym_addr( + result, + sym_id, + self.types.pointer_to(typ), + )); result } } @@ -1518,28 +1531,29 @@ impl<'a> Linearizer<'a> { // s.m as lvalue = &s + offset(m) let base = self.linearize_lvalue(inner); let struct_type = self.expr_type(inner); - let member_info = struct_type - .find_member(member) - .unwrap_or_else(|| MemberInfo { - offset: 0, - typ: self.expr_type(expr), - bit_offset: None, - bit_width: None, - storage_unit_size: None, - }); + let member_info = + self.types + .find_member(struct_type, member) + .unwrap_or_else(|| MemberInfo { + offset: 0, + typ: self.expr_type(expr), + bit_offset: None, + bit_width: None, + storage_unit_size: None, + }); if member_info.offset == 0 { base } else { - let offset_val = - self.emit_const(member_info.offset as i64, Type::basic(TypeKind::Long)); + let offset_val = self.emit_const(member_info.offset as i64, self.types.long_id); let result = self.alloc_pseudo(); self.emit(Instruction::binop( Opcode::Add, result, base, offset_val, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); result } @@ -1551,33 +1565,33 @@ impl<'a> Linearizer<'a> { // ptr->m as lvalue = ptr + offset(m) let ptr = self.linearize_expr(inner); let ptr_type = self.expr_type(inner); - let struct_type = ptr_type - .base - .as_ref() - .map(|b| b.as_ref().clone()) + let struct_type = self + .types + .base_type(ptr_type) .unwrap_or_else(|| self.expr_type(expr)); - let member_info = struct_type - .find_member(member) - .unwrap_or_else(|| MemberInfo { - offset: 0, - typ: self.expr_type(expr), - bit_offset: None, - bit_width: None, - storage_unit_size: None, - }); + let member_info = + self.types + .find_member(struct_type, member) + .unwrap_or_else(|| MemberInfo { + offset: 0, + typ: self.expr_type(expr), + bit_offset: None, + bit_width: None, + storage_unit_size: None, + }); if member_info.offset == 0 { ptr } else { - let offset_val = - self.emit_const(member_info.offset as i64, Type::basic(TypeKind::Long)); + let offset_val = self.emit_const(member_info.offset as i64, self.types.long_id); let result = self.alloc_pseudo(); self.emit(Instruction::binop( Opcode::Add, result, ptr, offset_val, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); result } @@ -1588,8 +1602,9 @@ impl<'a> Linearizer<'a> { let array_type = self.expr_type(array); let index_type = self.expr_type(index); + let array_kind = self.types.kind(array_type); let (ptr_expr, idx_expr, idx_type) = - if array_type.kind == TypeKind::Pointer || array_type.kind == TypeKind::Array { + if array_kind == TypeKind::Pointer || array_kind == TypeKind::Array { (array, index, index_type) } else { // Swap: index is actually the pointer/array @@ -1599,11 +1614,11 @@ impl<'a> Linearizer<'a> { let arr = self.linearize_expr(ptr_expr); let idx = self.linearize_expr(idx_expr); let elem_type = self.expr_type(expr); - let elem_size = elem_type.size_bits() / 8; - let elem_size_val = self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)); + let elem_size = self.types.size_bits(elem_type) / 8; + let elem_size_val = self.emit_const(elem_size as i64, self.types.long_id); // Sign-extend index to 64-bit for proper pointer arithmetic (negative indices) - let idx_extended = self.emit_convert(idx, &idx_type, &Type::basic(TypeKind::Long)); + let idx_extended = self.emit_convert(idx, idx_type, self.types.long_id); let offset = self.alloc_pseudo(); self.emit(Instruction::binop( @@ -1611,7 +1626,8 @@ impl<'a> Linearizer<'a> { offset, idx_extended, elem_size_val, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); let addr = self.alloc_pseudo(); @@ -1620,7 +1636,8 @@ impl<'a> Linearizer<'a> { addr, arr, offset, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); addr } @@ -1678,7 +1695,7 @@ impl<'a> Linearizer<'a> { ExprKind::Ident { name, .. } => { // First check if it's an enum constant if let Some(value) = self.symbols.get_enum_value(name) { - self.emit_const(value, Type::basic(TypeKind::Int)) + self.emit_const(value, self.types.int_id) } // Check if it's a local variable else if let Some(local) = self.locals.get(name).cloned() { @@ -1694,22 +1711,17 @@ impl<'a> Linearizer<'a> { } let typ = static_info.typ; // Arrays decay to pointers - get address, not value - if typ.kind == TypeKind::Array { + if self.types.kind(typ) == TypeKind::Array { let result = self.alloc_pseudo(); - self.emit(Instruction::sym_addr( - result, - sym_id, - Type::pointer( - typ.base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Int)), - ), - )); + let elem_type = + self.types.base_type(typ).unwrap_or(self.types.int_id); + let ptr_type = self.types.pointer_to(elem_type); + self.emit(Instruction::sym_addr(result, sym_id, ptr_type)); return result; } else { let result = self.alloc_pseudo(); - self.emit(Instruction::load(result, sym_id, 0, typ)); + let size = self.types.size_bits(typ); + self.emit(Instruction::load(result, sym_id, 0, typ, size)); return result; } } @@ -1720,21 +1732,14 @@ impl<'a> Linearizer<'a> { func.add_pseudo(pseudo); } // Arrays decay to pointers - get address, not value - if local.typ.kind == TypeKind::Array { - self.emit(Instruction::sym_addr( - result, - local.sym, - Type::pointer( - local - .typ - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Int)), - ), - )); + if self.types.kind(local.typ) == TypeKind::Array { + let elem_type = + self.types.base_type(local.typ).unwrap_or(self.types.int_id); + let ptr_type = self.types.pointer_to(elem_type); + self.emit(Instruction::sym_addr(result, local.sym, ptr_type)); } else { - self.emit(Instruction::load(result, local.sym, 0, local.typ)); + let size = self.types.size_bits(local.typ); + self.emit(Instruction::load(result, local.sym, 0, local.typ, size)); } result } @@ -1751,22 +1756,16 @@ impl<'a> Linearizer<'a> { } let typ = self.expr_type(expr); // Arrays decay to pointers - get address, not value - if typ.kind == TypeKind::Array { + if self.types.kind(typ) == TypeKind::Array { let result = self.alloc_pseudo(); - self.emit(Instruction::sym_addr( - result, - sym_id, - Type::pointer( - typ.base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Int)), - ), - )); + let elem_type = self.types.base_type(typ).unwrap_or(self.types.int_id); + let ptr_type = self.types.pointer_to(elem_type); + self.emit(Instruction::sym_addr(result, sym_id, ptr_type)); result } else { let result = self.alloc_pseudo(); - self.emit(Instruction::load(result, sym_id, 0, typ)); + let size = self.types.size_bits(typ); + self.emit(Instruction::load(result, sym_id, 0, typ, size)); result } } @@ -1782,22 +1781,18 @@ impl<'a> Linearizer<'a> { if *op == UnaryOp::PreInc || *op == UnaryOp::PreDec { let val = self.linearize_expr(operand); let typ = self.expr_type(operand); - let is_float = typ.is_float(); - let is_ptr = typ.kind == TypeKind::Pointer; + let is_float = self.types.is_float(typ); + let is_ptr = self.types.kind(typ) == TypeKind::Pointer; // Compute new value - for pointers, scale by element size let increment = if is_ptr { - let elem_type = typ - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Char)); - let elem_size = elem_type.size_bits() / 8; - self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)) + let elem_type = self.types.base_type(typ).unwrap_or(self.types.char_id); + let elem_size = self.types.size_bits(elem_type) / 8; + self.emit_const(elem_size as i64, self.types.long_id) } else if is_float { - self.emit_fconst(1.0, typ.clone()) + self.emit_fconst(1.0, typ) } else { - self.emit_const(1, typ.clone()) + self.emit_const(1, typ) }; let result = self.alloc_pseudo(); let pseudo = Pseudo::reg(result, result.0); @@ -1815,17 +1810,14 @@ impl<'a> Linearizer<'a> { } else { Opcode::Sub }; + let size = self.types.size_bits(typ); self.emit(Instruction::binop( - opcode, - result, - val, - increment, - typ.clone(), + opcode, result, val, increment, typ, size, )); // For _Bool, normalize the result (any non-zero -> 1) - let final_result = if typ.kind == TypeKind::Bool { - self.emit_convert(result, &Type::basic(TypeKind::Int), &typ) + let final_result = if self.types.kind(typ) == TypeKind::Bool { + self.emit_convert(result, self.types.int_id, typ) } else { result }; @@ -1833,7 +1825,14 @@ impl<'a> Linearizer<'a> { // Store back to the variable if let ExprKind::Ident { name, .. } = &operand.kind { if let Some(local) = self.locals.get(name).cloned() { - self.emit(Instruction::store(final_result, local.sym, 0, typ)); + let store_size = self.types.size_bits(typ); + self.emit(Instruction::store( + final_result, + local.sym, + 0, + typ, + store_size, + )); } else if self.var_map.contains_key(name) { self.var_map.insert(name.clone(), final_result); } else { @@ -1843,7 +1842,8 @@ impl<'a> Linearizer<'a> { if let Some(func) = &mut self.current_func { func.add_pseudo(pseudo); } - self.emit(Instruction::store(final_result, sym_id, 0, typ)); + let store_size = self.types.size_bits(typ); + self.emit(Instruction::store(final_result, sym_id, 0, typ, store_size)); } } @@ -1877,13 +1877,15 @@ impl<'a> Linearizer<'a> { let result_typ = self.expr_type(expr); // Check for pointer arithmetic: ptr +/- int or int + ptr + let left_kind = self.types.kind(left_typ); + let right_kind = self.types.kind(right_typ); let left_is_ptr_or_arr = - left_typ.kind == TypeKind::Pointer || left_typ.kind == TypeKind::Array; + left_kind == TypeKind::Pointer || left_kind == TypeKind::Array; let right_is_ptr_or_arr = - right_typ.kind == TypeKind::Pointer || right_typ.kind == TypeKind::Array; + right_kind == TypeKind::Pointer || right_kind == TypeKind::Array; let is_ptr_arith = (*op == BinaryOp::Add || *op == BinaryOp::Sub) - && ((left_is_ptr_or_arr && right_typ.is_integer()) - || (left_typ.is_integer() && right_is_ptr_or_arr)); + && ((left_is_ptr_or_arr && self.types.is_integer(right_typ)) + || (self.types.is_integer(left_typ) && right_is_ptr_or_arr)); // Check for pointer difference: ptr - ptr let is_ptr_diff = *op == BinaryOp::Sub && left_is_ptr_or_arr && right_is_ptr_or_arr; @@ -1900,26 +1902,24 @@ impl<'a> Linearizer<'a> { byte_diff, left_val, right_val, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); // Get element size from the pointer type - let elem_type = left_typ - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Char)); - let elem_size = elem_type.size_bits() / 8; + let elem_type = self.types.base_type(left_typ).unwrap_or(self.types.char_id); + let elem_size = self.types.size_bits(elem_type) / 8; // Divide by element size - let scale = self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)); + let scale = self.emit_const(elem_size as i64, self.types.long_id); let result = self.alloc_pseudo(); self.emit(Instruction::binop( Opcode::DivS, result, byte_diff, scale, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); result } else if is_ptr_arith { @@ -1927,37 +1927,31 @@ impl<'a> Linearizer<'a> { let (ptr_val, ptr_typ, int_val) = if left_is_ptr_or_arr { let ptr = self.linearize_expr(left); let int = self.linearize_expr(right); - (ptr, left_typ.clone(), int) + (ptr, left_typ, int) } else { // int + ptr case let int = self.linearize_expr(left); let ptr = self.linearize_expr(right); - (ptr, right_typ.clone(), int) + (ptr, right_typ, int) }; // Get element size - let elem_type = ptr_typ - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Char)); - let elem_size = elem_type.size_bits() / 8; + let elem_type = self.types.base_type(ptr_typ).unwrap_or(self.types.char_id); + let elem_size = self.types.size_bits(elem_type) / 8; // Scale the integer by element size - let scale = self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)); + let scale = self.emit_const(elem_size as i64, self.types.long_id); let scaled_offset = self.alloc_pseudo(); // Extend int_val to 64-bit for proper address arithmetic - let int_val_extended = self.emit_convert( - int_val, - &Type::basic(TypeKind::Int), - &Type::basic(TypeKind::Long), - ); + let int_val_extended = + self.emit_convert(int_val, self.types.int_id, self.types.long_id); self.emit(Instruction::binop( Opcode::Mul, scaled_offset, int_val_extended, scale, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); // Add (or subtract) to pointer @@ -1972,16 +1966,17 @@ impl<'a> Linearizer<'a> { result, ptr_val, scaled_offset, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); result } else { // For comparisons, compute common type for both operands // (usual arithmetic conversions) let operand_typ = if op.is_comparison() { - Self::common_type(&left_typ, &right_typ) + self.common_type(left_typ, right_typ) } else { - result_typ.clone() + result_typ }; // Linearize operands @@ -1989,8 +1984,8 @@ impl<'a> Linearizer<'a> { let right_val = self.linearize_expr(right); // Emit type conversions if needed - let left_val = self.emit_convert(left_val, &left_typ, &operand_typ); - let right_val = self.emit_convert(right_val, &right_typ, &operand_typ); + let left_val = self.emit_convert(left_val, left_typ, operand_typ); + let right_val = self.emit_convert(right_val, right_typ, operand_typ); self.emit_binary(*op, left_val, right_val, result_typ, operand_typ) } @@ -2001,8 +1996,8 @@ impl<'a> Linearizer<'a> { ExprKind::PostInc(operand) => { let val = self.linearize_expr(operand); let typ = self.expr_type(operand); - let is_float = typ.is_float(); - let is_ptr = typ.kind == TypeKind::Pointer; + let is_float = self.types.is_float(typ); + let is_ptr = self.types.kind(typ) == TypeKind::Pointer; // For locals, we need to save the old value before updating // because the pseudo will be reloaded from stack which gets overwritten @@ -2023,7 +2018,7 @@ impl<'a> Linearizer<'a> { Instruction::new(Opcode::Copy) .with_target(temp) .with_src(val) - .with_size(typ.size_bits()), + .with_size(self.types.size_bits(typ)), ); temp } else { @@ -2032,17 +2027,13 @@ impl<'a> Linearizer<'a> { // For pointers, increment by element size; for others, increment by 1 let increment = if is_ptr { - let elem_type = typ - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Char)); - let elem_size = elem_type.size_bits() / 8; - self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)) + let elem_type = self.types.base_type(typ).unwrap_or(self.types.char_id); + let elem_size = self.types.size_bits(elem_type) / 8; + self.emit_const(elem_size as i64, self.types.long_id) } else if is_float { - self.emit_fconst(1.0, typ.clone()) + self.emit_fconst(1.0, typ) } else { - self.emit_const(1, typ.clone()) + self.emit_const(1, typ) }; let result = self.alloc_pseudo(); let pseudo = Pseudo::reg(result, result.0); @@ -2050,27 +2041,31 @@ impl<'a> Linearizer<'a> { func.add_pseudo(pseudo); } let opcode = if is_float { Opcode::FAdd } else { Opcode::Add }; - let arith_type = if is_ptr { - Type::basic(TypeKind::Long) - } else { - typ.clone() - }; + let arith_type = if is_ptr { self.types.long_id } else { typ }; + let arith_size = self.types.size_bits(arith_type); self.emit(Instruction::binop( - opcode, result, val, increment, arith_type, + opcode, result, val, increment, arith_type, arith_size, )); // For _Bool, normalize the result (any non-zero -> 1) - let final_result = if typ.kind == TypeKind::Bool { - self.emit_convert(result, &Type::basic(TypeKind::Int), &typ) + let final_result = if self.types.kind(typ) == TypeKind::Bool { + self.emit_convert(result, self.types.int_id, typ) } else { result }; // Store to local, update parameter mapping, or store through pointer + let store_size = self.types.size_bits(typ); match &operand.kind { ExprKind::Ident { name, .. } => { if let Some(local) = self.locals.get(name).cloned() { - self.emit(Instruction::store(final_result, local.sym, 0, typ)); + self.emit(Instruction::store( + final_result, + local.sym, + 0, + typ, + store_size, + )); } else if self.var_map.contains_key(name) { self.var_map.insert(name.clone(), final_result); } else { @@ -2080,7 +2075,7 @@ impl<'a> Linearizer<'a> { if let Some(func) = &mut self.current_func { func.add_pseudo(pseudo); } - self.emit(Instruction::store(final_result, sym_id, 0, typ)); + self.emit(Instruction::store(final_result, sym_id, 0, typ, store_size)); } } ExprKind::Unary { @@ -2089,7 +2084,7 @@ impl<'a> Linearizer<'a> { } => { // (*p)++ - store back through the pointer let addr = self.linearize_expr(ptr_expr); - self.emit(Instruction::store(final_result, addr, 0, typ)); + self.emit(Instruction::store(final_result, addr, 0, typ, store_size)); } _ => {} } @@ -2100,8 +2095,8 @@ impl<'a> Linearizer<'a> { ExprKind::PostDec(operand) => { let val = self.linearize_expr(operand); let typ = self.expr_type(operand); - let is_float = typ.is_float(); - let is_ptr = typ.kind == TypeKind::Pointer; + let is_float = self.types.is_float(typ); + let is_ptr = self.types.kind(typ) == TypeKind::Pointer; // For locals, we need to save the old value before updating // because the pseudo will be reloaded from stack which gets overwritten @@ -2122,7 +2117,7 @@ impl<'a> Linearizer<'a> { Instruction::new(Opcode::Copy) .with_target(temp) .with_src(val) - .with_size(typ.size_bits()), + .with_size(self.types.size_bits(typ)), ); temp } else { @@ -2131,17 +2126,13 @@ impl<'a> Linearizer<'a> { // For pointers, decrement by element size; for others, decrement by 1 let decrement = if is_ptr { - let elem_type = typ - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Char)); - let elem_size = elem_type.size_bits() / 8; - self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)) + let elem_type = self.types.base_type(typ).unwrap_or(self.types.char_id); + let elem_size = self.types.size_bits(elem_type) / 8; + self.emit_const(elem_size as i64, self.types.long_id) } else if is_float { - self.emit_fconst(1.0, typ.clone()) + self.emit_fconst(1.0, typ) } else { - self.emit_const(1, typ.clone()) + self.emit_const(1, typ) }; let result = self.alloc_pseudo(); let pseudo = Pseudo::reg(result, result.0); @@ -2149,27 +2140,31 @@ impl<'a> Linearizer<'a> { func.add_pseudo(pseudo); } let opcode = if is_float { Opcode::FSub } else { Opcode::Sub }; - let arith_type = if is_ptr { - Type::basic(TypeKind::Long) - } else { - typ.clone() - }; + let arith_type = if is_ptr { self.types.long_id } else { typ }; + let arith_size = self.types.size_bits(arith_type); self.emit(Instruction::binop( - opcode, result, val, decrement, arith_type, + opcode, result, val, decrement, arith_type, arith_size, )); // For _Bool, normalize the result (any non-zero -> 1) - let final_result = if typ.kind == TypeKind::Bool { - self.emit_convert(result, &Type::basic(TypeKind::Int), &typ) + let final_result = if self.types.kind(typ) == TypeKind::Bool { + self.emit_convert(result, self.types.int_id, typ) } else { result }; // Store to local, update parameter mapping, or store through pointer + let store_size = self.types.size_bits(typ); match &operand.kind { ExprKind::Ident { name, .. } => { if let Some(local) = self.locals.get(name).cloned() { - self.emit(Instruction::store(final_result, local.sym, 0, typ)); + self.emit(Instruction::store( + final_result, + local.sym, + 0, + typ, + store_size, + )); } else if self.var_map.contains_key(name) { self.var_map.insert(name.clone(), final_result); } else { @@ -2179,7 +2174,7 @@ impl<'a> Linearizer<'a> { if let Some(func) = &mut self.current_func { func.add_pseudo(pseudo); } - self.emit(Instruction::store(final_result, sym_id, 0, typ)); + self.emit(Instruction::store(final_result, sym_id, 0, typ, store_size)); } } ExprKind::Unary { @@ -2188,7 +2183,7 @@ impl<'a> Linearizer<'a> { } => { // (*p)-- - store back through the pointer let addr = self.linearize_expr(ptr_expr); - self.emit(Instruction::store(final_result, addr, 0, typ)); + self.emit(Instruction::store(final_result, addr, 0, typ, store_size)); } _ => {} } @@ -2207,9 +2202,10 @@ impl<'a> Linearizer<'a> { let result = self.alloc_pseudo(); let typ = self.expr_type(expr); // Use evaluated type + let size = self.types.size_bits(typ); self.emit(Instruction::select( - result, cond_val, then_val, else_val, typ, + result, cond_val, then_val, else_val, typ, size, )); result } @@ -2225,10 +2221,11 @@ impl<'a> Linearizer<'a> { // Check if this is a variadic function call // If the function expression has a type, check its variadic flag - let variadic_arg_start = if let Some(ref func_type) = func.typ { - if func_type.variadic { + let variadic_arg_start = if let Some(func_type) = func.typ { + let ft = self.types.get(func_type); + if ft.variadic { // Variadic args start after the fixed parameters - func_type.params.as_ref().map(|p| p.len()) + ft.params.as_ref().map(|p| p.len()) } else { None } @@ -2238,9 +2235,10 @@ impl<'a> Linearizer<'a> { // Check if function returns a large struct // If so, allocate space and pass address as hidden first argument - let returns_large_struct = (typ.kind == TypeKind::Struct - || typ.kind == TypeKind::Union) - && typ.size_bits() > MAX_REGISTER_AGGREGATE_BITS; + let typ_kind = self.types.kind(typ); + let returns_large_struct = (typ_kind == TypeKind::Struct + || typ_kind == TypeKind::Union) + && self.types.size_bits(typ) > self.target.max_aggregate_register_bits; let (result_sym, mut arg_vals, mut arg_types_vec) = if returns_large_struct { // Allocate local storage for the return value @@ -2252,7 +2250,7 @@ impl<'a> Linearizer<'a> { func.add_local( format!("__sret_{}", sret_sym.0), sret_sym, - typ.clone(), + typ, false, self.current_bb, ); @@ -2267,11 +2265,11 @@ impl<'a> Linearizer<'a> { self.emit(Instruction::sym_addr( sret_addr, sret_sym, - Type::pointer(typ.clone()), + self.types.pointer_to(typ), )); // Hidden return pointer is the first argument (pointer type) - (sret_sym, vec![sret_addr], vec![Type::pointer(typ.clone())]) + (sret_sym, vec![sret_addr], vec![self.types.pointer_to(typ)]) } else { let result = self.alloc_pseudo(); (result, Vec::new(), Vec::new()) @@ -2281,13 +2279,13 @@ impl<'a> Linearizer<'a> { // For large structs, pass by reference (address) instead of by value for a in args.iter() { let arg_type = self.expr_type(a); - let arg_val = if (arg_type.kind == TypeKind::Struct - || arg_type.kind == TypeKind::Union) - && arg_type.size_bits() > MAX_REGISTER_AGGREGATE_BITS + let arg_kind = self.types.kind(arg_type); + let arg_val = if (arg_kind == TypeKind::Struct || arg_kind == TypeKind::Union) + && self.types.size_bits(arg_type) > self.target.max_aggregate_register_bits { // Large struct: pass address instead of value // The argument type becomes a pointer - arg_types_vec.push(Type::pointer(arg_type)); + arg_types_vec.push(self.types.pointer_to(arg_type)); self.linearize_lvalue(a) } else { arg_types_vec.push(arg_type); @@ -2304,24 +2302,29 @@ impl<'a> Linearizer<'a> { if let Some(func) = &mut self.current_func { func.add_pseudo(result_pseudo); } + let ptr_typ = self.types.pointer_to(typ); let mut call_insn = Instruction::call( Some(result), &func_name, arg_vals, arg_types_vec, - Type::pointer(typ), + ptr_typ, + 64, // pointers are 64-bit ); call_insn.variadic_arg_start = variadic_arg_start; + call_insn.is_sret_call = true; self.emit(call_insn); // Return the symbol (address) where struct is stored result_sym } else { + let ret_size = self.types.size_bits(typ); let mut call_insn = Instruction::call( Some(result_sym), &func_name, arg_vals, arg_types_vec, typ, + ret_size, ); call_insn.variadic_arg_start = variadic_arg_start; self.emit(call_insn); @@ -2338,30 +2341,32 @@ impl<'a> Linearizer<'a> { let struct_type = self.expr_type(inner_expr); // Look up member offset and type - let member_info = struct_type - .find_member(member) - .unwrap_or_else(|| MemberInfo { - offset: 0, - typ: self.expr_type(expr), - bit_offset: None, - bit_width: None, - storage_unit_size: None, - }); + let member_info = + self.types + .find_member(struct_type, member) + .unwrap_or_else(|| MemberInfo { + offset: 0, + typ: self.expr_type(expr), + bit_offset: None, + bit_width: None, + storage_unit_size: None, + }); // If member type is an array, return the address (arrays decay to pointers) - if member_info.typ.kind == TypeKind::Array { + if self.types.kind(member_info.typ) == TypeKind::Array { if member_info.offset == 0 { base } else { let result = self.alloc_pseudo(); let offset_val = - self.emit_const(member_info.offset as i64, Type::basic(TypeKind::Long)); + self.emit_const(member_info.offset as i64, self.types.long_id); self.emit(Instruction::binop( Opcode::Add, result, base, offset_val, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); result } @@ -2377,15 +2382,17 @@ impl<'a> Linearizer<'a> { bit_offset, bit_width, storage_size, - &member_info.typ, + member_info.typ, ) } else { let result = self.alloc_pseudo(); + let size = self.types.size_bits(member_info.typ); self.emit(Instruction::load( result, base, member_info.offset as i64, member_info.typ, + size, )); result } @@ -2400,37 +2407,38 @@ impl<'a> Linearizer<'a> { let ptr_type = self.expr_type(inner_expr); // Dereference pointer to get struct type - let struct_type = ptr_type - .base - .as_ref() - .map(|b| b.as_ref().clone()) + let struct_type = self + .types + .base_type(ptr_type) .unwrap_or_else(|| self.expr_type(expr)); // Look up member offset and type - let member_info = struct_type - .find_member(member) - .unwrap_or_else(|| MemberInfo { - offset: 0, - typ: self.expr_type(expr), - bit_offset: None, - bit_width: None, - storage_unit_size: None, - }); + let member_info = + self.types + .find_member(struct_type, member) + .unwrap_or_else(|| MemberInfo { + offset: 0, + typ: self.expr_type(expr), + bit_offset: None, + bit_width: None, + storage_unit_size: None, + }); // If member type is an array, return the address (arrays decay to pointers) - if member_info.typ.kind == TypeKind::Array { + if self.types.kind(member_info.typ) == TypeKind::Array { if member_info.offset == 0 { ptr } else { let result = self.alloc_pseudo(); let offset_val = - self.emit_const(member_info.offset as i64, Type::basic(TypeKind::Long)); + self.emit_const(member_info.offset as i64, self.types.long_id); self.emit(Instruction::binop( Opcode::Add, result, ptr, offset_val, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); result } @@ -2446,15 +2454,17 @@ impl<'a> Linearizer<'a> { bit_offset, bit_width, storage_size, - &member_info.typ, + member_info.typ, ) } else { let result = self.alloc_pseudo(); + let size = self.types.size_bits(member_info.typ); self.emit(Instruction::load( result, ptr, member_info.offset as i64, member_info.typ, + size, )); result } @@ -2466,8 +2476,9 @@ impl<'a> Linearizer<'a> { let array_type = self.expr_type(array); let index_type = self.expr_type(index); + let array_kind = self.types.kind(array_type); let (ptr_expr, idx_expr, idx_type) = - if array_type.kind == TypeKind::Pointer || array_type.kind == TypeKind::Array { + if array_kind == TypeKind::Pointer || array_kind == TypeKind::Array { (array, index, index_type) } else { // Swap: index is actually the pointer/array @@ -2479,31 +2490,40 @@ impl<'a> Linearizer<'a> { // Get element type from the expression type let elem_type = self.expr_type(expr); - let elem_size = elem_type.size_bits() / 8; - let elem_size_val = self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)); + let elem_size = self.types.size_bits(elem_type) / 8; + let elem_size_val = self.emit_const(elem_size as i64, self.types.long_id); // Sign-extend index to 64-bit for proper pointer arithmetic (negative indices) - let idx_extended = self.emit_convert(idx, &idx_type, &Type::basic(TypeKind::Long)); + let idx_extended = self.emit_convert(idx, idx_type, self.types.long_id); let offset = self.alloc_pseudo(); - let ptr_typ = Type::basic(TypeKind::Long); + let ptr_typ = self.types.long_id; self.emit(Instruction::binop( Opcode::Mul, offset, idx_extended, elem_size_val, - ptr_typ.clone(), + ptr_typ, + 64, )); let addr = self.alloc_pseudo(); - self.emit(Instruction::binop(Opcode::Add, addr, arr, offset, ptr_typ)); + self.emit(Instruction::binop( + Opcode::Add, + addr, + arr, + offset, + ptr_typ, + 64, + )); // If element type is an array, just return the address (arrays decay to pointers) - if elem_type.kind == TypeKind::Array { + if self.types.kind(elem_type) == TypeKind::Array { addr } else { let result = self.alloc_pseudo(); - self.emit(Instruction::load(result, addr, 0, elem_type)); + let size = self.types.size_bits(elem_type); + self.emit(Instruction::load(result, addr, 0, elem_type, size)); result } } @@ -2514,10 +2534,11 @@ impl<'a> Linearizer<'a> { } => { let src = self.linearize_expr(inner_expr); let src_type = self.expr_type(inner_expr); + let cast_type = *cast_type; // Copy the TypeId // Emit conversion if needed - let src_is_float = src_type.is_float(); - let dst_is_float = cast_type.is_float(); + let src_is_float = self.types.is_float(src_type); + let dst_is_float = self.types.is_float(cast_type); if src_is_float && !dst_is_float { // Float to integer conversion @@ -2527,7 +2548,7 @@ impl<'a> Linearizer<'a> { func.add_pseudo(pseudo); } // FCvtS for signed int, FCvtU for unsigned - let opcode = if cast_type.is_unsigned() { + let opcode = if self.types.is_unsigned(cast_type) { Opcode::FCvtU } else { Opcode::FCvtS @@ -2535,8 +2556,8 @@ impl<'a> Linearizer<'a> { let mut insn = Instruction::new(opcode) .with_target(result) .with_src(src) - .with_type(cast_type.clone()); - insn.src_size = src_type.size_bits(); + .with_type(cast_type); + insn.src_size = self.types.size_bits(src_type); self.emit(insn); result } else if !src_is_float && dst_is_float { @@ -2547,7 +2568,7 @@ impl<'a> Linearizer<'a> { func.add_pseudo(pseudo); } // SCvtF for signed int, UCvtF for unsigned - let opcode = if src_type.is_unsigned() { + let opcode = if self.types.is_unsigned(src_type) { Opcode::UCvtF } else { Opcode::SCvtF @@ -2555,13 +2576,15 @@ impl<'a> Linearizer<'a> { let mut insn = Instruction::new(opcode) .with_target(result) .with_src(src) - .with_type(cast_type.clone()); - insn.src_size = src_type.size_bits(); + .with_type(cast_type); + insn.src_size = self.types.size_bits(src_type); self.emit(insn); result } else if src_is_float && dst_is_float { // Float to float conversion (e.g., float to double) - if src_type.size_bits() != cast_type.size_bits() { + let src_size = self.types.size_bits(src_type); + let dst_size = self.types.size_bits(cast_type); + if src_size != dst_size { let result = self.alloc_pseudo(); let pseudo = Pseudo::reg(result, result.0); if let Some(func) = &mut self.current_func { @@ -2570,8 +2593,8 @@ impl<'a> Linearizer<'a> { let mut insn = Instruction::new(Opcode::FCvtF) .with_target(result) .with_src(src) - .with_type(cast_type.clone()); - insn.src_size = src_type.size_bits(); + .with_type(cast_type); + insn.src_size = src_size; self.emit(insn); result } else { @@ -2580,28 +2603,28 @@ impl<'a> Linearizer<'a> { } else { // Integer to integer conversion // Use emit_convert for proper type conversions including _Bool - self.emit_convert(src, &src_type, cast_type) + self.emit_convert(src, src_type, cast_type) } } ExprKind::SizeofType(typ) => { - let size = typ.size_bits() / 8; + let size = self.types.size_bits(*typ) / 8; // sizeof returns size_t, which is unsigned long in our implementation - let result_typ = Type::with_modifiers(TypeKind::Long, TypeModifiers::UNSIGNED); + let result_typ = self.types.ulong_id; self.emit_const(size as i64, result_typ) } ExprKind::SizeofExpr(inner_expr) => { // Get type from expression and compute size let inner_typ = self.expr_type(inner_expr); - let size = inner_typ.size_bits() / 8; + let size = self.types.size_bits(inner_typ) / 8; // sizeof returns size_t, which is unsigned long in our implementation - let result_typ = Type::with_modifiers(TypeKind::Long, TypeModifiers::UNSIGNED); + let result_typ = self.types.ulong_id; self.emit_const(size as i64, result_typ) } ExprKind::Comma(exprs) => { - let mut result = self.emit_const(0, Type::basic(TypeKind::Int)); + let mut result = self.emit_const(0, self.types.int_id); for e in exprs { result = self.linearize_expr(e); } @@ -2628,7 +2651,7 @@ impl<'a> Linearizer<'a> { .with_target(result) .with_src(ap_addr) .with_func(last_param.clone()) - .with_type(Type::basic(TypeKind::Void)); + .with_type(self.types.void_id); self.emit(insn); result } @@ -2642,7 +2665,7 @@ impl<'a> Linearizer<'a> { let insn = Instruction::new(Opcode::VaArg) .with_target(result) .with_src(ap_addr) - .with_type(arg_type.clone()); + .with_type(*arg_type); self.emit(insn); result } @@ -2655,7 +2678,7 @@ impl<'a> Linearizer<'a> { let insn = Instruction::new(Opcode::VaEnd) .with_target(result) .with_src(ap_addr) - .with_type(Type::basic(TypeKind::Void)); + .with_type(self.types.void_id); self.emit(insn); result } @@ -2670,7 +2693,7 @@ impl<'a> Linearizer<'a> { .with_target(result) .with_src(dest_addr) .with_src(src_addr) - .with_type(Type::basic(TypeKind::Void)); + .with_type(self.types.void_id); self.emit(insn); result } @@ -2686,10 +2709,7 @@ impl<'a> Linearizer<'a> { .with_target(result) .with_src(arg_val) .with_size(16) - .with_type(Type::with_modifiers( - TypeKind::Short, - TypeModifiers::UNSIGNED, - )); + .with_type(self.types.ushort_id); self.emit(insn); result } @@ -2702,7 +2722,7 @@ impl<'a> Linearizer<'a> { .with_target(result) .with_src(arg_val) .with_size(32) - .with_type(Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED)); + .with_type(self.types.uint_id); self.emit(insn); result } @@ -2715,10 +2735,7 @@ impl<'a> Linearizer<'a> { .with_target(result) .with_src(arg_val) .with_size(64) - .with_type(Type::with_modifiers( - TypeKind::LongLong, - TypeModifiers::UNSIGNED, - )); + .with_type(self.types.ulonglong_id); self.emit(insn); result } @@ -2730,15 +2747,14 @@ impl<'a> Linearizer<'a> { let insn = Instruction::new(Opcode::Alloca) .with_target(result) .with_src(size_val) - .with_size(64) // Returns pointer (64-bit) - .with_type(Type::pointer(Type::basic(TypeKind::Void))); + .with_type_and_size(self.types.void_ptr_id, 64); self.emit(insn); result } } } - fn emit_const(&mut self, val: i64, typ: Type) -> PseudoId { + fn emit_const(&mut self, val: i64, typ: TypeId) -> PseudoId { let id = self.alloc_pseudo(); let pseudo = Pseudo::val(id, val); if let Some(func) = &mut self.current_func { @@ -2748,13 +2764,13 @@ impl<'a> Linearizer<'a> { // Emit setval instruction let insn = Instruction::new(Opcode::SetVal) .with_target(id) - .with_type(typ); + .with_type_and_size(typ, self.types.size_bits(typ)); self.emit(insn); id } - fn emit_fconst(&mut self, val: f64, typ: Type) -> PseudoId { + fn emit_fconst(&mut self, val: f64, typ: TypeId) -> PseudoId { let id = self.alloc_pseudo(); let pseudo = Pseudo::fval(id, val); if let Some(func) = &mut self.current_func { @@ -2764,7 +2780,7 @@ impl<'a> Linearizer<'a> { // Emit setval instruction let insn = Instruction::new(Opcode::SetVal) .with_target(id) - .with_type(typ); + .with_type_and_size(typ, self.types.size_bits(typ)); self.emit(insn); id @@ -2779,16 +2795,17 @@ impl<'a> Linearizer<'a> { bit_offset: u32, bit_width: u32, storage_size: u32, - typ: &Type, + typ: TypeId, ) -> PseudoId { // Determine storage type based on storage unit size let storage_type = match storage_size { - 1 => Type::with_modifiers(TypeKind::Char, TypeModifiers::UNSIGNED), - 2 => Type::with_modifiers(TypeKind::Short, TypeModifiers::UNSIGNED), - 4 => Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED), - 8 => Type::with_modifiers(TypeKind::Long, TypeModifiers::UNSIGNED), - _ => Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED), + 1 => self.types.uchar_id, + 2 => self.types.ushort_id, + 4 => self.types.uint_id, + 8 => self.types.ulong_id, + _ => self.types.uint_id, }; + let storage_bits = storage_size * 8; // 1. Load the entire storage unit let storage_val = self.alloc_pseudo(); @@ -2796,19 +2813,21 @@ impl<'a> Linearizer<'a> { storage_val, base, byte_offset as i64, - storage_type.clone(), + storage_type, + storage_bits, )); // 2. Shift right by bit_offset (using logical shift for unsigned extraction) let shifted = if bit_offset > 0 { - let shift_amount = self.emit_const(bit_offset as i64, Type::basic(TypeKind::Int)); + let shift_amount = self.emit_const(bit_offset as i64, self.types.int_id); let shifted = self.alloc_pseudo(); self.emit(Instruction::binop( Opcode::Lsr, shifted, storage_val, shift_amount, - storage_type.clone(), + storage_type, + storage_bits, )); shifted } else { @@ -2817,19 +2836,20 @@ impl<'a> Linearizer<'a> { // 3. Mask to bit_width bits let mask = (1u64 << bit_width) - 1; - let mask_val = self.emit_const(mask as i64, storage_type.clone()); + let mask_val = self.emit_const(mask as i64, storage_type); let masked = self.alloc_pseudo(); self.emit(Instruction::binop( Opcode::And, masked, shifted, mask_val, - storage_type.clone(), + storage_type, + storage_bits, )); // 4. Sign extend if this is a signed bitfield - if !typ.is_unsigned() && bit_width < storage_size * 8 { - self.emit_sign_extend_bitfield(masked, bit_width, storage_size * 8) + if !self.types.is_unsigned(typ) && bit_width < storage_bits { + self.emit_sign_extend_bitfield(masked, bit_width, storage_bits) } else { masked } @@ -2844,16 +2864,17 @@ impl<'a> Linearizer<'a> { ) -> PseudoId { // Sign extend by shifting left then arithmetic shifting right let shift_amount = target_bits - bit_width; - let typ = Type::basic(TypeKind::Int); + let typ = self.types.int_id; - let shift_val = self.emit_const(shift_amount as i64, typ.clone()); + let shift_val = self.emit_const(shift_amount as i64, typ); let shifted_left = self.alloc_pseudo(); self.emit(Instruction::binop( Opcode::Shl, shifted_left, value, shift_val, - typ.clone(), + typ, + 32, )); let result = self.alloc_pseudo(); @@ -2863,6 +2884,7 @@ impl<'a> Linearizer<'a> { shifted_left, shift_val, typ, + 32, )); result } @@ -2879,12 +2901,13 @@ impl<'a> Linearizer<'a> { ) { // Determine storage type based on storage unit size let storage_type = match storage_size { - 1 => Type::with_modifiers(TypeKind::Char, TypeModifiers::UNSIGNED), - 2 => Type::with_modifiers(TypeKind::Short, TypeModifiers::UNSIGNED), - 4 => Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED), - 8 => Type::with_modifiers(TypeKind::Long, TypeModifiers::UNSIGNED), - _ => Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED), + 1 => self.types.uchar_id, + 2 => self.types.ushort_id, + 4 => self.types.uint_id, + 8 => self.types.ulong_id, + _ => self.types.uint_id, }; + let storage_bits = storage_size * 8; // 1. Load current storage unit value let old_val = self.alloc_pseudo(); @@ -2892,13 +2915,14 @@ impl<'a> Linearizer<'a> { old_val, base, byte_offset as i64, - storage_type.clone(), + storage_type, + storage_bits, )); // 2. Create mask for the bitfield bits: ~(((1 << width) - 1) << offset) let field_mask = ((1u64 << bit_width) - 1) << bit_offset; let clear_mask = !field_mask; - let clear_mask_val = self.emit_const(clear_mask as i64, storage_type.clone()); + let clear_mask_val = self.emit_const(clear_mask as i64, storage_type); // 3. Clear the bitfield bits in old value let cleared = self.alloc_pseudo(); @@ -2907,30 +2931,33 @@ impl<'a> Linearizer<'a> { cleared, old_val, clear_mask_val, - storage_type.clone(), + storage_type, + storage_bits, )); // 4. Mask new value to bit_width and shift to position let value_mask = (1u64 << bit_width) - 1; - let value_mask_val = self.emit_const(value_mask as i64, storage_type.clone()); + let value_mask_val = self.emit_const(value_mask as i64, storage_type); let masked_new = self.alloc_pseudo(); self.emit(Instruction::binop( Opcode::And, masked_new, new_value, value_mask_val, - storage_type.clone(), + storage_type, + storage_bits, )); let positioned = if bit_offset > 0 { - let shift_val = self.emit_const(bit_offset as i64, Type::basic(TypeKind::Int)); + let shift_val = self.emit_const(bit_offset as i64, self.types.int_id); let positioned = self.alloc_pseudo(); self.emit(Instruction::binop( Opcode::Shl, positioned, masked_new, shift_val, - storage_type.clone(), + storage_type, + storage_bits, )); positioned } else { @@ -2944,7 +2971,8 @@ impl<'a> Linearizer<'a> { combined, cleared, positioned, - storage_type.clone(), + storage_type, + storage_bits, )); // 6. Store back @@ -2953,12 +2981,14 @@ impl<'a> Linearizer<'a> { base, byte_offset as i64, storage_type, + storage_bits, )); } - fn emit_unary(&mut self, op: UnaryOp, src: PseudoId, typ: Type) -> PseudoId { + fn emit_unary(&mut self, op: UnaryOp, src: PseudoId, typ: TypeId) -> PseudoId { let result = self.alloc_pseudo(); - let is_float = typ.is_float(); + let is_float = self.types.is_float(typ); + let size = self.types.size_bits(typ); let opcode = match op { UnaryOp::Neg => { @@ -2971,11 +3001,25 @@ impl<'a> Linearizer<'a> { UnaryOp::Not => { // Logical not: compare with 0 if is_float { - let zero = self.emit_fconst(0.0, typ.clone()); - self.emit(Instruction::binop(Opcode::FCmpOEq, result, src, zero, typ)); + let zero = self.emit_fconst(0.0, typ); + self.emit(Instruction::binop( + Opcode::FCmpOEq, + result, + src, + zero, + typ, + size, + )); } else { - let zero = self.emit_const(0, typ.clone()); - self.emit(Instruction::binop(Opcode::SetEq, result, src, zero, typ)); + let zero = self.emit_const(0, typ); + self.emit(Instruction::binop( + Opcode::SetEq, + result, + src, + zero, + typ, + size, + )); } return result; } @@ -2986,53 +3030,49 @@ impl<'a> Linearizer<'a> { UnaryOp::Deref => { // Dereferencing a pointer-to-array gives an array, which is just an address // (arrays decay to their first element's address) - if typ.kind == TypeKind::Array { + if self.types.kind(typ) == TypeKind::Array { return src; } - self.emit(Instruction::load(result, src, 0, typ)); + self.emit(Instruction::load(result, src, 0, typ, size)); return result; } UnaryOp::PreInc => { - let is_ptr = typ.kind == TypeKind::Pointer; + let is_ptr = self.types.kind(typ) == TypeKind::Pointer; let increment = if is_ptr { - let elem_type = typ - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Char)); - let elem_size = elem_type.size_bits() / 8; - self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)) + let elem_type = self.types.base_type(typ).unwrap_or(self.types.char_id); + let elem_size = self.types.size_bits(elem_type) / 8; + self.emit_const(elem_size as i64, self.types.long_id) } else if is_float { - self.emit_fconst(1.0, typ.clone()) + self.emit_fconst(1.0, typ) } else { - self.emit_const(1, typ.clone()) + self.emit_const(1, typ) }; let opcode = if is_float { Opcode::FAdd } else { Opcode::Add }; - self.emit(Instruction::binop(opcode, result, src, increment, typ)); + self.emit(Instruction::binop( + opcode, result, src, increment, typ, size, + )); return result; } UnaryOp::PreDec => { - let is_ptr = typ.kind == TypeKind::Pointer; + let is_ptr = self.types.kind(typ) == TypeKind::Pointer; let decrement = if is_ptr { - let elem_type = typ - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Char)); - let elem_size = elem_type.size_bits() / 8; - self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)) + let elem_type = self.types.base_type(typ).unwrap_or(self.types.char_id); + let elem_size = self.types.size_bits(elem_type) / 8; + self.emit_const(elem_size as i64, self.types.long_id) } else if is_float { - self.emit_fconst(1.0, typ.clone()) + self.emit_fconst(1.0, typ) } else { - self.emit_const(1, typ.clone()) + self.emit_const(1, typ) }; let opcode = if is_float { Opcode::FSub } else { Opcode::Sub }; - self.emit(Instruction::binop(opcode, result, src, decrement, typ)); + self.emit(Instruction::binop( + opcode, result, src, decrement, typ, size, + )); return result; } }; - self.emit(Instruction::unop(opcode, result, src, typ)); + self.emit(Instruction::unop(opcode, result, src, typ, size)); result } @@ -3041,12 +3081,13 @@ impl<'a> Linearizer<'a> { op: BinaryOp, left: PseudoId, right: PseudoId, - result_typ: Type, - operand_typ: Type, + result_typ: TypeId, + operand_typ: TypeId, ) -> PseudoId { let result = self.alloc_pseudo(); - let is_float = operand_typ.is_float(); + let is_float = self.types.is_float(operand_typ); + let is_unsigned = self.types.is_unsigned(operand_typ); let opcode = match op { BinaryOp::Add => { @@ -3073,7 +3114,7 @@ impl<'a> Linearizer<'a> { BinaryOp::Div => { if is_float { Opcode::FDiv - } else if operand_typ.is_unsigned() { + } else if is_unsigned { Opcode::DivU } else { Opcode::DivS @@ -3082,7 +3123,7 @@ impl<'a> Linearizer<'a> { BinaryOp::Mod => { // Modulo is not supported for floats in hardware - use fmod() library call // For now, use integer modulo (semantic analysis should catch float % float) - if operand_typ.is_unsigned() { + if is_unsigned { Opcode::ModU } else { Opcode::ModS @@ -3091,7 +3132,7 @@ impl<'a> Linearizer<'a> { BinaryOp::Lt => { if is_float { Opcode::FCmpOLt - } else if operand_typ.is_unsigned() { + } else if is_unsigned { Opcode::SetB } else { Opcode::SetLt @@ -3100,7 +3141,7 @@ impl<'a> Linearizer<'a> { BinaryOp::Gt => { if is_float { Opcode::FCmpOGt - } else if operand_typ.is_unsigned() { + } else if is_unsigned { Opcode::SetA } else { Opcode::SetGt @@ -3109,7 +3150,7 @@ impl<'a> Linearizer<'a> { BinaryOp::Le => { if is_float { Opcode::FCmpOLe - } else if operand_typ.is_unsigned() { + } else if is_unsigned { Opcode::SetBe } else { Opcode::SetLe @@ -3118,7 +3159,7 @@ impl<'a> Linearizer<'a> { BinaryOp::Ge => { if is_float { Opcode::FCmpOGe - } else if operand_typ.is_unsigned() { + } else if is_unsigned { Opcode::SetAe } else { Opcode::SetGe @@ -3149,7 +3190,7 @@ impl<'a> Linearizer<'a> { BinaryOp::Shl => Opcode::Shl, BinaryOp::Shr => { // Logical shift for unsigned, arithmetic for signed - if operand_typ.is_unsigned() { + if is_unsigned { Opcode::Lsr } else { Opcode::Asr @@ -3178,19 +3219,24 @@ impl<'a> Linearizer<'a> { | Opcode::FCmpOGe => operand_typ, _ => result_typ, }; - self.emit(Instruction::binop(opcode, result, left, right, insn_typ)); + let insn_size = self.types.size_bits(insn_typ); + self.emit(Instruction::binop( + opcode, result, left, right, insn_typ, insn_size, + )); result } - fn emit_compare_zero(&mut self, val: PseudoId, operand_typ: &Type) -> PseudoId { + fn emit_compare_zero(&mut self, val: PseudoId, operand_typ: TypeId) -> PseudoId { let result = self.alloc_pseudo(); - let zero = self.emit_const(0, operand_typ.clone()); + let zero = self.emit_const(0, operand_typ); + let size = self.types.size_bits(operand_typ); self.emit(Instruction::binop( Opcode::SetNe, result, val, zero, - operand_typ.clone(), + operand_typ, + size, )); result } @@ -3199,7 +3245,7 @@ impl<'a> Linearizer<'a> { /// If a is false, skip evaluation of b and return 0. /// Otherwise, evaluate b and return (b != 0). fn emit_logical_and(&mut self, left: &Expr, right: &Expr) -> PseudoId { - let result_typ = Type::basic(TypeKind::Int); + let result_typ = self.types.int_id; // Create basic blocks let eval_b_bb = self.alloc_bb(); @@ -3208,11 +3254,11 @@ impl<'a> Linearizer<'a> { // Evaluate LHS let left_typ = self.expr_type(left); let left_val = self.linearize_expr(left); - let left_bool = self.emit_compare_zero(left_val, &left_typ); + let left_bool = self.emit_compare_zero(left_val, left_typ); // Emit the short-circuit value (0) BEFORE the branch, while still in LHS block // This value will be used if we short-circuit (LHS is false) - let zero = self.emit_const(0, result_typ.clone()); + let zero = self.emit_const(0, result_typ); // Get the block where LHS evaluation ended (may differ from initial block // if LHS contains nested control flow) @@ -3227,7 +3273,7 @@ impl<'a> Linearizer<'a> { self.switch_bb(eval_b_bb); let right_typ = self.expr_type(right); let right_val = self.linearize_expr(right); - let right_bool = self.emit_compare_zero(right_val, &right_typ); + let right_bool = self.emit_compare_zero(right_val, right_typ); // Get the actual block where RHS evaluation ended (may differ from eval_b_bb // if RHS contains nested control flow like another &&/||) @@ -3243,7 +3289,7 @@ impl<'a> Linearizer<'a> { // Result is 0 if we came from lhs_end_bb (LHS was false), // or right_bool if we came from rhs_end_bb (LHS was true) let result = self.alloc_pseudo(); - let mut phi_insn = Instruction::phi(result, result_typ); + let mut phi_insn = Instruction::phi(result, result_typ, 32); phi_insn.phi_list.push((lhs_end_bb, zero)); phi_insn.phi_list.push((rhs_end_bb, right_bool)); self.emit(phi_insn); @@ -3255,7 +3301,7 @@ impl<'a> Linearizer<'a> { /// If a is true, skip evaluation of b and return 1. /// Otherwise, evaluate b and return (b != 0). fn emit_logical_or(&mut self, left: &Expr, right: &Expr) -> PseudoId { - let result_typ = Type::basic(TypeKind::Int); + let result_typ = self.types.int_id; // Create basic blocks let eval_b_bb = self.alloc_bb(); @@ -3264,11 +3310,11 @@ impl<'a> Linearizer<'a> { // Evaluate LHS let left_typ = self.expr_type(left); let left_val = self.linearize_expr(left); - let left_bool = self.emit_compare_zero(left_val, &left_typ); + let left_bool = self.emit_compare_zero(left_val, left_typ); // Emit the short-circuit value (1) BEFORE the branch, while still in LHS block // This value will be used if we short-circuit (LHS is true) - let one = self.emit_const(1, result_typ.clone()); + let one = self.emit_const(1, result_typ); // Get the block where LHS evaluation ended (may differ from initial block // if LHS contains nested control flow) @@ -3283,7 +3329,7 @@ impl<'a> Linearizer<'a> { self.switch_bb(eval_b_bb); let right_typ = self.expr_type(right); let right_val = self.linearize_expr(right); - let right_bool = self.emit_compare_zero(right_val, &right_typ); + let right_bool = self.emit_compare_zero(right_val, right_typ); // Get the actual block where RHS evaluation ended (may differ from eval_b_bb // if RHS contains nested control flow like another &&/||) @@ -3299,7 +3345,7 @@ impl<'a> Linearizer<'a> { // Result is 1 if we came from lhs_end_bb (LHS was true), // or right_bool if we came from rhs_end_bb (LHS was false) let result = self.alloc_pseudo(); - let mut phi_insn = Instruction::phi(result, result_typ); + let mut phi_insn = Instruction::phi(result, result_typ, 32); phi_insn.phi_list.push((lhs_end_bb, one)); phi_insn.phi_list.push((rhs_end_bb, right_bool)); self.emit(phi_insn); @@ -3313,23 +3359,22 @@ impl<'a> Linearizer<'a> { let value_typ = self.expr_type(value); // Check for pointer compound assignment (p += n or p -= n) - let is_ptr_arith = target_typ.kind == TypeKind::Pointer - && value_typ.is_integer() + let is_ptr_arith = self.types.kind(target_typ) == TypeKind::Pointer + && self.types.is_integer(value_typ) && (op == AssignOp::AddAssign || op == AssignOp::SubAssign); // Convert RHS to target type if needed (but not for pointer arithmetic) let rhs = if is_ptr_arith { // For pointer arithmetic, scale the integer by element size - let elem_type = target_typ - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Char)); - let elem_size = elem_type.size_bits() / 8; - let scale = self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)); + let elem_type = self + .types + .base_type(target_typ) + .unwrap_or(self.types.char_id); + let elem_size = self.types.size_bits(elem_type) / 8; + let scale = self.emit_const(elem_size as i64, self.types.long_id); // Extend the integer to 64-bit for proper arithmetic - let rhs_extended = self.emit_convert(rhs, &value_typ, &Type::basic(TypeKind::Long)); + let rhs_extended = self.emit_convert(rhs, value_typ, self.types.long_id); let scaled = self.alloc_pseudo(); let pseudo = Pseudo::reg(scaled, scaled.0); @@ -3341,11 +3386,12 @@ impl<'a> Linearizer<'a> { scaled, rhs_extended, scale, - Type::basic(TypeKind::Long), + self.types.long_id, + 64, )); scaled } else { - self.emit_convert(rhs, &value_typ, &target_typ) + self.emit_convert(rhs, value_typ, target_typ) }; let final_val = match op { @@ -3359,7 +3405,8 @@ impl<'a> Linearizer<'a> { func.add_pseudo(pseudo); } - let is_float = target_typ.is_float(); + let is_float = self.types.is_float(target_typ); + let is_unsigned = self.types.is_unsigned(target_typ); let opcode = match op { AssignOp::AddAssign => { if is_float { @@ -3385,7 +3432,7 @@ impl<'a> Linearizer<'a> { AssignOp::DivAssign => { if is_float { Opcode::FDiv - } else if target_typ.is_unsigned() { + } else if is_unsigned { Opcode::DivU } else { Opcode::DivS @@ -3393,7 +3440,7 @@ impl<'a> Linearizer<'a> { } AssignOp::ModAssign => { // Modulo not supported for floats - if target_typ.is_unsigned() { + if is_unsigned { Opcode::ModU } else { Opcode::ModS @@ -3404,7 +3451,7 @@ impl<'a> Linearizer<'a> { AssignOp::XorAssign => Opcode::Xor, AssignOp::ShlAssign => Opcode::Shl, AssignOp::ShrAssign => { - if target_typ.is_unsigned() { + if is_unsigned { Opcode::Lsr } else { Opcode::Asr @@ -3415,16 +3462,20 @@ impl<'a> Linearizer<'a> { // For pointer arithmetic, use Long type for the operation let arith_type = if is_ptr_arith { - Type::basic(TypeKind::Long) + self.types.long_id } else { - target_typ.clone() + target_typ }; - self.emit(Instruction::binop(opcode, result, lhs, rhs, arith_type)); + let arith_size = self.types.size_bits(arith_type); + self.emit(Instruction::binop( + opcode, result, lhs, rhs, arith_type, arith_size, + )); result } }; // Store based on target expression type + let target_size = self.types.size_bits(target_typ); match &target.kind { ExprKind::Ident { name, .. } => { if let Some(local) = self.locals.get(name).cloned() { @@ -3438,7 +3489,13 @@ impl<'a> Linearizer<'a> { if let Some(func) = &mut self.current_func { func.add_pseudo(pseudo); } - self.emit(Instruction::store(final_val, sym_id, 0, target_typ.clone())); + self.emit(Instruction::store( + final_val, + sym_id, + 0, + target_typ, + target_size, + )); } } else { // Regular local variable: emit Store @@ -3446,7 +3503,8 @@ impl<'a> Linearizer<'a> { final_val, local.sym, 0, - target_typ.clone(), + target_typ, + target_size, )); } } else if self.var_map.contains_key(name) { @@ -3461,22 +3519,29 @@ impl<'a> Linearizer<'a> { if let Some(func) = &mut self.current_func { func.add_pseudo(pseudo); } - self.emit(Instruction::store(final_val, sym_id, 0, target_typ.clone())); + self.emit(Instruction::store( + final_val, + sym_id, + 0, + target_typ, + target_size, + )); } } ExprKind::Member { expr, member } => { // Struct member: get address and store with offset let base = self.linearize_lvalue(expr); let struct_type = self.expr_type(expr); - let member_info = struct_type - .find_member(member) - .unwrap_or_else(|| MemberInfo { - offset: 0, - typ: target_typ.clone(), - bit_offset: None, - bit_width: None, - storage_unit_size: None, - }); + let member_info = + self.types + .find_member(struct_type, member) + .unwrap_or(MemberInfo { + offset: 0, + typ: target_typ, + bit_offset: None, + bit_width: None, + storage_unit_size: None, + }); if let (Some(bit_offset), Some(bit_width), Some(storage_size)) = ( member_info.bit_offset, member_info.bit_width, @@ -3492,11 +3557,13 @@ impl<'a> Linearizer<'a> { final_val, ); } else { + let member_size = self.types.size_bits(member_info.typ); self.emit(Instruction::store( final_val, base, member_info.offset as i64, member_info.typ, + member_size, )); } } @@ -3504,20 +3571,17 @@ impl<'a> Linearizer<'a> { // Pointer member: pointer value is the base address let ptr = self.linearize_expr(expr); let ptr_type = self.expr_type(expr); - let struct_type = ptr_type - .base - .as_ref() - .map(|b| b.as_ref().clone()) - .unwrap_or_else(|| target_typ.clone()); - let member_info = struct_type - .find_member(member) - .unwrap_or_else(|| MemberInfo { - offset: 0, - typ: target_typ.clone(), - bit_offset: None, - bit_width: None, - storage_unit_size: None, - }); + let struct_type = self.types.base_type(ptr_type).unwrap_or(target_typ); + let member_info = + self.types + .find_member(struct_type, member) + .unwrap_or(MemberInfo { + offset: 0, + typ: target_typ, + bit_offset: None, + bit_width: None, + storage_unit_size: None, + }); if let (Some(bit_offset), Some(bit_width), Some(storage_size)) = ( member_info.bit_offset, member_info.bit_width, @@ -3533,11 +3597,13 @@ impl<'a> Linearizer<'a> { final_val, ); } else { + let member_size = self.types.size_bits(member_info.typ); self.emit(Instruction::store( final_val, ptr, member_info.offset as i64, member_info.typ, + member_size, )); } } @@ -3547,7 +3613,13 @@ impl<'a> Linearizer<'a> { } => { // Dereference: store to the pointer address let ptr = self.linearize_expr(operand); - self.emit(Instruction::store(final_val, ptr, 0, target_typ.clone())); + self.emit(Instruction::store( + final_val, + ptr, + 0, + target_typ, + target_size, + )); } ExprKind::Index { array, index } => { // Array subscript: compute address and store @@ -3555,8 +3627,9 @@ impl<'a> Linearizer<'a> { let array_type = self.expr_type(array); let index_type = self.expr_type(index); + let array_kind = self.types.kind(array_type); let (ptr_expr, idx_expr, idx_type) = - if array_type.kind == TypeKind::Pointer || array_type.kind == TypeKind::Array { + if array_kind == TypeKind::Pointer || array_kind == TypeKind::Array { (array, index, index_type) } else { // Swap: index is actually the pointer/array @@ -3565,26 +3638,40 @@ impl<'a> Linearizer<'a> { let arr = self.linearize_expr(ptr_expr); let idx = self.linearize_expr(idx_expr); - let elem_size = target_typ.size_bits() / 8; - let elem_size_val = self.emit_const(elem_size as i64, Type::basic(TypeKind::Long)); + let elem_size = target_size / 8; + let elem_size_val = self.emit_const(elem_size as i64, self.types.long_id); // Sign-extend index to 64-bit for proper pointer arithmetic (negative indices) - let idx_extended = self.emit_convert(idx, &idx_type, &Type::basic(TypeKind::Long)); + let idx_extended = self.emit_convert(idx, idx_type, self.types.long_id); let offset = self.alloc_pseudo(); - let ptr_typ = Type::basic(TypeKind::Long); + let ptr_typ = self.types.long_id; self.emit(Instruction::binop( Opcode::Mul, offset, idx_extended, elem_size_val, - ptr_typ.clone(), + ptr_typ, + 64, )); let addr = self.alloc_pseudo(); - self.emit(Instruction::binop(Opcode::Add, addr, arr, offset, ptr_typ)); + self.emit(Instruction::binop( + Opcode::Add, + addr, + arr, + offset, + ptr_typ, + 64, + )); - self.emit(Instruction::store(final_val, addr, 0, target_typ.clone())); + self.emit(Instruction::store( + final_val, + addr, + 0, + target_typ, + target_size, + )); } _ => { // Other lvalues - should not happen for valid C code @@ -3601,18 +3688,25 @@ impl<'a> Linearizer<'a> { /// Linearize an AST to IR (convenience wrapper for tests) #[cfg(test)] -pub fn linearize(tu: &TranslationUnit, symbols: &SymbolTable) -> Module { - linearize_with_debug(tu, symbols, false, None) +pub fn linearize( + tu: &TranslationUnit, + symbols: &SymbolTable, + types: &TypeTable, + target: &Target, +) -> Module { + linearize_with_debug(tu, symbols, types, target, false, None) } /// Linearize an AST to IR with debug info support pub fn linearize_with_debug( tu: &TranslationUnit, symbols: &SymbolTable, + types: &TypeTable, + target: &Target, debug: bool, source_file: Option<&str>, ) -> Module { - let mut linearizer = Linearizer::new(symbols); + let mut linearizer = Linearizer::new(symbols, types, target); let mut module = linearizer.linearize(tu); module.debug = debug; if let Some(path) = source_file { @@ -3642,14 +3736,15 @@ mod tests { } } - fn test_linearize(tu: &TranslationUnit) -> Module { + fn test_linearize(tu: &TranslationUnit, types: &TypeTable) -> Module { let symbols = SymbolTable::new(); - linearize(tu, &symbols) + let target = Target::host(); + linearize(tu, &symbols, types, &target) } - fn make_simple_func(name: &str, body: Stmt) -> FunctionDef { + fn make_simple_func(name: &str, body: Stmt, types: &TypeTable) -> FunctionDef { FunctionDef { - return_type: Type::basic(TypeKind::Int), + return_type: types.int_id, name: name.to_string(), params: vec![], body, @@ -3659,12 +3754,13 @@ mod tests { #[test] fn test_linearize_empty_function() { - let func = make_simple_func("test", Stmt::Block(vec![])); + let types = TypeTable::new(); + let func = make_simple_func("test", Stmt::Block(vec![]), &types); let tu = TranslationUnit { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); assert_eq!(module.functions.len(), 1); assert_eq!(module.functions[0].name, "test"); assert!(!module.functions[0].blocks.is_empty()); @@ -3672,68 +3768,79 @@ mod tests { #[test] fn test_linearize_return() { - let func = make_simple_func("test", Stmt::Return(Some(Expr::int(42)))); + let types = TypeTable::new(); + let func = make_simple_func("test", Stmt::Return(Some(Expr::int(42, &types))), &types); let tu = TranslationUnit { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); assert!(ir.contains("ret")); } #[test] fn test_linearize_if() { + let types = TypeTable::new(); let func = make_simple_func( "test", Stmt::If { - cond: Expr::int(1), - then_stmt: Box::new(Stmt::Return(Some(Expr::int(1)))), - else_stmt: Some(Box::new(Stmt::Return(Some(Expr::int(0))))), + cond: Expr::int(1, &types), + then_stmt: Box::new(Stmt::Return(Some(Expr::int(1, &types)))), + else_stmt: Some(Box::new(Stmt::Return(Some(Expr::int(0, &types))))), }, + &types, ); let tu = TranslationUnit { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); assert!(ir.contains("cbr")); // Conditional branch } #[test] fn test_linearize_while() { + let types = TypeTable::new(); let func = make_simple_func( "test", Stmt::While { - cond: Expr::int(1), + cond: Expr::int(1, &types), body: Box::new(Stmt::Break), }, + &types, ); let tu = TranslationUnit { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); assert!(module.functions[0].blocks.len() >= 3); // cond, body, exit } #[test] fn test_linearize_for() { + let types = TypeTable::new(); // for (int i = 0; i < 10; i++) { } - let int_type = Type::basic(TypeKind::Int); - let i_var = Expr::var_typed("i", int_type.clone()); + let int_type = types.int_id; + let i_var = Expr::var_typed("i", int_type); let func = make_simple_func( "test", Stmt::For { init: Some(ForInit::Declaration(Declaration { declarators: vec![crate::parse::ast::InitDeclarator { name: "i".to_string(), - typ: int_type.clone(), - init: Some(Expr::int(0)), + typ: int_type, + init: Some(Expr::int(0, &types)), }], })), - cond: Some(Expr::binary(BinaryOp::Lt, i_var.clone(), Expr::int(10))), + cond: Some(Expr::binary( + BinaryOp::Lt, + i_var.clone(), + Expr::int(10, &types), + &types, + )), post: Some(Expr::typed( ExprKind::PostInc(Box::new(i_var)), int_type, @@ -3741,31 +3848,40 @@ mod tests { )), body: Box::new(Stmt::Empty), }, + &types, ); let tu = TranslationUnit { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); assert!(module.functions[0].blocks.len() >= 4); // entry, cond, body, post, exit } #[test] fn test_linearize_binary_expr() { + let types = TypeTable::new(); // return 1 + 2 * 3; let func = make_simple_func( "test", Stmt::Return(Some(Expr::binary( BinaryOp::Add, - Expr::int(1), - Expr::binary(BinaryOp::Mul, Expr::int(2), Expr::int(3)), + Expr::int(1, &types), + Expr::binary( + BinaryOp::Mul, + Expr::int(2, &types), + Expr::int(3, &types), + &types, + ), + &types, ))), + &types, ); let tu = TranslationUnit { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); assert!(ir.contains("mul")); assert!(ir.contains("add")); @@ -3773,24 +3889,26 @@ mod tests { #[test] fn test_linearize_function_with_params() { - let int_type = Type::basic(TypeKind::Int); + let types = TypeTable::new(); + let int_type = types.int_id; let func = FunctionDef { - return_type: int_type.clone(), + return_type: int_type, name: "add".to_string(), params: vec![ Parameter { name: Some("a".to_string()), - typ: int_type.clone(), + typ: int_type, }, Parameter { name: Some("b".to_string()), - typ: int_type.clone(), + typ: int_type, }, ], body: Stmt::Return(Some(Expr::binary( BinaryOp::Add, - Expr::var_typed("a", int_type.clone()), + Expr::var_typed("a", int_type), Expr::var_typed("b", int_type), + &types, ))), pos: test_pos(), }; @@ -3798,7 +3916,7 @@ mod tests { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); assert!(ir.contains("add")); assert!(ir.contains("%a")); @@ -3807,18 +3925,21 @@ mod tests { #[test] fn test_linearize_call() { + let types = TypeTable::new(); let func = make_simple_func( "test", Stmt::Return(Some(Expr::call( Expr::var("foo"), - vec![Expr::int(1), Expr::int(2)], + vec![Expr::int(1, &types), Expr::int(2, &types)], + &types, ))), + &types, ); let tu = TranslationUnit { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); assert!(ir.contains("call")); assert!(ir.contains("foo")); @@ -3826,38 +3947,45 @@ mod tests { #[test] fn test_linearize_comparison() { + let types = TypeTable::new(); let func = make_simple_func( "test", - Stmt::Return(Some(Expr::binary(BinaryOp::Lt, Expr::int(1), Expr::int(2)))), + Stmt::Return(Some(Expr::binary( + BinaryOp::Lt, + Expr::int(1, &types), + Expr::int(2, &types), + &types, + ))), + &types, ); let tu = TranslationUnit { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); assert!(ir.contains("setlt")); } #[test] fn test_linearize_unsigned_comparison() { - use crate::types::TypeModifiers; + let types = TypeTable::new(); // Create unsigned comparison: (unsigned)1 < (unsigned)2 - let uint_type = Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED); - let mut left = Expr::int(1); - left.typ = Some(uint_type.clone()); - let mut right = Expr::int(2); + let uint_type = types.uint_id; + let mut left = Expr::int(1, &types); + left.typ = Some(uint_type); + let mut right = Expr::int(2, &types); right.typ = Some(uint_type); - let mut cmp = Expr::binary(BinaryOp::Lt, left, right); - cmp.typ = Some(Type::basic(TypeKind::Int)); + let mut cmp = Expr::binary(BinaryOp::Lt, left, right, &types); + cmp.typ = Some(types.int_id); - let func = make_simple_func("test", Stmt::Return(Some(cmp))); + let func = make_simple_func("test", Stmt::Return(Some(cmp)), &types); let tu = TranslationUnit { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); // Should use unsigned comparison opcode (setb = set if below) assert!( @@ -3869,12 +3997,13 @@ mod tests { #[test] fn test_display_module() { - let func = make_simple_func("main", Stmt::Return(Some(Expr::int(0)))); + let types = TypeTable::new(); + let func = make_simple_func("main", Stmt::Return(Some(Expr::int(0, &types))), &types); let tu = TranslationUnit { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); // Should have proper structure @@ -3886,36 +4015,40 @@ mod tests { #[test] fn test_type_propagation_expr_type() { - use crate::types::TypeModifiers; + let types = TypeTable::new(); // Create an expression with a type annotation - let mut expr = Expr::int(42); + let mut expr = Expr::int(42, &types); // Simulate type evaluation having set the type - expr.typ = Some(Type::basic(TypeKind::Int)); + expr.typ = Some(types.int_id); // Create linearizer and test that expr_type reads from the expression let symbols = SymbolTable::new(); - let linearizer = Linearizer::new(&symbols); + let target = Target::host(); + let linearizer = Linearizer::new(&symbols, &types, &target); let typ = linearizer.expr_type(&expr); - assert_eq!(typ.kind, TypeKind::Int); + assert_eq!(types.kind(typ), TypeKind::Int); // Test with unsigned type - let mut unsigned_expr = Expr::int(42); - unsigned_expr.typ = Some(Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED)); + let mut unsigned_expr = Expr::int(42, &types); + unsigned_expr.typ = Some(types.uint_id); let typ = linearizer.expr_type(&unsigned_expr); - assert!(typ.is_unsigned()); + assert!(types.is_unsigned(typ)); } #[test] fn test_type_propagation_double_literal() { + let types = TypeTable::new(); + // Create a double literal let mut expr = Expr::new(ExprKind::FloatLit(3.14), test_pos()); - expr.typ = Some(Type::basic(TypeKind::Double)); + expr.typ = Some(types.double_id); let symbols = SymbolTable::new(); - let linearizer = Linearizer::new(&symbols); + let target = Target::host(); + let linearizer = Linearizer::new(&symbols, &types, &target); let typ = linearizer.expr_type(&expr); - assert_eq!(typ.kind, TypeKind::Double); + assert_eq!(types.kind(typ), TypeKind::Double); } // ======================================================================== @@ -3923,26 +4056,28 @@ mod tests { // ======================================================================== /// Helper to linearize without SSA conversion (for comparing before/after) - fn linearize_no_ssa(tu: &TranslationUnit) -> Module { + fn linearize_no_ssa(tu: &TranslationUnit, types: &TypeTable) -> Module { let symbols = SymbolTable::new(); - let mut linearizer = Linearizer::new_no_ssa(&symbols); + let target = Target::host(); + let mut linearizer = Linearizer::new_no_ssa(&symbols, types, &target); linearizer.linearize(tu) } #[test] fn test_local_var_emits_load_store() { + let types = TypeTable::new(); // int test() { int x = 1; return x; } - let int_type = Type::basic(TypeKind::Int); + let int_type = types.int_id; let func = FunctionDef { - return_type: int_type.clone(), + return_type: int_type, name: "test".to_string(), params: vec![], body: Stmt::Block(vec![ BlockItem::Declaration(Declaration { declarators: vec![crate::parse::ast::InitDeclarator { name: "x".to_string(), - typ: int_type.clone(), - init: Some(Expr::int(1)), + typ: int_type, + init: Some(Expr::int(1, &types)), }], }), BlockItem::Statement(Stmt::Return(Some(Expr::var_typed("x", int_type)))), @@ -3954,7 +4089,7 @@ mod tests { }; // Without SSA, should have store and load - let module = linearize_no_ssa(&tu); + let module = linearize_no_ssa(&tu, &types); let ir = format!("{}", module); assert!( ir.contains("store"), @@ -3970,35 +4105,37 @@ mod tests { #[test] fn test_ssa_converts_local_to_phi() { + let types = TypeTable::new(); // int test(int cond) { // int x = 1; // if (cond) x = 2; // return x; // } - let int_type = Type::basic(TypeKind::Int); + let int_type = types.int_id; let func = FunctionDef { - return_type: int_type.clone(), + return_type: int_type, name: "test".to_string(), params: vec![Parameter { name: Some("cond".to_string()), - typ: int_type.clone(), + typ: int_type, }], body: Stmt::Block(vec![ // int x = 1; BlockItem::Declaration(Declaration { declarators: vec![crate::parse::ast::InitDeclarator { name: "x".to_string(), - typ: int_type.clone(), - init: Some(Expr::int(1)), + typ: int_type, + init: Some(Expr::int(1, &types)), }], }), // if (cond) x = 2; BlockItem::Statement(Stmt::If { - cond: Expr::var_typed("cond", int_type.clone()), + cond: Expr::var_typed("cond", int_type), then_stmt: Box::new(Stmt::Expr(Expr::assign( - Expr::var_typed("x", int_type.clone()), - Expr::int(2), + Expr::var_typed("x", int_type), + Expr::int(2, &types), + &types, ))), else_stmt: None, }), @@ -4012,7 +4149,7 @@ mod tests { }; // With SSA, should have phi node at merge point - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); // Should have a phi instruction @@ -4025,17 +4162,18 @@ mod tests { #[test] fn test_ssa_loop_variable() { + let types = TypeTable::new(); // int test() { // int i = 0; // while (i < 10) { i = i + 1; } // return i; // } - let int_type = Type::basic(TypeKind::Int); + let int_type = types.int_id; - let i_var = || Expr::var_typed("i", int_type.clone()); + let i_var = || Expr::var_typed("i", int_type); let func = FunctionDef { - return_type: int_type.clone(), + return_type: int_type, name: "test".to_string(), params: vec![], body: Stmt::Block(vec![ @@ -4043,16 +4181,17 @@ mod tests { BlockItem::Declaration(Declaration { declarators: vec![crate::parse::ast::InitDeclarator { name: "i".to_string(), - typ: int_type.clone(), - init: Some(Expr::int(0)), + typ: int_type, + init: Some(Expr::int(0, &types)), }], }), // while (i < 10) { i = i + 1; } BlockItem::Statement(Stmt::While { - cond: Expr::binary(BinaryOp::Lt, i_var(), Expr::int(10)), + cond: Expr::binary(BinaryOp::Lt, i_var(), Expr::int(10, &types), &types), body: Box::new(Stmt::Expr(Expr::assign( i_var(), - Expr::binary(BinaryOp::Add, i_var(), Expr::int(1)), + Expr::binary(BinaryOp::Add, i_var(), Expr::int(1, &types), &types), + &types, ))), }), // return i; @@ -4065,7 +4204,7 @@ mod tests { }; // With SSA, should have phi node at loop header - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); // Loop should have a phi at the condition block @@ -4074,29 +4213,31 @@ mod tests { #[test] fn test_short_circuit_and() { + let types = TypeTable::new(); // int test(int a, int b) { // return a && b; // } // Short-circuit: if a is false, don't evaluate b - let int_type = Type::basic(TypeKind::Int); + let int_type = types.int_id; let func = FunctionDef { - return_type: int_type.clone(), + return_type: int_type, name: "test".to_string(), params: vec![ Parameter { name: Some("a".to_string()), - typ: int_type.clone(), + typ: int_type, }, Parameter { name: Some("b".to_string()), - typ: int_type.clone(), + typ: int_type, }, ], body: Stmt::Return(Some(Expr::binary( BinaryOp::LogAnd, - Expr::var_typed("a", int_type.clone()), + Expr::var_typed("a", int_type), Expr::var_typed("b", int_type), + &types, ))), pos: test_pos(), }; @@ -4104,7 +4245,7 @@ mod tests { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); // Short-circuit AND should have: @@ -4124,29 +4265,31 @@ mod tests { #[test] fn test_short_circuit_or() { + let types = TypeTable::new(); // int test(int a, int b) { // return a || b; // } // Short-circuit: if a is true, don't evaluate b - let int_type = Type::basic(TypeKind::Int); + let int_type = types.int_id; let func = FunctionDef { - return_type: int_type.clone(), + return_type: int_type, name: "test".to_string(), params: vec![ Parameter { name: Some("a".to_string()), - typ: int_type.clone(), + typ: int_type, }, Parameter { name: Some("b".to_string()), - typ: int_type.clone(), + typ: int_type, }, ], body: Stmt::Return(Some(Expr::binary( BinaryOp::LogOr, - Expr::var_typed("a", int_type.clone()), + Expr::var_typed("a", int_type), Expr::var_typed("b", int_type), + &types, ))), pos: test_pos(), }; @@ -4154,7 +4297,7 @@ mod tests { items: vec![ExternalDecl::FunctionDef(func)], }; - let module = test_linearize(&tu); + let module = test_linearize(&tu, &types); let ir = format!("{}", module); // Short-circuit OR should have: diff --git a/cc/lower.rs b/cc/lower.rs index f079bd57..bb1fd0e7 100644 --- a/cc/lower.rs +++ b/cc/lower.rs @@ -140,7 +140,7 @@ pub fn lower_function(func: &mut Function) { mod tests { use super::*; use crate::ir::{BasicBlock, Instruction, Opcode, Pseudo, PseudoId}; - use crate::types::{Type, TypeKind}; + use crate::types::TypeTable; fn make_loop_cfg() -> Function { // Create a simple loop CFG: @@ -158,7 +158,9 @@ mod tests { // With phi node in cond block: // %3 = phi [.L0: %1], [.L2: %2] - let mut func = Function::new("test", Type::basic(TypeKind::Int)); + let types = TypeTable::new(); + let int_type = types.int_id; + let mut func = Function::new("test", int_type); // Entry block let mut entry = BasicBlock::new(BasicBlockId(0)); @@ -177,7 +179,7 @@ mod tests { cond.children = vec![BasicBlockId(2), BasicBlockId(3)]; // Phi node: %3 = phi [.L0: %1], [.L2: %2] - let mut phi = Instruction::phi(PseudoId(3), Type::basic(TypeKind::Int)); + let mut phi = Instruction::phi(PseudoId(3), int_type, 32); phi.phi_list = vec![ (BasicBlockId(0), PseudoId(1)), (BasicBlockId(2), PseudoId(2)), @@ -199,7 +201,8 @@ mod tests { PseudoId(2), PseudoId(3), PseudoId(4), // constant 1 - Type::basic(TypeKind::Int), + int_type, + 32, ); body.add_insn(add); body.add_insn(Instruction::br(BasicBlockId(1))); diff --git a/cc/main.rs b/cc/main.rs index aa3d7977..81dccb39 100644 --- a/cc/main.rs +++ b/cc/main.rs @@ -201,12 +201,13 @@ fn process_file( return Ok(()); } - // Create symbol table BEFORE parsing + // Create symbol table and type table BEFORE parsing // symbols are bound during parsing let mut symbols = SymbolTable::new(); + let mut types = types::TypeTable::new(); // Parse (this also binds symbols to the symbol table) - let mut parser = CParser::new(&preprocessed, idents, &mut symbols); + let mut parser = CParser::new(&preprocessed, idents, &mut symbols, &mut types); let ast = parser .parse_translation_unit() .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("parse error: {}", e)))?; @@ -217,8 +218,14 @@ fn process_file( } // Linearize to IR - let mut module = - linearize::linearize_with_debug(&ast, &symbols, args.debug, Some(display_path)); + let mut module = linearize::linearize_with_debug( + &ast, + &symbols, + &types, + target, + args.debug, + Some(display_path), + ); if args.dump_ir { print!("{}", module); @@ -232,7 +239,7 @@ fn process_file( let emit_unwind_tables = !args.no_unwind_tables; let mut codegen = arch::codegen::create_codegen_with_options(target.clone(), emit_unwind_tables); - let asm = codegen.generate(&module); + let asm = codegen.generate(&module, &types); if args.dump_asm { print!("{}", asm); diff --git a/cc/parse/ast.rs b/cc/parse/ast.rs index 3ad5e588..8ae4cb50 100644 --- a/cc/parse/ast.rs +++ b/cc/parse/ast.rs @@ -11,7 +11,7 @@ // use crate::diag::Position; -use crate::types::Type; +use crate::types::TypeId; // ============================================================================ // Operators @@ -116,8 +116,8 @@ pub struct Expr { /// The expression kind/variant pub kind: ExprKind, /// The computed type of this expression (like sparse's expr->ctype) - /// None before type evaluation, Some after - pub typ: Option, + /// None before type evaluation, Some after (interned TypeId) + pub typ: Option, /// Source position for debug info pub pos: Position, } @@ -132,8 +132,8 @@ impl Expr { } } - /// Create an expression with a known type and position - pub fn typed(kind: ExprKind, typ: Type, pos: Position) -> Self { + /// Create an expression with a known type ID and position + pub fn typed(kind: ExprKind, typ: TypeId, pos: Position) -> Self { Self { kind, typ: Some(typ), @@ -151,9 +151,9 @@ impl Expr { } } - /// Create an expression with a known type but default position (for tests) + /// Create an expression with a known type ID but default position (for tests) #[cfg(test)] - pub fn typed_unpositioned(kind: ExprKind, typ: Type) -> Self { + pub fn typed_unpositioned(kind: ExprKind, typ: TypeId) -> Self { Self { kind, typ: Some(typ), @@ -239,12 +239,12 @@ pub enum ExprKind { /// Type cast: (type)expr Cast { - cast_type: Type, + cast_type: TypeId, expr: Box, }, /// sizeof type: sizeof(int) - SizeofType(Type), + SizeofType(TypeId), /// sizeof expression: sizeof expr SizeofExpr(Box), @@ -274,8 +274,8 @@ pub enum ExprKind { VaArg { /// The va_list to read from ap: Box, - /// The type of argument to retrieve - arg_type: Type, + /// The type of argument to retrieve (interned TypeId) + arg_type: TypeId, }, /// __builtin_va_end(ap) @@ -351,13 +351,13 @@ pub struct InitElement { // Test-only helper constructors for AST nodes #[cfg(test)] -use crate::types::TypeKind; +use crate::types::TypeTable; #[cfg(test)] impl Expr { /// Create an integer literal (typed as int) - no position (for tests/internal use) - pub fn int(value: i64) -> Self { - Expr::typed_unpositioned(ExprKind::IntLit(value), Type::basic(TypeKind::Int)) + pub fn int(value: i64, types: &TypeTable) -> Self { + Expr::typed_unpositioned(ExprKind::IntLit(value), types.int_id) } /// Create a variable reference (untyped - needs type evaluation) - no position @@ -368,7 +368,7 @@ impl Expr { } /// Create a variable reference with a known type - no position - pub fn var_typed(name: &str, typ: Type) -> Self { + pub fn var_typed(name: &str, typ: TypeId) -> Self { Expr::typed_unpositioned( ExprKind::Ident { name: name.to_string(), @@ -377,8 +377,8 @@ impl Expr { ) } - /// Create a binary expression (result type matches left operand for arithmetic) - pub fn binary(op: BinaryOp, left: Expr, right: Expr) -> Self { + /// Create a binary expression (using TypeTable for type inference) + pub fn binary(op: BinaryOp, left: Expr, right: Expr, types: &TypeTable) -> Self { // Derive type from operands - comparisons return int, arithmetic uses left type let result_type = match op { BinaryOp::Lt @@ -388,11 +388,8 @@ impl Expr { | BinaryOp::Eq | BinaryOp::Ne | BinaryOp::LogAnd - | BinaryOp::LogOr => Type::basic(TypeKind::Int), - _ => left - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)), + | BinaryOp::LogOr => types.int_id, + _ => left.typ.unwrap_or(types.int_id), }; let pos = left.pos; Expr::typed( @@ -407,13 +404,10 @@ impl Expr { } /// Create a unary expression - pub fn unary(op: UnaryOp, operand: Expr) -> Self { + pub fn unary(op: UnaryOp, operand: Expr, types: &TypeTable) -> Self { let result_type = match op { - UnaryOp::Not => Type::basic(TypeKind::Int), - _ => operand - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)), + UnaryOp::Not => types.int_id, + _ => operand.typ.unwrap_or(types.int_id), }; let pos = operand.pos; Expr::typed( @@ -427,11 +421,8 @@ impl Expr { } /// Create an assignment (result type is target type) - pub fn assign(target: Expr, value: Expr) -> Self { - let result_type = target - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + pub fn assign(target: Expr, value: Expr, types: &TypeTable) -> Self { + let result_type = target.typ.unwrap_or(types.int_id); let pos = target.pos; Expr::typed( ExprKind::Assign { @@ -445,14 +436,14 @@ impl Expr { } /// Create a function call (returns int by default - proper type needs evaluation) - pub fn call(func: Expr, args: Vec) -> Self { + pub fn call(func: Expr, args: Vec, types: &TypeTable) -> Self { let pos = func.pos; Expr::typed( ExprKind::Call { func: Box::new(func), args, }, - Type::basic(TypeKind::Int), + types.int_id, pos, ) } @@ -552,8 +543,8 @@ pub struct Declaration { pub struct InitDeclarator { /// The name being declared pub name: String, - /// The complete type (after applying declarator modifiers) - pub typ: Type, + /// The complete type (after applying declarator modifiers) - interned TypeId + pub typ: TypeId, /// Optional initializer pub init: Option, } @@ -561,7 +552,7 @@ pub struct InitDeclarator { #[cfg(test)] impl Declaration { /// Create a simple declaration with one variable - pub fn simple(name: &str, typ: Type, init: Option) -> Self { + pub fn simple(name: &str, typ: TypeId, init: Option) -> Self { Declaration { declarators: vec![InitDeclarator { name: name.to_string(), @@ -580,14 +571,15 @@ impl Declaration { #[derive(Debug, Clone)] pub struct Parameter { pub name: Option, - pub typ: Type, + /// Parameter type (interned TypeId) + pub typ: TypeId, } /// A function definition #[derive(Debug, Clone)] pub struct FunctionDef { - /// Return type - pub return_type: Type, + /// Return type (interned TypeId) + pub return_type: TypeId, /// Function name pub name: String, /// Parameters @@ -644,7 +636,8 @@ mod tests { #[test] fn test_int_literal() { - let expr = Expr::int(42); + let types = TypeTable::new(); + let expr = Expr::int(42, &types); match expr.kind { ExprKind::IntLit(v) => assert_eq!(v, 42), _ => panic!("Expected IntLit"), @@ -653,8 +646,14 @@ mod tests { #[test] fn test_binary_expr() { + let types = TypeTable::new(); // 1 + 2 - let expr = Expr::binary(BinaryOp::Add, Expr::int(1), Expr::int(2)); + let expr = Expr::binary( + BinaryOp::Add, + Expr::int(1, &types), + Expr::int(2, &types), + &types, + ); match expr.kind { ExprKind::Binary { op, left, right } => { @@ -674,9 +673,15 @@ mod tests { #[test] fn test_nested_binary() { + let types = TypeTable::new(); // 1 + 2 * 3 (represented as 1 + (2 * 3)) - let mul = Expr::binary(BinaryOp::Mul, Expr::int(2), Expr::int(3)); - let add = Expr::binary(BinaryOp::Add, Expr::int(1), mul); + let mul = Expr::binary( + BinaryOp::Mul, + Expr::int(2, &types), + Expr::int(3, &types), + &types, + ); + let add = Expr::binary(BinaryOp::Add, Expr::int(1, &types), mul, &types); match add.kind { ExprKind::Binary { op, left, right } => { @@ -696,8 +701,9 @@ mod tests { #[test] fn test_unary_expr() { + let types = TypeTable::new(); // -x - let expr = Expr::unary(UnaryOp::Neg, Expr::var("x")); + let expr = Expr::unary(UnaryOp::Neg, Expr::var("x"), &types); match expr.kind { ExprKind::Unary { op, operand } => { @@ -713,8 +719,9 @@ mod tests { #[test] fn test_assignment() { + let types = TypeTable::new(); // x = 5 - let expr = Expr::assign(Expr::var("x"), Expr::int(5)); + let expr = Expr::assign(Expr::var("x"), Expr::int(5, &types), &types); match expr.kind { ExprKind::Assign { op, target, value } => { @@ -734,8 +741,13 @@ mod tests { #[test] fn test_function_call() { + let types = TypeTable::new(); // foo(1, 2) - let expr = Expr::call(Expr::var("foo"), vec![Expr::int(1), Expr::int(2)]); + let expr = Expr::call( + Expr::var("foo"), + vec![Expr::int(1, &types), Expr::int(2, &types)], + &types, + ); match expr.kind { ExprKind::Call { func, args } => { @@ -751,10 +763,11 @@ mod tests { #[test] fn test_if_stmt() { + let types = TypeTable::new(); // if (x) return 1; let stmt = Stmt::If { cond: Expr::var("x"), - then_stmt: Box::new(Stmt::Return(Some(Expr::int(1)))), + then_stmt: Box::new(Stmt::Return(Some(Expr::int(1, &types)))), else_stmt: None, }; @@ -811,21 +824,23 @@ mod tests { #[test] fn test_declaration() { + let types = TypeTable::new(); // int x = 5; - let decl = Declaration::simple("x", Type::basic(TypeKind::Int), Some(Expr::int(5))); + let decl = Declaration::simple("x", types.int_id, Some(Expr::int(5, &types))); assert_eq!(decl.declarators.len(), 1); assert_eq!(decl.declarators[0].name, "x"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Int); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Int); assert!(decl.declarators[0].init.is_some()); } #[test] fn test_translation_unit() { + let types = TypeTable::new(); let mut tu = TranslationUnit::new(); // Add a declaration - let decl = Declaration::simple("x", Type::basic(TypeKind::Int), None); + let decl = Declaration::simple("x", types.int_id, None); tu.add(ExternalDecl::Declaration(decl)); assert_eq!(tu.items.len(), 1); @@ -833,13 +848,14 @@ mod tests { #[test] fn test_for_loop() { + let types = TypeTable::new(); // for (int i = 0; i < 10; i++) {} let init = ForInit::Declaration(Declaration::simple( "i", - Type::basic(TypeKind::Int), - Some(Expr::int(0)), + types.int_id, + Some(Expr::int(0, &types)), )); - let cond = Expr::binary(BinaryOp::Lt, Expr::var("i"), Expr::int(10)); + let cond = Expr::binary(BinaryOp::Lt, Expr::var("i"), Expr::int(10, &types), &types); let post = Expr::new_unpositioned(ExprKind::PostInc(Box::new(Expr::var("i")))); let stmt = Stmt::For { diff --git a/cc/parse/parser.rs b/cc/parse/parser.rs index 09bc005f..ef981404 100644 --- a/cc/parse/parser.rs +++ b/cc/parse/parser.rs @@ -17,7 +17,9 @@ use super::ast::{ use crate::diag; use crate::symbol::{Namespace, Symbol, SymbolTable}; use crate::token::lexer::{IdentTable, Position, SpecialToken, Token, TokenType, TokenValue}; -use crate::types::{CompositeType, EnumConstant, StructMember, Type, TypeKind, TypeModifiers}; +use crate::types::{ + CompositeType, EnumConstant, StructMember, Type, TypeId, TypeKind, TypeModifiers, TypeTable, +}; use std::fmt; // ============================================================================ @@ -174,17 +176,25 @@ pub struct Parser<'a> { idents: &'a IdentTable, /// Symbol table for binding declarations (like sparse's bind_symbol) symbols: &'a mut SymbolTable, + /// Type table for interning types + types: &'a mut TypeTable, /// Current position in token stream pos: usize, } impl<'a> Parser<'a> { - /// Create a new parser with a symbol table - pub fn new(tokens: &'a [Token], idents: &'a IdentTable, symbols: &'a mut SymbolTable) -> Self { + /// Create a new parser with a symbol table and type table + pub fn new( + tokens: &'a [Token], + idents: &'a IdentTable, + symbols: &'a mut SymbolTable, + types: &'a mut TypeTable, + ) -> Self { Self { tokens, idents, symbols, + types, pos: 0, } } @@ -526,7 +536,7 @@ impl<'a> Parser<'a> { self.advance(); let right = self.parse_assignment_expr()?; // Comma expression type is the type of the rightmost expression - let result_typ = right.typ.clone(); + let result_typ = right.typ; // Build comma expression let Expr { kind, typ, pos } = expr; @@ -705,10 +715,7 @@ impl<'a> Parser<'a> { // The result type is the common type of then and else branches // For simplicity, use then_expr's type (proper impl would need type promotion) - let typ = then_expr - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + let typ = then_expr.typ.unwrap_or(self.types.int_id); let pos = cond.pos; Ok(Self::typed_expr( @@ -732,7 +739,7 @@ impl<'a> Parser<'a> { while self.is_special_token(SpecialToken::LogicalOr) { self.advance(); let right = self.parse_logical_and_expr()?; - left = Self::make_binary(BinaryOp::LogOr, left, right); + left = self.make_binary(BinaryOp::LogOr, left, right); } Ok(left) @@ -745,7 +752,7 @@ impl<'a> Parser<'a> { while self.is_special_token(SpecialToken::LogicalAnd) { self.advance(); let right = self.parse_bitwise_or_expr()?; - left = Self::make_binary(BinaryOp::LogAnd, left, right); + left = self.make_binary(BinaryOp::LogAnd, left, right); } Ok(left) @@ -759,7 +766,7 @@ impl<'a> Parser<'a> { while self.is_special(b'|') && !self.is_special_token(SpecialToken::LogicalOr) { self.advance(); let right = self.parse_bitwise_xor_expr()?; - left = Self::make_binary(BinaryOp::BitOr, left, right); + left = self.make_binary(BinaryOp::BitOr, left, right); } Ok(left) @@ -772,7 +779,7 @@ impl<'a> Parser<'a> { while self.is_special(b'^') && !self.is_special_token(SpecialToken::XorAssign) { self.advance(); let right = self.parse_bitwise_and_expr()?; - left = Self::make_binary(BinaryOp::BitXor, left, right); + left = self.make_binary(BinaryOp::BitXor, left, right); } Ok(left) @@ -786,7 +793,7 @@ impl<'a> Parser<'a> { while self.is_special(b'&') && !self.is_special_token(SpecialToken::LogicalAnd) { self.advance(); let right = self.parse_equality_expr()?; - left = Self::make_binary(BinaryOp::BitAnd, left, right); + left = self.make_binary(BinaryOp::BitAnd, left, right); } Ok(left) @@ -808,7 +815,7 @@ impl<'a> Parser<'a> { if let Some(binary_op) = op { self.advance(); let right = self.parse_relational_expr()?; - left = Self::make_binary(binary_op, left, right); + left = self.make_binary(binary_op, left, right); } else { break; } @@ -837,7 +844,7 @@ impl<'a> Parser<'a> { if let Some(binary_op) = op { self.advance(); let right = self.parse_shift_expr()?; - left = Self::make_binary(binary_op, left, right); + left = self.make_binary(binary_op, left, right); } else { break; } @@ -862,7 +869,7 @@ impl<'a> Parser<'a> { if let Some(binary_op) = op { self.advance(); let right = self.parse_additive_expr()?; - left = Self::make_binary(binary_op, left, right); + left = self.make_binary(binary_op, left, right); } else { break; } @@ -887,7 +894,7 @@ impl<'a> Parser<'a> { if let Some(binary_op) = op { self.advance(); let right = self.parse_multiplicative_expr()?; - left = Self::make_binary(binary_op, left, right); + left = self.make_binary(binary_op, left, right); } else { break; } @@ -914,7 +921,7 @@ impl<'a> Parser<'a> { if let Some(binary_op) = op { self.advance(); let right = self.parse_unary_expr()?; - left = Self::make_binary(binary_op, left, right); + left = self.make_binary(binary_op, left, right); } else { break; } @@ -933,10 +940,7 @@ impl<'a> Parser<'a> { // Check for const modification self.check_const_assignment(&operand, op_pos); // PreInc has same type as operand - let typ = operand - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + let typ = operand.typ.unwrap_or(self.types.int_id); return Ok(Self::typed_expr( ExprKind::Unary { op: UnaryOp::PreInc, @@ -954,10 +958,7 @@ impl<'a> Parser<'a> { // Check for const modification self.check_const_assignment(&operand, op_pos); // PreDec has same type as operand - let typ = operand - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + let typ = operand.typ.unwrap_or(self.types.int_id); return Ok(Self::typed_expr( ExprKind::Unary { op: UnaryOp::PreDec, @@ -973,11 +974,8 @@ impl<'a> Parser<'a> { self.advance(); let operand = self.parse_unary_expr()?; // AddrOf produces pointer to operand's type - let base_type = operand - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); - let ptr_type = Type::pointer(base_type); + let base_type = operand.typ.unwrap_or(self.types.int_id); + let ptr_type = self.types.intern(Type::pointer(base_type)); return Ok(Self::typed_expr( ExprKind::Unary { op: UnaryOp::AddrOf, @@ -995,9 +993,8 @@ impl<'a> Parser<'a> { // Deref produces the base type of the pointer let typ = operand .typ - .as_ref() - .and_then(|t| t.base.as_ref().map(|b| (**b).clone())) - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + .and_then(|t| self.types.base_type(t)) + .unwrap_or(self.types.int_id); return Ok(Self::typed_expr( ExprKind::Unary { op: UnaryOp::Deref, @@ -1019,10 +1016,7 @@ impl<'a> Parser<'a> { self.advance(); let operand = self.parse_unary_expr()?; // Neg has same type as operand - let typ = operand - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + let typ = operand.typ.unwrap_or(self.types.int_id); return Ok(Self::typed_expr( ExprKind::Unary { op: UnaryOp::Neg, @@ -1038,18 +1032,15 @@ impl<'a> Parser<'a> { self.advance(); let operand = self.parse_unary_expr()?; // BitNot: C99 integer promotion - types smaller than int promote to int - let op_typ = operand - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + let op_typ = operand.typ.unwrap_or(self.types.int_id); // Apply integer promotion: _Bool, char, short -> int - let typ = if matches!( - op_typ.kind, - TypeKind::Bool | TypeKind::Char | TypeKind::Short - ) { - Type::basic(TypeKind::Int) - } else { - op_typ + let typ = { + let kind = self.types.kind(op_typ); + if matches!(kind, TypeKind::Bool | TypeKind::Char | TypeKind::Short) { + self.types.int_id + } else { + op_typ + } }; return Ok(Self::typed_expr( ExprKind::Unary { @@ -1071,7 +1062,7 @@ impl<'a> Parser<'a> { op: UnaryOp::Not, operand: Box::new(operand), }, - Type::basic(TypeKind::Int), + self.types.int_id, op_pos, )); } @@ -1094,7 +1085,7 @@ impl<'a> Parser<'a> { fn parse_sizeof(&mut self) -> ParseResult { let sizeof_pos = self.current_pos(); // sizeof returns size_t, which is unsigned long in our implementation - let size_t = Type::with_modifiers(TypeKind::Long, TypeModifiers::UNSIGNED); + let size_t = self.types.ulong_id; if self.is_special(b'(') { // Could be sizeof(type) or sizeof(expr) @@ -1151,14 +1142,14 @@ impl<'a> Parser<'a> { } /// Parse a type name (required, returns error if not a type) - fn parse_type_name(&mut self) -> ParseResult { + fn parse_type_name(&mut self) -> ParseResult { self.try_parse_type_name() .ok_or_else(|| ParseError::new("expected type name".to_string(), self.current_pos())) } /// Try to parse a type name for casts and sizeof /// Supports compound types like `unsigned char`, `long long`, pointers, etc. - fn try_parse_type_name(&mut self) -> Option { + fn try_parse_type_name(&mut self) -> Option { if self.peek() != TokenType::Ident { return None; } @@ -1278,22 +1269,67 @@ impl<'a> Parser<'a> { parsed_something = true; } "struct" => { - // For cast to struct, parse struct tag + self.advance(); // consume 'struct' + // For struct tag reference, look up directly in symbol table + if let Some(tag_name) = self.get_ident_name(self.current()) { + if !self.is_special(b'{') { + // This is a tag reference (e.g., "struct Point*") + // Look up the existing tag to get its TypeId directly + self.advance(); // consume tag name + if let Some(existing) = self.symbols.lookup_tag(&tag_name) { + let mut result_id = existing.typ; + // Handle pointer + while self.is_special(b'*') { + self.advance(); + result_id = self.types.intern(Type::pointer(result_id)); + } + // Handle array declarators + while self.is_special(b'[') { + self.advance(); + if let Ok(size_expr) = self.parse_conditional_expr() { + if let Some(size) = self.eval_const_expr(&size_expr) { + result_id = self + .types + .intern(Type::array(result_id, size as usize)); + } + } + if !self.is_special(b']') { + return None; + } + self.advance(); + } + return Some(result_id); + } + // Tag not found - return incomplete struct type + let incomplete = Type::incomplete_struct(tag_name); + let mut result_id = self.types.intern(incomplete); + while self.is_special(b'*') { + self.advance(); + result_id = self.types.intern(Type::pointer(result_id)); + } + return Some(result_id); + } + } + // Fall back to full struct parsing for definitions + // (rewind position first since we consumed 'struct') + self.pos -= 1; if let Ok(struct_type) = self.parse_struct_or_union_specifier(false) { + // Intern base struct type with modifiers let mut typ = struct_type; typ.modifiers |= modifiers; + let mut result_id = self.types.intern(typ); // Handle pointer - let mut result = typ; while self.is_special(b'*') { self.advance(); - result = Type::pointer(result); + result_id = self.types.intern(Type::pointer(result_id)); } // Handle array declarators while self.is_special(b'[') { self.advance(); if let Ok(size_expr) = self.parse_conditional_expr() { if let Some(size) = self.eval_const_expr(&size_expr) { - result.array_size = Some(size as usize); + result_id = + self.types.intern(Type::array(result_id, size as usize)); } } if !self.is_special(b']') { @@ -1301,25 +1337,66 @@ impl<'a> Parser<'a> { } self.advance(); } - return Some(result); + return Some(result_id); } return None; } "union" => { + self.advance(); // consume 'union' + // For union tag reference, look up directly in symbol table + if let Some(tag_name) = self.get_ident_name(self.current()) { + if !self.is_special(b'{') { + // This is a tag reference + self.advance(); // consume tag name + if let Some(existing) = self.symbols.lookup_tag(&tag_name) { + let mut result_id = existing.typ; + while self.is_special(b'*') { + self.advance(); + result_id = self.types.intern(Type::pointer(result_id)); + } + while self.is_special(b'[') { + self.advance(); + if let Ok(size_expr) = self.parse_conditional_expr() { + if let Some(size) = self.eval_const_expr(&size_expr) { + result_id = self + .types + .intern(Type::array(result_id, size as usize)); + } + } + if !self.is_special(b']') { + return None; + } + self.advance(); + } + return Some(result_id); + } + // Tag not found - return incomplete union type + let incomplete = Type::incomplete_union(tag_name); + let mut result_id = self.types.intern(incomplete); + while self.is_special(b'*') { + self.advance(); + result_id = self.types.intern(Type::pointer(result_id)); + } + return Some(result_id); + } + } + // Fall back to full union parsing for definitions + self.pos -= 1; if let Ok(union_type) = self.parse_struct_or_union_specifier(true) { let mut typ = union_type; typ.modifiers |= modifiers; - let mut result = typ; + let mut result_id = self.types.intern(typ); while self.is_special(b'*') { self.advance(); - result = Type::pointer(result); + result_id = self.types.intern(Type::pointer(result_id)); } // Handle array declarators while self.is_special(b'[') { self.advance(); if let Ok(size_expr) = self.parse_conditional_expr() { if let Some(size) = self.eval_const_expr(&size_expr) { - result.array_size = Some(size as usize); + result_id = + self.types.intern(Type::array(result_id, size as usize)); } } if !self.is_special(b']') { @@ -1327,7 +1404,7 @@ impl<'a> Parser<'a> { } self.advance(); } - return Some(result); + return Some(result_id); } return None; } @@ -1335,17 +1412,18 @@ impl<'a> Parser<'a> { if let Ok(enum_type) = self.parse_enum_specifier() { let mut typ = enum_type; typ.modifiers |= modifiers; - let mut result = typ; + let mut result_id = self.types.intern(typ); while self.is_special(b'*') { self.advance(); - result = Type::pointer(result); + result_id = self.types.intern(Type::pointer(result_id)); } // Handle array declarators while self.is_special(b'[') { self.advance(); if let Ok(size_expr) = self.parse_conditional_expr() { if let Some(size) = self.eval_const_expr(&size_expr) { - result.array_size = Some(size as usize); + result_id = + self.types.intern(Type::array(result_id, size as usize)); } } if !self.is_special(b']') { @@ -1353,28 +1431,30 @@ impl<'a> Parser<'a> { } self.advance(); } - return Some(result); + return Some(result_id); } return None; } _ => { // Check if it's a typedef name if base_kind.is_none() { - if let Some(typedef_type) = self.symbols.lookup_typedef(&name).cloned() { + if let Some(typedef_type_id) = self.symbols.lookup_typedef(&name) { self.advance(); - let mut result = typedef_type; - result.modifiers |= modifiers; + // For typedef, we already have a TypeId - just apply pointer/array modifiers + let mut result_id = typedef_type_id; // Handle pointer declarators while self.is_special(b'*') { self.advance(); - result = Type::pointer(result); + result_id = self.types.intern(Type::pointer(result_id)); } // Handle array declarators while self.is_special(b'[') { self.advance(); if let Ok(size_expr) = self.parse_conditional_expr() { if let Some(size) = self.eval_const_expr(&size_expr) { - result.array_size = Some(size as usize); + result_id = self + .types + .intern(Type::array(result_id, size as usize)); } } if !self.is_special(b']') { @@ -1382,7 +1462,7 @@ impl<'a> Parser<'a> { } self.advance(); } - return Some(result); + return Some(result_id); } } break; @@ -1396,12 +1476,13 @@ impl<'a> Parser<'a> { // If we only have modifiers like `unsigned` without a base type, default to int let kind = base_kind.unwrap_or(TypeKind::Int); - let mut typ = Type::with_modifiers(kind, modifiers); + let typ = Type::with_modifiers(kind, modifiers); + let mut result_id = self.types.intern(typ); // Handle pointer declarators while self.is_special(b'*') { self.advance(); - typ = Type::pointer(typ); + result_id = self.types.intern(Type::pointer(result_id)); } // Handle array declarators: int[10], char[20], etc. @@ -1409,7 +1490,7 @@ impl<'a> Parser<'a> { self.advance(); if let Ok(size_expr) = self.parse_conditional_expr() { if let Some(size) = self.eval_const_expr(&size_expr) { - typ.array_size = Some(size as usize); + result_id = self.types.intern(Type::array(result_id, size as usize)); } } if !self.is_special(b']') { @@ -1418,7 +1499,7 @@ impl<'a> Parser<'a> { self.advance(); } - Some(typ) + Some(result_id) } /// Parse postfix expression: x++, x--, x[i], x.member, x->member, x(args) @@ -1435,10 +1516,7 @@ impl<'a> Parser<'a> { // Check for const modification self.check_const_assignment(&expr, op_pos); // PostInc has same type as operand - let typ = expr - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + let typ = expr.typ.unwrap_or(self.types.int_id); expr = Self::typed_expr(ExprKind::PostInc(Box::new(expr)), typ, base_pos); } else if self.is_special_token(SpecialToken::Decrement) { let op_pos = self.current_pos(); @@ -1446,10 +1524,7 @@ impl<'a> Parser<'a> { // Check for const modification self.check_const_assignment(&expr, op_pos); // PostDec has same type as operand - let typ = expr - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + let typ = expr.typ.unwrap_or(self.types.int_id); expr = Self::typed_expr(ExprKind::PostDec(Box::new(expr)), typ, base_pos); } else if self.is_special(b'[') { // Array subscript @@ -1459,10 +1534,8 @@ impl<'a> Parser<'a> { // Get element type from array/pointer type let elem_type = expr .typ - .as_ref() - .and_then(|t| t.base.as_ref()) - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + .and_then(|t| self.types.base_type(t)) + .unwrap_or(self.types.int_id); expr = Self::typed_expr( ExprKind::Index { array: Box::new(expr), @@ -1478,10 +1551,9 @@ impl<'a> Parser<'a> { // Get member type from struct type let member_type = expr .typ - .as_ref() - .and_then(|t| t.find_member(&member)) + .and_then(|t| self.types.find_member(t, &member)) .map(|info| info.typ) - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + .unwrap_or(self.types.int_id); expr = Self::typed_expr( ExprKind::Member { expr: Box::new(expr), @@ -1497,11 +1569,10 @@ impl<'a> Parser<'a> { // Get member type: dereference pointer to get struct, then find member let member_type = expr .typ - .as_ref() - .and_then(|t| t.base.as_ref()) - .and_then(|struct_type| struct_type.find_member(&member)) + .and_then(|t| self.types.base_type(t)) + .and_then(|struct_type| self.types.find_member(struct_type, &member)) .map(|info| info.typ) - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + .unwrap_or(self.types.int_id); expr = Self::typed_expr( ExprKind::Arrow { expr: Box::new(expr), @@ -1521,15 +1592,14 @@ impl<'a> Parser<'a> { // and the return type is stored in base let return_type = expr .typ - .as_ref() .and_then(|t| { - if t.kind == TypeKind::Function { - t.base.as_ref().map(|b| (**b).clone()) + if self.types.kind(t) == TypeKind::Function { + self.types.base_type(t) } else { None } }) - .unwrap_or_else(|| Type::basic(TypeKind::Int)); // Default to int + .unwrap_or(self.types.int_id); // Default to int expr = Self::typed_expr( ExprKind::Call { @@ -1591,9 +1661,13 @@ impl<'a> Parser<'a> { operand, } = &target.kind { - if let Some(ptr_type) = &operand.typ { - if let Some(base_type) = &ptr_type.base { - if base_type.modifiers.contains(TypeModifiers::CONST) { + if let Some(ptr_type_id) = operand.typ { + if let Some(base_type_id) = self.types.base_type(ptr_type_id) { + if self + .types + .modifiers(base_type_id) + .contains(TypeModifiers::CONST) + { diag::error(pos, "assignment of read-only location"); return; // Don't duplicate with the general const check } @@ -1602,8 +1676,8 @@ impl<'a> Parser<'a> { } // Check if target type has CONST modifier (direct const variable) - if let Some(typ) = &target.typ { - if typ.modifiers.contains(TypeModifiers::CONST) { + if let Some(typ_id) = target.typ { + if self.types.modifiers(typ_id).contains(TypeModifiers::CONST) { // Get variable name if it's an identifier let var_name = match &target.kind { ExprKind::Ident { name } => format!(" '{}'", name), @@ -1619,7 +1693,7 @@ impl<'a> Parser<'a> { /// Parse primary expression: literals, identifiers, parenthesized expressions /// Create a typed expression with position - fn typed_expr(kind: ExprKind, typ: Type, pos: Position) -> Expr { + fn typed_expr(kind: ExprKind, typ: TypeId, pos: Position) -> Expr { Expr { kind, typ: Some(typ), @@ -1628,16 +1702,10 @@ impl<'a> Parser<'a> { } /// Create a typed binary expression, computing result type from operands - fn make_binary(op: BinaryOp, left: Expr, right: Expr) -> Expr { + fn make_binary(&mut self, op: BinaryOp, left: Expr, right: Expr) -> Expr { // Compute result type based on operator and operand types - let left_type = left - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); - let right_type = right - .typ - .clone() - .unwrap_or_else(|| Type::basic(TypeKind::Int)); + let left_type = left.typ.unwrap_or(self.types.int_id); + let right_type = right.typ.unwrap_or(self.types.int_id); let result_type = match op { // Comparison and logical operators always return int @@ -1648,50 +1716,51 @@ impl<'a> Parser<'a> { | BinaryOp::Le | BinaryOp::Ge | BinaryOp::LogAnd - | BinaryOp::LogOr => Type::basic(TypeKind::Int), + | BinaryOp::LogOr => self.types.int_id, // Arithmetic operators use usual arithmetic conversions // But Add/Sub with pointers/arrays need special handling BinaryOp::Add | BinaryOp::Sub => { + let left_kind = self.types.kind(left_type); + let right_kind = self.types.kind(right_type); let left_is_ptr_or_arr = - left_type.kind == TypeKind::Pointer || left_type.kind == TypeKind::Array; + left_kind == TypeKind::Pointer || left_kind == TypeKind::Array; let right_is_ptr_or_arr = - right_type.kind == TypeKind::Pointer || right_type.kind == TypeKind::Array; + right_kind == TypeKind::Pointer || right_kind == TypeKind::Array; - if left_is_ptr_or_arr && right_type.is_integer() { + if left_is_ptr_or_arr && self.types.is_integer(right_type) { // ptr + int or arr + int -> pointer to element type - if left_type.kind == TypeKind::Array { + if left_kind == TypeKind::Array { // Array decays to pointer - let elem_type = left_type - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Int)); - Type::pointer(elem_type) + let elem_type = + self.types.base_type(left_type).unwrap_or(self.types.int_id); + self.types.intern(Type::pointer(elem_type)) } else { - left_type.clone() + left_type } - } else if left_type.is_integer() && right_is_ptr_or_arr && op == BinaryOp::Add { + } else if self.types.is_integer(left_type) + && right_is_ptr_or_arr + && op == BinaryOp::Add + { // int + ptr or int + arr -> pointer to element type - if right_type.kind == TypeKind::Array { - let elem_type = right_type - .base - .as_ref() - .map(|b| (**b).clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Int)); - Type::pointer(elem_type) + if right_kind == TypeKind::Array { + let elem_type = self + .types + .base_type(right_type) + .unwrap_or(self.types.int_id); + self.types.intern(Type::pointer(elem_type)) } else { - right_type.clone() + right_type } } else if left_is_ptr_or_arr && right_is_ptr_or_arr && op == BinaryOp::Sub { // ptr - ptr -> ptrdiff_t (long) - Type::basic(TypeKind::Long) + self.types.long_id } else { - Self::usual_arithmetic_conversions(&left_type, &right_type) + self.usual_arithmetic_conversions(left_type, right_type) } } BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => { - Self::usual_arithmetic_conversions(&left_type, &right_type) + self.usual_arithmetic_conversions(left_type, right_type) } // Bitwise and shift operators use integer promotions @@ -1699,7 +1768,7 @@ impl<'a> Parser<'a> { | BinaryOp::BitOr | BinaryOp::BitXor | BinaryOp::Shl - | BinaryOp::Shr => Self::usual_arithmetic_conversions(&left_type, &right_type), + | BinaryOp::Shr => self.usual_arithmetic_conversions(left_type, right_type), }; let pos = left.pos; @@ -1715,39 +1784,40 @@ impl<'a> Parser<'a> { } /// Compute usual arithmetic conversions (C99 6.3.1.8) - fn usual_arithmetic_conversions(left: &Type, right: &Type) -> Type { - use crate::types::TypeModifiers; - + fn usual_arithmetic_conversions(&mut self, left: TypeId, right: TypeId) -> TypeId { // C99 6.3.1.8: Usual arithmetic conversions // 1. If either is long double, result is long double // 2. If either is double, result is double // 3. If either is float, result is float // 4. Otherwise, integer promotions apply - if left.kind == TypeKind::LongDouble || right.kind == TypeKind::LongDouble { - Type::basic(TypeKind::LongDouble) - } else if left.kind == TypeKind::Double || right.kind == TypeKind::Double { - Type::basic(TypeKind::Double) - } else if left.kind == TypeKind::Float || right.kind == TypeKind::Float { - Type::basic(TypeKind::Float) - } else if left.kind == TypeKind::LongLong || right.kind == TypeKind::LongLong { + let left_kind = self.types.kind(left); + let right_kind = self.types.kind(right); + + if left_kind == TypeKind::LongDouble || right_kind == TypeKind::LongDouble { + self.types.longdouble_id + } else if left_kind == TypeKind::Double || right_kind == TypeKind::Double { + self.types.double_id + } else if left_kind == TypeKind::Float || right_kind == TypeKind::Float { + self.types.float_id + } else if left_kind == TypeKind::LongLong || right_kind == TypeKind::LongLong { // If either is unsigned long long, result is unsigned long long - if left.is_unsigned() || right.is_unsigned() { - Type::with_modifiers(TypeKind::LongLong, TypeModifiers::UNSIGNED) + if self.types.is_unsigned(left) || self.types.is_unsigned(right) { + self.types.ulonglong_id } else { - Type::basic(TypeKind::LongLong) + self.types.longlong_id } - } else if left.kind == TypeKind::Long || right.kind == TypeKind::Long { + } else if left_kind == TypeKind::Long || right_kind == TypeKind::Long { // If either is unsigned long, result is unsigned long - if left.is_unsigned() || right.is_unsigned() { - Type::with_modifiers(TypeKind::Long, TypeModifiers::UNSIGNED) + if self.types.is_unsigned(left) || self.types.is_unsigned(right) { + self.types.ulong_id } else { - Type::basic(TypeKind::Long) + self.types.long_id } - } else if left.is_unsigned() || right.is_unsigned() { - Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED) + } else if self.types.is_unsigned(left) || self.types.is_unsigned(right) { + self.types.uint_id } else { - Type::basic(TypeKind::Int) + self.types.int_id } } @@ -1784,7 +1854,7 @@ impl<'a> Parser<'a> { ap: Box::new(ap), last_param, }, - Type::basic(TypeKind::Void), + self.types.void_id, token_pos, )); } @@ -1799,7 +1869,7 @@ impl<'a> Parser<'a> { return Ok(Self::typed_expr( ExprKind::VaArg { ap: Box::new(ap), - arg_type: arg_type.clone(), + arg_type, }, arg_type, token_pos, @@ -1812,7 +1882,7 @@ impl<'a> Parser<'a> { self.expect_special(b')')?; return Ok(Self::typed_expr( ExprKind::VaEnd { ap: Box::new(ap) }, - Type::basic(TypeKind::Void), + self.types.void_id, token_pos, )); } @@ -1828,7 +1898,7 @@ impl<'a> Parser<'a> { dest: Box::new(dest), src: Box::new(src), }, - Type::basic(TypeKind::Void), + self.types.void_id, token_pos, )); } @@ -1839,7 +1909,7 @@ impl<'a> Parser<'a> { self.expect_special(b')')?; return Ok(Self::typed_expr( ExprKind::Bswap16 { arg: Box::new(arg) }, - Type::with_modifiers(TypeKind::Short, TypeModifiers::UNSIGNED), + self.types.ushort_id, token_pos, )); } @@ -1850,7 +1920,7 @@ impl<'a> Parser<'a> { self.expect_special(b')')?; return Ok(Self::typed_expr( ExprKind::Bswap32 { arg: Box::new(arg) }, - Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED), + self.types.uint_id, token_pos, )); } @@ -1861,7 +1931,7 @@ impl<'a> Parser<'a> { self.expect_special(b')')?; return Ok(Self::typed_expr( ExprKind::Bswap64 { arg: Box::new(arg) }, - Type::with_modifiers(TypeKind::LongLong, TypeModifiers::UNSIGNED), + self.types.ulonglong_id, token_pos, )); } @@ -1874,7 +1944,7 @@ impl<'a> Parser<'a> { ExprKind::Alloca { size: Box::new(size), }, - Type::pointer(Type::basic(TypeKind::Void)), + self.types.void_ptr_id, token_pos, )); } @@ -1888,7 +1958,7 @@ impl<'a> Parser<'a> { let is_constant = self.eval_const_expr(&arg).is_some(); return Ok(Self::typed_expr( ExprKind::IntLit(if is_constant { 1 } else { 0 }), - Type::basic(TypeKind::Int), + self.types.int_id, token_pos, )); } @@ -1901,10 +1971,10 @@ impl<'a> Parser<'a> { let type2 = self.parse_type_name()?; self.expect_special(b')')?; // Check type compatibility (ignoring qualifiers) - let compatible = type1.types_compatible(&type2); + let compatible = self.types.types_compatible(type1, type2); return Ok(Self::typed_expr( ExprKind::IntLit(if compatible { 1 } else { 0 }), - Type::basic(TypeKind::Int), + self.types.int_id, token_pos, )); } @@ -1915,8 +1985,8 @@ impl<'a> Parser<'a> { let typ = self .symbols .lookup(&name, Namespace::Ordinary) - .map(|s| s.typ.clone()) - .unwrap_or_else(|| Type::basic(TypeKind::Int)); // Default to int if not found + .map(|s| s.typ) + .unwrap_or(self.types.int_id); // Default to int if not found Ok(Self::typed_expr(ExprKind::Ident { name }, typ, token_pos)) } else { Err(ParseError::new("invalid identifier token", token.pos)) @@ -1931,7 +2001,7 @@ impl<'a> Parser<'a> { let c = self.parse_char_literal(s); Ok(Self::typed_expr( ExprKind::CharLit(c), - Type::basic(TypeKind::Int), + self.types.int_id, token_pos, )) } else { @@ -1948,7 +2018,7 @@ impl<'a> Parser<'a> { // String literal type is char* Ok(Self::typed_expr( ExprKind::StringLit(parsed), - Type::pointer(Type::basic(TypeKind::Char)), + self.types.char_ptr_id, token_pos, )) } else { @@ -1969,7 +2039,7 @@ impl<'a> Parser<'a> { // Cast expression has the cast type return Ok(Self::typed_expr( ExprKind::Cast { - cast_type: typ.clone(), + cast_type: typ, expr: Box::new(expr), }, typ, @@ -2034,11 +2104,11 @@ impl<'a> Parser<'a> { .map_err(|_| ParseError::new(format!("invalid float literal: {}", s), pos))? }; let typ = if is_float_suffix { - Type::basic(TypeKind::Float) + self.types.float_id } else if is_longdouble_suffix { - Type::basic(TypeKind::LongDouble) + self.types.longdouble_id } else { - Type::basic(TypeKind::Double) + self.types.double_id }; Ok(Self::typed_expr(ExprKind::FloatLit(value), typ, pos)) } else { @@ -2071,18 +2141,12 @@ impl<'a> Parser<'a> { let value = value_u64 as i64; let typ = match (is_longlong, is_long, is_unsigned) { - (true, _, false) => Type::basic(TypeKind::LongLong), - (true, _, true) => { - Type::with_modifiers(TypeKind::LongLong, TypeModifiers::UNSIGNED) - } - (false, true, false) => Type::basic(TypeKind::Long), - (false, true, true) => { - Type::with_modifiers(TypeKind::Long, TypeModifiers::UNSIGNED) - } - (false, false, false) => Type::basic(TypeKind::Int), - (false, false, true) => { - Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED) - } + (true, _, false) => self.types.longlong_id, + (true, _, true) => self.types.ulonglong_id, + (false, true, false) => self.types.long_id, + (false, true, true) => self.types.ulong_id, + (false, false, false) => self.types.int_id, + (false, false, true) => self.types.uint_id, }; Ok(Self::typed_expr(ExprKind::IntLit(value), typ, pos)) } @@ -2570,6 +2634,7 @@ impl Parser<'_> { fn parse_declaration(&mut self) -> ParseResult { // Parse type specifiers let base_type = self.parse_type_specifier()?; + let base_type_id = self.types.intern(base_type); // Parse declarators let mut declarators = Vec::new(); @@ -2578,7 +2643,7 @@ impl Parser<'_> { // e.g., "struct point { int x; int y; };" if !self.is_special(b';') { loop { - let (name, typ) = self.parse_declarator(base_type.clone())?; + let (name, typ) = self.parse_declarator(base_type_id)?; let init = if self.is_special(b'=') { self.advance(); Some(self.parse_initializer()?) @@ -2608,6 +2673,10 @@ impl Parser<'_> { fn parse_declaration_and_bind(&mut self) -> ParseResult { // Parse type specifiers let base_type = self.parse_type_specifier()?; + // Check modifiers from the specifier before interning (storage class is not part of type) + let is_typedef = base_type.modifiers.contains(TypeModifiers::TYPEDEF); + // Intern the base type (without storage class specifiers for the actual type) + let base_type_id = self.types.intern(base_type); // Parse declarators let mut declarators = Vec::new(); @@ -2616,9 +2685,7 @@ impl Parser<'_> { // e.g., "struct point { int x; int y; };" if !self.is_special(b';') { loop { - let (name, typ) = self.parse_declarator(base_type.clone())?; - // Check if this is a typedef declaration - let is_typedef = base_type.modifiers.contains(TypeModifiers::TYPEDEF); + let (name, typ) = self.parse_declarator(base_type_id)?; let init = if self.is_special(b'=') { if is_typedef { @@ -2636,13 +2703,11 @@ impl Parser<'_> { // Bind to symbol table (like sparse's bind_symbol) if !name.is_empty() { if is_typedef { - // Strip TYPEDEF modifier from the type being aliased - let mut aliased_type = typ.clone(); - aliased_type.modifiers.remove(TypeModifiers::TYPEDEF); - let sym = Symbol::typedef(name.clone(), aliased_type, self.symbols.depth()); + // For typedef, the type being aliased is the declarator type + let sym = Symbol::typedef(name.clone(), typ, self.symbols.depth()); let _ = self.symbols.declare(sym); } else { - let sym = Symbol::variable(name.clone(), typ.clone(), self.symbols.depth()); + let sym = Symbol::variable(name.clone(), typ, self.symbols.depth()); let _ = self.symbols.declare(sym); } } @@ -2806,10 +2871,13 @@ impl Parser<'_> { // Check if it's a typedef name // Only consume the typedef if we haven't already seen a base type if base_kind.is_none() { - if let Some(typedef_type) = self.symbols.lookup_typedef(&name).cloned() { + if let Some(typedef_type_id) = self.symbols.lookup_typedef(&name) { self.advance(); - let mut result = typedef_type; - // Merge in any modifiers we collected (const, volatile, etc.) + // Get the underlying type and merge in any modifiers we collected + let typedef_type = self.types.get(typedef_type_id); + let mut result = typedef_type.clone(); + // Strip TYPEDEF modifier - we're using the typedef, not defining one + result.modifiers &= !TypeModifiers::TYPEDEF; result.modifiers |= modifiers; return Ok(result); } @@ -2863,7 +2931,8 @@ impl Parser<'_> { next_value = value + 1; // Register enum constant in symbol table (Ordinary namespace) - let sym = Symbol::enum_constant(name, value, self.symbols.depth()); + let sym = + Symbol::enum_constant(name, value, self.types.int_id, self.symbols.depth()); let _ = self.symbols.declare(sym); if self.is_special(b',') { @@ -2888,20 +2957,23 @@ impl Parser<'_> { is_complete: true, }; + let enum_type = Type::enum_type(composite); + // Register tag if present if let Some(ref tag_name) = tag { - let typ = Type::enum_type(composite.clone()); - let sym = Symbol::tag(tag_name.clone(), typ, self.symbols.depth()); + let enum_type_id = self.types.intern(enum_type.clone()); + let sym = Symbol::tag(tag_name.clone(), enum_type_id, self.symbols.depth()); let _ = self.symbols.declare(sym); } - Ok(Type::enum_type(composite)) + Ok(enum_type) } else { // Forward reference - look up existing tag if let Some(ref tag_name) = tag { // Look up or create incomplete type if let Some(existing) = self.symbols.lookup_tag(tag_name) { - Ok(existing.typ.clone()) + // Return a clone of the underlying type + Ok(self.types.get(existing.typ).clone()) } else { Ok(Type::incomplete_enum(tag_name.clone())) } @@ -2952,6 +3024,7 @@ impl Parser<'_> { while !self.is_special(b'}') && !self.is_eof() { // Parse member declaration let member_base_type = self.parse_type_specifier()?; + let member_base_type_id = self.types.intern(member_base_type); // Check for unnamed bitfield (starts with ':') if self.is_special(b':') { @@ -2961,7 +3034,7 @@ impl Parser<'_> { members.push(StructMember { name: String::new(), - typ: member_base_type.clone(), + typ: member_base_type_id, offset: 0, bit_offset: None, bit_width: Some(width), @@ -2973,14 +3046,14 @@ impl Parser<'_> { } loop { - let (name, typ) = self.parse_declarator(member_base_type.clone())?; + let (name, typ) = self.parse_declarator(member_base_type_id)?; // Check for bitfield: name : width let bit_width = if self.is_special(b':') { self.advance(); // consume ':' let width = self.parse_bitfield_width()?; // Validate bitfield type and width - self.validate_bitfield(&typ, width)?; + self.validate_bitfield(typ, width)?; Some(width) } else { None @@ -3012,9 +3085,9 @@ impl Parser<'_> { // Compute layout let (size, align) = if is_union { - CompositeType::compute_union_layout(&mut members) + self.types.compute_union_layout(&mut members) } else { - CompositeType::compute_struct_layout(&mut members) + self.types.compute_struct_layout(&mut members) }; let composite = CompositeType { @@ -3026,28 +3099,26 @@ impl Parser<'_> { is_complete: true, }; + let struct_type = if is_union { + Type::union_type(composite) + } else { + Type::struct_type(composite) + }; + // Register tag if present if let Some(ref tag_name) = tag { - let typ = if is_union { - Type::union_type(composite.clone()) - } else { - Type::struct_type(composite.clone()) - }; - let sym = Symbol::tag(tag_name.clone(), typ, self.symbols.depth()); + let typ_id = self.types.intern(struct_type.clone()); + let sym = Symbol::tag(tag_name.clone(), typ_id, self.symbols.depth()); let _ = self.symbols.declare(sym); } - if is_union { - Ok(Type::union_type(composite)) - } else { - Ok(Type::struct_type(composite)) - } + Ok(struct_type) } else { // Forward reference if let Some(ref tag_name) = tag { // Look up existing tag if let Some(existing) = self.symbols.lookup_tag(tag_name) { - Ok(existing.typ.clone()) + Ok(self.types.get(existing.typ).clone()) } else if is_union { Ok(Type::incomplete_union(tag_name.clone())) } else { @@ -3069,7 +3140,7 @@ impl Parser<'_> { /// - `(*p)` means p is a pointer /// - `[3]` after the parens means "to array of 3" /// So p is "pointer to array of 3 ints" - fn parse_declarator(&mut self, base_type: Type) -> ParseResult<(String, Type)> { + fn parse_declarator(&mut self, base_type_id: TypeId) -> ParseResult<(String, TypeId)> { // Collect pointer modifiers (they bind tighter than array/function) let mut pointer_modifiers: Vec = Vec::new(); while self.is_special(b'*') { @@ -3103,7 +3174,7 @@ impl Parser<'_> { // Check for parenthesized declarator: int (*p)[3] // The paren comes AFTER pointers, e.g. int *(*p)[3] = pointer to (pointer to array of 3 ints) - let (name, inner_type) = if self.is_special(b'(') { + let (name, inner_type_id) = if self.is_special(b'(') { // Check if this looks like a function parameter list or a grouped declarator // A grouped declarator will have * or identifier immediately after ( let saved_pos = self.pos; @@ -3120,14 +3191,13 @@ impl Parser<'_> { // because [3] comes after the ) // So we use a placeholder and fix it up after - // Parse the inner declarator with a placeholder type - let placeholder = Type::basic(TypeKind::Void); - let (inner_name, inner_decl_type) = self.parse_declarator(placeholder)?; + // Parse the inner declarator with a placeholder type (void) + let (inner_name, inner_decl_type_id) = self.parse_declarator(self.types.void_id)?; self.expect_special(b')')?; // Now parse any suffix modifiers (arrays, function params) // These apply to the base type, not the inner declarator - (inner_name, Some(inner_decl_type)) + (inner_name, Some(inner_decl_type_id)) } else { // Not a grouped declarator, restore position self.pos = saved_pos; @@ -3159,62 +3229,57 @@ impl Parser<'_> { // Handle function declarators: void (*fp)(int, char) // This parses the parameter list after a grouped declarator - let func_params: Option<(Vec, bool)> = if self.is_special(b'(') { + let func_params: Option<(Vec, bool)> = if self.is_special(b'(') { self.advance(); let (params, variadic) = self.parse_parameter_list()?; self.expect_special(b')')?; - Some((params.iter().map(|p| p.typ.clone()).collect(), variadic)) + Some((params.iter().map(|p| p.typ).collect(), variadic)) } else { None }; // Build the type from the base type - let mut result_type = base_type; + let mut result_type_id = base_type_id; - if inner_type.is_some() { + if inner_type_id.is_some() { // Grouped declarator: int (*p)[3] or void (*fp)(int) // Arrays/functions in suffix apply to the base type first // Then we substitute into the inner declarator // Apply function parameters to base type first (if present) // For void (*fp)(int): base is void, suffix (int) -> Function(void, [int]) - if let Some((param_types, variadic)) = func_params { - result_type = Type { - kind: TypeKind::Function, - modifiers: TypeModifiers::empty(), - base: Some(Box::new(result_type)), - array_size: None, - params: Some(param_types), - variadic, - composite: None, - }; + if let Some((param_type_ids, variadic)) = func_params { + let func_type = Type::function(result_type_id, param_type_ids, variadic); + result_type_id = self.types.intern(func_type); } // Apply array dimensions to base type // For int (*p)[3]: base is int, suffix [3] -> Array(3, int) for size in dimensions.into_iter().rev() { - result_type = Type { + let arr_type = Type { kind: TypeKind::Array, modifiers: TypeModifiers::empty(), - base: Some(Box::new(result_type)), + base: Some(result_type_id), array_size: size, params: None, variadic: false, composite: None, }; + result_type_id = self.types.intern(arr_type); } // Apply any outer pointers (before the parens) for modifiers in pointer_modifiers.into_iter().rev() { - result_type = Type { + let ptr_type = Type { kind: TypeKind::Pointer, modifiers, - base: Some(Box::new(result_type)), + base: Some(result_type_id), array_size: None, params: None, variadic: false, composite: None, }; + result_type_id = self.types.intern(ptr_type); } // Substitute into inner declarator @@ -3222,73 +3287,83 @@ impl Parser<'_> { // -> Pointer(Array(3, int)) // For void (*fp)(int): inner_decl is Pointer(Void), result_type is Function(void, [int]) // -> Pointer(Function(void, [int])) - result_type = Self::substitute_base_type(inner_type.unwrap(), result_type); + result_type_id = self.substitute_base_type(inner_type_id.unwrap(), result_type_id); } else { // Simple declarator: char *arr[3] // Pointers bind tighter than arrays: *arr[3] = array of pointers // Apply pointer modifiers to base type first for modifiers in pointer_modifiers.into_iter().rev() { - result_type = Type { + let ptr_type = Type { kind: TypeKind::Pointer, modifiers, - base: Some(Box::new(result_type)), + base: Some(result_type_id), array_size: None, params: None, variadic: false, composite: None, }; + result_type_id = self.types.intern(ptr_type); } // Then apply array dimensions // For char *arr[3]: result_type is char*, suffix [3] -> Array(3, char*) for size in dimensions.into_iter().rev() { - result_type = Type { + let arr_type = Type { kind: TypeKind::Array, modifiers: TypeModifiers::empty(), - base: Some(Box::new(result_type)), + base: Some(result_type_id), array_size: size, params: None, variadic: false, composite: None, }; + result_type_id = self.types.intern(arr_type); } } - Ok((name, result_type)) + Ok((name, result_type_id)) } /// Substitute the actual base type into a declarator parsed with a placeholder /// For int (*p)[3]: inner_decl is Pointer(Void), actual_base is Array(3, int) /// Result should be Pointer(Array(3, int)) - fn substitute_base_type(decl_type: Type, actual_base: Type) -> Type { + fn substitute_base_type(&mut self, decl_type_id: TypeId, actual_base_id: TypeId) -> TypeId { + let decl_type = self.types.get(decl_type_id); match decl_type.kind { - TypeKind::Void => actual_base, - TypeKind::Pointer => Type { - kind: TypeKind::Pointer, - modifiers: decl_type.modifiers, - base: Some(Box::new(Self::substitute_base_type( - *decl_type.base.unwrap(), - actual_base, - ))), - array_size: None, - params: None, - variadic: false, - composite: None, - }, - TypeKind::Array => Type { - kind: TypeKind::Array, - modifiers: decl_type.modifiers, - base: Some(Box::new(Self::substitute_base_type( - *decl_type.base.unwrap(), - actual_base, - ))), - array_size: decl_type.array_size, - params: None, - variadic: false, - composite: None, - }, - _ => decl_type, // Other types don't need substitution + TypeKind::Void => actual_base_id, + TypeKind::Pointer => { + let inner_base_id = decl_type.base.unwrap(); + let decl_modifiers = decl_type.modifiers; + let new_base_id = self.substitute_base_type(inner_base_id, actual_base_id); + let ptr_type = Type { + kind: TypeKind::Pointer, + modifiers: decl_modifiers, + base: Some(new_base_id), + array_size: None, + params: None, + variadic: false, + composite: None, + }; + self.types.intern(ptr_type) + } + TypeKind::Array => { + let inner_base_id = decl_type.base.unwrap(); + let decl_modifiers = decl_type.modifiers; + let decl_array_size = decl_type.array_size; + let new_base_id = self.substitute_base_type(inner_base_id, actual_base_id); + let arr_type = Type { + kind: TypeKind::Array, + modifiers: decl_modifiers, + base: Some(new_base_id), + array_size: decl_array_size, + params: None, + variadic: false, + composite: None, + }; + self.types.intern(arr_type) + } + _ => decl_type_id, // Other types don't need substitution } } @@ -3302,9 +3377,9 @@ impl Parser<'_> { let func_pos = self.current_pos(); // Parse return type let return_type = self.parse_type_specifier()?; + let mut ret_type_id = self.types.intern(return_type); // Handle pointer in return type with qualifiers - let mut ret_type = return_type; while self.is_special(b'*') { self.advance(); let mut ptr_modifiers = TypeModifiers::empty(); @@ -3332,15 +3407,16 @@ impl Parser<'_> { } } - ret_type = Type { + let ptr_type = Type { kind: TypeKind::Pointer, modifiers: ptr_modifiers, - base: Some(Box::new(ret_type)), + base: Some(ret_type_id), array_size: None, params: None, variadic: false, composite: None, }; + ret_type_id = self.types.intern(ptr_type); } // Parse function name @@ -3352,12 +3428,13 @@ impl Parser<'_> { self.expect_special(b')')?; // Build the function type - let param_types: Vec = params.iter().map(|p| p.typ.clone()).collect(); - let func_type = Type::function(ret_type.clone(), param_types, variadic); + let param_type_ids: Vec = params.iter().map(|p| p.typ).collect(); + let func_type = Type::function(ret_type_id, param_type_ids, variadic); + let func_type_id = self.types.intern(func_type); // Bind function to symbol table at current (global) scope // Like sparse's bind_symbol() in parse.c - let func_sym = Symbol::function(name.clone(), func_type, self.symbols.depth()); + let func_sym = Symbol::function(name.clone(), func_type_id, self.symbols.depth()); let _ = self.symbols.declare(func_sym); // Ignore redefinition errors for now // Enter function scope for parameters and body @@ -3367,7 +3444,7 @@ impl Parser<'_> { for param in ¶ms { if let Some(param_name) = ¶m.name { let param_sym = - Symbol::parameter(param_name.clone(), param.typ.clone(), self.symbols.depth()); + Symbol::parameter(param_name.clone(), param.typ, self.symbols.depth()); let _ = self.symbols.declare(param_sym); } } @@ -3379,7 +3456,7 @@ impl Parser<'_> { self.symbols.leave_scope(); Ok(FunctionDef { - return_type: ret_type, + return_type: ret_type_id, name, params, body, @@ -3421,9 +3498,9 @@ impl Parser<'_> { // Parse parameter type let param_type = self.parse_type_specifier()?; + let mut typ_id = self.types.intern(param_type); // Handle pointer with qualifiers (const, volatile, restrict) - let mut typ = param_type; while self.is_special(b'*') { self.advance(); let mut ptr_modifiers = TypeModifiers::empty(); @@ -3451,15 +3528,16 @@ impl Parser<'_> { } } - typ = Type { + let ptr_type = Type { kind: TypeKind::Pointer, modifiers: ptr_modifiers, - base: Some(Box::new(typ)), + base: Some(typ_id), array_size: None, params: None, variadic: false, composite: None, }; + typ_id = self.types.intern(ptr_type); } // Parse optional parameter name @@ -3491,20 +3569,21 @@ impl Parser<'_> { self.expect_special(b']')?; // Convert array to pointer - typ = Type { + let ptr_type = Type { kind: TypeKind::Pointer, modifiers: TypeModifiers::empty(), - base: Some(Box::new(typ)), + base: Some(typ_id), array_size: None, params: None, variadic: false, composite: None, }; + typ_id = self.types.intern(ptr_type); } params.push(Parameter { name: param_name, - typ, + typ: typ_id, }); if self.is_special(b',') { @@ -3538,6 +3617,9 @@ impl Parser<'_> { let decl_pos = self.current_pos(); // Parse type specifier let base_type = self.parse_type_specifier()?; + // Check modifiers before interning (storage class specifiers) + let is_typedef = base_type.modifiers.contains(TypeModifiers::TYPEDEF); + let base_type_id = self.types.intern(base_type); // Check for standalone type definition (e.g., "enum Color { ... };") // This happens when a composite type is defined but no variables are declared @@ -3558,13 +3640,12 @@ impl Parser<'_> { if self.is_special(b'*') { // This is a grouped declarator - use parse_declarator self.pos = saved_pos; // restore position before '(' - let (name, typ) = self.parse_declarator(base_type.clone())?; + let (name, typ) = self.parse_declarator(base_type_id)?; // Skip any __attribute__ after declarator self.skip_extensions(); // Handle initializer - let is_typedef = base_type.modifiers.contains(TypeModifiers::TYPEDEF); let init = if self.is_special(b'=') { if is_typedef { return Err(ParseError::new( @@ -3582,12 +3663,10 @@ impl Parser<'_> { // Add to symbol table if is_typedef { - let mut aliased_type = typ.clone(); - aliased_type.modifiers.remove(TypeModifiers::TYPEDEF); - let sym = Symbol::typedef(name.clone(), aliased_type, self.symbols.depth()); + let sym = Symbol::typedef(name.clone(), typ, self.symbols.depth()); let _ = self.symbols.declare(sym); } else { - let var_sym = Symbol::variable(name.clone(), typ.clone(), self.symbols.depth()); + let var_sym = Symbol::variable(name.clone(), typ, self.symbols.depth()); let _ = self.symbols.declare(var_sym); } @@ -3600,7 +3679,7 @@ impl Parser<'_> { } // Handle pointer with qualifiers (const, volatile, restrict) - let mut typ = base_type.clone(); + let mut typ_id = base_type_id; while self.is_special(b'*') { self.advance(); let mut ptr_modifiers = TypeModifiers::empty(); @@ -3628,32 +3707,32 @@ impl Parser<'_> { } } - typ = Type { + let ptr_type = Type { kind: TypeKind::Pointer, modifiers: ptr_modifiers, - base: Some(Box::new(typ)), + base: Some(typ_id), array_size: None, params: None, variadic: false, composite: None, }; + typ_id = self.types.intern(ptr_type); } // Check again for grouped declarator after pointer modifiers: char *(*fp)(int) - // At this point typ is char*, and we see (*fp)(int) + // At this point typ_id is char*, and we see (*fp)(int) if self.is_special(b'(') { let saved_pos = self.pos; self.advance(); // consume '(' if self.is_special(b'*') { // This is a grouped declarator - use parse_declarator self.pos = saved_pos; // restore position before '(' - let (name, full_typ) = self.parse_declarator(typ)?; + let (name, full_typ) = self.parse_declarator(typ_id)?; // Skip any __attribute__ after declarator self.skip_extensions(); // Handle initializer - let is_typedef = base_type.modifiers.contains(TypeModifiers::TYPEDEF); let init = if self.is_special(b'=') { if is_typedef { return Err(ParseError::new( @@ -3671,13 +3750,10 @@ impl Parser<'_> { // Add to symbol table if is_typedef { - let mut aliased_type = full_typ.clone(); - aliased_type.modifiers.remove(TypeModifiers::TYPEDEF); - let sym = Symbol::typedef(name.clone(), aliased_type, self.symbols.depth()); + let sym = Symbol::typedef(name.clone(), full_typ, self.symbols.depth()); let _ = self.symbols.declare(sym); } else { - let var_sym = - Symbol::variable(name.clone(), full_typ.clone(), self.symbols.depth()); + let var_sym = Symbol::variable(name.clone(), full_typ, self.symbols.depth()); let _ = self.symbols.declare(var_sym); } @@ -3709,12 +3785,10 @@ impl Parser<'_> { if self.is_special(b'{') { // Function definition // Add function to symbol table so it can be called by other functions - let func_type = Type::function( - typ.clone(), - params.iter().map(|p| p.typ.clone()).collect(), - variadic, - ); - let func_sym = Symbol::function(name.clone(), func_type, self.symbols.depth()); + let param_type_ids: Vec = params.iter().map(|p| p.typ).collect(); + let func_type = Type::function(typ_id, param_type_ids.clone(), variadic); + let func_type_id = self.types.intern(func_type); + let func_sym = Symbol::function(name.clone(), func_type_id, self.symbols.depth()); let _ = self.symbols.declare(func_sym); // Enter function scope for parameters @@ -3723,11 +3797,8 @@ impl Parser<'_> { // Bind parameters in function scope for param in ¶ms { if let Some(param_name) = ¶m.name { - let param_sym = Symbol::parameter( - param_name.clone(), - param.typ.clone(), - self.symbols.depth(), - ); + let param_sym = + Symbol::parameter(param_name.clone(), param.typ, self.symbols.depth()); let _ = self.symbols.declare(param_sym); } } @@ -3739,7 +3810,7 @@ impl Parser<'_> { self.symbols.leave_scope(); return Ok(ExternalDecl::FunctionDef(FunctionDef { - return_type: typ, + return_type: typ_id, name, params, body, @@ -3748,34 +3819,28 @@ impl Parser<'_> { } else { // Function declaration self.expect_special(b';')?; - let func_type = Type::function( - typ, - params.iter().map(|p| p.typ.clone()).collect(), - variadic, - ); + let param_type_ids: Vec = params.iter().map(|p| p.typ).collect(); + let func_type = Type::function(typ_id, param_type_ids, variadic); + let func_type_id = self.types.intern(func_type); // Add function declaration to symbol table so the variadic flag // is available when the function is called - let func_sym = - Symbol::function(name.clone(), func_type.clone(), self.symbols.depth()); + let func_sym = Symbol::function(name.clone(), func_type_id, self.symbols.depth()); let _ = self.symbols.declare(func_sym); return Ok(ExternalDecl::Declaration(Declaration { declarators: vec![InitDeclarator { name, - typ: func_type, + typ: func_type_id, init: None, }], })); } } - // Check if this is a typedef declaration - let is_typedef = base_type.modifiers.contains(TypeModifiers::TYPEDEF); - // Variable/typedef declaration let mut declarators = Vec::new(); // Handle array - collect dimensions first, build type from right to left - let mut var_type = typ; + let mut var_type_id = typ_id; let mut dimensions: Vec> = Vec::new(); while self.is_special(b'[') { self.advance(); @@ -3793,7 +3858,8 @@ impl Parser<'_> { } // Build type from right to left (innermost dimension first) for size in dimensions.into_iter().rev() { - var_type = Type::array(var_type, size.unwrap_or(0)); + let arr_type = Type::array(var_type_id, size.unwrap_or(0)); + var_type_id = self.types.intern(arr_type); } // Skip any __attribute__ after variable name/array declarator @@ -3815,27 +3881,24 @@ impl Parser<'_> { // Add to symbol table if is_typedef { - // Strip TYPEDEF modifier from the aliased type - let mut aliased_type = var_type.clone(); - aliased_type.modifiers.remove(TypeModifiers::TYPEDEF); - let sym = Symbol::typedef(name.clone(), aliased_type, self.symbols.depth()); + let sym = Symbol::typedef(name.clone(), var_type_id, self.symbols.depth()); let _ = self.symbols.declare(sym); } else { // Add global variable to symbol table so it can be referenced by later code - let var_sym = Symbol::variable(name.clone(), var_type.clone(), self.symbols.depth()); + let var_sym = Symbol::variable(name.clone(), var_type_id, self.symbols.depth()); let _ = self.symbols.declare(var_sym); } declarators.push(InitDeclarator { name, - typ: var_type, + typ: var_type_id, init, }); // Handle additional declarators while self.is_special(b',') { self.advance(); - let (decl_name, decl_type) = self.parse_declarator(base_type.clone())?; + let (decl_name, decl_type) = self.parse_declarator(base_type_id)?; let decl_init = if self.is_special(b'=') { if is_typedef { return Err(ParseError::new( @@ -3850,13 +3913,10 @@ impl Parser<'_> { }; // Add to symbol table if is_typedef { - let mut aliased_type = decl_type.clone(); - aliased_type.modifiers.remove(TypeModifiers::TYPEDEF); - let sym = Symbol::typedef(decl_name.clone(), aliased_type, self.symbols.depth()); + let sym = Symbol::typedef(decl_name.clone(), decl_type, self.symbols.depth()); let _ = self.symbols.declare(sym); } else { - let var_sym = - Symbol::variable(decl_name.clone(), decl_type.clone(), self.symbols.depth()); + let var_sym = Symbol::variable(decl_name.clone(), decl_type, self.symbols.depth()); let _ = self.symbols.declare(var_sym); } declarators.push(InitDeclarator { @@ -3958,10 +4018,11 @@ impl Parser<'_> { } /// Validate a bitfield declaration - fn validate_bitfield(&self, typ: &Type, width: u32) -> ParseResult<()> { + fn validate_bitfield(&self, typ_id: TypeId, width: u32) -> ParseResult<()> { // Check allowed types: _Bool, int, unsigned int (and their signed/unsigned variants) + let kind = self.types.kind(typ_id); let valid_type = matches!( - typ.kind, + kind, TypeKind::Bool | TypeKind::Int | TypeKind::Char | TypeKind::Short | TypeKind::Long ); @@ -3973,7 +4034,7 @@ impl Parser<'_> { } // Check that width doesn't exceed type size - let max_width = typ.size_bits(); + let max_width = self.types.size_bits(typ_id); if width > max_width { return Err(ParseError::new( format!("bitfield width {} exceeds type size {}", width, max_width), @@ -3995,14 +4056,16 @@ mod tests { use crate::symbol::SymbolTable; use crate::token::lexer::Tokenizer; - fn parse_expr(input: &str) -> ParseResult { + fn parse_expr(input: &str) -> ParseResult<(Expr, TypeTable)> { let mut tokenizer = Tokenizer::new(input.as_bytes(), 0); let tokens = tokenizer.tokenize(); let idents = tokenizer.ident_table(); let mut symbols = SymbolTable::new(); - let mut parser = Parser::new(&tokens, idents, &mut symbols); + let mut types = TypeTable::new(); + let mut parser = Parser::new(&tokens, idents, &mut symbols, &mut types); parser.skip_stream_tokens(); - parser.parse_expression() + let expr = parser.parse_expression()?; + Ok((expr, types)) } // ======================================================================== @@ -4011,25 +4074,25 @@ mod tests { #[test] fn test_int_literal() { - let expr = parse_expr("42").unwrap(); + let (expr, _types) = parse_expr("42").unwrap(); assert!(matches!(expr.kind, ExprKind::IntLit(42))); } #[test] fn test_hex_literal() { - let expr = parse_expr("0xFF").unwrap(); + let (expr, _types) = parse_expr("0xFF").unwrap(); assert!(matches!(expr.kind, ExprKind::IntLit(255))); } #[test] fn test_octal_literal() { - let expr = parse_expr("0777").unwrap(); + let (expr, _types) = parse_expr("0777").unwrap(); assert!(matches!(expr.kind, ExprKind::IntLit(511))); } #[test] fn test_float_literal() { - let expr = parse_expr("3.14").unwrap(); + let (expr, _types) = parse_expr("3.14").unwrap(); match expr.kind { ExprKind::FloatLit(v) => assert!((v - 3.14).abs() < 0.001), _ => panic!("Expected FloatLit"), @@ -4038,19 +4101,19 @@ mod tests { #[test] fn test_char_literal() { - let expr = parse_expr("'a'").unwrap(); + let (expr, _types) = parse_expr("'a'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('a'))); } #[test] fn test_char_escape() { - let expr = parse_expr("'\\n'").unwrap(); + let (expr, _types) = parse_expr("'\\n'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\n'))); } #[test] fn test_string_literal() { - let expr = parse_expr("\"hello\"").unwrap(); + let (expr, _types) = parse_expr("\"hello\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "hello"), _ => panic!("Expected StringLit"), @@ -4063,91 +4126,91 @@ mod tests { #[test] fn test_char_escape_newline() { - let expr = parse_expr("'\\n'").unwrap(); + let (expr, _types) = parse_expr("'\\n'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\n'))); } #[test] fn test_char_escape_tab() { - let expr = parse_expr("'\\t'").unwrap(); + let (expr, _types) = parse_expr("'\\t'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\t'))); } #[test] fn test_char_escape_carriage_return() { - let expr = parse_expr("'\\r'").unwrap(); + let (expr, _types) = parse_expr("'\\r'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\r'))); } #[test] fn test_char_escape_backslash() { - let expr = parse_expr("'\\\\'").unwrap(); + let (expr, _types) = parse_expr("'\\\\'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\\'))); } #[test] fn test_char_escape_single_quote() { - let expr = parse_expr("'\\''").unwrap(); + let (expr, _types) = parse_expr("'\\''").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\''))); } #[test] fn test_char_escape_double_quote() { - let expr = parse_expr("'\\\"'").unwrap(); + let (expr, _types) = parse_expr("'\\\"'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('"'))); } #[test] fn test_char_escape_bell() { - let expr = parse_expr("'\\a'").unwrap(); + let (expr, _types) = parse_expr("'\\a'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\x07'))); } #[test] fn test_char_escape_backspace() { - let expr = parse_expr("'\\b'").unwrap(); + let (expr, _types) = parse_expr("'\\b'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\x08'))); } #[test] fn test_char_escape_formfeed() { - let expr = parse_expr("'\\f'").unwrap(); + let (expr, _types) = parse_expr("'\\f'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\x0C'))); } #[test] fn test_char_escape_vertical_tab() { - let expr = parse_expr("'\\v'").unwrap(); + let (expr, _types) = parse_expr("'\\v'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\x0B'))); } #[test] fn test_char_escape_null() { - let expr = parse_expr("'\\0'").unwrap(); + let (expr, _types) = parse_expr("'\\0'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\0'))); } #[test] fn test_char_escape_hex() { - let expr = parse_expr("'\\x41'").unwrap(); + let (expr, _types) = parse_expr("'\\x41'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('A'))); } #[test] fn test_char_escape_hex_lowercase() { - let expr = parse_expr("'\\x0a'").unwrap(); + let (expr, _types) = parse_expr("'\\x0a'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\n'))); } #[test] fn test_char_escape_octal() { - let expr = parse_expr("'\\101'").unwrap(); + let (expr, _types) = parse_expr("'\\101'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('A'))); // octal 101 = 65 = 'A' } #[test] fn test_char_escape_octal_012() { - let expr = parse_expr("'\\012'").unwrap(); + let (expr, _types) = parse_expr("'\\012'").unwrap(); assert!(matches!(expr.kind, ExprKind::CharLit('\n'))); // octal 012 = 10 = '\n' } @@ -4157,7 +4220,7 @@ mod tests { #[test] fn test_string_escape_newline() { - let expr = parse_expr("\"hello\\nworld\"").unwrap(); + let (expr, _types) = parse_expr("\"hello\\nworld\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "hello\nworld"), _ => panic!("Expected StringLit"), @@ -4166,7 +4229,7 @@ mod tests { #[test] fn test_string_escape_tab() { - let expr = parse_expr("\"hello\\tworld\"").unwrap(); + let (expr, _types) = parse_expr("\"hello\\tworld\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "hello\tworld"), _ => panic!("Expected StringLit"), @@ -4175,7 +4238,7 @@ mod tests { #[test] fn test_string_escape_carriage_return() { - let expr = parse_expr("\"hello\\rworld\"").unwrap(); + let (expr, _types) = parse_expr("\"hello\\rworld\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "hello\rworld"), _ => panic!("Expected StringLit"), @@ -4184,7 +4247,7 @@ mod tests { #[test] fn test_string_escape_backslash() { - let expr = parse_expr("\"hello\\\\world\"").unwrap(); + let (expr, _types) = parse_expr("\"hello\\\\world\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "hello\\world"), _ => panic!("Expected StringLit"), @@ -4193,7 +4256,7 @@ mod tests { #[test] fn test_string_escape_double_quote() { - let expr = parse_expr("\"hello\\\"world\"").unwrap(); + let (expr, _types) = parse_expr("\"hello\\\"world\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "hello\"world"), _ => panic!("Expected StringLit"), @@ -4202,7 +4265,7 @@ mod tests { #[test] fn test_string_escape_bell() { - let expr = parse_expr("\"\\a\"").unwrap(); + let (expr, _types) = parse_expr("\"\\a\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "\x07"), _ => panic!("Expected StringLit"), @@ -4211,7 +4274,7 @@ mod tests { #[test] fn test_string_escape_backspace() { - let expr = parse_expr("\"\\b\"").unwrap(); + let (expr, _types) = parse_expr("\"\\b\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "\x08"), _ => panic!("Expected StringLit"), @@ -4220,7 +4283,7 @@ mod tests { #[test] fn test_string_escape_formfeed() { - let expr = parse_expr("\"\\f\"").unwrap(); + let (expr, _types) = parse_expr("\"\\f\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "\x0C"), _ => panic!("Expected StringLit"), @@ -4229,7 +4292,7 @@ mod tests { #[test] fn test_string_escape_vertical_tab() { - let expr = parse_expr("\"\\v\"").unwrap(); + let (expr, _types) = parse_expr("\"\\v\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "\x0B"), _ => panic!("Expected StringLit"), @@ -4238,7 +4301,7 @@ mod tests { #[test] fn test_string_escape_null() { - let expr = parse_expr("\"hello\\0world\"").unwrap(); + let (expr, _types) = parse_expr("\"hello\\0world\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => { assert_eq!(s.len(), 11); @@ -4250,7 +4313,7 @@ mod tests { #[test] fn test_string_escape_hex() { - let expr = parse_expr("\"\\x41\\x42\\x43\"").unwrap(); + let (expr, _types) = parse_expr("\"\\x41\\x42\\x43\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "ABC"), _ => panic!("Expected StringLit"), @@ -4259,7 +4322,7 @@ mod tests { #[test] fn test_string_escape_octal() { - let expr = parse_expr("\"\\101\\102\\103\"").unwrap(); + let (expr, _types) = parse_expr("\"\\101\\102\\103\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "ABC"), // octal 101,102,103 = A,B,C _ => panic!("Expected StringLit"), @@ -4268,7 +4331,7 @@ mod tests { #[test] fn test_string_escape_octal_012() { - let expr = parse_expr("\"line1\\012line2\"").unwrap(); + let (expr, _types) = parse_expr("\"line1\\012line2\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "line1\nline2"), // octal 012 = newline _ => panic!("Expected StringLit"), @@ -4277,7 +4340,7 @@ mod tests { #[test] fn test_string_multiple_escapes() { - let expr = parse_expr("\"\\t\\n\\r\\\\\"").unwrap(); + let (expr, _types) = parse_expr("\"\\t\\n\\r\\\\\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "\t\n\r\\"), _ => panic!("Expected StringLit"), @@ -4286,7 +4349,7 @@ mod tests { #[test] fn test_string_mixed_content() { - let expr = parse_expr("\"Name:\\tJohn\\nAge:\\t30\"").unwrap(); + let (expr, _types) = parse_expr("\"Name:\\tJohn\\nAge:\\t30\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, "Name:\tJohn\nAge:\t30"), _ => panic!("Expected StringLit"), @@ -4295,7 +4358,7 @@ mod tests { #[test] fn test_string_empty() { - let expr = parse_expr("\"\"").unwrap(); + let (expr, _types) = parse_expr("\"\"").unwrap(); match expr.kind { ExprKind::StringLit(s) => assert_eq!(s, ""), _ => panic!("Expected StringLit"), @@ -4305,58 +4368,58 @@ mod tests { #[test] fn test_integer_literal_suffixes() { // Plain int - let expr = parse_expr("42").unwrap(); - assert_eq!(expr.typ.as_ref().unwrap().kind, TypeKind::Int); - assert!(!expr.typ.as_ref().unwrap().is_unsigned()); + let (expr, types) = parse_expr("42").unwrap(); + assert_eq!(types.kind(expr.typ.unwrap()), TypeKind::Int); + assert!(!types.is_unsigned(expr.typ.unwrap())); // Unsigned int - let expr = parse_expr("42U").unwrap(); - assert_eq!(expr.typ.as_ref().unwrap().kind, TypeKind::Int); - assert!(expr.typ.as_ref().unwrap().is_unsigned()); + let (expr, types) = parse_expr("42U").unwrap(); + assert_eq!(types.kind(expr.typ.unwrap()), TypeKind::Int); + assert!(types.is_unsigned(expr.typ.unwrap())); // Long - let expr = parse_expr("42L").unwrap(); - assert_eq!(expr.typ.as_ref().unwrap().kind, TypeKind::Long); - assert!(!expr.typ.as_ref().unwrap().is_unsigned()); + let (expr, types) = parse_expr("42L").unwrap(); + assert_eq!(types.kind(expr.typ.unwrap()), TypeKind::Long); + assert!(!types.is_unsigned(expr.typ.unwrap())); // Unsigned long (UL) - let expr = parse_expr("42UL").unwrap(); - assert_eq!(expr.typ.as_ref().unwrap().kind, TypeKind::Long); - assert!(expr.typ.as_ref().unwrap().is_unsigned()); + let (expr, types) = parse_expr("42UL").unwrap(); + assert_eq!(types.kind(expr.typ.unwrap()), TypeKind::Long); + assert!(types.is_unsigned(expr.typ.unwrap())); // Unsigned long (LU) - let expr = parse_expr("42LU").unwrap(); - assert_eq!(expr.typ.as_ref().unwrap().kind, TypeKind::Long); - assert!(expr.typ.as_ref().unwrap().is_unsigned()); + let (expr, types) = parse_expr("42LU").unwrap(); + assert_eq!(types.kind(expr.typ.unwrap()), TypeKind::Long); + assert!(types.is_unsigned(expr.typ.unwrap())); // Long long - let expr = parse_expr("42LL").unwrap(); - assert_eq!(expr.typ.as_ref().unwrap().kind, TypeKind::LongLong); - assert!(!expr.typ.as_ref().unwrap().is_unsigned()); + let (expr, types) = parse_expr("42LL").unwrap(); + assert_eq!(types.kind(expr.typ.unwrap()), TypeKind::LongLong); + assert!(!types.is_unsigned(expr.typ.unwrap())); // Unsigned long long (ULL) - let expr = parse_expr("42ULL").unwrap(); - assert_eq!(expr.typ.as_ref().unwrap().kind, TypeKind::LongLong); - assert!(expr.typ.as_ref().unwrap().is_unsigned()); + let (expr, types) = parse_expr("42ULL").unwrap(); + assert_eq!(types.kind(expr.typ.unwrap()), TypeKind::LongLong); + assert!(types.is_unsigned(expr.typ.unwrap())); // Unsigned long long (LLU) - let expr = parse_expr("42LLU").unwrap(); - assert_eq!(expr.typ.as_ref().unwrap().kind, TypeKind::LongLong); - assert!(expr.typ.as_ref().unwrap().is_unsigned()); + let (expr, types) = parse_expr("42LLU").unwrap(); + assert_eq!(types.kind(expr.typ.unwrap()), TypeKind::LongLong); + assert!(types.is_unsigned(expr.typ.unwrap())); // Hex with suffix - let expr = parse_expr("0xFFLL").unwrap(); - assert_eq!(expr.typ.as_ref().unwrap().kind, TypeKind::LongLong); - assert!(!expr.typ.as_ref().unwrap().is_unsigned()); + let (expr, types) = parse_expr("0xFFLL").unwrap(); + assert_eq!(types.kind(expr.typ.unwrap()), TypeKind::LongLong); + assert!(!types.is_unsigned(expr.typ.unwrap())); - let expr = parse_expr("0xFFULL").unwrap(); - assert_eq!(expr.typ.as_ref().unwrap().kind, TypeKind::LongLong); - assert!(expr.typ.as_ref().unwrap().is_unsigned()); + let (expr, types) = parse_expr("0xFFULL").unwrap(); + assert_eq!(types.kind(expr.typ.unwrap()), TypeKind::LongLong); + assert!(types.is_unsigned(expr.typ.unwrap())); } #[test] fn test_identifier() { - let expr = parse_expr("foo").unwrap(); + let (expr, _types) = parse_expr("foo").unwrap(); match expr.kind { ExprKind::Ident { name, .. } => assert_eq!(name, "foo"), _ => panic!("Expected Ident"), @@ -4369,7 +4432,7 @@ mod tests { #[test] fn test_addition() { - let expr = parse_expr("1 + 2").unwrap(); + let (expr, _types) = parse_expr("1 + 2").unwrap(); match expr.kind { ExprKind::Binary { op, left, right } => { assert_eq!(op, BinaryOp::Add); @@ -4382,7 +4445,7 @@ mod tests { #[test] fn test_subtraction() { - let expr = parse_expr("5 - 3").unwrap(); + let (expr, _types) = parse_expr("5 - 3").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Sub), _ => panic!("Expected Binary"), @@ -4391,7 +4454,7 @@ mod tests { #[test] fn test_multiplication() { - let expr = parse_expr("2 * 3").unwrap(); + let (expr, _types) = parse_expr("2 * 3").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Mul), _ => panic!("Expected Binary"), @@ -4400,7 +4463,7 @@ mod tests { #[test] fn test_division() { - let expr = parse_expr("10 / 2").unwrap(); + let (expr, _types) = parse_expr("10 / 2").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Div), _ => panic!("Expected Binary"), @@ -4409,7 +4472,7 @@ mod tests { #[test] fn test_modulo() { - let expr = parse_expr("10 % 3").unwrap(); + let (expr, _types) = parse_expr("10 % 3").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Mod), _ => panic!("Expected Binary"), @@ -4419,7 +4482,7 @@ mod tests { #[test] fn test_precedence_mul_add() { // 1 + 2 * 3 should be 1 + (2 * 3) - let expr = parse_expr("1 + 2 * 3").unwrap(); + let (expr, _types) = parse_expr("1 + 2 * 3").unwrap(); match expr.kind { ExprKind::Binary { op, left, right } => { assert_eq!(op, BinaryOp::Add); @@ -4436,7 +4499,7 @@ mod tests { #[test] fn test_left_associativity() { // 1 - 2 - 3 should be (1 - 2) - 3 - let expr = parse_expr("1 - 2 - 3").unwrap(); + let (expr, _types) = parse_expr("1 - 2 - 3").unwrap(); match expr.kind { ExprKind::Binary { op, left, right } => { assert_eq!(op, BinaryOp::Sub); @@ -4456,25 +4519,25 @@ mod tests { #[test] fn test_comparison_ops() { - let expr = parse_expr("a < b").unwrap(); + let (expr, _types) = parse_expr("a < b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Lt), _ => panic!("Expected Binary"), } - let expr = parse_expr("a > b").unwrap(); + let (expr, _types) = parse_expr("a > b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Gt), _ => panic!("Expected Binary"), } - let expr = parse_expr("a <= b").unwrap(); + let (expr, _types) = parse_expr("a <= b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Le), _ => panic!("Expected Binary"), } - let expr = parse_expr("a >= b").unwrap(); + let (expr, _types) = parse_expr("a >= b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Ge), _ => panic!("Expected Binary"), @@ -4483,13 +4546,13 @@ mod tests { #[test] fn test_equality_ops() { - let expr = parse_expr("a == b").unwrap(); + let (expr, _types) = parse_expr("a == b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Eq), _ => panic!("Expected Binary"), } - let expr = parse_expr("a != b").unwrap(); + let (expr, _types) = parse_expr("a != b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Ne), _ => panic!("Expected Binary"), @@ -4498,13 +4561,13 @@ mod tests { #[test] fn test_logical_ops() { - let expr = parse_expr("a && b").unwrap(); + let (expr, _types) = parse_expr("a && b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::LogAnd), _ => panic!("Expected Binary"), } - let expr = parse_expr("a || b").unwrap(); + let (expr, _types) = parse_expr("a || b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::LogOr), _ => panic!("Expected Binary"), @@ -4513,19 +4576,19 @@ mod tests { #[test] fn test_bitwise_ops() { - let expr = parse_expr("a & b").unwrap(); + let (expr, _types) = parse_expr("a & b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::BitAnd), _ => panic!("Expected Binary"), } - let expr = parse_expr("a | b").unwrap(); + let (expr, _types) = parse_expr("a | b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::BitOr), _ => panic!("Expected Binary"), } - let expr = parse_expr("a ^ b").unwrap(); + let (expr, _types) = parse_expr("a ^ b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::BitXor), _ => panic!("Expected Binary"), @@ -4534,13 +4597,13 @@ mod tests { #[test] fn test_shift_ops() { - let expr = parse_expr("a << b").unwrap(); + let (expr, _types) = parse_expr("a << b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Shl), _ => panic!("Expected Binary"), } - let expr = parse_expr("a >> b").unwrap(); + let (expr, _types) = parse_expr("a >> b").unwrap(); match expr.kind { ExprKind::Binary { op, .. } => assert_eq!(op, BinaryOp::Shr), _ => panic!("Expected Binary"), @@ -4553,7 +4616,7 @@ mod tests { #[test] fn test_unary_neg() { - let expr = parse_expr("-x").unwrap(); + let (expr, _types) = parse_expr("-x").unwrap(); match expr.kind { ExprKind::Unary { op, operand } => { assert_eq!(op, UnaryOp::Neg); @@ -4568,7 +4631,7 @@ mod tests { #[test] fn test_unary_not() { - let expr = parse_expr("!x").unwrap(); + let (expr, _types) = parse_expr("!x").unwrap(); match expr.kind { ExprKind::Unary { op, .. } => assert_eq!(op, UnaryOp::Not), _ => panic!("Expected Unary"), @@ -4577,7 +4640,7 @@ mod tests { #[test] fn test_unary_bitnot() { - let expr = parse_expr("~x").unwrap(); + let (expr, _types) = parse_expr("~x").unwrap(); match expr.kind { ExprKind::Unary { op, .. } => assert_eq!(op, UnaryOp::BitNot), _ => panic!("Expected Unary"), @@ -4586,7 +4649,7 @@ mod tests { #[test] fn test_unary_addr() { - let expr = parse_expr("&x").unwrap(); + let (expr, _types) = parse_expr("&x").unwrap(); match expr.kind { ExprKind::Unary { op, .. } => assert_eq!(op, UnaryOp::AddrOf), _ => panic!("Expected Unary"), @@ -4595,7 +4658,7 @@ mod tests { #[test] fn test_unary_deref() { - let expr = parse_expr("*p").unwrap(); + let (expr, _types) = parse_expr("*p").unwrap(); match expr.kind { ExprKind::Unary { op, .. } => assert_eq!(op, UnaryOp::Deref), _ => panic!("Expected Unary"), @@ -4604,7 +4667,7 @@ mod tests { #[test] fn test_pre_increment() { - let expr = parse_expr("++x").unwrap(); + let (expr, _types) = parse_expr("++x").unwrap(); match expr.kind { ExprKind::Unary { op, .. } => assert_eq!(op, UnaryOp::PreInc), _ => panic!("Expected Unary"), @@ -4613,7 +4676,7 @@ mod tests { #[test] fn test_pre_decrement() { - let expr = parse_expr("--x").unwrap(); + let (expr, _types) = parse_expr("--x").unwrap(); match expr.kind { ExprKind::Unary { op, .. } => assert_eq!(op, UnaryOp::PreDec), _ => panic!("Expected Unary"), @@ -4626,19 +4689,19 @@ mod tests { #[test] fn test_post_increment() { - let expr = parse_expr("x++").unwrap(); + let (expr, _types) = parse_expr("x++").unwrap(); assert!(matches!(expr.kind, ExprKind::PostInc(_))); } #[test] fn test_post_decrement() { - let expr = parse_expr("x--").unwrap(); + let (expr, _types) = parse_expr("x--").unwrap(); assert!(matches!(expr.kind, ExprKind::PostDec(_))); } #[test] fn test_array_subscript() { - let expr = parse_expr("arr[5]").unwrap(); + let (expr, _types) = parse_expr("arr[5]").unwrap(); match expr.kind { ExprKind::Index { array, index } => { match array.kind { @@ -4653,7 +4716,7 @@ mod tests { #[test] fn test_member_access() { - let expr = parse_expr("obj.field").unwrap(); + let (expr, _types) = parse_expr("obj.field").unwrap(); match expr.kind { ExprKind::Member { expr, member } => { match expr.kind { @@ -4668,7 +4731,7 @@ mod tests { #[test] fn test_arrow_access() { - let expr = parse_expr("ptr->field").unwrap(); + let (expr, _types) = parse_expr("ptr->field").unwrap(); match expr.kind { ExprKind::Arrow { expr, member } => { match expr.kind { @@ -4683,7 +4746,7 @@ mod tests { #[test] fn test_function_call_no_args() { - let expr = parse_expr("foo()").unwrap(); + let (expr, _types) = parse_expr("foo()").unwrap(); match expr.kind { ExprKind::Call { func, args } => { match func.kind { @@ -4698,7 +4761,7 @@ mod tests { #[test] fn test_function_call_with_args() { - let expr = parse_expr("foo(1, 2, 3)").unwrap(); + let (expr, _types) = parse_expr("foo(1, 2, 3)").unwrap(); match expr.kind { ExprKind::Call { func, args } => { match func.kind { @@ -4714,7 +4777,7 @@ mod tests { #[test] fn test_chained_postfix() { // obj.arr[0]->next - let expr = parse_expr("obj.arr[0]").unwrap(); + let (expr, _types) = parse_expr("obj.arr[0]").unwrap(); match expr.kind { ExprKind::Index { array, index } => { match array.kind { @@ -4739,7 +4802,7 @@ mod tests { #[test] fn test_simple_assignment() { - let expr = parse_expr("x = 5").unwrap(); + let (expr, _types) = parse_expr("x = 5").unwrap(); match expr.kind { ExprKind::Assign { op, target, value } => { assert_eq!(op, AssignOp::Assign); @@ -4755,19 +4818,19 @@ mod tests { #[test] fn test_compound_assignments() { - let expr = parse_expr("x += 5").unwrap(); + let (expr, _types) = parse_expr("x += 5").unwrap(); match expr.kind { ExprKind::Assign { op, .. } => assert_eq!(op, AssignOp::AddAssign), _ => panic!("Expected Assign"), } - let expr = parse_expr("x -= 5").unwrap(); + let (expr, _types) = parse_expr("x -= 5").unwrap(); match expr.kind { ExprKind::Assign { op, .. } => assert_eq!(op, AssignOp::SubAssign), _ => panic!("Expected Assign"), } - let expr = parse_expr("x *= 5").unwrap(); + let (expr, _types) = parse_expr("x *= 5").unwrap(); match expr.kind { ExprKind::Assign { op, .. } => assert_eq!(op, AssignOp::MulAssign), _ => panic!("Expected Assign"), @@ -4777,7 +4840,7 @@ mod tests { #[test] fn test_assignment_right_associativity() { // a = b = c should be a = (b = c) - let expr = parse_expr("a = b = c").unwrap(); + let (expr, _types) = parse_expr("a = b = c").unwrap(); match expr.kind { ExprKind::Assign { target, value, .. } => { match target.kind { @@ -4802,7 +4865,7 @@ mod tests { #[test] fn test_ternary() { - let expr = parse_expr("a ? b : c").unwrap(); + let (expr, _types) = parse_expr("a ? b : c").unwrap(); match expr.kind { ExprKind::Conditional { cond, @@ -4829,7 +4892,7 @@ mod tests { #[test] fn test_nested_ternary() { // a ? b : c ? d : e should be a ? b : (c ? d : e) - let expr = parse_expr("a ? b : c ? d : e").unwrap(); + let (expr, _types) = parse_expr("a ? b : c ? d : e").unwrap(); match expr.kind { ExprKind::Conditional { else_expr, .. } => { assert!(matches!(else_expr.kind, ExprKind::Conditional { .. })); @@ -4844,7 +4907,7 @@ mod tests { #[test] fn test_comma_expr() { - let expr = parse_expr("a, b, c").unwrap(); + let (expr, _types) = parse_expr("a, b, c").unwrap(); match expr.kind { ExprKind::Comma(exprs) => assert_eq!(exprs.len(), 3), _ => panic!("Expected Comma"), @@ -4857,15 +4920,15 @@ mod tests { #[test] fn test_sizeof_expr() { - let expr = parse_expr("sizeof x").unwrap(); + let (expr, _types) = parse_expr("sizeof x").unwrap(); assert!(matches!(expr.kind, ExprKind::SizeofExpr(_))); } #[test] fn test_sizeof_type() { - let expr = parse_expr("sizeof(int)").unwrap(); + let (expr, types) = parse_expr("sizeof(int)").unwrap(); match expr.kind { - ExprKind::SizeofType(typ) => assert_eq!(typ.kind, TypeKind::Int), + ExprKind::SizeofType(typ) => assert_eq!(types.kind(typ), TypeKind::Int), _ => panic!("Expected SizeofType"), } } @@ -4873,7 +4936,7 @@ mod tests { #[test] fn test_sizeof_paren_expr() { // sizeof(x) where x is not a type - let expr = parse_expr("sizeof(x)").unwrap(); + let (expr, _types) = parse_expr("sizeof(x)").unwrap(); assert!(matches!(expr.kind, ExprKind::SizeofExpr(_))); } @@ -4883,10 +4946,10 @@ mod tests { #[test] fn test_cast() { - let expr = parse_expr("(int)x").unwrap(); + let (expr, types) = parse_expr("(int)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, expr } => { - assert_eq!(cast_type.kind, TypeKind::Int); + assert_eq!(types.kind(cast_type), TypeKind::Int); match expr.kind { ExprKind::Ident { name, .. } => assert_eq!(name, "x"), _ => panic!("Expected Ident"), @@ -4898,11 +4961,14 @@ mod tests { #[test] fn test_cast_unsigned_char() { - let expr = parse_expr("(unsigned char)x").unwrap(); + let (expr, types) = parse_expr("(unsigned char)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, .. } => { - assert_eq!(cast_type.kind, TypeKind::Char); - assert!(cast_type.modifiers.contains(TypeModifiers::UNSIGNED)); + assert_eq!(types.kind(cast_type), TypeKind::Char); + assert!(types + .get(cast_type) + .modifiers + .contains(TypeModifiers::UNSIGNED)); } _ => panic!("Expected Cast"), } @@ -4910,11 +4976,14 @@ mod tests { #[test] fn test_cast_signed_int() { - let expr = parse_expr("(signed int)x").unwrap(); + let (expr, types) = parse_expr("(signed int)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, .. } => { - assert_eq!(cast_type.kind, TypeKind::Int); - assert!(cast_type.modifiers.contains(TypeModifiers::SIGNED)); + assert_eq!(types.kind(cast_type), TypeKind::Int); + assert!(types + .get(cast_type) + .modifiers + .contains(TypeModifiers::SIGNED)); } _ => panic!("Expected Cast"), } @@ -4922,11 +4991,14 @@ mod tests { #[test] fn test_cast_unsigned_long() { - let expr = parse_expr("(unsigned long)x").unwrap(); + let (expr, types) = parse_expr("(unsigned long)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, .. } => { - assert_eq!(cast_type.kind, TypeKind::Long); - assert!(cast_type.modifiers.contains(TypeModifiers::UNSIGNED)); + assert_eq!(types.kind(cast_type), TypeKind::Long); + assert!(types + .get(cast_type) + .modifiers + .contains(TypeModifiers::UNSIGNED)); } _ => panic!("Expected Cast"), } @@ -4934,10 +5006,10 @@ mod tests { #[test] fn test_cast_long_long() { - let expr = parse_expr("(long long)x").unwrap(); + let (expr, types) = parse_expr("(long long)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, .. } => { - assert_eq!(cast_type.kind, TypeKind::LongLong); + assert_eq!(types.kind(cast_type), TypeKind::LongLong); } _ => panic!("Expected Cast"), } @@ -4945,11 +5017,14 @@ mod tests { #[test] fn test_cast_unsigned_long_long() { - let expr = parse_expr("(unsigned long long)x").unwrap(); + let (expr, types) = parse_expr("(unsigned long long)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, .. } => { - assert_eq!(cast_type.kind, TypeKind::LongLong); - assert!(cast_type.modifiers.contains(TypeModifiers::UNSIGNED)); + assert_eq!(types.kind(cast_type), TypeKind::LongLong); + assert!(types + .get(cast_type) + .modifiers + .contains(TypeModifiers::UNSIGNED)); } _ => panic!("Expected Cast"), } @@ -4957,12 +5032,12 @@ mod tests { #[test] fn test_cast_pointer() { - let expr = parse_expr("(int*)x").unwrap(); + let (expr, types) = parse_expr("(int*)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, .. } => { - assert_eq!(cast_type.kind, TypeKind::Pointer); - let base = cast_type.get_base().unwrap(); - assert_eq!(base.kind, TypeKind::Int); + assert_eq!(types.kind(cast_type), TypeKind::Pointer); + let base = types.base_type(cast_type).unwrap(); + assert_eq!(types.kind(base), TypeKind::Int); } _ => panic!("Expected Cast"), } @@ -4970,12 +5045,12 @@ mod tests { #[test] fn test_cast_void_pointer() { - let expr = parse_expr("(void*)x").unwrap(); + let (expr, types) = parse_expr("(void*)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, .. } => { - assert_eq!(cast_type.kind, TypeKind::Pointer); - let base = cast_type.get_base().unwrap(); - assert_eq!(base.kind, TypeKind::Void); + assert_eq!(types.kind(cast_type), TypeKind::Pointer); + let base = types.base_type(cast_type).unwrap(); + assert_eq!(types.kind(base), TypeKind::Void); } _ => panic!("Expected Cast"), } @@ -4983,13 +5058,13 @@ mod tests { #[test] fn test_cast_unsigned_char_pointer() { - let expr = parse_expr("(unsigned char*)x").unwrap(); + let (expr, types) = parse_expr("(unsigned char*)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, .. } => { - assert_eq!(cast_type.kind, TypeKind::Pointer); - let base = cast_type.get_base().unwrap(); - assert_eq!(base.kind, TypeKind::Char); - assert!(base.modifiers.contains(TypeModifiers::UNSIGNED)); + assert_eq!(types.kind(cast_type), TypeKind::Pointer); + let base = types.base_type(cast_type).unwrap(); + assert_eq!(types.kind(base), TypeKind::Char); + assert!(types.get(base).modifiers.contains(TypeModifiers::UNSIGNED)); } _ => panic!("Expected Cast"), } @@ -4997,11 +5072,14 @@ mod tests { #[test] fn test_cast_const_int() { - let expr = parse_expr("(const int)x").unwrap(); + let (expr, types) = parse_expr("(const int)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, .. } => { - assert_eq!(cast_type.kind, TypeKind::Int); - assert!(cast_type.modifiers.contains(TypeModifiers::CONST)); + assert_eq!(types.kind(cast_type), TypeKind::Int); + assert!(types + .get(cast_type) + .modifiers + .contains(TypeModifiers::CONST)); } _ => panic!("Expected Cast"), } @@ -5009,14 +5087,14 @@ mod tests { #[test] fn test_cast_double_pointer() { - let expr = parse_expr("(int**)x").unwrap(); + let (expr, types) = parse_expr("(int**)x").unwrap(); match expr.kind { ExprKind::Cast { cast_type, .. } => { - assert_eq!(cast_type.kind, TypeKind::Pointer); - let base = cast_type.get_base().unwrap(); - assert_eq!(base.kind, TypeKind::Pointer); - let innermost = base.get_base().unwrap(); - assert_eq!(innermost.kind, TypeKind::Int); + assert_eq!(types.kind(cast_type), TypeKind::Pointer); + let base = types.base_type(cast_type).unwrap(); + assert_eq!(types.kind(base), TypeKind::Pointer); + let innermost = types.base_type(base).unwrap(); + assert_eq!(types.kind(innermost), TypeKind::Int); } _ => panic!("Expected Cast"), } @@ -5024,11 +5102,11 @@ mod tests { #[test] fn test_sizeof_compound_type() { - let expr = parse_expr("sizeof(unsigned long long)").unwrap(); + let (expr, types) = parse_expr("sizeof(unsigned long long)").unwrap(); match expr.kind { ExprKind::SizeofType(typ) => { - assert_eq!(typ.kind, TypeKind::LongLong); - assert!(typ.modifiers.contains(TypeModifiers::UNSIGNED)); + assert_eq!(types.kind(typ), TypeKind::LongLong); + assert!(types.get(typ).modifiers.contains(TypeModifiers::UNSIGNED)); } _ => panic!("Expected SizeofType"), } @@ -5036,10 +5114,10 @@ mod tests { #[test] fn test_sizeof_pointer_type() { - let expr = parse_expr("sizeof(int*)").unwrap(); + let (expr, types) = parse_expr("sizeof(int*)").unwrap(); match expr.kind { ExprKind::SizeofType(typ) => { - assert_eq!(typ.kind, TypeKind::Pointer); + assert_eq!(types.kind(typ), TypeKind::Pointer); } _ => panic!("Expected SizeofType"), } @@ -5051,7 +5129,7 @@ mod tests { #[test] fn test_parentheses() { - let expr = parse_expr("(1 + 2) * 3").unwrap(); + let (expr, _types) = parse_expr("(1 + 2) * 3").unwrap(); match expr.kind { ExprKind::Binary { op, left, .. } => { assert_eq!(op, BinaryOp::Mul); @@ -5071,14 +5149,14 @@ mod tests { #[test] fn test_complex_expr() { // x = a + b * c - d / e - let expr = parse_expr("x = a + b * c - d / e").unwrap(); + let (expr, _types) = parse_expr("x = a + b * c - d / e").unwrap(); assert!(matches!(expr.kind, ExprKind::Assign { .. })); } #[test] fn test_function_call_complex() { // foo(a + b, c * d) - let expr = parse_expr("foo(a + b, c * d)").unwrap(); + let (expr, _types) = parse_expr("foo(a + b, c * d)").unwrap(); match expr.kind { ExprKind::Call { args, .. } => { assert_eq!(args.len(), 2); @@ -5104,7 +5182,7 @@ mod tests { #[test] fn test_pointer_arithmetic() { // *p++ - let expr = parse_expr("*p++").unwrap(); + let (expr, _types) = parse_expr("*p++").unwrap(); match expr.kind { ExprKind::Unary { op: UnaryOp::Deref, @@ -5125,7 +5203,8 @@ mod tests { let tokens = tokenizer.tokenize(); let idents = tokenizer.ident_table(); let mut symbols = SymbolTable::new(); - let mut parser = Parser::new(&tokens, idents, &mut symbols); + let mut types = TypeTable::new(); + let mut parser = Parser::new(&tokens, idents, &mut symbols, &mut types); parser.skip_stream_tokens(); parser.parse_statement() } @@ -5332,34 +5411,36 @@ mod tests { // Declaration tests // ======================================================================== - fn parse_decl(input: &str) -> ParseResult { + fn parse_decl(input: &str) -> ParseResult<(Declaration, TypeTable)> { let mut tokenizer = Tokenizer::new(input.as_bytes(), 0); let tokens = tokenizer.tokenize(); let idents = tokenizer.ident_table(); let mut symbols = SymbolTable::new(); - let mut parser = Parser::new(&tokens, idents, &mut symbols); + let mut types = TypeTable::new(); + let mut parser = Parser::new(&tokens, idents, &mut symbols, &mut types); parser.skip_stream_tokens(); - parser.parse_declaration() + let decl = parser.parse_declaration()?; + Ok((decl, types)) } #[test] fn test_simple_decl() { - let decl = parse_decl("int x;").unwrap(); + let (decl, types) = parse_decl("int x;").unwrap(); assert_eq!(decl.declarators.len(), 1); assert_eq!(decl.declarators[0].name, "x"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Int); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Int); } #[test] fn test_decl_with_init() { - let decl = parse_decl("int x = 5;").unwrap(); + let (decl, _types) = parse_decl("int x = 5;").unwrap(); assert_eq!(decl.declarators.len(), 1); assert!(decl.declarators[0].init.is_some()); } #[test] fn test_multiple_declarators() { - let decl = parse_decl("int x, y, z;").unwrap(); + let (decl, _types) = parse_decl("int x, y, z;").unwrap(); assert_eq!(decl.declarators.len(), 3); assert_eq!(decl.declarators[0].name, "x"); assert_eq!(decl.declarators[1].name, "y"); @@ -5368,120 +5449,124 @@ mod tests { #[test] fn test_pointer_decl() { - let decl = parse_decl("int *p;").unwrap(); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Pointer); + let (decl, types) = parse_decl("int *p;").unwrap(); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Pointer); } #[test] fn test_array_decl() { - let decl = parse_decl("int arr[10];").unwrap(); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Array); - assert_eq!(decl.declarators[0].typ.array_size, Some(10)); + let (decl, types) = parse_decl("int arr[10];").unwrap(); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Array); + assert_eq!(types.get(decl.declarators[0].typ).array_size, Some(10)); } #[test] fn test_const_decl() { - let decl = parse_decl("const int x = 5;").unwrap(); - assert!(decl.declarators[0] - .typ + let (decl, types) = parse_decl("const int x = 5;").unwrap(); + assert!(types + .get(decl.declarators[0].typ) .modifiers .contains(TypeModifiers::CONST)); } #[test] fn test_unsigned_decl() { - let decl = parse_decl("unsigned int x;").unwrap(); - assert!(decl.declarators[0] - .typ + let (decl, types) = parse_decl("unsigned int x;").unwrap(); + assert!(types + .get(decl.declarators[0].typ) .modifiers .contains(TypeModifiers::UNSIGNED)); } #[test] fn test_long_long_decl() { - let decl = parse_decl("long long x;").unwrap(); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::LongLong); + let (decl, types) = parse_decl("long long x;").unwrap(); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::LongLong); } // ======================================================================== // Function parsing tests // ======================================================================== - fn parse_func(input: &str) -> ParseResult { + fn parse_func(input: &str) -> ParseResult<(FunctionDef, TypeTable)> { let mut tokenizer = Tokenizer::new(input.as_bytes(), 0); let tokens = tokenizer.tokenize(); let idents = tokenizer.ident_table(); let mut symbols = SymbolTable::new(); - let mut parser = Parser::new(&tokens, idents, &mut symbols); + let mut types = TypeTable::new(); + let mut parser = Parser::new(&tokens, idents, &mut symbols, &mut types); parser.skip_stream_tokens(); - parser.parse_function_def() + let func = parser.parse_function_def()?; + Ok((func, types)) } #[test] fn test_simple_function() { - let func = parse_func("int main() { return 0; }").unwrap(); + let (func, types) = parse_func("int main() { return 0; }").unwrap(); assert_eq!(func.name, "main"); - assert_eq!(func.return_type.kind, TypeKind::Int); + assert_eq!(types.kind(func.return_type), TypeKind::Int); assert!(func.params.is_empty()); } #[test] fn test_function_with_params() { - let func = parse_func("int add(int a, int b) { return a + b; }").unwrap(); + let (func, _types) = parse_func("int add(int a, int b) { return a + b; }").unwrap(); assert_eq!(func.name, "add"); assert_eq!(func.params.len(), 2); } #[test] fn test_void_function() { - let func = parse_func("void foo(void) { }").unwrap(); - assert_eq!(func.return_type.kind, TypeKind::Void); + let (func, types) = parse_func("void foo(void) { }").unwrap(); + assert_eq!(types.kind(func.return_type), TypeKind::Void); assert!(func.params.is_empty()); } #[test] fn test_variadic_function() { // Variadic functions are parsed but variadic info is not tracked in FunctionDef - let func = parse_func("int printf(char *fmt, ...) { return 0; }").unwrap(); + let (func, _types) = parse_func("int printf(char *fmt, ...) { return 0; }").unwrap(); assert_eq!(func.name, "printf"); } #[test] fn test_pointer_return() { - let func = parse_func("int *getptr() { return 0; }").unwrap(); - assert_eq!(func.return_type.kind, TypeKind::Pointer); + let (func, types) = parse_func("int *getptr() { return 0; }").unwrap(); + assert_eq!(types.kind(func.return_type), TypeKind::Pointer); } // ======================================================================== // Translation unit tests // ======================================================================== - fn parse_tu(input: &str) -> ParseResult { + fn parse_tu(input: &str) -> ParseResult<(TranslationUnit, TypeTable)> { let mut tokenizer = Tokenizer::new(input.as_bytes(), 0); let tokens = tokenizer.tokenize(); let idents = tokenizer.ident_table(); let mut symbols = SymbolTable::new(); - let mut parser = Parser::new(&tokens, idents, &mut symbols); - parser.parse_translation_unit() + let mut types = TypeTable::new(); + let mut parser = Parser::new(&tokens, idents, &mut symbols, &mut types); + let tu = parser.parse_translation_unit()?; + Ok((tu, types)) } #[test] fn test_simple_program() { - let tu = parse_tu("int main() { return 0; }").unwrap(); + let (tu, _types) = parse_tu("int main() { return 0; }").unwrap(); assert_eq!(tu.items.len(), 1); assert!(matches!(tu.items[0], ExternalDecl::FunctionDef(_))); } #[test] fn test_global_var() { - let tu = parse_tu("int x = 5;").unwrap(); + let (tu, _types) = parse_tu("int x = 5;").unwrap(); assert_eq!(tu.items.len(), 1); assert!(matches!(tu.items[0], ExternalDecl::Declaration(_))); } #[test] fn test_multiple_items() { - let tu = parse_tu("int x; int main() { return x; }").unwrap(); + let (tu, _types) = parse_tu("int x; int main() { return x; }").unwrap(); assert_eq!(tu.items.len(), 2); assert!(matches!(tu.items[0], ExternalDecl::Declaration(_))); assert!(matches!(tu.items[1], ExternalDecl::FunctionDef(_))); @@ -5489,12 +5574,12 @@ mod tests { #[test] fn test_function_declaration() { - let tu = parse_tu("int foo(int x);").unwrap(); + let (tu, types) = parse_tu("int foo(int x);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "foo"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Function); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Function); } _ => panic!("Expected Declaration"), } @@ -5503,7 +5588,7 @@ mod tests { #[test] fn test_struct_only_declaration() { // Struct definition without a variable declarator - let tu = parse_tu("struct point { int x; int y; };").unwrap(); + let (tu, _types) = parse_tu("struct point { int x; int y; };").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { @@ -5517,13 +5602,13 @@ mod tests { #[test] fn test_struct_with_variable_declaration() { // Struct definition with a variable declarator - let tu = parse_tu("struct point { int x; int y; } p;").unwrap(); + let (tu, types) = parse_tu("struct point { int x; int y; } p;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); assert_eq!(decl.declarators[0].name, "p"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Struct); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Struct); } _ => panic!("Expected Declaration"), } @@ -5536,15 +5621,15 @@ mod tests { #[test] fn test_typedef_basic() { // Basic typedef declaration - let tu = parse_tu("typedef int myint;").unwrap(); + let (tu, types) = parse_tu("typedef int myint;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); assert_eq!(decl.declarators[0].name, "myint"); // The type includes the TYPEDEF modifier - assert!(decl.declarators[0] - .typ + assert!(types + .get(decl.declarators[0].typ) .modifiers .contains(TypeModifiers::TYPEDEF)); } @@ -5555,7 +5640,7 @@ mod tests { #[test] fn test_typedef_usage() { // Typedef declaration followed by usage - let tu = parse_tu("typedef int myint; myint x;").unwrap(); + let (tu, types) = parse_tu("typedef int myint; myint x;").unwrap(); assert_eq!(tu.items.len(), 2); // First item: typedef declaration @@ -5571,7 +5656,7 @@ mod tests { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "x"); // The variable should have int type (resolved from typedef) - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Int); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Int); } _ => panic!("Expected variable Declaration"), } @@ -5580,14 +5665,14 @@ mod tests { #[test] fn test_typedef_pointer() { // Typedef for pointer type - let tu = parse_tu("typedef int *intptr; intptr p;").unwrap(); + let (tu, types) = parse_tu("typedef int *intptr; intptr p;").unwrap(); assert_eq!(tu.items.len(), 2); // Variable should have pointer type match &tu.items[1] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "p"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Pointer); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Pointer); } _ => panic!("Expected variable Declaration"), } @@ -5596,14 +5681,14 @@ mod tests { #[test] fn test_typedef_struct() { // Typedef for anonymous struct - let tu = parse_tu("typedef struct { int x; int y; } Point; Point p;").unwrap(); + let (tu, types) = parse_tu("typedef struct { int x; int y; } Point; Point p;").unwrap(); assert_eq!(tu.items.len(), 2); // Variable should have struct type match &tu.items[1] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "p"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Struct); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Struct); } _ => panic!("Expected variable Declaration"), } @@ -5612,14 +5697,14 @@ mod tests { #[test] fn test_typedef_chained() { // Chained typedef: typedef of typedef - let tu = parse_tu("typedef int myint; typedef myint myint2; myint2 x;").unwrap(); + let (tu, types) = parse_tu("typedef int myint; typedef myint myint2; myint2 x;").unwrap(); assert_eq!(tu.items.len(), 3); // Final variable should resolve to int match &tu.items[2] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "x"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Int); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Int); } _ => panic!("Expected variable Declaration"), } @@ -5628,7 +5713,7 @@ mod tests { #[test] fn test_typedef_multiple() { // Multiple typedefs in one declaration - let tu = parse_tu("typedef int INT, *INTPTR;").unwrap(); + let (tu, types) = parse_tu("typedef int INT, *INTPTR;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { @@ -5637,7 +5722,7 @@ mod tests { assert_eq!(decl.declarators[0].name, "INT"); assert_eq!(decl.declarators[1].name, "INTPTR"); // INTPTR should be a pointer type - assert_eq!(decl.declarators[1].typ.kind, TypeKind::Pointer); + assert_eq!(types.kind(decl.declarators[1].typ), TypeKind::Pointer); } _ => panic!("Expected Declaration"), } @@ -5646,7 +5731,7 @@ mod tests { #[test] fn test_typedef_in_function() { // Typedef used in function parameter and return type - let tu = + let (tu, types) = parse_tu("typedef int myint; myint add(myint a, myint b) { return a + b; }").unwrap(); assert_eq!(tu.items.len(), 2); @@ -5654,11 +5739,33 @@ mod tests { ExternalDecl::FunctionDef(func) => { assert_eq!(func.name, "add"); // Return type should resolve to int - assert_eq!(func.return_type.kind, TypeKind::Int); + assert_eq!(types.kind(func.return_type), TypeKind::Int); // Parameters should also resolve to int assert_eq!(func.params.len(), 2); - assert_eq!(func.params[0].typ.kind, TypeKind::Int); - assert_eq!(func.params[1].typ.kind, TypeKind::Int); + assert_eq!(types.kind(func.params[0].typ), TypeKind::Int); + assert_eq!(types.kind(func.params[1].typ), TypeKind::Int); + } + _ => panic!("Expected FunctionDef"), + } + } + + #[test] + fn test_typedef_local_variable() { + // Typedef used as local variable type inside function body + let (tu, _types) = + parse_tu("typedef int myint; int main(void) { myint x; x = 42; return 0; }").unwrap(); + assert_eq!(tu.items.len(), 2); + + match &tu.items[1] { + ExternalDecl::FunctionDef(func) => { + assert_eq!(func.name, "main"); + // Check that the body parsed correctly + match &func.body { + Stmt::Block(items) => { + assert!(items.len() >= 2, "Expected at least 2 block items"); + } + _ => panic!("Expected Block statement"), + } } _ => panic!("Expected FunctionDef"), } @@ -5671,7 +5778,7 @@ mod tests { #[test] fn test_restrict_pointer_decl() { // Local variable with restrict qualifier - let tu = parse_tu("int main(void) { int * restrict p; return 0; }").unwrap(); + let (tu, _types) = parse_tu("int main(void) { int * restrict p; return 0; }").unwrap(); assert_eq!(tu.items.len(), 1); // Just verify it parses without error } @@ -5679,8 +5786,9 @@ mod tests { #[test] fn test_restrict_function_param() { // Function with restrict-qualified pointer parameters - let tu = parse_tu("void copy(int * restrict dest, int * restrict src) { *dest = *src; }") - .unwrap(); + let (tu, types) = + parse_tu("void copy(int * restrict dest, int * restrict src) { *dest = *src; }") + .unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { @@ -5688,14 +5796,14 @@ mod tests { assert_eq!(func.name, "copy"); assert_eq!(func.params.len(), 2); // Both params should be restrict-qualified pointers - assert_eq!(func.params[0].typ.kind, TypeKind::Pointer); - assert!(func.params[0] - .typ + assert_eq!(types.kind(func.params[0].typ), TypeKind::Pointer); + assert!(types + .get(func.params[0].typ) .modifiers .contains(TypeModifiers::RESTRICT)); - assert_eq!(func.params[1].typ.kind, TypeKind::Pointer); - assert!(func.params[1] - .typ + assert_eq!(types.kind(func.params[1].typ), TypeKind::Pointer); + assert!(types + .get(func.params[1].typ) .modifiers .contains(TypeModifiers::RESTRICT)); } @@ -5706,7 +5814,8 @@ mod tests { #[test] fn test_restrict_with_const() { // Pointer with both const and restrict qualifiers - let tu = parse_tu("int main(void) { int * const restrict p = 0; return 0; }").unwrap(); + let (tu, _types) = + parse_tu("int main(void) { int * const restrict p = 0; return 0; }").unwrap(); assert_eq!(tu.items.len(), 1); // Just verify it parses without error - both qualifiers should be accepted } @@ -5714,15 +5823,15 @@ mod tests { #[test] fn test_restrict_global_pointer() { // Global pointer with restrict qualifier - let tu = parse_tu("int * restrict global_ptr;").unwrap(); + let (tu, types) = parse_tu("int * restrict global_ptr;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "global_ptr"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Pointer); - assert!(decl.declarators[0] - .typ + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Pointer); + assert!(types + .get(decl.declarators[0].typ) .modifiers .contains(TypeModifiers::RESTRICT)); } @@ -5737,14 +5846,14 @@ mod tests { #[test] fn test_volatile_basic() { // Basic volatile variable - let tu = parse_tu("volatile int x;").unwrap(); + let (tu, types) = parse_tu("volatile int x;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "x"); - assert!(decl.declarators[0] - .typ + assert!(types + .get(decl.declarators[0].typ) .modifiers .contains(TypeModifiers::VOLATILE)); } @@ -5755,16 +5864,19 @@ mod tests { #[test] fn test_volatile_pointer() { // Pointer to volatile int - let tu = parse_tu("volatile int *p;").unwrap(); + let (tu, types) = parse_tu("volatile int *p;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "p"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Pointer); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Pointer); // The base type should be volatile - let base = decl.declarators[0].typ.base.as_ref().unwrap(); - assert!(base.modifiers.contains(TypeModifiers::VOLATILE)); + let base_id = types.base_type(decl.declarators[0].typ).unwrap(); + assert!(types + .get(base_id) + .modifiers + .contains(TypeModifiers::VOLATILE)); } _ => panic!("Expected Declaration"), } @@ -5773,16 +5885,16 @@ mod tests { #[test] fn test_volatile_pointer_itself() { // Volatile pointer to int (pointer itself is volatile) - let tu = parse_tu("int * volatile p;").unwrap(); + let (tu, types) = parse_tu("int * volatile p;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "p"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Pointer); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Pointer); // The pointer type itself should be volatile - assert!(decl.declarators[0] - .typ + assert!(types + .get(decl.declarators[0].typ) .modifiers .contains(TypeModifiers::VOLATILE)); } @@ -5793,18 +5905,18 @@ mod tests { #[test] fn test_volatile_const_combined() { // Both const and volatile - let tu = parse_tu("const volatile int x;").unwrap(); + let (tu, types) = parse_tu("const volatile int x;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "x"); - assert!(decl.declarators[0] - .typ + assert!(types + .get(decl.declarators[0].typ) .modifiers .contains(TypeModifiers::VOLATILE)); - assert!(decl.declarators[0] - .typ + assert!(types + .get(decl.declarators[0].typ) .modifiers .contains(TypeModifiers::CONST)); } @@ -5815,7 +5927,7 @@ mod tests { #[test] fn test_volatile_function_param() { // Function with volatile pointer parameter - let tu = parse_tu("void foo(volatile int *p) { *p = 1; }").unwrap(); + let (tu, types) = parse_tu("void foo(volatile int *p) { *p = 1; }").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { @@ -5823,9 +5935,12 @@ mod tests { assert_eq!(func.name, "foo"); assert_eq!(func.params.len(), 1); // Parameter is pointer to volatile int - assert_eq!(func.params[0].typ.kind, TypeKind::Pointer); - let base = func.params[0].typ.base.as_ref().unwrap(); - assert!(base.modifiers.contains(TypeModifiers::VOLATILE)); + assert_eq!(types.kind(func.params[0].typ), TypeKind::Pointer); + let base_id = types.base_type(func.params[0].typ).unwrap(); + assert!(types + .get(base_id) + .modifiers + .contains(TypeModifiers::VOLATILE)); } _ => panic!("Expected FunctionDef"), } @@ -5838,7 +5953,7 @@ mod tests { #[test] fn test_attribute_on_function_declaration() { // Attribute on function declaration - let tu = parse_tu("void foo(void) __attribute__((noreturn));").unwrap(); + let (tu, _types) = parse_tu("void foo(void) __attribute__((noreturn));").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { @@ -5853,14 +5968,14 @@ mod tests { #[test] fn test_attribute_on_struct() { // Attribute between struct keyword and name (with variable) - let tu = parse_tu("struct __attribute__((packed)) foo { int x; } s;").unwrap(); + let (tu, types) = parse_tu("struct __attribute__((packed)) foo { int x; } s;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); assert_eq!(decl.declarators[0].name, "s"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Struct); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Struct); } _ => panic!("Expected Declaration"), } @@ -5869,14 +5984,15 @@ mod tests { #[test] fn test_attribute_after_struct() { // Attribute after struct closing brace (with variable) - let tu = parse_tu("struct foo { int x; } __attribute__((aligned(16))) s;").unwrap(); + let (tu, types) = + parse_tu("struct foo { int x; } __attribute__((aligned(16))) s;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); assert_eq!(decl.declarators[0].name, "s"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Struct); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Struct); } _ => panic!("Expected Declaration"), } @@ -5885,7 +6001,7 @@ mod tests { #[test] fn test_attribute_on_struct_only() { // Attribute on struct-only definition (no variable) - let tu = parse_tu("struct __attribute__((packed)) foo { int x; };").unwrap(); + let (tu, _types) = parse_tu("struct __attribute__((packed)) foo { int x; };").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { @@ -5900,7 +6016,7 @@ mod tests { #[test] fn test_attribute_on_variable() { // Attribute on variable declaration - let tu = parse_tu("int x __attribute__((aligned(8)));").unwrap(); + let (tu, _types) = parse_tu("int x __attribute__((aligned(8)));").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { @@ -5915,7 +6031,7 @@ mod tests { #[test] fn test_attribute_multiple() { // Multiple attributes in one list - let tu = parse_tu("void foo(void) __attribute__((noreturn, cold));").unwrap(); + let (tu, _types) = parse_tu("void foo(void) __attribute__((noreturn, cold));").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { @@ -5929,7 +6045,7 @@ mod tests { #[test] fn test_attribute_with_args() { // Attribute with multiple arguments - let tu = parse_tu( + let (tu, _types) = parse_tu( "void foo(const char *fmt, ...) __attribute__((__format__(__printf__, 1, 2)));", ) .unwrap(); @@ -5946,7 +6062,8 @@ mod tests { #[test] fn test_attribute_before_declaration() { // Attribute before declaration - let tu = parse_tu("__attribute__((visibility(\"default\"))) int exported_var;").unwrap(); + let (tu, _types) = + parse_tu("__attribute__((visibility(\"default\"))) int exported_var;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { @@ -5960,7 +6077,7 @@ mod tests { #[test] fn test_attribute_underscore_variant() { // __attribute variant (single underscore pair) - let tu = parse_tu("void foo(void) __attribute((noreturn));").unwrap(); + let (tu, _types) = parse_tu("void foo(void) __attribute((noreturn));").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { @@ -5982,7 +6099,8 @@ mod tests { #[test] fn test_const_assignment_parses() { // Assignment to const variable should still parse (errors are reported but parsing continues) - let tu = parse_tu("int main(void) { const int x = 42; x = 10; return 0; }").unwrap(); + let (tu, _types) = + parse_tu("int main(void) { const int x = 42; x = 10; return 0; }").unwrap(); assert_eq!(tu.items.len(), 1); // Verify we got a function definition assert!(matches!(tu.items[0], ExternalDecl::FunctionDef(_))); @@ -5991,7 +6109,7 @@ mod tests { #[test] fn test_const_pointer_deref_parses() { // Assignment through pointer to const should still parse - let tu = + let (tu, _types) = parse_tu("int main(void) { int v = 1; const int *p = &v; *p = 2; return 0; }").unwrap(); assert_eq!(tu.items.len(), 1); assert!(matches!(tu.items[0], ExternalDecl::FunctionDef(_))); @@ -6000,7 +6118,8 @@ mod tests { #[test] fn test_const_usage_valid() { // Valid const usage - reading const values - let tu = parse_tu("int main(void) { const int x = 42; int y = x + 1; return y; }").unwrap(); + let (tu, _types) = + parse_tu("int main(void) { const int x = 42; int y = x + 1; return y; }").unwrap(); assert_eq!(tu.items.len(), 1); assert!(matches!(tu.items[0], ExternalDecl::FunctionDef(_))); } @@ -6008,7 +6127,7 @@ mod tests { #[test] fn test_const_pointer_types() { // Different const pointer combinations - let tu = parse_tu( + let (tu, _types) = parse_tu( "int main(void) { int v = 1; const int *a = &v; int * const b = &v; const int * const c = &v; return 0; }", ) .unwrap(); @@ -6022,20 +6141,20 @@ mod tests { #[test] fn test_function_decl_no_params() { - let tu = parse_tu("int foo(void);").unwrap(); + let (tu, types) = parse_tu("int foo(void);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); assert_eq!(decl.declarators[0].name, "foo"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Function); - assert!(!decl.declarators[0].typ.variadic); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Function); + assert!(!types.is_variadic(decl.declarators[0].typ)); // Check return type - if let Some(ref base) = decl.declarators[0].typ.base { - assert_eq!(base.kind, TypeKind::Int); + if let Some(base_id) = types.base_type(decl.declarators[0].typ) { + assert_eq!(types.kind(base_id), TypeKind::Int); } // Check params (void means empty) - if let Some(ref params) = decl.declarators[0].typ.params { + if let Some(params) = types.params(decl.declarators[0].typ) { assert!(params.is_empty()); } } @@ -6045,16 +6164,16 @@ mod tests { #[test] fn test_function_decl_one_param() { - let tu = parse_tu("int square(int x);").unwrap(); + let (tu, types) = parse_tu("int square(int x);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "square"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Function); - assert!(!decl.declarators[0].typ.variadic); - if let Some(ref params) = decl.declarators[0].typ.params { + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Function); + assert!(!types.is_variadic(decl.declarators[0].typ)); + if let Some(params) = types.params(decl.declarators[0].typ) { assert_eq!(params.len(), 1); - assert_eq!(params[0].kind, TypeKind::Int); + assert_eq!(types.kind(params[0]), TypeKind::Int); } } _ => panic!("Expected Declaration"), @@ -6063,17 +6182,17 @@ mod tests { #[test] fn test_function_decl_multiple_params() { - let tu = parse_tu("int add(int a, int b, int c);").unwrap(); + let (tu, types) = parse_tu("int add(int a, int b, int c);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "add"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Function); - assert!(!decl.declarators[0].typ.variadic); - if let Some(ref params) = decl.declarators[0].typ.params { + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Function); + assert!(!types.is_variadic(decl.declarators[0].typ)); + if let Some(params) = types.params(decl.declarators[0].typ) { assert_eq!(params.len(), 3); for p in params { - assert_eq!(p.kind, TypeKind::Int); + assert_eq!(types.kind(*p), TypeKind::Int); } } } @@ -6083,14 +6202,14 @@ mod tests { #[test] fn test_function_decl_void_return() { - let tu = parse_tu("void do_something(int x);").unwrap(); + let (tu, types) = parse_tu("void do_something(int x);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "do_something"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Function); - if let Some(ref base) = decl.declarators[0].typ.base { - assert_eq!(base.kind, TypeKind::Void); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Function); + if let Some(base_id) = types.base_type(decl.declarators[0].typ) { + assert_eq!(types.kind(base_id), TypeKind::Void); } } _ => panic!("Expected Declaration"), @@ -6099,14 +6218,14 @@ mod tests { #[test] fn test_function_decl_pointer_return() { - let tu = parse_tu("char *get_string(void);").unwrap(); + let (tu, types) = parse_tu("char *get_string(void);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "get_string"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Function); - if let Some(ref base) = decl.declarators[0].typ.base { - assert_eq!(base.kind, TypeKind::Pointer); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Function); + if let Some(base_id) = types.base_type(decl.declarators[0].typ) { + assert_eq!(types.kind(base_id), TypeKind::Pointer); } } _ => panic!("Expected Declaration"), @@ -6115,15 +6234,15 @@ mod tests { #[test] fn test_function_decl_pointer_param() { - let tu = parse_tu("void process(int *data, int count);").unwrap(); + let (tu, types) = parse_tu("void process(int *data, int count);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "process"); - if let Some(ref params) = decl.declarators[0].typ.params { + if let Some(params) = types.params(decl.declarators[0].typ) { assert_eq!(params.len(), 2); - assert_eq!(params[0].kind, TypeKind::Pointer); - assert_eq!(params[1].kind, TypeKind::Int); + assert_eq!(types.kind(params[0]), TypeKind::Pointer); + assert_eq!(types.kind(params[1]), TypeKind::Int); } } _ => panic!("Expected Declaration"), @@ -6137,21 +6256,21 @@ mod tests { #[test] fn test_function_decl_variadic_printf() { // Classic printf prototype - let tu = parse_tu("int printf(const char *fmt, ...);").unwrap(); + let (tu, types) = parse_tu("int printf(const char *fmt, ...);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "printf"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Function); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Function); // Should be marked as variadic assert!( - decl.declarators[0].typ.variadic, + types.is_variadic(decl.declarators[0].typ), "printf should be marked as variadic" ); - if let Some(ref params) = decl.declarators[0].typ.params { + if let Some(params) = types.params(decl.declarators[0].typ) { // Only the fixed parameter (fmt) should be in params assert_eq!(params.len(), 1); - assert_eq!(params[0].kind, TypeKind::Pointer); + assert_eq!(types.kind(params[0]), TypeKind::Pointer); } } _ => panic!("Expected Declaration"), @@ -6160,16 +6279,16 @@ mod tests { #[test] fn test_function_decl_variadic_sprintf() { - let tu = parse_tu("int sprintf(char *buf, const char *fmt, ...);").unwrap(); + let (tu, types) = parse_tu("int sprintf(char *buf, const char *fmt, ...);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "sprintf"); assert!( - decl.declarators[0].typ.variadic, + types.is_variadic(decl.declarators[0].typ), "sprintf should be marked as variadic" ); - if let Some(ref params) = decl.declarators[0].typ.params { + if let Some(params) = types.params(decl.declarators[0].typ) { // Two fixed parameters: buf and fmt assert_eq!(params.len(), 2); } @@ -6181,18 +6300,18 @@ mod tests { #[test] fn test_function_decl_variadic_custom() { // Custom variadic function with int first param - let tu = parse_tu("int sum_ints(int count, ...);").unwrap(); + let (tu, types) = parse_tu("int sum_ints(int count, ...);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "sum_ints"); assert!( - decl.declarators[0].typ.variadic, + types.is_variadic(decl.declarators[0].typ), "sum_ints should be marked as variadic" ); - if let Some(ref params) = decl.declarators[0].typ.params { + if let Some(params) = types.params(decl.declarators[0].typ) { assert_eq!(params.len(), 1); - assert_eq!(params[0].kind, TypeKind::Int); + assert_eq!(types.kind(params[0]), TypeKind::Int); } } _ => panic!("Expected Declaration"), @@ -6202,20 +6321,20 @@ mod tests { #[test] fn test_function_decl_variadic_multiple_fixed() { // Variadic function with multiple fixed parameters - let tu = parse_tu("int variadic_func(int a, double b, char *c, ...);").unwrap(); + let (tu, types) = parse_tu("int variadic_func(int a, double b, char *c, ...);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "variadic_func"); assert!( - decl.declarators[0].typ.variadic, + types.is_variadic(decl.declarators[0].typ), "variadic_func should be marked as variadic" ); - if let Some(ref params) = decl.declarators[0].typ.params { + if let Some(params) = types.params(decl.declarators[0].typ) { assert_eq!(params.len(), 3); - assert_eq!(params[0].kind, TypeKind::Int); - assert_eq!(params[1].kind, TypeKind::Double); - assert_eq!(params[2].kind, TypeKind::Pointer); + assert_eq!(types.kind(params[0]), TypeKind::Int); + assert_eq!(types.kind(params[1]), TypeKind::Double); + assert_eq!(types.kind(params[2]), TypeKind::Pointer); } } _ => panic!("Expected Declaration"), @@ -6225,17 +6344,17 @@ mod tests { #[test] fn test_function_decl_variadic_void_return() { // Variadic function with void return type - let tu = parse_tu("void log_message(const char *fmt, ...);").unwrap(); + let (tu, types) = parse_tu("void log_message(const char *fmt, ...);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "log_message"); assert!( - decl.declarators[0].typ.variadic, + types.is_variadic(decl.declarators[0].typ), "log_message should be marked as variadic" ); - if let Some(ref base) = decl.declarators[0].typ.base { - assert_eq!(base.kind, TypeKind::Void); + if let Some(base_id) = types.base_type(decl.declarators[0].typ) { + assert_eq!(types.kind(base_id), TypeKind::Void); } } _ => panic!("Expected Declaration"), @@ -6245,13 +6364,13 @@ mod tests { #[test] fn test_function_decl_not_variadic() { // Make sure non-variadic functions are NOT marked as variadic - let tu = parse_tu("int regular_func(int a, int b);").unwrap(); + let (tu, types) = parse_tu("int regular_func(int a, int b);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "regular_func"); assert!( - !decl.declarators[0].typ.variadic, + !types.is_variadic(decl.declarators[0].typ), "regular_func should NOT be marked as variadic" ); } @@ -6262,7 +6381,7 @@ mod tests { #[test] fn test_variadic_function_definition() { // Variadic function definition (not just declaration) - let func = parse_func("int my_printf(char *fmt, ...) { return 0; }").unwrap(); + let (func, _types) = parse_func("int my_printf(char *fmt, ...) { return 0; }").unwrap(); assert_eq!(func.name, "my_printf"); // Note: FunctionDef doesn't directly expose variadic, but the function // body can use va_start etc. This test just ensures parsing succeeds. @@ -6272,7 +6391,7 @@ mod tests { #[test] fn test_multiple_function_decls_mixed() { // Mix of variadic and non-variadic declarations - let tu = parse_tu( + let (tu, types) = parse_tu( "int printf(const char *fmt, ...); int puts(const char *s); int sprintf(char *buf, const char *fmt, ...);", ) .unwrap(); @@ -6282,7 +6401,7 @@ mod tests { match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "printf"); - assert!(decl.declarators[0].typ.variadic); + assert!(types.is_variadic(decl.declarators[0].typ)); } _ => panic!("Expected Declaration"), } @@ -6291,7 +6410,7 @@ mod tests { match &tu.items[1] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "puts"); - assert!(!decl.declarators[0].typ.variadic); + assert!(!types.is_variadic(decl.declarators[0].typ)); } _ => panic!("Expected Declaration"), } @@ -6300,7 +6419,7 @@ mod tests { match &tu.items[2] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "sprintf"); - assert!(decl.declarators[0].typ.variadic); + assert!(types.is_variadic(decl.declarators[0].typ)); } _ => panic!("Expected Declaration"), } @@ -6309,14 +6428,14 @@ mod tests { #[test] fn test_function_decl_with_struct_param() { // Function declaration with struct parameter - let tu = + let (tu, types) = parse_tu("struct point { int x; int y; }; void move_point(struct point p);").unwrap(); assert_eq!(tu.items.len(), 2); match &tu.items[1] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "move_point"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Function); - assert!(!decl.declarators[0].typ.variadic); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Function); + assert!(!types.is_variadic(decl.declarators[0].typ)); } _ => panic!("Expected Declaration"), } @@ -6325,19 +6444,18 @@ mod tests { #[test] fn test_function_decl_array_decay() { // Array parameters decay to pointers in function declarations - let tu = parse_tu("void process_array(int arr[]);").unwrap(); + let (tu, types) = parse_tu("void process_array(int arr[]);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "process_array"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Function); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Function); // The array parameter should decay to pointer - if let Some(ref params) = decl.declarators[0].typ.params { + if let Some(params) = types.params(decl.declarators[0].typ) { assert_eq!(params.len(), 1); // Array params in function declarations become pointers - assert!( - params[0].kind == TypeKind::Pointer || params[0].kind == TypeKind::Array - ); + let p_kind = types.kind(params[0]); + assert!(p_kind == TypeKind::Pointer || p_kind == TypeKind::Array); } } _ => panic!("Expected Declaration"), @@ -6351,25 +6469,25 @@ mod tests { #[test] fn test_function_pointer_declaration() { // Basic function pointer: void (*fp)(int) - let tu = parse_tu("void (*fp)(int);").unwrap(); + let (tu, types) = parse_tu("void (*fp)(int);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); assert_eq!(decl.declarators[0].name, "fp"); // fp should be a pointer to a function - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Pointer); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Pointer); // The base type of the pointer should be a function - if let Some(ref base) = decl.declarators[0].typ.base { - assert_eq!(base.kind, TypeKind::Function); + if let Some(base_id) = types.base_type(decl.declarators[0].typ) { + assert_eq!(types.kind(base_id), TypeKind::Function); // Function returns void - if let Some(ref ret) = base.base { - assert_eq!(ret.kind, TypeKind::Void); + if let Some(ret_id) = types.base_type(base_id) { + assert_eq!(types.kind(ret_id), TypeKind::Void); } // Function takes one int parameter - if let Some(ref params) = base.params { + if let Some(params) = types.params(base_id) { assert_eq!(params.len(), 1); - assert_eq!(params[0].kind, TypeKind::Int); + assert_eq!(types.kind(params[0]), TypeKind::Int); } } else { panic!("Expected function pointer base type"); @@ -6382,16 +6500,16 @@ mod tests { #[test] fn test_function_pointer_no_params() { // Function pointer with no parameters: int (*fp)(void) - let tu = parse_tu("int (*fp)(void);").unwrap(); + let (tu, types) = parse_tu("int (*fp)(void);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "fp"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Pointer); - if let Some(ref base) = decl.declarators[0].typ.base { - assert_eq!(base.kind, TypeKind::Function); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Pointer); + if let Some(base_id) = types.base_type(decl.declarators[0].typ) { + assert_eq!(types.kind(base_id), TypeKind::Function); // (void) means no parameters - if let Some(ref params) = base.params { + if let Some(params) = types.params(base_id) { assert!(params.is_empty()); } } @@ -6403,19 +6521,19 @@ mod tests { #[test] fn test_function_pointer_multiple_params() { // Function pointer with multiple parameters: int (*fp)(int, char, double) - let tu = parse_tu("int (*fp)(int, char, double);").unwrap(); + let (tu, types) = parse_tu("int (*fp)(int, char, double);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "fp"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Pointer); - if let Some(ref base) = decl.declarators[0].typ.base { - assert_eq!(base.kind, TypeKind::Function); - if let Some(ref params) = base.params { + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Pointer); + if let Some(base_id) = types.base_type(decl.declarators[0].typ) { + assert_eq!(types.kind(base_id), TypeKind::Function); + if let Some(params) = types.params(base_id) { assert_eq!(params.len(), 3); - assert_eq!(params[0].kind, TypeKind::Int); - assert_eq!(params[1].kind, TypeKind::Char); - assert_eq!(params[2].kind, TypeKind::Double); + assert_eq!(types.kind(params[0]), TypeKind::Int); + assert_eq!(types.kind(params[1]), TypeKind::Char); + assert_eq!(types.kind(params[2]), TypeKind::Double); } } } @@ -6426,19 +6544,19 @@ mod tests { #[test] fn test_function_pointer_returning_pointer() { // Function pointer returning a pointer: char *(*fp)(int) - let tu = parse_tu("char *(*fp)(int);").unwrap(); + let (tu, types) = parse_tu("char *(*fp)(int);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "fp"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Pointer); - if let Some(ref base) = decl.declarators[0].typ.base { - assert_eq!(base.kind, TypeKind::Function); + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Pointer); + if let Some(base_id) = types.base_type(decl.declarators[0].typ) { + assert_eq!(types.kind(base_id), TypeKind::Function); // Return type is char* - if let Some(ref ret) = base.base { - assert_eq!(ret.kind, TypeKind::Pointer); - if let Some(ref char_type) = ret.base { - assert_eq!(char_type.kind, TypeKind::Char); + if let Some(ret_id) = types.base_type(base_id) { + assert_eq!(types.kind(ret_id), TypeKind::Pointer); + if let Some(char_id) = types.base_type(ret_id) { + assert_eq!(types.kind(char_id), TypeKind::Char); } } } @@ -6450,18 +6568,18 @@ mod tests { #[test] fn test_function_pointer_variadic() { // Variadic function pointer: int (*fp)(const char *, ...) - let tu = parse_tu("int (*fp)(const char *, ...);").unwrap(); + let (tu, types) = parse_tu("int (*fp)(const char *, ...);").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators[0].name, "fp"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Pointer); - if let Some(ref base) = decl.declarators[0].typ.base { - assert_eq!(base.kind, TypeKind::Function); - assert!(base.variadic, "Function should be variadic"); - if let Some(ref params) = base.params { + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Pointer); + if let Some(base_id) = types.base_type(decl.declarators[0].typ) { + assert_eq!(types.kind(base_id), TypeKind::Function); + assert!(types.is_variadic(base_id), "Function should be variadic"); + if let Some(params) = types.params(base_id) { assert_eq!(params.len(), 1); - assert_eq!(params[0].kind, TypeKind::Pointer); + assert_eq!(types.kind(params[0]), TypeKind::Pointer); } } } @@ -6476,14 +6594,15 @@ mod tests { #[test] fn test_bitfield_basic() { // Basic bitfield parsing - include a variable declarator - let tu = parse_tu("struct flags { unsigned int a : 4; unsigned int b : 4; } f;").unwrap(); + let (tu, types) = + parse_tu("struct flags { unsigned int a : 4; unsigned int b : 4; } f;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); assert_eq!(decl.declarators[0].name, "f"); - assert_eq!(decl.declarators[0].typ.kind, TypeKind::Struct); - if let Some(ref composite) = decl.declarators[0].typ.composite { + assert_eq!(types.kind(decl.declarators[0].typ), TypeKind::Struct); + if let Some(composite) = types.composite(decl.declarators[0].typ) { assert_eq!(composite.members.len(), 2); // First bitfield assert_eq!(composite.members[0].name, "a"); @@ -6502,7 +6621,7 @@ mod tests { #[test] fn test_bitfield_unnamed() { // Unnamed bitfield for padding - let tu = parse_tu( + let (tu, types) = parse_tu( "struct padded { unsigned int a : 4; unsigned int : 4; unsigned int b : 8; } p;", ) .unwrap(); @@ -6510,7 +6629,7 @@ mod tests { match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); - if let Some(ref composite) = decl.declarators[0].typ.composite { + if let Some(composite) = types.composite(decl.declarators[0].typ) { assert_eq!(composite.members.len(), 3); // First named bitfield assert_eq!(composite.members[0].name, "a"); @@ -6530,7 +6649,7 @@ mod tests { #[test] fn test_bitfield_zero_width() { // Zero-width bitfield forces alignment - let tu = parse_tu( + let (tu, types) = parse_tu( "struct aligned { unsigned int a : 4; unsigned int : 0; unsigned int b : 4; } x;", ) .unwrap(); @@ -6538,7 +6657,7 @@ mod tests { match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); - if let Some(ref composite) = decl.declarators[0].typ.composite { + if let Some(composite) = types.composite(decl.declarators[0].typ) { assert_eq!(composite.members.len(), 3); // After zero-width bitfield, b should start at new storage unit assert_eq!(composite.members[2].name, "b"); @@ -6556,12 +6675,13 @@ mod tests { #[test] fn test_bitfield_mixed_with_regular() { // Bitfield mixed with regular member - let tu = parse_tu("struct mixed { int x; unsigned int bits : 8; int y; } m;").unwrap(); + let (tu, types) = + parse_tu("struct mixed { int x; unsigned int bits : 8; int y; } m;").unwrap(); assert_eq!(tu.items.len(), 1); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); - if let Some(ref composite) = decl.declarators[0].typ.composite { + if let Some(composite) = types.composite(decl.declarators[0].typ) { assert_eq!(composite.members.len(), 3); // x is regular member assert_eq!(composite.members[0].name, "x"); @@ -6581,14 +6701,14 @@ mod tests { #[test] fn test_bitfield_struct_size() { // Verify struct size calculation with bitfields - let tu = parse_tu( + let (tu, types) = parse_tu( "struct small { unsigned int a : 1; unsigned int b : 1; unsigned int c : 1; } s;", ) .unwrap(); match &tu.items[0] { ExternalDecl::Declaration(decl) => { assert_eq!(decl.declarators.len(), 1); - if let Some(ref composite) = decl.declarators[0].typ.composite { + if let Some(composite) = types.composite(decl.declarators[0].typ) { // Three 1-bit fields should fit in one 4-byte int assert_eq!(composite.size, 4); } diff --git a/cc/ssa.rs b/cc/ssa.rs index 6fb8a8cc..870813cd 100644 --- a/cc/ssa.rs +++ b/cc/ssa.rs @@ -18,19 +18,9 @@ use crate::dominate::{compute_dominance_frontiers, domtree_build, idf_compute}; use crate::ir::{ BasicBlockId, Function, InsnRef, Instruction, Opcode, Pseudo, PseudoId, PseudoKind, }; -use crate::types::Type; +use crate::types::{TypeId, TypeTable}; use std::collections::{HashMap, HashSet}; -// ============================================================================ -// Promotability Check -// ============================================================================ - -/// Check if a type is promotable to SSA form. -/// Only scalar types (integers, pointers, floats) can be promoted. -fn is_promotable_type(typ: &Type) -> bool { - typ.is_scalar() -} - // ============================================================================ // SSA Conversion State // ============================================================================ @@ -112,7 +102,9 @@ impl<'a> SsaConverter<'a> { #[derive(Default)] struct VarInfo { /// Type of the variable - typ: Type, + typ: TypeId, + /// Size of the type in bits + size: u32, /// Blocks that store to this variable def_blocks: Vec, /// Total number of stores @@ -128,19 +120,20 @@ struct VarInfo { } /// Analyze a variable to determine if it can be promoted to SSA. -fn analyze_variable(func: &Function, var_name: &str) -> Option { +fn analyze_variable(func: &Function, types: &TypeTable, var_name: &str) -> Option { let local = func.get_local(var_name)?; let sym_id = local.sym; - let typ = local.typ.clone(); + let typ = local.typ; let decl_block = local.decl_block; - // Check basic promotability - if local.is_volatile || !is_promotable_type(&typ) { + // Check basic promotability - only scalar types can be promoted + if local.is_volatile || !types.is_scalar(typ) { return None; } let mut info = VarInfo { typ, + size: types.size_bits(typ), decl_block, ..Default::default() }; @@ -232,7 +225,7 @@ fn insert_phi_nodes(converter: &mut SsaConverter, var_name: &str, var_info: &Var let target = converter.alloc_phi(); // Create phi instruction - let phi = Instruction::phi(target, var_info.typ.clone()); + let phi = Instruction::phi(target, var_info.typ, var_info.size); // Store the variable name for later phi renaming // We'll use the phi_list field to store this temporarily @@ -610,7 +603,7 @@ fn remove_dead_stores(func: &mut Function, dead_stores: &[InsnRef]) { /// - Record stores as new definitions /// - Fill in phi operands from predecessors /// 4. Remove dead stores -pub fn ssa_convert(func: &mut Function) { +pub fn ssa_convert(func: &mut Function, types: &TypeTable) { if func.blocks.is_empty() { return; } @@ -625,7 +618,7 @@ pub fn ssa_convert(func: &mut Function) { let local_names: Vec = converter.func.locals.keys().cloned().collect(); for var_name in &local_names { - if let Some(var_info) = analyze_variable(converter.func, var_name) { + if let Some(var_info) = analyze_variable(converter.func, types, var_name) { // Skip if all usage is in a single block (no phi needed) if var_info.single_block.is_some() { // Could do local rewriting here but skip for now @@ -666,9 +659,8 @@ pub fn ssa_convert(func: &mut Function) { mod tests { use super::*; use crate::ir::BasicBlock; - use crate::types::TypeKind; - fn make_simple_if_cfg() -> Function { + fn make_simple_if_cfg(types: &TypeTable) -> Function { // Create a CFG with a simple if-then-else: // // int x = 1; @@ -683,7 +675,8 @@ mod tests { // v v // merge(3) - let mut func = Function::new("test", Type::basic(TypeKind::Int)); + let int_id = types.int_id; + let mut func = Function::new("test", int_id); // Create symbol pseudo for local variable 'x' let x_sym = PseudoId(0); @@ -692,7 +685,7 @@ mod tests { func.add_local( "x", x_sym, - Type::basic(TypeKind::Int), + int_id, false, // not volatile Some(BasicBlockId(0)), ); @@ -710,12 +703,7 @@ mod tests { entry.children = vec![BasicBlockId(1), BasicBlockId(2)]; entry.add_insn(Instruction::new(Opcode::Entry)); // Store x = 1 - entry.add_insn(Instruction::store( - val1, - x_sym, - 0, - Type::basic(TypeKind::Int), - )); + entry.add_insn(Instruction::store(val1, x_sym, 0, int_id, 32)); // Conditional branch entry.add_insn(Instruction::cbr(cond, BasicBlockId(1), BasicBlockId(2))); @@ -723,12 +711,7 @@ mod tests { let mut then_bb = BasicBlock::new(BasicBlockId(1)); then_bb.parents = vec![BasicBlockId(0)]; then_bb.children = vec![BasicBlockId(3)]; - then_bb.add_insn(Instruction::store( - val2, - x_sym, - 0, - Type::basic(TypeKind::Int), - )); + then_bb.add_insn(Instruction::store(val2, x_sym, 0, int_id, 32)); then_bb.add_insn(Instruction::br(BasicBlockId(3))); // Else block: (no assignment) @@ -743,12 +726,7 @@ mod tests { let result = PseudoId(4); func.add_pseudo(Pseudo::reg(result, 0)); // Load x - merge.add_insn(Instruction::load( - result, - x_sym, - 0, - Type::basic(TypeKind::Int), - )); + merge.add_insn(Instruction::load(result, x_sym, 0, int_id, 32)); merge.add_insn(Instruction::ret(Some(result))); func.entry = BasicBlockId(0); @@ -758,9 +736,10 @@ mod tests { #[test] fn test_analyze_variable() { - let func = make_simple_if_cfg(); + let types = TypeTable::new(); + let func = make_simple_if_cfg(&types); - let info = analyze_variable(&func, "x").unwrap(); + let info = analyze_variable(&func, &types, "x").unwrap(); assert_eq!(info.store_count, 2); // One in entry, one in then assert_eq!(info.def_blocks.len(), 2); assert!(!info.addr_taken); @@ -768,8 +747,9 @@ mod tests { #[test] fn test_ssa_convert_creates_phi() { - let mut func = make_simple_if_cfg(); - ssa_convert(&mut func); + let types = TypeTable::new(); + let mut func = make_simple_if_cfg(&types); + ssa_convert(&mut func, &types); // After SSA conversion, merge block should have a phi node let merge = func.get_block(BasicBlockId(3)).unwrap(); @@ -780,7 +760,8 @@ mod tests { #[test] fn test_domtree_built() { - let mut func = make_simple_if_cfg(); + let types = TypeTable::new(); + let mut func = make_simple_if_cfg(&types); domtree_build(&mut func); // Entry should be the root (no idom) diff --git a/cc/symbol.rs b/cc/symbol.rs index 14e3fa64..c8ad969d 100644 --- a/cc/symbol.rs +++ b/cc/symbol.rs @@ -10,7 +10,7 @@ // Based on sparse's scope-aware symbol management // -use crate::types::Type; +use crate::types::TypeId; use std::collections::HashMap; // ============================================================================ @@ -82,8 +82,8 @@ pub struct Symbol { /// Which namespace this symbol belongs to pub namespace: Namespace, - /// The type of this symbol - pub typ: Type, + /// The type of this symbol (interned TypeId) + pub typ: TypeId, /// Scope depth where this symbol was declared pub scope_depth: u32, @@ -97,7 +97,7 @@ pub struct Symbol { impl Symbol { /// Create a new variable symbol - pub fn variable(name: String, typ: Type, scope_depth: u32) -> Self { + pub fn variable(name: String, typ: TypeId, scope_depth: u32) -> Self { Self { name, kind: SymbolKind::Variable, @@ -110,7 +110,7 @@ impl Symbol { } /// Create a new function symbol - pub fn function(name: String, typ: Type, scope_depth: u32) -> Self { + pub fn function(name: String, typ: TypeId, scope_depth: u32) -> Self { Self { name, kind: SymbolKind::Function, @@ -123,7 +123,7 @@ impl Symbol { } /// Create a new parameter symbol - pub fn parameter(name: String, typ: Type, scope_depth: u32) -> Self { + pub fn parameter(name: String, typ: TypeId, scope_depth: u32) -> Self { Self { name, kind: SymbolKind::Parameter, @@ -135,13 +135,13 @@ impl Symbol { } } - /// Create a new enum constant symbol - pub fn enum_constant(name: String, value: i64, scope_depth: u32) -> Self { + /// Create a new enum constant symbol (requires int_id from TypeTable) + pub fn enum_constant(name: String, value: i64, int_id: TypeId, scope_depth: u32) -> Self { Self { name, kind: SymbolKind::EnumConstant, namespace: Namespace::Ordinary, - typ: Type::basic(crate::types::TypeKind::Int), + typ: int_id, scope_depth, defined: true, enum_value: Some(value), @@ -149,7 +149,7 @@ impl Symbol { } /// Create a new tag symbol (struct/union/enum tag) - pub fn tag(name: String, typ: Type, scope_depth: u32) -> Self { + pub fn tag(name: String, typ: TypeId, scope_depth: u32) -> Self { Self { name, kind: SymbolKind::Tag, @@ -162,7 +162,7 @@ impl Symbol { } /// Create a new typedef symbol - pub fn typedef(name: String, typ: Type, scope_depth: u32) -> Self { + pub fn typedef(name: String, typ: TypeId, scope_depth: u32) -> Self { Self { name, kind: SymbolKind::Typedef, @@ -335,11 +335,11 @@ impl SymbolTable { } /// Look up a typedef by name - /// Returns the aliased type if found - pub fn lookup_typedef(&self, name: &str) -> Option<&Type> { + /// Returns the aliased TypeId if found + pub fn lookup_typedef(&self, name: &str) -> Option { self.lookup(name, Namespace::Ordinary).and_then(|s| { if s.is_typedef() { - Some(&s.typ) + Some(s.typ) // TypeId is Copy } else { None } @@ -392,14 +392,15 @@ impl std::error::Error for SymbolError {} #[cfg(test)] mod tests { use super::*; - use crate::types::{Type, TypeKind}; + use crate::types::{Type, TypeKind, TypeTable}; #[test] fn test_declare_and_lookup() { + let types = TypeTable::new(); let mut table = SymbolTable::new(); // Declare a variable - let sym = Symbol::variable("x".to_string(), Type::basic(TypeKind::Int), 0); + let sym = Symbol::variable("x".to_string(), types.int_id, 0); let _id = table.declare(sym).unwrap(); // Look it up @@ -410,17 +411,18 @@ mod tests { #[test] fn test_scopes() { + let types = TypeTable::new(); let mut table = SymbolTable::new(); // Declare x in global scope - let sym1 = Symbol::variable("x".to_string(), Type::basic(TypeKind::Int), 0); + let sym1 = Symbol::variable("x".to_string(), types.int_id, 0); table.declare(sym1).unwrap(); // Enter a new scope table.enter_scope(); // Declare y in inner scope - let sym2 = Symbol::variable("y".to_string(), Type::basic(TypeKind::Char), 0); + let sym2 = Symbol::variable("y".to_string(), types.char_id, 0); table.declare(sym2).unwrap(); // Both should be visible @@ -437,41 +439,43 @@ mod tests { #[test] fn test_shadowing() { + let types = TypeTable::new(); let mut table = SymbolTable::new(); // Declare x as int in global scope - let sym1 = Symbol::variable("x".to_string(), Type::basic(TypeKind::Int), 0); + let sym1 = Symbol::variable("x".to_string(), types.int_id, 0); table.declare(sym1).unwrap(); // Enter a new scope table.enter_scope(); // Shadow x with char - let sym2 = Symbol::variable("x".to_string(), Type::basic(TypeKind::Char), 0); + let sym2 = Symbol::variable("x".to_string(), types.char_id, 0); table.declare(sym2).unwrap(); // Should find the inner x (char) let found = table.lookup("x", Namespace::Ordinary).unwrap(); - assert_eq!(found.typ.kind, TypeKind::Char); + assert_eq!(types.kind(found.typ), TypeKind::Char); // Leave scope table.leave_scope(); // Should find the outer x (int) let found = table.lookup("x", Namespace::Ordinary).unwrap(); - assert_eq!(found.typ.kind, TypeKind::Int); + assert_eq!(types.kind(found.typ), TypeKind::Int); } #[test] fn test_redefinition_error() { + let types = TypeTable::new(); let mut table = SymbolTable::new(); // Declare x - let sym1 = Symbol::variable("x".to_string(), Type::basic(TypeKind::Int), 0); + let sym1 = Symbol::variable("x".to_string(), types.int_id, 0); table.declare(sym1).unwrap(); // Try to redeclare x in the same scope - let sym2 = Symbol::variable("x".to_string(), Type::basic(TypeKind::Char), 0); + let sym2 = Symbol::variable("x".to_string(), types.char_id, 0); let result = table.declare(sym2); assert!(matches!(result, Err(SymbolError::Redefinition(_)))); @@ -498,14 +502,11 @@ mod tests { #[test] fn test_function_symbol() { + let mut types = TypeTable::new(); let mut table = SymbolTable::new(); // Declare a function - let func_type = Type::function( - Type::basic(TypeKind::Int), - vec![Type::basic(TypeKind::Int)], - false, - ); + let func_type = types.intern(Type::function(types.int_id, vec![types.int_id], false)); let func = Symbol::function("foo".to_string(), func_type, 0); table.declare(func).unwrap(); diff --git a/cc/target.rs b/cc/target.rs index 63730c79..ae29cc48 100644 --- a/cc/target.rs +++ b/cc/target.rs @@ -63,6 +63,10 @@ pub struct Target { pub long_width: u32, /// char is signed by default pub char_signed: bool, + /// Maximum size (in bits) for aggregate types (struct/union) that can be + /// passed or returned by value in registers. Aggregates larger than this + /// require indirect passing (pointer) or sret (struct return pointer). + pub max_aggregate_register_bits: u32, } impl Target { @@ -91,12 +95,20 @@ impl Target { (Arch::X86_64, _) => true, }; + // Maximum aggregate size that can be returned in registers. + // Both x86-64 SysV ABI and AAPCS64 technically support returning + // 16-byte structs in registers (rax+rdx or x0+x1), but pcc currently + // only supports single-register returns. Use sret for >8 byte structs + // until multi-register returns are implemented. + let max_aggregate_register_bits = 64; + Self { arch, os, pointer_width, long_width, char_signed, + max_aggregate_register_bits, } } diff --git a/cc/tests/datatypes/struct_type.rs b/cc/tests/datatypes/struct_type.rs index d197f7f6..139170ec 100644 --- a/cc/tests/datatypes/struct_type.rs +++ b/cc/tests/datatypes/struct_type.rs @@ -269,3 +269,44 @@ int main(void) { "#; assert_eq!(compile_and_run("struct_pointer", code), 0); } + +// ============================================================================ +// Large Struct Return: Return struct >8 bytes requiring sret (hidden pointer) +// This tests the sret ABI where large structs are returned via a hidden pointer. +// On ARM64, the sret pointer goes in X8 (not X0 like other args). +// On x86-64, the sret pointer goes in RDI (first arg register). +// ============================================================================ + +#[test] +fn struct_return_large() { + let code = r#" +struct large { + long first; + long second; +}; + +struct large make_large(long a, long b) { + struct large s; + s.first = a; + s.second = b; + return s; +} + +int main(void) { + struct large result; + result = make_large(300000, 200000); + + // Verify the struct was correctly returned + if (result.first != 300000) return 1; + if (result.second != 200000) return 2; + + // Test with different values to ensure correct member mapping + result = make_large(42, 84); + if (result.first != 42) return 3; + if (result.second != 84) return 4; + + return 0; +} +"#; + assert_eq!(compile_and_run("struct_return_large", code), 0); +} diff --git a/cc/types.rs b/cc/types.rs index b248d20b..4a9aca88 100644 --- a/cc/types.rs +++ b/cc/types.rs @@ -10,8 +10,28 @@ // Based on sparse's compositional type model // +use std::collections::HashMap; use std::fmt; +// ============================================================================ +// Type ID - Unique identifier for interned types +// ============================================================================ + +/// A unique identifier for an interned type (like IdentTable for strings) +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct TypeId(pub u32); + +impl TypeId { + /// Invalid/uninitialized type ID + pub const INVALID: TypeId = TypeId(u32::MAX); + + /// Check if this is a valid type ID + #[cfg(test)] + pub fn is_valid(&self) -> bool { + self.0 != u32::MAX + } +} + // ============================================================================ // Composite Type Components // ============================================================================ @@ -21,8 +41,8 @@ use std::fmt; pub struct StructMember { /// Member name (empty string for unnamed bitfields) pub name: String, - /// Member type - pub typ: Type, + /// Member type (interned TypeId) + pub typ: TypeId, /// Byte offset within struct (0 for unions, offset of storage unit for bitfields) pub offset: usize, /// For bitfields: bit offset within storage unit (0 = LSB) @@ -34,12 +54,12 @@ pub struct StructMember { } /// Information about a struct/union member lookup -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct MemberInfo { /// Byte offset within struct pub offset: usize, - /// Member type - pub typ: Type, + /// Member type (interned TypeId) + pub typ: TypeId, /// For bitfields: bit offset within storage unit pub bit_offset: Option, /// For bitfields: bit width @@ -87,114 +107,8 @@ impl CompositeType { } } - /// Compute struct layout with natural alignment, including bitfield packing - /// Returns (total_size, alignment) - pub fn compute_struct_layout(members: &mut [StructMember]) -> (usize, usize) { - let mut offset = 0usize; - let mut max_align = 1usize; - let mut current_bit_offset = 0u32; - let mut current_storage_unit_size = 0u32; // 0 means no active bitfield storage - - for member in members.iter_mut() { - if let Some(bit_width) = member.bit_width { - // This is a bitfield - let storage_size = member.typ.size_bytes() as u32; - let storage_bits = storage_size * 8; - - if bit_width == 0 { - // Zero-width bitfield: force alignment to next storage unit - if current_storage_unit_size > 0 { - offset += current_storage_unit_size as usize; - current_bit_offset = 0; - current_storage_unit_size = 0; - } - member.offset = offset; - member.bit_offset = None; - member.storage_unit_size = None; - continue; - } - - // Check if we need a new storage unit: - // - No active storage unit - // - Storage unit type changed - // - Bitfield doesn't fit in remaining space - let need_new_unit = current_storage_unit_size == 0 - || current_storage_unit_size != storage_size - || current_bit_offset + bit_width > storage_bits; - - if need_new_unit { - // Close current storage unit if active - if current_storage_unit_size > 0 { - offset += current_storage_unit_size as usize; - } - // Align to storage unit size - let align = storage_size as usize; - offset = (offset + align - 1) & !(align - 1); - max_align = max_align.max(align); - current_bit_offset = 0; - current_storage_unit_size = storage_size; - } - - member.offset = offset; - member.bit_offset = Some(current_bit_offset); - member.storage_unit_size = Some(storage_size); - current_bit_offset += bit_width; - } else { - // Regular member - close any active bitfield storage unit - if current_storage_unit_size > 0 { - offset += current_storage_unit_size as usize; - current_bit_offset = 0; - current_storage_unit_size = 0; - } - - let align = member.typ.alignment(); - max_align = max_align.max(align); - - // Align offset to member's alignment - offset = (offset + align - 1) & !(align - 1); - member.offset = offset; - member.bit_offset = None; - member.storage_unit_size = None; - - // Advance by member size - offset += member.typ.size_bytes(); - } - } - - // Close final bitfield storage unit if active - if current_storage_unit_size > 0 { - offset += current_storage_unit_size as usize; - } - - // Pad struct size to alignment - let size = if max_align > 1 { - (offset + max_align - 1) & !(max_align - 1) - } else { - offset - }; - (size, max_align) - } - - /// Compute union layout: all members at offset 0, size = max member size - /// Returns (total_size, alignment) - pub fn compute_union_layout(members: &mut [StructMember]) -> (usize, usize) { - let mut max_size = 0usize; - let mut max_align = 1usize; - - for member in members.iter_mut() { - member.offset = 0; // All union members at offset 0 - max_size = max_size.max(member.typ.size_bytes()); - max_align = max_align.max(member.typ.alignment()); - } - - // Pad size to alignment - let size = if max_align > 1 { - (max_size + max_align - 1) & !(max_align - 1) - } else { - max_size - }; - (size, max_align) - } + // NOTE: compute_struct_layout and compute_union_layout have been moved to TypeTable + // since they require access to member type sizes via TypeId lookup. } // ============================================================================ @@ -236,7 +150,7 @@ bitflags::bitflags! { // ============================================================================ /// Basic type kinds (mirrors sparse's SYM_* for types) -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TypeKind { // Basic types Void, @@ -298,10 +212,12 @@ impl fmt::Display for TypeKind { /// A C type (compositional structure like sparse) /// -/// Types are built compositionally: -/// - `int *p` -> Pointer { base: Int } -/// - `int arr[10]` -> Array { base: Int, size: 10 } -/// - `int (*fp)(int)` -> Pointer { base: Function { return: Int, params: [Int] } } +/// Types are built compositionally using TypeId references: +/// - `int *p` -> Pointer { base: TypeId(int) } +/// - `int arr[10]` -> Array { base: TypeId(int), size: 10 } +/// - `int (*fp)(int)` -> Pointer { base: Function { return: TypeId(int), params: [TypeId(int)] } } +/// +/// All nested types are referenced by TypeId, which are looked up in a TypeTable. #[derive(Debug, Clone, PartialEq)] pub struct Type { /// The kind of type @@ -310,14 +226,14 @@ pub struct Type { /// Type modifiers (const, volatile, signed, unsigned, etc.) pub modifiers: TypeModifiers, - /// Base type for pointers, arrays, and function return types - pub base: Option>, + /// Base type for pointers, arrays, and function return types (interned TypeId) + pub base: Option, /// Array size (for arrays) pub array_size: Option, - /// Function parameter types (for functions) - pub params: Option>, + /// Function parameter types (interned TypeIds) + pub params: Option>, /// Is this function variadic? (for functions) pub variadic: bool, @@ -358,12 +274,12 @@ impl Type { } } - /// Create a pointer type - pub fn pointer(base: Type) -> Self { + /// Create a pointer type (base type is a TypeId) + pub fn pointer(base: TypeId) -> Self { Self { kind: TypeKind::Pointer, modifiers: TypeModifiers::empty(), - base: Some(Box::new(base)), + base: Some(base), array_size: None, params: None, variadic: false, @@ -371,12 +287,12 @@ impl Type { } } - /// Create an array type - pub fn array(base: Type, size: usize) -> Self { + /// Create an array type (element type is a TypeId) + pub fn array(base: TypeId, size: usize) -> Self { Self { kind: TypeKind::Array, modifiers: TypeModifiers::empty(), - base: Some(Box::new(base)), + base: Some(base), array_size: Some(size), params: None, variadic: false, @@ -384,12 +300,12 @@ impl Type { } } - /// Create a function type - pub fn function(return_type: Type, params: Vec, variadic: bool) -> Self { + /// Create a function type (return type and param types are TypeIds) + pub fn function(return_type: TypeId, params: Vec, variadic: bool) -> Self { Self { kind: TypeKind::Function, modifiers: TypeModifiers::empty(), - base: Some(Box::new(return_type)), + base: Some(return_type), array_size: None, params: Some(params), variadic, @@ -451,133 +367,6 @@ impl Type { Self::enum_type(CompositeType::incomplete(Some(tag))) } - /// Check if this is an integer type - pub fn is_integer(&self) -> bool { - matches!( - self.kind, - TypeKind::Bool - | TypeKind::Char - | TypeKind::Short - | TypeKind::Int - | TypeKind::Long - | TypeKind::LongLong - ) - } - - /// Check if this is a floating point type - pub fn is_float(&self) -> bool { - matches!( - self.kind, - TypeKind::Float | TypeKind::Double | TypeKind::LongDouble - ) - } - - /// Check if this is an arithmetic type (integer or float) - pub fn is_arithmetic(&self) -> bool { - self.is_integer() || self.is_float() - } - - /// Check if this is a scalar type (arithmetic or pointer) - pub fn is_scalar(&self) -> bool { - self.is_arithmetic() || self.kind == TypeKind::Pointer - } - - /// Check if this type is unsigned - pub fn is_unsigned(&self) -> bool { - self.modifiers.contains(TypeModifiers::UNSIGNED) - } - - /// Check if this is a plain char (no explicit signed/unsigned) - /// Plain char has platform-dependent signedness - pub fn is_plain_char(&self) -> bool { - self.kind == TypeKind::Char - && !self.modifiers.contains(TypeModifiers::SIGNED) - && !self.modifiers.contains(TypeModifiers::UNSIGNED) - } - - /// Get the base type (for pointers, arrays, functions) - pub fn get_base(&self) -> Option<&Type> { - self.base.as_deref() - } - - /// Get the size of this type in bits (for a typical 64-bit target) - /// This is used by the IR to size operations - pub fn size_bits(&self) -> u32 { - match self.kind { - TypeKind::Void => 0, - TypeKind::Bool => 8, - TypeKind::Char => 8, - TypeKind::Short => 16, - TypeKind::Int => 32, - TypeKind::Long => 64, // LP64 model (macOS, Linux x86-64) - TypeKind::LongLong => 64, - TypeKind::Float => 32, - TypeKind::Double => 64, - TypeKind::LongDouble => 128, // x86-64 - TypeKind::Pointer => 64, // 64-bit pointers - TypeKind::Array => { - let base_size = self.base.as_ref().map(|b| b.size_bits()).unwrap_or(0); - let count = self.array_size.unwrap_or(0) as u32; - base_size * count - } - TypeKind::Function => 0, // Functions don't have a size - TypeKind::Struct | TypeKind::Union => (self.size_bytes() * 8) as u32, - TypeKind::Enum => 32, // Enums are int-sized - // va_list size is platform-specific: - // - x86-64: 24 bytes (192 bits) - struct with 4 fields - // - aarch64-macos: 8 bytes (64 bits) - just a char* - // - aarch64-linux: 32 bytes (256 bits) - struct with 5 fields - TypeKind::VaList => { - #[cfg(target_arch = "x86_64")] - { - 192 - } - #[cfg(all(target_arch = "aarch64", target_os = "macos"))] - { - 64 // On Apple ARM64, va_list is just char* - } - #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))] - { - 256 // On Linux ARM64, va_list is a 32-byte struct - } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - { - 64 // Fallback to pointer size - } - } - } - } - - /// Get the size of this type in bytes - pub fn size_bytes(&self) -> usize { - match self.kind { - TypeKind::Struct | TypeKind::Union => { - self.composite.as_ref().map(|c| c.size).unwrap_or(0) - } - TypeKind::Enum => 4, - _ => (self.size_bits() / 8) as usize, - } - } - - /// Get natural alignment for this type in bytes - pub fn alignment(&self) -> usize { - match self.kind { - TypeKind::Void => 1, - TypeKind::Bool | TypeKind::Char => 1, - TypeKind::Short => 2, - TypeKind::Int | TypeKind::Float => 4, - TypeKind::Long | TypeKind::LongLong | TypeKind::Double | TypeKind::Pointer => 8, - TypeKind::LongDouble => 16, - TypeKind::Struct | TypeKind::Union => { - self.composite.as_ref().map(|c| c.align).unwrap_or(1) - } - TypeKind::Enum => 4, // Enums are int-aligned - TypeKind::Array => self.base.as_ref().map(|b| b.alignment()).unwrap_or(1), - TypeKind::Function => 1, - TypeKind::VaList => 8, // All platforms use 8-byte alignment for va_list - } - } - /// Find a member in a struct/union type /// Returns MemberInfo with full bitfield details if found pub fn find_member(&self, name: &str) -> Option { @@ -586,7 +375,7 @@ impl Type { if member.name == name { return Some(MemberInfo { offset: member.offset, - typ: member.typ.clone(), + typ: member.typ, // TypeId is Copy, no clone needed bit_offset: member.bit_offset, bit_width: member.bit_width, storage_unit_size: member.storage_unit_size, @@ -602,6 +391,9 @@ impl Type { /// but otherwise requires types to be identical. /// Note: Different enum types are NOT compatible, even if they have /// the same underlying integer type. + /// + /// With TypeId interning, base types are compared by TypeId equality. + /// For full recursive comparison, use TypeTable::types_compatible(). pub fn types_compatible(&self, other: &Type) -> bool { // Top-level qualifiers to ignore const QUALIFIERS: TypeModifiers = TypeModifiers::CONST @@ -630,27 +422,20 @@ impl Type { return false; } - // Compare base types (for pointers, arrays, functions) - match (&self.base, &other.base) { - (Some(a), Some(b)) => { - if !a.types_compatible(b) { - return false; - } - } - (None, None) => {} - _ => return false, + // Compare base types by TypeId (interned types with same ID are equal) + if self.base != other.base { + return false; } - // Compare function parameters + // Compare function parameters by TypeId match (&self.params, &other.params) { (Some(a), Some(b)) => { if a.len() != b.len() { return false; } - for (pa, pb) in a.iter().zip(b.iter()) { - if !pa.types_compatible(pb) { - return false; - } + // TypeIds are directly comparable + if a != b { + return false; } } (None, None) => {} @@ -683,34 +468,36 @@ impl fmt::Display for Type { write!(f, "signed ")?; } + // Note: With TypeId, we can't recursively print base types without TypeTable access + // For debugging, we just show the TypeId value match self.kind { TypeKind::Pointer => { - if let Some(base) = &self.base { - write!(f, "{}*", base) + if let Some(base) = self.base { + write!(f, "T{}*", base.0) } else { write!(f, "*") } } TypeKind::Array => { - if let Some(base) = &self.base { + if let Some(base) = self.base { if let Some(size) = self.array_size { - write!(f, "{}[{}]", base, size) + write!(f, "T{}[{}]", base.0, size) } else { - write!(f, "{}[]", base) + write!(f, "T{}[]", base.0) } } else { write!(f, "[]") } } TypeKind::Function => { - if let Some(ret) = &self.base { - write!(f, "{}(", ret)?; + if let Some(ret) = self.base { + write!(f, "T{}(", ret.0)?; if let Some(params) = &self.params { for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } - write!(f, "{}", param)?; + write!(f, "T{}", param.0)?; } if self.variadic { if !params.is_empty() { @@ -729,6 +516,576 @@ impl fmt::Display for Type { } } +// ============================================================================ +// Type Table - Interned type storage and query methods +// ============================================================================ + +/// Key for type lookup/deduplication (hashable representation) +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum TypeKey { + /// Basic type: kind + modifiers + Basic(TypeKind, u32), + /// Pointer to interned type + Pointer(TypeId, u32), // base_id, modifiers + /// Array of interned type + Array(TypeId, Option), + /// Function type + Function { + ret: TypeId, + params: Vec, + variadic: bool, + }, +} + +/// Default capacity for type table allocations (reduces reallocation overhead) +const DEFAULT_TYPE_TABLE_CAPACITY: usize = 2048; + +/// Type table - stores all types and provides ID-based lookup +/// Pattern follows IdentTable in token/lexer.rs +pub struct TypeTable { + /// All interned types (indexed by TypeId) + types: Vec, + /// Lookup map for deduplication + lookup: HashMap, + + // Pre-computed common type IDs for fast access + pub void_id: TypeId, + pub bool_id: TypeId, + pub char_id: TypeId, + pub schar_id: TypeId, + pub uchar_id: TypeId, + pub short_id: TypeId, + pub ushort_id: TypeId, + pub int_id: TypeId, + pub uint_id: TypeId, + pub long_id: TypeId, + pub ulong_id: TypeId, + pub longlong_id: TypeId, + pub ulonglong_id: TypeId, + pub float_id: TypeId, + pub double_id: TypeId, + pub longdouble_id: TypeId, + pub void_ptr_id: TypeId, + pub char_ptr_id: TypeId, +} + +impl TypeTable { + /// Create a new type table with common types pre-interned + pub fn new() -> Self { + let mut table = Self { + types: Vec::with_capacity(DEFAULT_TYPE_TABLE_CAPACITY), + lookup: HashMap::with_capacity(DEFAULT_TYPE_TABLE_CAPACITY), + void_id: TypeId::INVALID, + bool_id: TypeId::INVALID, + char_id: TypeId::INVALID, + schar_id: TypeId::INVALID, + uchar_id: TypeId::INVALID, + short_id: TypeId::INVALID, + ushort_id: TypeId::INVALID, + int_id: TypeId::INVALID, + uint_id: TypeId::INVALID, + long_id: TypeId::INVALID, + ulong_id: TypeId::INVALID, + longlong_id: TypeId::INVALID, + ulonglong_id: TypeId::INVALID, + float_id: TypeId::INVALID, + double_id: TypeId::INVALID, + longdouble_id: TypeId::INVALID, + void_ptr_id: TypeId::INVALID, + char_ptr_id: TypeId::INVALID, + }; + + // Pre-intern common basic types + table.void_id = table.intern(Type::basic(TypeKind::Void)); + table.bool_id = table.intern(Type::basic(TypeKind::Bool)); + table.char_id = table.intern(Type::basic(TypeKind::Char)); + table.schar_id = table.intern(Type::with_modifiers(TypeKind::Char, TypeModifiers::SIGNED)); + table.uchar_id = table.intern(Type::with_modifiers( + TypeKind::Char, + TypeModifiers::UNSIGNED, + )); + table.short_id = table.intern(Type::basic(TypeKind::Short)); + table.ushort_id = table.intern(Type::with_modifiers( + TypeKind::Short, + TypeModifiers::UNSIGNED, + )); + table.int_id = table.intern(Type::basic(TypeKind::Int)); + table.uint_id = table.intern(Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED)); + table.long_id = table.intern(Type::basic(TypeKind::Long)); + table.ulong_id = table.intern(Type::with_modifiers( + TypeKind::Long, + TypeModifiers::UNSIGNED, + )); + table.longlong_id = table.intern(Type::basic(TypeKind::LongLong)); + table.ulonglong_id = table.intern(Type::with_modifiers( + TypeKind::LongLong, + TypeModifiers::UNSIGNED, + )); + table.float_id = table.intern(Type::basic(TypeKind::Float)); + table.double_id = table.intern(Type::basic(TypeKind::Double)); + table.longdouble_id = table.intern(Type::basic(TypeKind::LongDouble)); + + // Pre-intern common pointer types + table.void_ptr_id = table.intern(Type::pointer(table.void_id)); + table.char_ptr_id = table.intern(Type::pointer(table.char_id)); + + table + } + + /// Intern a type, returning its unique ID + /// Deduplicates equivalent types (same ID for equivalent types) + pub fn intern(&mut self, typ: Type) -> TypeId { + // Try to create a key for deduplication + if let Some(key) = self.make_key(&typ) { + if let Some(&existing_id) = self.lookup.get(&key) { + return existing_id; + } + let id = TypeId(self.types.len() as u32); + self.types.push(typ); + self.lookup.insert(key, id); + id + } else { + // Types with composite data (structs) are not deduplicated + let id = TypeId(self.types.len() as u32); + self.types.push(typ); + id + } + } + + /// Create lookup key for deduplication (None for non-deduplicatable types) + fn make_key(&self, typ: &Type) -> Option { + // Don't deduplicate types with composite data (structs/unions/enums have identity) + if typ.composite.is_some() { + return None; + } + + match typ.kind { + TypeKind::Pointer => { + let base = typ.base?; + Some(TypeKey::Pointer(base, typ.modifiers.bits())) + } + TypeKind::Array => { + let base = typ.base?; + Some(TypeKey::Array(base, typ.array_size)) + } + TypeKind::Function => { + let ret = typ.base?; + let params = typ.params.clone().unwrap_or_default(); + Some(TypeKey::Function { + ret, + params, + variadic: typ.variadic, + }) + } + _ => Some(TypeKey::Basic(typ.kind, typ.modifiers.bits())), + } + } + + /// Get a type by ID (returns reference) + #[inline] + pub fn get(&self, id: TypeId) -> &Type { + &self.types[id.0 as usize] + } + + // ========================================================================= + // Type query methods (moved from Type to TypeTable) + // ========================================================================= + + /// Get the type kind + #[inline] + pub fn kind(&self, id: TypeId) -> TypeKind { + self.get(id).kind + } + + /// Get type modifiers + #[inline] + pub fn modifiers(&self, id: TypeId) -> TypeModifiers { + self.get(id).modifiers + } + + /// Get the base type ID (for pointers, arrays, functions) + #[inline] + pub fn base_type(&self, id: TypeId) -> Option { + self.get(id).base + } + + /// Look up an existing pointer type to the given base type + /// Returns void_ptr_id if not found (since all pointers are same size) + #[inline] + pub fn pointer_to(&self, base: TypeId) -> TypeId { + // Look for an existing pointer to this base type + let key = TypeKey::Pointer(base, 0); // No modifiers + if let Some(&id) = self.lookup.get(&key) { + return id; + } + // All pointers are same size, so void* works as fallback + self.void_ptr_id + } + + // ========================================================================= + // Test-only methods (used by tests but not production code) + // ========================================================================= + + /// Get array size + #[cfg(test)] + #[inline] + pub fn array_size(&self, id: TypeId) -> Option { + self.get(id).array_size + } + + /// Get function parameters + #[cfg(test)] + #[inline] + pub fn params(&self, id: TypeId) -> Option<&Vec> { + self.get(id).params.as_ref() + } + + /// Check if function is variadic + #[cfg(test)] + #[inline] + pub fn is_variadic(&self, id: TypeId) -> bool { + self.get(id).variadic + } + + /// Get composite type data (for struct/union/enum) + #[cfg(test)] + pub fn composite(&self, id: TypeId) -> Option<&CompositeType> { + self.get(id).composite.as_deref() + } + + /// Format a type for display (with recursive base type printing) + #[cfg(test)] + pub fn format_type(&self, id: TypeId) -> String { + let typ = self.get(id); + let mut result = String::new(); + + // Print modifiers + if typ.modifiers.contains(TypeModifiers::CONST) { + result.push_str("const "); + } + if typ.modifiers.contains(TypeModifiers::VOLATILE) { + result.push_str("volatile "); + } + if typ.modifiers.contains(TypeModifiers::UNSIGNED) { + result.push_str("unsigned "); + } else if typ.modifiers.contains(TypeModifiers::SIGNED) && typ.kind == TypeKind::Char { + result.push_str("signed "); + } + + match typ.kind { + TypeKind::Pointer => { + if let Some(base) = typ.base { + result.push_str(&self.format_type(base)); + result.push('*'); + } else { + result.push('*'); + } + } + TypeKind::Array => { + if let Some(base) = typ.base { + result.push_str(&self.format_type(base)); + if let Some(size) = typ.array_size { + result.push_str(&format!("[{}]", size)); + } else { + result.push_str("[]"); + } + } else { + result.push_str("[]"); + } + } + TypeKind::Function => { + if let Some(ret) = typ.base { + result.push_str(&self.format_type(ret)); + result.push('('); + if let Some(params) = &typ.params { + for (i, ¶m) in params.iter().enumerate() { + if i > 0 { + result.push_str(", "); + } + result.push_str(&self.format_type(param)); + } + if typ.variadic { + if !params.is_empty() { + result.push_str(", "); + } + result.push_str("..."); + } + } + result.push(')'); + } else { + result.push_str("()"); + } + } + _ => { + result.push_str(&typ.kind.to_string()); + } + } + + result + } + + // ========================================================================= + // Production methods (used by compiler proper) + // ========================================================================= + + /// Check if type is an integer type + #[inline] + pub fn is_integer(&self, id: TypeId) -> bool { + matches!( + self.get(id).kind, + TypeKind::Bool + | TypeKind::Char + | TypeKind::Short + | TypeKind::Int + | TypeKind::Long + | TypeKind::LongLong + ) + } + + /// Check if type is a floating point type + #[inline] + pub fn is_float(&self, id: TypeId) -> bool { + matches!( + self.get(id).kind, + TypeKind::Float | TypeKind::Double | TypeKind::LongDouble + ) + } + + /// Check if type is an arithmetic type (integer or float) + #[inline] + pub fn is_arithmetic(&self, id: TypeId) -> bool { + self.is_integer(id) || self.is_float(id) + } + + /// Check if type is a scalar type (arithmetic or pointer) + #[inline] + pub fn is_scalar(&self, id: TypeId) -> bool { + self.is_arithmetic(id) || self.get(id).kind == TypeKind::Pointer + } + + /// Check if type is unsigned + #[inline] + pub fn is_unsigned(&self, id: TypeId) -> bool { + self.get(id).modifiers.contains(TypeModifiers::UNSIGNED) + } + + /// Check if type is a plain char (no explicit signed/unsigned) + #[inline] + pub fn is_plain_char(&self, id: TypeId) -> bool { + let typ = self.get(id); + typ.kind == TypeKind::Char + && !typ.modifiers.contains(TypeModifiers::SIGNED) + && !typ.modifiers.contains(TypeModifiers::UNSIGNED) + } + + /// Get the unsigned version of a type + #[inline] + pub fn unsigned_version(&self, id: TypeId) -> TypeId { + match self.get(id).kind { + TypeKind::Char => self.uchar_id, + TypeKind::Short => self.ushort_id, + TypeKind::Int => self.uint_id, + TypeKind::Long => self.ulong_id, + TypeKind::LongLong => self.ulonglong_id, + _ => id, // For non-integer types, just return the original + } + } + + /// Get the size of a type in bits + pub fn size_bits(&self, id: TypeId) -> u32 { + let typ = self.get(id); + match typ.kind { + TypeKind::Void => 0, + TypeKind::Bool => 8, + TypeKind::Char => 8, + TypeKind::Short => 16, + TypeKind::Int => 32, + TypeKind::Long => 64, + TypeKind::LongLong => 64, + TypeKind::Float => 32, + TypeKind::Double => 64, + TypeKind::LongDouble => 128, + TypeKind::Pointer => 64, + TypeKind::Array => { + let elem_size = typ.base.map(|b| self.size_bits(b)).unwrap_or(0); + let count = typ.array_size.unwrap_or(0) as u32; + elem_size * count + } + TypeKind::Struct | TypeKind::Union => { + (typ.composite.as_ref().map(|c| c.size).unwrap_or(0) * 8) as u32 + } + TypeKind::Function => 0, + TypeKind::Enum => 32, + TypeKind::VaList => { + #[cfg(target_arch = "x86_64")] + { + 192 + } + #[cfg(all(target_arch = "aarch64", target_os = "macos"))] + { + 64 + } + #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))] + { + 256 + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + 64 + } + } + } + } + + /// Get the size of a type in bytes + pub fn size_bytes(&self, id: TypeId) -> usize { + let typ = self.get(id); + match typ.kind { + TypeKind::Struct | TypeKind::Union => { + typ.composite.as_ref().map(|c| c.size).unwrap_or(0) + } + TypeKind::Enum => 4, + _ => (self.size_bits(id) / 8) as usize, + } + } + + /// Get natural alignment for a type in bytes + pub fn alignment(&self, id: TypeId) -> usize { + let typ = self.get(id); + match typ.kind { + TypeKind::Void => 1, + TypeKind::Bool | TypeKind::Char => 1, + TypeKind::Short => 2, + TypeKind::Int | TypeKind::Float => 4, + TypeKind::Long | TypeKind::LongLong | TypeKind::Double | TypeKind::Pointer => 8, + TypeKind::LongDouble => 16, + TypeKind::Struct | TypeKind::Union => { + typ.composite.as_ref().map(|c| c.align).unwrap_or(1) + } + TypeKind::Enum => 4, + TypeKind::Array => typ.base.map(|b| self.alignment(b)).unwrap_or(1), + TypeKind::Function => 1, + TypeKind::VaList => 8, + } + } + + /// Find a member in a struct/union type + pub fn find_member(&self, id: TypeId, name: &str) -> Option { + self.get(id).find_member(name) + } + + /// Check if two types are compatible (for __builtin_types_compatible_p) + pub fn types_compatible(&self, id1: TypeId, id2: TypeId) -> bool { + // Quick check: same TypeId means same type + if id1 == id2 { + return true; + } + // Compare underlying types + self.get(id1).types_compatible(self.get(id2)) + } + + /// Compute struct layout with natural alignment + /// Updates member offsets in place and returns (total_size, alignment) + pub fn compute_struct_layout(&self, members: &mut [StructMember]) -> (usize, usize) { + let mut offset = 0usize; + let mut max_align = 1usize; + let mut current_bit_offset = 0u32; + let mut current_storage_unit_size = 0u32; + + for member in members.iter_mut() { + if let Some(bit_width) = member.bit_width { + let storage_size = self.size_bytes(member.typ) as u32; + let storage_bits = storage_size * 8; + + if bit_width == 0 { + if current_storage_unit_size > 0 { + offset += current_storage_unit_size as usize; + current_bit_offset = 0; + current_storage_unit_size = 0; + } + member.offset = offset; + member.bit_offset = None; + member.storage_unit_size = None; + continue; + } + + let need_new_unit = current_storage_unit_size == 0 + || current_storage_unit_size != storage_size + || current_bit_offset + bit_width > storage_bits; + + if need_new_unit { + if current_storage_unit_size > 0 { + offset += current_storage_unit_size as usize; + } + let align = storage_size as usize; + offset = (offset + align - 1) & !(align - 1); + max_align = max_align.max(align); + current_bit_offset = 0; + current_storage_unit_size = storage_size; + } + + member.offset = offset; + member.bit_offset = Some(current_bit_offset); + member.storage_unit_size = Some(storage_size); + current_bit_offset += bit_width; + } else { + if current_storage_unit_size > 0 { + offset += current_storage_unit_size as usize; + current_bit_offset = 0; + current_storage_unit_size = 0; + } + + let align = self.alignment(member.typ); + max_align = max_align.max(align); + + offset = (offset + align - 1) & !(align - 1); + member.offset = offset; + member.bit_offset = None; + member.storage_unit_size = None; + + offset += self.size_bytes(member.typ); + } + } + + if current_storage_unit_size > 0 { + offset += current_storage_unit_size as usize; + } + + let size = if max_align > 1 { + (offset + max_align - 1) & !(max_align - 1) + } else { + offset + }; + (size, max_align) + } + + /// Compute union layout (all members at offset 0) + /// Returns (total_size, alignment) + pub fn compute_union_layout(&self, members: &mut [StructMember]) -> (usize, usize) { + let mut max_size = 0usize; + let mut max_align = 1usize; + + for member in members.iter_mut() { + member.offset = 0; + max_size = max_size.max(self.size_bytes(member.typ)); + max_align = max_align.max(self.alignment(member.typ)); + } + + let size = if max_align > 1 { + (max_size + max_align - 1) & !(max_align - 1) + } else { + max_size + }; + (size, max_align) + } +} + +impl Default for TypeTable { + fn default() -> Self { + Self::new() + } +} + // ============================================================================ // Tests // ============================================================================ @@ -739,94 +1096,93 @@ mod tests { #[test] fn test_basic_types() { - let int_type = Type::basic(TypeKind::Int); - assert!(int_type.is_integer()); - assert!(int_type.is_arithmetic()); - assert!(int_type.is_scalar()); - assert!(!int_type.is_float()); + let types = TypeTable::new(); + assert!(types.is_integer(types.int_id)); + assert!(types.is_arithmetic(types.int_id)); + assert!(types.is_scalar(types.int_id)); + assert!(!types.is_float(types.int_id)); } #[test] fn test_pointer_type() { - let int_ptr = Type::pointer(Type::basic(TypeKind::Int)); - assert_eq!(int_ptr.kind, TypeKind::Pointer); - assert!(int_ptr.is_scalar()); - assert!(!int_ptr.is_integer()); - - let base = int_ptr.get_base().unwrap(); - assert_eq!(base.kind, TypeKind::Int); + let mut types = TypeTable::new(); + let int_ptr_id = types.intern(Type::pointer(types.int_id)); + assert_eq!(types.kind(int_ptr_id), TypeKind::Pointer); + assert!(types.is_scalar(int_ptr_id)); + assert!(!types.is_integer(int_ptr_id)); + + let base_id = types.base_type(int_ptr_id).unwrap(); + assert_eq!(types.kind(base_id), TypeKind::Int); } #[test] fn test_array_type() { - let int_arr = Type::array(Type::basic(TypeKind::Int), 10); - assert_eq!(int_arr.kind, TypeKind::Array); - assert_eq!(int_arr.array_size, Some(10)); + let mut types = TypeTable::new(); + let int_arr_id = types.intern(Type::array(types.int_id, 10)); + assert_eq!(types.kind(int_arr_id), TypeKind::Array); + assert_eq!(types.array_size(int_arr_id), Some(10)); - let base = int_arr.get_base().unwrap(); - assert_eq!(base.kind, TypeKind::Int); + let base_id = types.base_type(int_arr_id).unwrap(); + assert_eq!(types.kind(base_id), TypeKind::Int); } #[test] fn test_function_type() { - let func = Type::function( - Type::basic(TypeKind::Int), - vec![Type::basic(TypeKind::Int), Type::basic(TypeKind::Char)], + let mut types = TypeTable::new(); + let func_id = types.intern(Type::function( + types.int_id, + vec![types.int_id, types.char_id], false, - ); - assert_eq!(func.kind, TypeKind::Function); - assert!(!func.variadic); + )); + assert_eq!(types.kind(func_id), TypeKind::Function); + assert!(!types.is_variadic(func_id)); - let params = func.params.as_ref().unwrap(); + let params = types.params(func_id).unwrap(); assert_eq!(params.len(), 2); - assert_eq!(params[0].kind, TypeKind::Int); - assert_eq!(params[1].kind, TypeKind::Char); + assert_eq!(types.kind(params[0]), TypeKind::Int); + assert_eq!(types.kind(params[1]), TypeKind::Char); } #[test] fn test_unsigned_modifier() { - let uint = Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED); - assert!(uint.is_unsigned()); + let types = TypeTable::new(); + assert!(types.is_unsigned(types.uint_id)); + assert!(!types.is_unsigned(types.int_id)); } #[test] - fn test_type_display() { - let int_type = Type::basic(TypeKind::Int); - assert_eq!(format!("{}", int_type), "int"); - - let const_int = Type::with_modifiers(TypeKind::Int, TypeModifiers::CONST); - assert_eq!(format!("{}", const_int), "const int"); - - let uint = Type::with_modifiers(TypeKind::Int, TypeModifiers::UNSIGNED); - assert_eq!(format!("{}", uint), "unsigned int"); - - let int_ptr = Type::pointer(Type::basic(TypeKind::Int)); - assert_eq!(format!("{}", int_ptr), "int*"); + fn test_type_format() { + let types = TypeTable::new(); + assert_eq!(types.format_type(types.int_id), "int"); + assert_eq!(types.format_type(types.uint_id), "unsigned int"); } #[test] fn test_nested_pointer() { + let mut types = TypeTable::new(); // int **pp - let int_ptr_ptr = Type::pointer(Type::pointer(Type::basic(TypeKind::Int))); - assert_eq!(int_ptr_ptr.kind, TypeKind::Pointer); + let int_ptr_id = types.intern(Type::pointer(types.int_id)); + let int_ptr_ptr_id = types.intern(Type::pointer(int_ptr_id)); + assert_eq!(types.kind(int_ptr_ptr_id), TypeKind::Pointer); - let inner = int_ptr_ptr.get_base().unwrap(); - assert_eq!(inner.kind, TypeKind::Pointer); + let inner_id = types.base_type(int_ptr_ptr_id).unwrap(); + assert_eq!(types.kind(inner_id), TypeKind::Pointer); - let innermost = inner.get_base().unwrap(); - assert_eq!(innermost.kind, TypeKind::Int); + let innermost_id = types.base_type(inner_id).unwrap(); + assert_eq!(types.kind(innermost_id), TypeKind::Int); } #[test] fn test_pointer_to_array() { + let mut types = TypeTable::new(); // int (*p)[10] - pointer to array of 10 ints - let arr_type = Type::array(Type::basic(TypeKind::Int), 10); - let ptr_to_arr = Type::pointer(arr_type); + let arr_id = types.intern(Type::array(types.int_id, 10)); + let ptr_to_arr_id = types.intern(Type::pointer(arr_id)); - assert_eq!(ptr_to_arr.kind, TypeKind::Pointer); - let base = ptr_to_arr.get_base().unwrap(); - assert_eq!(base.kind, TypeKind::Array); - assert_eq!(base.array_size, Some(10)); + assert_eq!(types.kind(ptr_to_arr_id), TypeKind::Pointer); + let base_id = types.base_type(ptr_to_arr_id).unwrap(); + assert_eq!(types.kind(base_id), TypeKind::Array); + assert_eq!(types.array_size(base_id), Some(10)); } #[test] @@ -877,21 +1233,50 @@ mod tests { #[test] fn test_types_compatible_pointers() { - let int_ptr = Type::pointer(Type::basic(TypeKind::Int)); - let int_ptr2 = Type::pointer(Type::basic(TypeKind::Int)); + let types = TypeTable::new(); + let int_ptr = Type::pointer(types.int_id); + let int_ptr2 = Type::pointer(types.int_id); assert!(int_ptr.types_compatible(&int_ptr2)); - let char_ptr = Type::pointer(Type::basic(TypeKind::Char)); + let char_ptr = Type::pointer(types.char_id); assert!(!int_ptr.types_compatible(&char_ptr)); } #[test] fn test_types_compatible_arrays() { - let arr10 = Type::array(Type::basic(TypeKind::Int), 10); - let arr10_2 = Type::array(Type::basic(TypeKind::Int), 10); + let types = TypeTable::new(); + let arr10 = Type::array(types.int_id, 10); + let arr10_2 = Type::array(types.int_id, 10); assert!(arr10.types_compatible(&arr10_2)); - let arr20 = Type::array(Type::basic(TypeKind::Int), 20); + let arr20 = Type::array(types.int_id, 20); assert!(!arr10.types_compatible(&arr20)); } + + #[test] + fn test_type_deduplication() { + let mut types = TypeTable::new(); + // Interning the same type should return the same ID + let int_ptr1 = types.intern(Type::pointer(types.int_id)); + let int_ptr2 = types.intern(Type::pointer(types.int_id)); + assert_eq!(int_ptr1, int_ptr2); + + // Different types should have different IDs + let char_ptr = types.intern(Type::pointer(types.char_id)); + assert_ne!(int_ptr1, char_ptr); + } + + #[test] + fn test_type_table_pre_interned() { + let types = TypeTable::new(); + // Pre-interned types should be valid + assert!(types.int_id.is_valid()); + assert!(types.char_id.is_valid()); + assert!(types.void_id.is_valid()); + assert!(types.void_ptr_id.is_valid()); + + // And have correct kinds + assert_eq!(types.kind(types.int_id), TypeKind::Int); + assert_eq!(types.kind(types.void_ptr_id), TypeKind::Pointer); + } }