use crate::{ instruction_cache::{InstructionCache, InstructionCacheEntry, OpCode}, layout_cache::LayoutCache, }; use fxhash::FxHashMap; use leaf_assembly::{ context::Ctx, functions::Function, types::{Type, intrinsics::IntT}, values::{Const, Int, Value}, }; use smallvec::SmallVec; use std::{ops::Range, sync::Arc}; #[derive(Debug)] pub enum Error { InvalidParameterCount, InvalidParameterType, FunctionHasNoBody, } #[derive(Debug)] pub enum AnyValue { Void, Int(Int), Ptr(*const u8), } pub struct Interpreter<'l> { ctx: Ctx<'l>, stack: Vec, layouts: Arc>, instructions: Arc>, native_funcs: FxHashMap< *const Function<'l>, Arc]) -> Result + 'l>, >, } impl<'l> Interpreter<'l> { pub fn new(ctx: Ctx<'l>) -> Self { let layouts = Arc::new(LayoutCache::default()); Self { ctx, stack: vec![0; 1024], layouts: layouts.clone(), instructions: Arc::new(InstructionCache::new(layouts)), native_funcs: FxHashMap::default(), } } pub fn register_function( &mut self, func: &'l Function<'l>, fn_impl: impl Fn(&mut Self, &[Range]) -> Result + 'l, ) { self.native_funcs.insert(func, Arc::new(fn_impl)); } pub fn run(&mut self, func: &'l Function<'l>, args: Vec>) -> Result { if func.body().is_none() { return Err(Error::FunctionHasNoBody); }; if func.ty.par_t.len() != args.len() { return Err(Error::InvalidParameterCount); } if func .ty .par_t .iter() .zip(&args) .any(|(ty, arg)| *ty != arg.ty()) { return Err(Error::InvalidParameterType); } let mut sp = 0; let mut args_ranges: SmallVec<[Range; 4]> = args .into_iter() .map(|v| self.push_ctx_val(&mut sp, &v)) .collect(); let func_entry = self.instructions.get(func); let ret_range = self.call(sp, &mut args_ranges, &func_entry); match func.ty.ret_t { Type::Int(IntT { signed: false, precision: 32, .. }) => { let mem = &self.stack[ret_range]; let val = u32::from_ne_bytes(mem.try_into().unwrap()); Ok(AnyValue::Int(Int::U32(val))) } _ => todo!("Unsupported type `{}`", func.ty.ret_t), } } fn call( &mut self, sp: usize, _args: &[Range], function: &InstructionCacheEntry<'l>, ) -> Range { unsafe { let l = function.layout; let sp = sp + (self.stack.as_ptr().add(sp)).align_offset(l.align()); assert!(sp + l.size() < self.stack.len()); // Since the earlier assertion guarantees the stack won't overflow, we can skip the range checks. let mut i = 0; while let Some(opcode) = function.opcodes.get(i) { match opcode { OpCode::Store_CL_U32(v, offset) => { let start = sp + *offset; self.stack .get_unchecked_mut(start..start + 4) .copy_from_slice(&v.to_ne_bytes()); } OpCode::CopyRange(src, dst) => { let len = src.len(); let src = self.stack.as_ptr().add(sp + src.start); let dst = self.stack.as_mut_ptr().add(sp + dst.start); std::ptr::copy_nonoverlapping(src, dst, len); } OpCode::Add_LL_U32(a, b, dst) => { let [a, b, dst] = [ sp + a..sp + a + 4, sp + b..sp + b + 4, sp + dst..sp + dst + 4, ]; let a = u32::from_ne_bytes(self.stack.get_unchecked(a).try_into().unwrap()); let b = u32::from_ne_bytes(self.stack.get_unchecked(b).try_into().unwrap()); self.stack .get_unchecked_mut(dst) .copy_from_slice(&a.wrapping_add(b).to_ne_bytes()); } OpCode::Add_LC_U32(a, b, dst) => { let [a, dst] = [sp + a..sp + a + 4, sp + dst..sp + dst + 4]; let a = u32::from_ne_bytes(self.stack.get_unchecked(a).try_into().unwrap()); self.stack .get_unchecked_mut(dst) .copy_from_slice(&a.wrapping_add(*b).to_ne_bytes()); } OpCode::CmpLt_LC_U32(a, b, dst) => { let a = sp + a..sp + a + 4; let a = u32::from_ne_bytes(self.stack.get_unchecked(a).try_into().unwrap()); *self.stack.get_unchecked_mut(*dst) = (a < *b) as u8; } OpCode::Jump(target) => { i = *target; continue; } OpCode::Branch { cond, true_case, false_case, } => { i = match self.stack().get_unchecked(*cond) { 0 => *false_case, _ => *true_case, }; continue; } OpCode::Call(func, args, out) => { let key = *func as *const Function<'l>; match self.native_funcs.get(&key).cloned() { Some(native_func) => { let res = native_func(self, args).unwrap(); self.write_any_val(out, &res); } None => { let func = self.instructions.get(func); let res = self.call(sp, args, &func); self.stack.copy_within(res, out.start); } } } OpCode::ReturnRange(range) => { return sp + range.start..sp + range.end; } _ => todo!("Unimplemented opcode `{opcode:?}`"), } i += 1; } } unreachable!("Execution has produced no results"); } pub fn stack(&self) -> &[u8] { &self.stack } fn push_ctx_val(&mut self, sp: &mut usize, value: &Value) -> Range { let ptr = &self.stack[*sp] as *const u8; match value { Value::Const(Const::Int(Int::U32(v), _)) => { *sp += ptr.align_offset(align_of::()); let range = *sp..*sp + size_of::(); self.stack[range.clone()].copy_from_slice(&v.to_ne_bytes()); range } _ => todo!("Unsupported type `{}`", value.ty()), } } fn write_any_val(&mut self, range: &Range, value: &AnyValue) { match value { AnyValue::Void => assert_eq!(range.len(), 0), AnyValue::Int(Int::U32(v)) => { self.stack[range.clone()].copy_from_slice(&v.to_ne_bytes()); } _ => todo!("Unsupported value `{:?}`", value), } } }