use crate::layout_cache::{GetLayout, LayoutCache}; use fxhash::{FxBuildHasher, FxHashMap}; use leaf_assembly::{ functions::{ Function, ir::{Cmp, Instruction, InstructionVariant}, }, types::{Type, intrinsics::IntT}, values::{Const, ConstData, Int, Value}, }; use scc::HashMap; use std::{alloc::Layout, fmt::Debug, ops::Range, sync::Arc}; #[derive(Debug)] #[allow(non_camel_case_types)] pub enum OpCode<'l> { Store_CL_U8(u8, usize), Store_CL_U16(u16, usize), Store_CL_U32(u32, usize), Store_CL_U64(u64, usize), Add_LL_U32(usize, usize, usize), Add_LC_U32(usize, u32, usize), CmpEq_LC_U32(usize, u32, usize), CmpLt_LC_U32(usize, u32, usize), CmpGt_LC_U32(usize, u32, usize), CmpLe_LC_U32(usize, u32, usize), CmpGe_LC_U32(usize, u32, usize), Call(&'l Function<'l>, Vec>, Range), Jump(usize), Branch { cond: usize, true_case: usize, false_case: usize, }, CopyRange(Range, Range), ReturnRange(Range), } #[non_exhaustive] pub struct InstructionCacheEntry<'l> { pub layout: Layout, pub opcodes: Vec>, } impl Debug for InstructionCacheEntry<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { struct OpCodes<'a, 'b>(&'a [OpCode<'b>]); impl Debug for OpCodes<'_, '_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut dbg = f.debug_list(); for op in self.0 { dbg.entry(&format_args!("{op:?}")); } dbg.finish() } } f.debug_struct("InstructionCacheEntry") .field("layout", &self.layout) .field("opcodes", &OpCodes(&self.opcodes)) .finish() } } impl<'l> InstructionCacheEntry<'l> { pub fn new(func: &'l Function<'l>, layouts: &LayoutCache<'l>) -> Self { let mut opcodes = Vec::new(); let mut layout = Layout::new::<()>(); let mut memory_ranges = FxHashMap::default(); macro_rules! alloc_range { ($ty:expr, $id:expr) => {{ &*memory_ranges.entry($id).or_insert_with(|| { let l = layouts.get_layout($ty); let (new_layout, offset) = layout.extend(l.layout).unwrap(); layout = new_layout; offset..offset + l.layout.size() }) }}; } let Some(body) = func.body() else { unreachable!("InstructionCacheEntry::new called with a function without a body."); }; let mut backpatch = vec![]; let mut block_starts = Vec::with_capacity(body.blocks.len()); for inst in body.blocks.iter().flat_map(|b| b.instructions()) { match &inst.variant { InstructionVariant::StackAlloc(ty) => { let _ = alloc_range!(*ty, inst.id()); } _ => {} } } for block in &body.blocks { block_starts.push(opcodes.len()); for inst in block.instructions() { match &inst.variant { InstructionVariant::StackAlloc(_) => {} InstructionVariant::Store( Value::Instruction( t @ Instruction { variant: InstructionVariant::StackAlloc(ty), .. }, ), Value::Const(Const { data: ConstData::Int(Int::U8(v)), ctx, }), ) => { assert_eq!(ctx.u8_t(), *ty); opcodes.push(OpCode::Store_CL_U8(*v, memory_ranges[&t.id()].start)); } InstructionVariant::Store( Value::Instruction( t @ Instruction { variant: InstructionVariant::StackAlloc(ty), .. }, ), Value::Const(Const { data: ConstData::Int(Int::U16(v)), ctx, }), ) => { assert_eq!(ctx.u16_t(), *ty); opcodes.push(OpCode::Store_CL_U16(*v, memory_ranges[&t.id()].start)); } InstructionVariant::Store( Value::Instruction( t @ Instruction { variant: InstructionVariant::StackAlloc(ty), .. }, ), Value::Const(Const { data: ConstData::Int(Int::U32(v)), ctx, }), ) => { assert_eq!(ctx.u32_t(), *ty); opcodes.push(OpCode::Store_CL_U32(*v, memory_ranges[&t.id()].start)); } InstructionVariant::Store( Value::Instruction( t @ Instruction { variant: InstructionVariant::StackAlloc(ty), .. }, ), Value::Const(Const { data: ConstData::Int(Int::U64(v)), ctx, }), ) => { assert_eq!(ctx.u64_t(), *ty); opcodes.push(OpCode::Store_CL_U64(*v, memory_ranges[&t.id()].start)); } InstructionVariant::Store( Value::Instruction( t @ Instruction { variant: InstructionVariant::StackAlloc(ty), .. }, ), Value::Instruction(v), ) => { let src = alloc_range!(*ty, v.id()).clone(); let dst = alloc_range!(*ty, t.id()).clone(); opcodes.push(OpCode::CopyRange(src, dst)); } InstructionVariant::Load(Value::Instruction( t @ Instruction { variant: InstructionVariant::StackAlloc(ty), .. }, )) => { let src = alloc_range!(*ty, t.id()).clone(); let dst = alloc_range!(*ty, inst.id()).clone(); opcodes.push(OpCode::CopyRange(src, dst)); } InstructionVariant::IAdd(Value::Instruction(a), Value::Instruction(b)) => { match a.value_ty() { ty @ Type::Int(IntT { signed: false, precision: 32, .. }) => { let a = memory_ranges[&a.id()].start; let b = memory_ranges[&b.id()].start; let c = alloc_range!(ty, inst.id()).start; opcodes.push(OpCode::Add_LL_U32(a, b, c)); } _ => todo!("Unimplemented type `{}`", a.value_ty()), } } InstructionVariant::IAdd( Value::Instruction(a), Value::Const(Const { data: ConstData::Int(Int::U32(b)), ctx, }), ) => { assert!(matches!( a.value_ty(), Type::Int(IntT { signed: false, precision: 32, .. }) )); let a = memory_ranges[&a.id()].start; let c = alloc_range!(ctx.u32_t(), inst.id()).start; opcodes.push(OpCode::Add_LC_U32(a, *b, c)); } InstructionVariant::ICmp( Value::Instruction(a), Value::Const(Const { data: ConstData::Int(Int::U32(b)), ctx, }), cmp, ) => { assert!(matches!( a.value_ty(), Type::Int(IntT { signed: false, precision: 32, .. }) )); let a = memory_ranges[&a.id()].start; let target = alloc_range!(ctx.bool_t(), inst.id()).start; match *cmp { Cmp::Eq => opcodes.push(OpCode::CmpEq_LC_U32(a, *b, target)), Cmp::Lt => opcodes.push(OpCode::CmpLt_LC_U32(a, *b, target)), Cmp::Gt => opcodes.push(OpCode::CmpGt_LC_U32(a, *b, target)), Cmp::Le => opcodes.push(OpCode::CmpLe_LC_U32(a, *b, target)), Cmp::Ge => opcodes.push(OpCode::CmpGe_LC_U32(a, *b, target)), }; } InstructionVariant::Call(func, args) => { let mut ranges = Vec::with_capacity(args.len()); for arg in args { match arg { Value::Instruction(i) => { ranges.push(memory_ranges[&i.id()].clone()); } _ => todo!("Unimplemented variant `{arg}`."), } } opcodes.push(OpCode::Call( func, ranges, alloc_range!(func.ty.ret_t, inst.id()).clone(), )); } InstructionVariant::Jump(target) => { backpatch.push(opcodes.len()); opcodes.push(OpCode::Jump(target.id as usize)); } InstructionVariant::Branch { cond: Value::Instruction(i), true_case, false_case, } => { assert!(matches!(i.value_ty(), Type::Bool(_))); backpatch.push(opcodes.len()); opcodes.push(OpCode::Branch { cond: memory_ranges[&i.id()].start, true_case: true_case.id as usize, false_case: false_case.id as usize, }); } InstructionVariant::Return(Some(Value::Instruction(t))) => { assert_eq!(func.ty.ret_t, t.value_ty()); opcodes.push(OpCode::ReturnRange(memory_ranges[&t.id()].clone())); } _ => todo!("Unimplemented instruction `{inst:?}`"), } } } for idx in backpatch { match &mut opcodes[idx] { OpCode::Jump(target) => { *target = block_starts[*target]; } OpCode::Branch { true_case, false_case, .. } => { *true_case = block_starts[*true_case]; *false_case = block_starts[*false_case]; } _ => unreachable!(), } } Self { layout, opcodes } } } pub struct InstructionCache<'l> { layouts: Arc>, entries: HashMap<*const Function<'l>, Arc>, FxBuildHasher>, } impl<'l> InstructionCache<'l> { pub fn new(layouts: Arc>) -> Self { Self { layouts, entries: HashMap::default(), } } pub fn get(&self, func: &'l Function<'l>) -> Arc> { self.entries .entry_sync(func as *const Function<'l>) .or_insert_with(|| Arc::new(InstructionCacheEntry::new(func, &self.layouts))) .clone() } }