diff --git a/.gitignore b/.gitignore index eb8f31e..ed6a96a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ target/ test.leaf out.asm a.out +.zed/ diff --git a/Cargo.lock b/Cargo.lock index c87227f..a1695ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,6 +98,12 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -124,6 +130,22 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "inkwell" version = "0.8.0" @@ -180,6 +202,7 @@ dependencies = [ "derive_more", "fxhash", "half", + "indexmap", "leaf_allocators", "scc", ] @@ -216,6 +239,7 @@ version = "0.1.0" dependencies = [ "arcstr", "derive_more", + "indexmap", "peg", ] diff --git a/assembly/Cargo.toml b/assembly/Cargo.toml index 25da544..9ef651a 100644 --- a/assembly/Cargo.toml +++ b/assembly/Cargo.toml @@ -9,6 +9,6 @@ boxcar = "0.2.14" derive_more = { version = "2.0.1", features = ["deref", "deref_mut", "debug", "display", "try_from", "from", "try_into", "into"] } half = "2.7.1" scc = "3.3.7" - leaf_allocators = { path = "../allocators" } fxhash = "0.2.1" +indexmap = "2.13.0" diff --git a/assembly/src/assembly.rs b/assembly/src/assembly.rs index 55cd382..ce1ad3c 100644 --- a/assembly/src/assembly.rs +++ b/assembly/src/assembly.rs @@ -1,6 +1,9 @@ use crate::{ functions::Function, - types::derivations::{FuncT, TypeDerivations}, + types::{ + compound::StructT, + derivations::{FuncT, TypeDerivations}, + }, }; use derive_more::{Debug, Display}; use fxhash::FxBuildHasher; @@ -76,6 +79,9 @@ impl<'l> Context<'l> { } pub fn intern_str(&'l self, str: &str) -> &'l str { + if str.is_empty() { + return ""; + } if let Some(value) = self.strings.get_sync(str) { return *value; } @@ -124,10 +130,18 @@ impl<'l> Assembly<'l> { &self.ident } - pub fn create_function(&'l self, ty: &'l FuncT<'l>) -> &'l Function<'l> { + pub fn create_struct(&'l self, name: &str) -> &'l StructT<'l> { + self.ctx().alloc.alloc(StructT { + name: self.ctx().intern_str(name), + fields: OnceLock::new(), + declaring_assembly: self, + }) + } + + pub fn create_function(&'l self, ty: &'l FuncT<'l>, name: &str) -> &'l Function<'l> { let func = self.ctx.alloc.alloc(Function { ty, - name: OnceLock::new(), + name: self.ctx().intern_str(name), body: OnceLock::new(), declaring_assembly: self, }); diff --git a/assembly/src/functions/ir.rs b/assembly/src/functions/ir.rs index 6c543d4..e248424 100644 --- a/assembly/src/functions/ir.rs +++ b/assembly/src/functions/ir.rs @@ -1,8 +1,8 @@ use crate::{ assembly::Ctx, functions::{Function, FunctionBody}, - types::{IntT, Type, derivations::*}, - values::{Value, ValueFlags}, + types::{IntT, Type, compound::StructT, derivations::*}, + values::{AnyConst, AnyValue, Value, ValueFlags, default_associated_values}, }; use derive_more::{Debug, Display}; use std::{borrow::Cow, cell::UnsafeCell, hash::Hash, ops::Deref, sync::OnceLock}; @@ -56,19 +56,10 @@ impl<'l> Instruction<'l> { pub fn id(&self) -> u32 { unsafe { *self.id.0.get() } } +} - pub fn value_flags(&self) -> ValueFlags { - match self.variant { - InstructionVariant::StackAlloc(_) => ValueFlags::LValue, - InstructionVariant::GetElementPtr(v, _) if v.flags().contains(ValueFlags::LValue) => { - ValueFlags::LValue - } - InstructionVariant::Reinterpret(_, _, f) => f, - _ => ValueFlags::empty(), - } - } - - pub fn value_ty(&'l self) -> Type<'l> { +impl<'l> Value<'l> for &'l Instruction<'l> { + fn ty(&self) -> Type<'l> { match self.variant { InstructionVariant::Return(_) => Type::Void, InstructionVariant::Store(_, _) => Type::Void, @@ -88,6 +79,15 @@ impl<'l> Instruction<'l> { }) => base.make_ref(*mutable).into(), _ => unreachable!(), }, + InstructionVariant::GetElementVal(v, i) => match v.ty() { + Type::Struct(StructT { fields, .. }) => { + let AnyValue::Constant(AnyConst::Str(name)) = i else { + unreachable!() + }; + fields.get().unwrap()[name].ty + } + _ => unreachable!(), + }, InstructionVariant::IAdd(a, _) => a.ty(), InstructionVariant::ISub(a, _) => a.ty(), InstructionVariant::IMul(a, _) => a.ty(), @@ -96,12 +96,15 @@ impl<'l> Instruction<'l> { InstructionVariant::SExt(_, t) => Type::Int(t), InstructionVariant::ZExt(_, t) => Type::Int(t), InstructionVariant::Trunc(_, t) => Type::Int(t), + InstructionVariant::IntToPtr(_, t) => Type::Ptr(t), + InstructionVariant::PtrToInt(_, t) => Type::Int(t), InstructionVariant::FAdd(a, _) => a.ty(), InstructionVariant::FSub(a, _) => a.ty(), InstructionVariant::FMul(a, _) => a.ty(), InstructionVariant::FDiv(a, _) => a.ty(), InstructionVariant::FMod(a, _) => a.ty(), InstructionVariant::ICmp(_, _, _) => Type::Bool, + InstructionVariant::MakeStruct(t, _) => Type::Struct(t), InstructionVariant::Call(f, _) => f.ty.ret_t, InstructionVariant::Jump(_) => Type::Void, InstructionVariant::Branch { .. } => Type::Void, @@ -109,6 +112,23 @@ impl<'l> Instruction<'l> { _ => todo!("{self:?}"), } } + + fn flags(&self) -> ValueFlags { + match self.variant { + InstructionVariant::StackAlloc(_) => ValueFlags::LValue, + InstructionVariant::GetElementPtr(v, _) if v.is_lvalue() => ValueFlags::LValue, + InstructionVariant::Reinterpret(_, _, f) => f, + _ => ValueFlags::empty(), + } + } + + fn get_associated_value(&self, name: &str) -> Option> { + default_associated_values(self, name) + } + + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Instruction(self) + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -126,37 +146,42 @@ pub enum InstructionVariant<'l> { StackAlloc(Type<'l>), GCAlloc(Type<'l>), - Load(Value<'l>), - Store(Value<'l>, Value<'l>), - GetElementPtr(Value<'l>, Value<'l>), + Load(AnyValue<'l>), + Store(AnyValue<'l>, AnyValue<'l>), + GetElementVal(AnyValue<'l>, AnyValue<'l>), + GetElementPtr(AnyValue<'l>, AnyValue<'l>), - IAdd(Value<'l>, Value<'l>), - ISub(Value<'l>, Value<'l>), - IMul(Value<'l>, Value<'l>), - IDiv(Value<'l>, Value<'l>), - IMod(Value<'l>, Value<'l>), - SExt(Value<'l>, IntT), - ZExt(Value<'l>, IntT), - Trunc(Value<'l>, IntT), - FAdd(Value<'l>, Value<'l>), - FSub(Value<'l>, Value<'l>), - FMul(Value<'l>, Value<'l>), - FDiv(Value<'l>, Value<'l>), - FMod(Value<'l>, Value<'l>), + IAdd(AnyValue<'l>, AnyValue<'l>), + ISub(AnyValue<'l>, AnyValue<'l>), + IMul(AnyValue<'l>, AnyValue<'l>), + IDiv(AnyValue<'l>, AnyValue<'l>), + IMod(AnyValue<'l>, AnyValue<'l>), + SExt(AnyValue<'l>, IntT), + ZExt(AnyValue<'l>, IntT), + Trunc(AnyValue<'l>, IntT), + IntToPtr(AnyValue<'l>, &'l PtrT<'l>), + PtrToInt(AnyValue<'l>, IntT), + FAdd(AnyValue<'l>, AnyValue<'l>), + FSub(AnyValue<'l>, AnyValue<'l>), + FMul(AnyValue<'l>, AnyValue<'l>), + FDiv(AnyValue<'l>, AnyValue<'l>), + FMod(AnyValue<'l>, AnyValue<'l>), - ICmp(Value<'l>, Value<'l>, Cmp), - FCmp(Value<'l>, Value<'l>, Cmp), + ICmp(AnyValue<'l>, AnyValue<'l>, Cmp), + FCmp(AnyValue<'l>, AnyValue<'l>, Cmp), - Call(&'l Function<'l>, Vec>), + MakeStruct(&'l StructT<'l>, &'l [AnyValue<'l>]), + + Call(&'l Function<'l>, Vec>), Jump(&'l Block<'l>), Branch { - cond: Value<'l>, + cond: AnyValue<'l>, true_case: &'l Block<'l>, false_case: &'l Block<'l>, }, - Return(Option>), + Return(Option>), - Reinterpret(Value<'l>, Type<'l>, ValueFlags), + Reinterpret(AnyValue<'l>, Type<'l>, ValueFlags), } impl InstructionVariant<'_> { @@ -172,7 +197,10 @@ impl std::fmt::Debug for InstructionVariant<'_> { Self::GCAlloc(ty) => write!(f, "gcalloc {ty}"), Self::Load(v) => write!(f, "load {v}"), Self::Store(t, v) => write!(f, "store {t}, {v}"), - Self::GetElementPtr(t, v) => write!(f, "gep {t}, {v}"), + Self::GetElementVal(v, i) => write!(f, "gev {v}, {i}"), + Self::GetElementPtr(v, i) => write!(f, "gep {v}, {i}"), + Self::IntToPtr(v, t) => write!(f, "itp {v}, {t}"), + Self::PtrToInt(v, t) => write!(f, "pti {v}, {t}"), Self::IAdd(a, b) => write!(f, "iadd {a}, {b}"), Self::ISub(a, b) => write!(f, "isub {a}, {b}"), Self::IMul(a, b) => write!(f, "imul {a}, {b}"), @@ -188,11 +216,32 @@ impl std::fmt::Debug for InstructionVariant<'_> { Self::FMod(a, b) => write!(f, "fmod {a}, {b}"), Self::ICmp(a, b, c) => write!(f, "icmp {c:?} {a}, {b}"), Self::FCmp(a, b, c) => write!(f, "fcmp {c:?} {a}, {b}"), + + Self::MakeStruct(ty, vals) => { + write!( + f, + "struct {} {{", + match ty.name { + "" => "", + _ => ty.name, + } + )?; + let mut separator = ""; + for val in *vals { + write!(f, "{separator}{val}")?; + separator = ", "; + } + write!(f, "}}") + } + Self::Call(func, args) => { write!( f, "call {}(", - func.name.get().unwrap_or(&"") + match func.name { + "" => "", + _ => func.name, + } )?; let mut separator = ""; for arg in args { @@ -247,16 +296,16 @@ pub type BlockBuilderError<'l> = Cow<'l, str>; pub type BlockBuilderResult<'l, T> = Result>; impl<'l> BlockBuilder<'l> { - pub fn stack_alloc(&mut self, ty: Type<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn stack_alloc(&mut self, ty: Type<'l>) -> BlockBuilderResult<'l, AnyValue<'l>> { let inst = self.push_instruction(InstructionVariant::StackAlloc(ty))?; Ok(inst.into()) } pub fn store( &mut self, - target: Value<'l>, - value: Value<'l>, - ) -> BlockBuilderResult<'l, Value<'l>> { + target: AnyValue<'l>, + value: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let value_ty = value.ty(); let target_ty = target.ty(); match target_ty { @@ -281,7 +330,7 @@ impl<'l> BlockBuilder<'l> { Ok(inst.into()) } - pub fn load(&mut self, value: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn load(&mut self, value: AnyValue<'l>) -> BlockBuilderResult<'l, AnyValue<'l>> { let value_ty = value.ty(); match value_ty { Type::Ptr(PtrT { .. }) => {} @@ -294,7 +343,11 @@ impl<'l> BlockBuilder<'l> { Ok(inst.into()) } - pub fn add(&mut self, a: Value<'l>, b: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn add( + &mut self, + a: AnyValue<'l>, + b: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let [a_ty, b_ty] = [a.ty(), b.ty()]; match (a_ty, b_ty) { @@ -310,7 +363,11 @@ impl<'l> BlockBuilder<'l> { } } - pub fn sub(&mut self, a: Value<'l>, b: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn sub( + &mut self, + a: AnyValue<'l>, + b: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let [a_ty, b_ty] = [a.ty(), b.ty()]; match (a_ty, b_ty) { @@ -326,7 +383,11 @@ impl<'l> BlockBuilder<'l> { } } - pub fn mul(&mut self, a: Value<'l>, b: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn mul( + &mut self, + a: AnyValue<'l>, + b: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let [a_ty, b_ty] = [a.ty(), b.ty()]; match (a_ty, b_ty) { @@ -342,7 +403,11 @@ impl<'l> BlockBuilder<'l> { } } - pub fn div(&mut self, a: Value<'l>, b: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn div( + &mut self, + a: AnyValue<'l>, + b: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let [a_ty, b_ty] = [a.ty(), b.ty()]; match (a_ty, b_ty) { @@ -358,7 +423,11 @@ impl<'l> BlockBuilder<'l> { } } - pub fn modulo(&mut self, a: Value<'l>, b: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn modulo( + &mut self, + a: AnyValue<'l>, + b: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let [a_ty, b_ty] = [a.ty(), b.ty()]; match (a_ty, b_ty) { @@ -374,7 +443,11 @@ impl<'l> BlockBuilder<'l> { } } - pub fn trunc(&mut self, val: Value<'l>, target: IntT) -> BlockBuilderResult<'l, Value<'l>> { + pub fn trunc( + &mut self, + val: AnyValue<'l>, + target: IntT, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let ty = val.ty(); match ty { Type::Int(a_ty) if a_ty.precision > target.precision => { @@ -387,12 +460,29 @@ impl<'l> BlockBuilder<'l> { } } + pub fn int_to_ptr( + &mut self, + val: AnyValue<'l>, + target: &'l PtrT<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { + let ty = val.ty(); + match ty { + Type::Int(_) => { + let inst = self.push_instruction(InstructionVariant::IntToPtr(val, target))?; + Ok(inst.into()) + } + _ => Err( + format!("Cannot convert value of type `{ty}` to pointer type `{target}`.").into(), + ), + } + } + pub fn cmp( &mut self, - a: Value<'l>, - b: Value<'l>, + a: AnyValue<'l>, + b: AnyValue<'l>, cmp: Cmp, - ) -> BlockBuilderResult<'l, Value<'l>> { + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let [a_ty, b_ty] = [a.ty(), b.ty()]; match (a_ty, b_ty) { @@ -408,7 +498,11 @@ impl<'l> BlockBuilder<'l> { } } - pub fn gep(&mut self, value: Value<'l>, index: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn get_element_ptr( + &mut self, + value: AnyValue<'l>, + index: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let v_ty = value.ty(); let i_ty = index.ty(); @@ -432,7 +526,62 @@ impl<'l> BlockBuilder<'l> { Ok(inst.into()) } - pub fn jump(&mut self, block: &'l Block<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn get_element_value( + &mut self, + value: AnyValue<'l>, + index: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { + let v_ty = value.ty(); + + match v_ty { + Type::Struct(StructT { fields, .. }) => 'val: { + let AnyValue::Constant(AnyConst::Str(field)) = index else { + return Err( + format!("Expeted index type `const str`, found {}.", index.ty()).into(), + ); + }; + if let Some(fields) = fields.get() { + if fields.contains_key(field) { + let inst = + self.push_instruction(InstructionVariant::GetElementVal(value, index))?; + break 'val Ok(inst.into()); + } + } + Err(format!("Struct does not contain field `{field}`.").into()) + } + _ => Err(format!("Cannot index a value of type `{}`.", v_ty).into()), + } + } + + pub fn make_struct( + &mut self, + struct_ty: &'l StructT<'l>, + values: &'l [AnyValue<'l>], + ) -> BlockBuilderResult<'l, AnyValue<'l>> { + let fields = struct_ty.fields.get().unwrap(); + if fields.len() != values.len() { + return Err( + format!("Expected {} values, found {}.", fields.len(), values.len()).into(), + ); + } + if let Some((i, (a, b))) = fields + .values() + .zip(values) + .enumerate() + .find(|(_, (a, b))| a.ty != b.ty()) + { + return Err(format!( + "Invalid valua at position {i}. Expected type `{}`, found `{}`.", + a.ty, + b.ty() + ) + .into()); + } + let inst = self.push_instruction(InstructionVariant::MakeStruct(struct_ty, values))?; + return Ok(inst.into()); + } + + pub fn jump(&mut self, block: &'l Block<'l>) -> BlockBuilderResult<'l, AnyValue<'l>> { if !std::ptr::eq(block.func, self.block.func) { return Err("Block does not belong to this function.".into()); } @@ -442,10 +591,10 @@ impl<'l> BlockBuilder<'l> { pub fn branch( &mut self, - cond: Value<'l>, + cond: AnyValue<'l>, true_case: &'l Block<'l>, false_case: &'l Block<'l>, - ) -> BlockBuilderResult<'l, Value<'l>> { + ) -> BlockBuilderResult<'l, AnyValue<'l>> { if !std::ptr::eq(true_case.func, self.block.func) { return Err("Block does not belong to this function.".into()); } @@ -466,8 +615,8 @@ impl<'l> BlockBuilder<'l> { pub fn call( &mut self, func: &'l Function<'l>, - args: Vec>, - ) -> BlockBuilderResult<'l, Value<'l>> { + args: Vec>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let par_t = &*func.ty.par_t; if par_t.len() != args.len() { return Err(format!( @@ -491,7 +640,7 @@ impl<'l> BlockBuilder<'l> { Ok(inst.into()) } - pub fn ret(&mut self, value: Option>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn ret(&mut self, value: Option>) -> BlockBuilderResult<'l, AnyValue<'l>> { let ret_t = self.block.func.ty.ret_t; let value_ty = match value { Some(v) => v.ty(), @@ -514,10 +663,10 @@ impl<'l> BlockBuilder<'l> { /// The target reinterpretation must maintain the IR's invariants. pub unsafe fn reinterpret( &mut self, - value: Value<'l>, + value: AnyValue<'l>, ty: Type<'l>, flags: ValueFlags, - ) -> BlockBuilderResult<'l, Value<'l>> { + ) -> BlockBuilderResult<'l, AnyValue<'l>> { let inst = self.push_instruction(InstructionVariant::Reinterpret(value, ty, flags))?; Ok(inst.into()) } @@ -620,81 +769,133 @@ impl<'l> FunctionBodyBuilder<'l> { Ok(self.func.body.get().unwrap()) } - pub fn stack_alloc(&mut self, ty: Type<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn stack_alloc(&mut self, ty: Type<'l>) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().stack_alloc(ty) } pub fn store( &mut self, - target: Value<'l>, - value: Value<'l>, - ) -> BlockBuilderResult<'l, Value<'l>> { + target: AnyValue<'l>, + value: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().store(target, value) } - pub fn load(&mut self, value: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn load(&mut self, value: AnyValue<'l>) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().load(value) } - pub fn add(&mut self, a: Value<'l>, b: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn add( + &mut self, + a: AnyValue<'l>, + b: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().add(a, b) } - pub fn sub(&mut self, a: Value<'l>, b: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn sub( + &mut self, + a: AnyValue<'l>, + b: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().sub(a, b) } - pub fn mul(&mut self, a: Value<'l>, b: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn mul( + &mut self, + a: AnyValue<'l>, + b: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().mul(a, b) } - pub fn div(&mut self, a: Value<'l>, b: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn div( + &mut self, + a: AnyValue<'l>, + b: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().div(a, b) } - pub fn modulo(&mut self, a: Value<'l>, b: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn modulo( + &mut self, + a: AnyValue<'l>, + b: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().modulo(a, b) } - pub fn trunc(&mut self, val: Value<'l>, target: IntT) -> BlockBuilderResult<'l, Value<'l>> { + pub fn trunc( + &mut self, + val: AnyValue<'l>, + target: IntT, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().trunc(val, target) } + pub fn int_to_ptr( + &mut self, + val: AnyValue<'l>, + target: &'l PtrT<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { + self.current_builder().int_to_ptr(val, target) + } + pub fn cmp( &mut self, - a: Value<'l>, - b: Value<'l>, + a: AnyValue<'l>, + b: AnyValue<'l>, cmp: Cmp, - ) -> BlockBuilderResult<'l, Value<'l>> { + ) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().cmp(a, b, cmp) } - pub fn gep(&mut self, value: Value<'l>, index: Value<'l>) -> BlockBuilderResult<'l, Value<'l>> { - self.current_builder().gep(value, index) + pub fn get_element_ptr( + &mut self, + value: AnyValue<'l>, + index: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { + self.current_builder().get_element_ptr(value, index) } - pub fn jump(&mut self, block: &'l Block<'l>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn get_element_value( + &mut self, + value: AnyValue<'l>, + index: AnyValue<'l>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { + self.current_builder().get_element_value(value, index) + } + + pub fn make_struct( + &mut self, + struct_ty: &'l StructT<'l>, + values: &'l [AnyValue<'l>], + ) -> BlockBuilderResult<'l, AnyValue<'l>> { + self.current_builder().make_struct(struct_ty, values) + } + + pub fn jump(&mut self, block: &'l Block<'l>) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().jump(block) } pub fn branch( &mut self, - cond: Value<'l>, + cond: AnyValue<'l>, true_case: &'l Block<'l>, false_case: &'l Block<'l>, - ) -> BlockBuilderResult<'l, Value<'l>> { + ) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().branch(cond, true_case, false_case) } pub fn call( &mut self, func: &'l Function<'l>, - args: Vec>, - ) -> BlockBuilderResult<'l, Value<'l>> { + args: Vec>, + ) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().call(func, args) } - pub fn ret(&mut self, value: Option>) -> BlockBuilderResult<'l, Value<'l>> { + pub fn ret(&mut self, value: Option>) -> BlockBuilderResult<'l, AnyValue<'l>> { self.current_builder().ret(value) } @@ -705,10 +906,10 @@ impl<'l> FunctionBodyBuilder<'l> { /// The target reinterpretation must maintain the IR's invariants. pub unsafe fn reinterpret( &mut self, - value: Value<'l>, + value: AnyValue<'l>, ty: Type<'l>, flags: ValueFlags, - ) -> BlockBuilderResult<'l, Value<'l>> { + ) -> BlockBuilderResult<'l, AnyValue<'l>> { unsafe { self.current_builder().reinterpret(value, ty, flags) } } diff --git a/assembly/src/functions/mod.rs b/assembly/src/functions/mod.rs index 94adff6..71cae3d 100644 --- a/assembly/src/functions/mod.rs +++ b/assembly/src/functions/mod.rs @@ -1,7 +1,8 @@ use crate::{ assembly::{Assembly, Ctx}, functions::ir::{Block, FunctionBodyBuilder}, - types::derivations::FuncT, + types::{Type, derivations::FuncT}, + values::{AnyConst, AnyValue, Value, ValueFlags, default_associated_values}, }; use std::{ fmt::{Debug as FmtDebug, Display}, @@ -14,7 +15,7 @@ pub mod ir; #[non_exhaustive] pub struct Function<'l> { pub ty: &'l FuncT<'l>, - pub name: OnceLock<&'l str>, + pub name: &'l str, pub(crate) body: OnceLock>, pub(crate) declaring_assembly: &'l Assembly<'l>, } @@ -40,6 +41,28 @@ impl<'l> Function<'l> { } } +impl<'l> Value<'l> for &'l Function<'l> { + #[inline] + fn ty(&self) -> Type<'l> { + Type::Func(self.ty) + } + + #[inline] + fn flags(&self) -> ValueFlags { + ValueFlags::Function + } + + #[inline] + fn get_associated_value(&self, name: &str) -> Option> { + default_associated_values(self, name) + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Function(self)) + } +} + impl Eq for Function<'_> {} impl PartialEq for Function<'_> { @@ -50,9 +73,9 @@ impl PartialEq for Function<'_> { impl Display for Function<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.name.get() { - Some(n) => write!(f, "{} @ {}", self.ty, n), - None => Display::fmt(&self.ty, f), + match self.name { + "" => Display::fmt(&self.ty, f), + _ => write!(f, "{} @ {}", self.ty, self.name), } } } @@ -66,7 +89,7 @@ impl FmtDebug for Function<'_> { f.debug_struct("Function") .field("ty", &format_args!("{}", self.ty)) - .field("name", &self.name.get()) + .field("name", &self.name) .field("body", body) .finish_non_exhaustive() } diff --git a/assembly/src/types/compound.rs b/assembly/src/types/compound.rs new file mode 100644 index 0000000..ac2b3f3 --- /dev/null +++ b/assembly/src/types/compound.rs @@ -0,0 +1,71 @@ +use crate::{ + assembly::Assembly, + types::Type, + values::{AnyConst, AnyValue, Value, ValueFlags}, +}; +use derive_more::{Debug, Display}; +use fxhash::FxBuildHasher; +use indexmap::IndexMap; +use std::{hash::Hash, sync::OnceLock}; + +pub type FieldMap<'l> = IndexMap<&'l str, Field<'l>, FxBuildHasher>; + +#[derive(Debug, Display, Clone)] +#[display("{}", if name.is_empty() { "" } else { name })] +pub struct StructT<'l> { + pub name: &'l str, + pub fields: OnceLock>, + #[debug(ignore)] + pub declaring_assembly: &'l Assembly<'l>, +} + +impl Eq for StructT<'_> {} + +impl PartialEq for StructT<'_> { + fn eq(&self, other: &Self) -> bool { + std::ptr::eq(self, other) + } +} + +impl Hash for StructT<'_> { + fn hash(&self, state: &mut H) { + std::ptr::hash(self, state); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Field<'l> { + pub name: &'l str, + pub ty: Type<'l>, + pub public: bool, + pub mutable: bool, +} + +impl<'l> Value<'l> for &'l StructT<'l> { + fn ty(&self) -> Type<'l> { + Type::Type + } + + fn flags(&self) -> ValueFlags { + ValueFlags::Type + } + + fn get_associated_value(&self, name: &str) -> Option> { + todo!() + } + + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Type(Type::Struct(self))) + } +} + +struct DebugFields<'l, 'r>(&'r OnceLock>); + +impl std::fmt::Debug for DebugFields<'_, '_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0.get() { + Some(v) => std::fmt::Debug::fmt(v, f), + None => f.write_str("{}"), + } + } +} diff --git a/assembly/src/types/derivations.rs b/assembly/src/types/derivations.rs index 4b2013b..6d0fb00 100644 --- a/assembly/src/types/derivations.rs +++ b/assembly/src/types/derivations.rs @@ -61,13 +61,12 @@ impl<'l> TypeDerivations<'l> { .fun_t .entry_sync((ret_t, par_t.clone())) .or_insert_with(|| { - let ctx = ret_t.ctx(); - for ty in par_t.iter() { - if ty.ctx() != ctx { - panic!("All types must share the same context."); - } - } - (&*self.alloc.alloc(FuncT { ret_t, par_t })).into() + (&*self.alloc.alloc(FuncT { + ret_t, + par_t, + par_t_const: OnceLock::new(), + })) + .into() }) else { unreachable!() @@ -99,6 +98,32 @@ pub struct PtrT<'l> { pub mutable: bool, } +impl<'l> Value<'l> for &'l PtrT<'l> { + #[inline] + fn ty(&self) -> Type<'l> { + Type::Type + } + + #[inline] + fn flags(&self) -> ValueFlags { + ValueFlags::Type + } + + #[inline] + fn get_associated_value(&self, name: &str) -> Option> { + match name { + "base" => Some(self.base.as_any_value()), + "mutable" => Some(self.mutable.as_any_value()), + _ => default_associated_values(self, name), + } + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Type(Type::Ptr(self))) + } +} + #[non_exhaustive] #[derive(Debug, Display, Clone, Copy, PartialEq, Eq, Hash)] #[display("&{}{}", if *mutable { "mut " } else { "" }, *base)] @@ -108,6 +133,32 @@ pub struct RefT<'l> { pub mutable: bool, } +impl<'l> Value<'l> for &'l RefT<'l> { + #[inline] + fn ty(&self) -> Type<'l> { + Type::Type + } + + #[inline] + fn flags(&self) -> ValueFlags { + ValueFlags::Type + } + + #[inline] + fn get_associated_value(&self, name: &str) -> Option> { + match name { + "base" => Some(self.base.as_any_value()), + "mutable" => Some(self.mutable.as_any_value()), + _ => default_associated_values(self, name), + } + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Type(Type::Ref(self))) + } +} + #[non_exhaustive] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct ArrayT<'l> { @@ -116,6 +167,35 @@ pub struct ArrayT<'l> { pub length: Option, } +impl<'l> Value<'l> for &'l ArrayT<'l> { + #[inline] + fn ty(&self) -> Type<'l> { + Type::Type + } + + #[inline] + fn flags(&self) -> ValueFlags { + ValueFlags::Type + } + + #[inline] + fn get_associated_value(&self, name: &str) -> Option> { + match name { + "base" => Some(self.base.as_any_value()), + "length" => match self.length { + None => None, + Some(len) => Some(len.as_any_value()), + }, + _ => default_associated_values(self, name), + } + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Type(Type::Array(self))) + } +} + impl Display for ArrayT<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.length { @@ -126,10 +206,55 @@ impl Display for ArrayT<'_> { } #[non_exhaustive] -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct FuncT<'l> { pub ret_t: Type<'l>, pub par_t: Arc<[Type<'l>]>, + par_t_const: OnceLock>>, +} + +impl Eq for FuncT<'_> {} + +impl PartialEq for FuncT<'_> { + fn eq(&self, other: &Self) -> bool { + self.ret_t == other.ret_t && self.par_t == other.par_t + } +} + +impl Hash for FuncT<'_> { + fn hash(&self, state: &mut H) { + self.ret_t.hash(state); + self.par_t.hash(state); + } +} + +impl<'l> Value<'l> for &'l FuncT<'l> { + #[inline] + fn ty(&self) -> Type<'l> { + Type::Type + } + + #[inline] + fn flags(&self) -> ValueFlags { + ValueFlags::Type + } + + #[inline] + fn get_associated_value(&self, name: &str) -> Option> { + match name { + "ret_t" => Some(self.ret_t.as_any_value()), + "par_t" => Some(AnyValue::Constant(AnyConst::Array( + self.par_t_const + .get_or_init(|| self.par_t.iter().map(|t| AnyConst::Type(*t)).collect()), + ))), + _ => default_associated_values(self, name), + } + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Type(Type::Func(self))) + } } impl Display for FuncT<'_> { diff --git a/assembly/src/types/mod.rs b/assembly/src/types/mod.rs index aa7a63a..e61f94c 100644 --- a/assembly/src/types/mod.rs +++ b/assembly/src/types/mod.rs @@ -1,8 +1,13 @@ -use crate::{assembly::Context, types::derivations::*}; +use crate::{ + assembly::Context, + types::{compound::*, derivations::*}, + values::{AnyConst, AnyValue, Value, ValueFlags, default_associated_values}, +}; use derive_more::{Debug, Display, From, TryInto}; use leaf_allocators::SyncArenaAllocator; use std::{fmt::Display, sync::OnceLock}; +pub mod compound; pub mod derivations; #[non_exhaustive] @@ -18,6 +23,32 @@ pub struct IntT { pub precision: u32, } +impl<'l> Value<'l> for IntT { + #[inline] + fn ty(&self) -> Type<'l> { + Type::Type + } + + #[inline] + fn flags(&self) -> ValueFlags { + ValueFlags::Type + } + + #[inline] + fn get_associated_value(&self, name: &str) -> Option> { + match name { + "signed" => Some(self.signed.as_any_value()), + "precision" => Some(self.precision.as_any_value()), + _ => default_associated_values(self, name), + } + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Type(Type::Int(*self))) + } +} + #[non_exhaustive] #[derive(Debug, Display, Clone, Copy, PartialEq, Eq, Hash)] #[display("f{}", precision)] @@ -26,6 +57,31 @@ pub struct FloatT { pub precision: u32, } +impl<'l> Value<'l> for FloatT { + #[inline] + fn ty(&self) -> Type<'l> { + Type::Type + } + + #[inline] + fn flags(&self) -> ValueFlags { + ValueFlags::Type + } + + #[inline] + fn get_associated_value(&self, name: &str) -> Option> { + match name { + "precision" => Some(self.precision.as_any_value()), + _ => default_associated_values(self, name), + } + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Type(Type::Float(*self))) + } +} + #[non_exhaustive] #[derive(Debug, Display, Clone, Copy, From, TryInto, PartialEq, Eq, Hash)] pub enum Type<'l> { @@ -58,6 +114,10 @@ pub enum Type<'l> { #[debug("{_0:?}")] #[display("{_0}")] Func(&'l FuncT<'l>), + + #[debug("{_0:?}")] + #[display("{_0}")] + Struct(&'l StructT<'l>), } impl<'l> Type<'l> { @@ -65,16 +125,7 @@ impl<'l> Type<'l> { pub fn ctx(&self) -> &'l Context<'l> { match self.non_default_ctx() { Some(ctx) => ctx, - None => unsafe { - static DEFAULT: OnceLock<&'static Context> = OnceLock::new(); - static ALLOCATOR: OnceLock = OnceLock::new(); - let ctx: &'static Context = DEFAULT.get_or_init(|| { - let allocator: &'static SyncArenaAllocator = - ALLOCATOR.get_or_init(SyncArenaAllocator::default); - Context::new(allocator) - }); - std::mem::transmute(ctx) - }, + None => Self::default_ctx(), } } @@ -94,8 +145,55 @@ impl<'l> Type<'l> { Some(ctx) => Some(ctx), None => f.par_t.iter().find_map(|t| t.non_default_ctx()), }, + Type::Struct(s) => Some(s.declaring_assembly.ctx()), } } + + fn default_ctx() -> &'l Context<'l> { + static ALLOCATOR: OnceLock = OnceLock::new(); + static DEFAULT: OnceLock<&'static Context> = OnceLock::new(); + let ctx: &'static Context = DEFAULT.get_or_init(|| { + let allocator: &'static SyncArenaAllocator = + ALLOCATOR.get_or_init(SyncArenaAllocator::default); + Context::new(allocator) + }); + unsafe { std::mem::transmute(ctx) } + } +} + +impl<'l> Value<'l> for Type<'l> { + #[inline] + fn ty(&self) -> Type<'l> { + Type::Type + } + + #[inline] + fn flags(&self) -> ValueFlags { + ValueFlags::Type + } + + #[inline] + fn get_associated_value(&self, name: &str) -> Option> { + match self { + Type::Void => default_associated_values(self, name), + Type::Char => default_associated_values(self, name), + Type::Bool => default_associated_values(self, name), + Type::Type => default_associated_values(self, name), + Type::ConstStr => default_associated_values(self, name), + Type::Int(t) => t.get_associated_value(name), + Type::Float(t) => t.get_associated_value(name), + Type::Ptr(t) => t.get_associated_value(name), + Type::Ref(t) => t.get_associated_value(name), + Type::Array(t) => t.get_associated_value(name), + Type::Func(t) => t.get_associated_value(name), + Type::Struct(t) => t.get_associated_value(name), + } + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Type(*self)) + } } #[rustfmt::skip] diff --git a/assembly/src/values/constants.rs b/assembly/src/values/constants.rs index af36bb6..bcb9d82 100644 --- a/assembly/src/values/constants.rs +++ b/assembly/src/values/constants.rs @@ -1,7 +1,7 @@ use crate::{ functions::Function, - types::{Type, derivations::ArrayT}, - values::ValueFlags, + types::{Type, compound::StructT, derivations::ArrayT}, + values::{AnyValue, Value, ValueFlags, default_associated_values}, }; use derive_more::{Debug, *}; use half::f16; @@ -12,18 +12,66 @@ pub enum Int { I8(i8), I16(i16), I32(i32), - I64(i64), #[from] + I64(i64), I128(i128), - ISize(i128), + ISize(i64), U8(u8), U16(u16), U32(u32), - U64(u64), #[from] + U64(u64), U128(u128), - USize(u128), + USize(u64), +} + +impl<'l> Value<'l> for Int { + #[inline] + fn ty(&self) -> Type<'l> { + match self { + Int::I8(v) => v.ty(), + Int::I16(v) => v.ty(), + Int::I32(v) => v.ty(), + Int::I64(v) => v.ty(), + Int::I128(v) => v.ty(), + Int::ISize(_) => Type::ISIZE, + Int::U8(v) => v.ty(), + Int::U16(v) => v.ty(), + Int::U32(v) => v.ty(), + Int::U64(v) => v.ty(), + Int::U128(v) => v.ty(), + Int::USize(_) => Type::USIZE, + } + } + + #[inline] + fn flags(&self) -> ValueFlags { + default_value_flags(self) + } + + #[inline] + fn get_associated_value(&self, name: &str) -> Option> { + match self { + Int::I8(v) => v.get_associated_value(name), + Int::I16(v) => v.get_associated_value(name), + Int::I32(v) => v.get_associated_value(name), + Int::I64(v) => v.get_associated_value(name), + Int::I128(v) => v.get_associated_value(name), + Int::ISize(v) => v.get_associated_value(name), + Int::U8(v) => v.get_associated_value(name), + Int::U16(v) => v.get_associated_value(name), + Int::U32(v) => v.get_associated_value(name), + Int::U64(v) => v.get_associated_value(name), + Int::U128(v) => v.get_associated_value(name), + Int::USize(v) => v.get_associated_value(name), + } + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Int(*self)) + } } #[derive(Debug, Display, Clone, Copy, From, TryInto, PartialEq)] @@ -33,6 +81,33 @@ pub enum Float { F64(f64), } +impl<'l> Value<'l> for Float { + #[inline] + fn ty(&self) -> Type<'l> { + match self { + Float::F16(v) => v.ty(), + Float::F32(v) => v.ty(), + Float::F64(v) => v.ty(), + _ => unreachable!(), + } + } + + #[inline] + fn flags(&self) -> ValueFlags { + ValueFlags::Const + } + + #[inline] + fn get_associated_value(&self, _name: &str) -> Option> { + None + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + AnyValue::Constant(AnyConst::Float(*self)) + } +} + impl Eq for Float {} impl Hash for Float { @@ -45,20 +120,6 @@ impl Hash for Float { } } -#[derive(Debug, Display, Deref, DerefMut, Clone, Copy, From, PartialEq, Eq, Hash)] -#[debug("{:?}", _0)] -#[display("{}", _0)] -pub struct Const(T); - -impl<'l, T: 'l> From<&'l Const> for AnyConst<'l> -where - for<'a> &'a T: Into>, -{ - fn from(val: &'l Const) -> Self { - (&val.0).into() - } -} - struct ListDisplay<'l>(&'l [AnyConst<'l>]); impl std::fmt::Display for ListDisplay<'_> { @@ -71,6 +132,19 @@ impl std::fmt::Display for ListDisplay<'_> { } } +struct StructDisplay<'l>(&'l StructT<'l>, &'l [AnyConst<'l>]); + +impl std::fmt::Display for StructDisplay<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} ", self.0)?; + let mut list = f.debug_list(); + for ele in self.1 { + list.entry(&format_args!("{ele}")); + } + list.finish() + } +} + #[derive(Debug, Display, Clone, Copy, From, TryInto, PartialEq, Eq, Hash)] #[from(forward)] pub enum AnyConst<'l> { @@ -98,45 +172,115 @@ pub enum AnyConst<'l> { #[debug("{:?}", _0)] #[display("{}", ListDisplay(_0))] Array(&'l [AnyConst<'l>]), + #[debug("{:?}", _0)] + #[display("{}", StructDisplay(_0, _1))] + Struct(&'l StructT<'l>, &'l [AnyConst<'l>]), Type(Type<'l>), } -impl<'l> AnyConst<'l> { - pub fn ty(&self) -> Type<'l> { +impl<'l> Value<'l> for AnyConst<'l> { + fn ty(&self) -> Type<'l> { match self { Self::Void => Type::Void, - Self::Bool(_) => Type::Bool, - Self::Char(_) => Type::Char, - Self::Type(_) => Type::Type, - Self::Int(Int::I8(_)) => Type::I8, - Self::Int(Int::I16(_)) => Type::I16, - Self::Int(Int::I32(_)) => Type::I32, - Self::Int(Int::I64(_)) => Type::I64, - Self::Int(Int::I128(_)) => Type::I128, - Self::Int(Int::ISize(_)) => Type::ISIZE, - Self::Int(Int::U8(_)) => Type::U8, - Self::Int(Int::U16(_)) => Type::U16, - Self::Int(Int::U32(_)) => Type::U32, - Self::Int(Int::U64(_)) => Type::U64, - Self::Int(Int::U128(_)) => Type::U128, - Self::Int(Int::USize(_)) => Type::USIZE, - Self::Float(Float::F16(_)) => Type::F16, - Self::Float(Float::F32(_)) => Type::F32, - Self::Float(Float::F64(_)) => Type::F64, + Self::Bool(v) => v.ty(), + Self::Char(v) => v.ty(), + Self::Type(v) => v.ty(), + Self::Int(v) => v.ty(), + Self::Float(v) => v.ty(), Self::Array([]) => Type::Array(&ArrayT { base: Type::Void, length: Some(0), }), Self::Array(a @ [v, ..]) => Type::Array(v.ty().make_array(Some(a.len() as u32))), + Self::Str(s) => s.ty(), + Self::Struct(t, _) => Type::Struct(t), _ => todo!("{self:?}"), } } - pub fn flags(&self) -> super::ValueFlags { + fn flags(&self) -> super::ValueFlags { match self { - AnyConst::Function(_) => ValueFlags::Function, + AnyConst::Function(f) => ValueFlags::Function, AnyConst::Type(_) => ValueFlags::Type, _ => ValueFlags::Const, } } + + fn get_associated_value(&self, name: &str) -> Option> { + match self { + AnyConst::Void => ().get_associated_value(name), + AnyConst::Bool(v) => v.get_associated_value(name), + AnyConst::Char(v) => v.get_associated_value(name), + AnyConst::Int(v) => v.get_associated_value(name), + AnyConst::Float(v) => v.get_associated_value(name), + AnyConst::Str(v) => v.get_associated_value(name), + AnyConst::Function(f) => f.get_associated_value(name), + AnyConst::Array(_) => default_associated_values(self, name), + AnyConst::Struct(t, _) => t.get_associated_value(name), + AnyConst::Type(t) => t.get_associated_value(name), + } + } + + fn as_any_value(&self) -> super::AnyValue<'l> { + AnyValue::Constant(*self) + } +} + +#[inline] +fn default_value_flags(_: &dyn Value) -> ValueFlags { + ValueFlags::Const +} + +macro_rules! impl_value { + ($( + $variant:ident : $ty:ty | $leaf_ty:expr => |$ident:ident| $expr:expr, + )*) => { + $( + impl<'l> Value<'l> for $ty { + #[inline] + fn ty(&self) -> Type<'l> { + $leaf_ty + } + + #[inline] + fn flags(&self) -> ValueFlags { + default_value_flags(self) + } + + #[inline] + fn get_associated_value(&self, name: &str) -> Option> { + default_associated_values(self, name) + } + + #[inline] + fn as_any_value(&self) -> AnyValue<'l> { + let $ident = self; + $expr + } + } + )* + }; +} + +impl_value! { + Void: () | Type::Void => |_v| AnyValue::Constant(AnyConst::Void), + Char: char | Type::Char => |v| AnyValue::Constant(AnyConst::Char(*v)), + Bool: bool | Type::Bool => |v| AnyValue::Constant(AnyConst::Bool(*v)), + + I8 : i8 | Type::I8 => |v| AnyValue::Constant(AnyConst::Int(Int::I8(*v))), + I16 : i16 | Type::I16 => |v| AnyValue::Constant(AnyConst::Int(Int::I16(*v))), + I32 : i32 | Type::I32 => |v| AnyValue::Constant(AnyConst::Int(Int::I32(*v))), + I64 : i64 | Type::I64 => |v| AnyValue::Constant(AnyConst::Int(Int::I64(*v))), + I128: i128 | Type::I128 => |v| AnyValue::Constant(AnyConst::Int(Int::I128(*v))), + U8 : u8 | Type::U8 => |v| AnyValue::Constant(AnyConst::Int(Int::U8(*v))), + U16 : u16 | Type::U16 => |v| AnyValue::Constant(AnyConst::Int(Int::U16(*v))), + U32 : u32 | Type::U32 => |v| AnyValue::Constant(AnyConst::Int(Int::U32(*v))), + U64 : u64 | Type::U64 => |v| AnyValue::Constant(AnyConst::Int(Int::U64(*v))), + U128: u128 | Type::U128 => |v| AnyValue::Constant(AnyConst::Int(Int::U128(*v))), + + F16 : f16 | Type::F16 => |v| AnyValue::Constant(AnyConst::Float(Float::F16(*v))), + F32 : f32 | Type::F32 => |v| AnyValue::Constant(AnyConst::Float(Float::F32(*v))), + F64 : f64 | Type::F64 => |v| AnyValue::Constant(AnyConst::Float(Float::F64(*v))), + + ConstStr: &'l str | Type::ConstStr => |v| AnyValue::Constant(AnyConst::Str(*v)), } diff --git a/assembly/src/values/mod.rs b/assembly/src/values/mod.rs index e5ae463..077d860 100644 --- a/assembly/src/values/mod.rs +++ b/assembly/src/values/mod.rs @@ -23,7 +23,7 @@ bitflags! { } #[derive(Debug, Display, Clone, Copy, From, TryInto, PartialEq, Eq, Hash)] -pub enum Value<'l> { +pub enum AnyValue<'l> { #[display("{_0}")] Constant(AnyConst<'l>), #[display("{_0}")] @@ -34,27 +34,62 @@ pub enum Value<'l> { Parameter(usize, &'l Function<'l>), } -impl From for Value<'_> { +impl From for AnyValue<'_> { fn from(value: Int) -> Self { - Value::Constant(AnyConst::Int(value)) + AnyValue::Constant(AnyConst::Int(value)) } } -impl<'l> Value<'l> { - pub fn ty(&self) -> Type<'l> { +pub trait Value<'l> { + fn ty(&self) -> Type<'l>; + fn flags(&self) -> ValueFlags; + fn get_associated_value(&self, name: &str) -> Option>; + fn as_any_value(&self) -> AnyValue<'l>; + + #[inline] + fn is_const(&self) -> bool { + self.flags().contains(ValueFlags::Const) + } + + #[inline] + fn is_lvalue(&self) -> bool { + self.flags().contains(ValueFlags::LValue) + } +} + +impl<'l> Value<'l> for AnyValue<'l> { + fn ty(&self) -> Type<'l> { match self { - Value::Constant(v) => v.ty(), - Value::Instruction(v) => v.value_ty(), - Value::Parameter(i, f) => f.ty.par_t[*i], + AnyValue::Constant(v) => v.ty(), + AnyValue::Instruction(v) => v.ty(), + AnyValue::Parameter(i, f) => f.ty.par_t[*i], } } - pub fn flags(&self) -> ValueFlags { + fn flags(&self) -> ValueFlags { match self { - Value::Instruction(v) => v.value_flags(), - Value::Parameter(_, _) => ValueFlags::empty(), - Value::Constant(c) => c.flags(), - _ => todo!("{self:?}"), + AnyValue::Instruction(v) => v.flags(), + AnyValue::Parameter(_, _) => ValueFlags::empty(), + AnyValue::Constant(c) => c.flags(), } } + + fn get_associated_value(&self, name: &str) -> Option> { + match self { + AnyValue::Constant(v) => v.get_associated_value(name), + AnyValue::Instruction(v) => todo!(), + AnyValue::Parameter(_, _) => default_associated_values(self, name), + } + } + + fn as_any_value(&self) -> AnyValue<'l> { + *self + } +} + +pub(crate) fn default_associated_values<'l>(v: &dyn Value<'l>, name: &str) -> Option> { + match name { + "#type" => Some(AnyValue::Constant(AnyConst::Type(v.ty()))), + _ => None, + } } diff --git a/backends/llvm/src/lib.rs b/backends/llvm/src/lib.rs index e48a2d8..f4cb661 100644 --- a/backends/llvm/src/lib.rs +++ b/backends/llvm/src/lib.rs @@ -5,7 +5,7 @@ use inkwell::{ module::Module, targets::TargetMachine, types::{AnyTypeEnum, BasicMetadataTypeEnum, BasicTypeEnum, IntType}, - values::{AnyValue, BasicValue, BasicValueEnum}, + values::{AggregateValueEnum, AnyValue as LlvmAnyValue, BasicValue, BasicValueEnum, IntValue}, }; use leaf_assembly::{ assembly::Assembly, @@ -15,9 +15,10 @@ use leaf_assembly::{ }, types::{ Type, + compound::{Field, StructT}, derivations::{ArrayT, FuncT, PtrT, RefT}, }, - values::{AnyConst, Int, Value}, + values::{AnyConst, AnyValue, Int, Value}, }; use scc::HashMap; @@ -32,6 +33,7 @@ pub struct CompilationContext<'l> { types: HashMap, AnyTypeEnum<'l>, FxBuildHasher>, modules: HashMap<&'l Assembly<'l>, Module<'l>, FxBuildHasher>, functions: HashMap<(&'l Assembly<'l>, &'l Function<'l>), LlvmFunction<'l>, FxBuildHasher>, + fields: HashMap<(&'l StructT<'l>, &'l str), u32, FxBuildHasher>, } impl<'l> CompilationContext<'l> { @@ -41,6 +43,7 @@ impl<'l> CompilationContext<'l> { types: HashMap::default(), modules: HashMap::default(), functions: HashMap::default(), + fields: HashMap::default(), native_int_ty: match target.get_target_data().get_pointer_byte_size(None) { 8 => ctx.i8_type(), 16 => ctx.i16_type(), @@ -70,9 +73,9 @@ impl<'l> CompilationContext<'l> { for (i, func) in assembly.functions() { let ty = self.get_type(Type::Func(func.ty)).into_function_type(); - let name = match func.name.get() { - Some(n) => *n, - None => &format!(""), + let name = match func.name { + "" => &format!(""), + _ => func.name, }; self.functions @@ -87,12 +90,12 @@ impl<'l> CompilationContext<'l> { continue; }; - let mut values = FxHashMap::>::default(); + let mut values = FxHashMap::>::default(); for (i, ty) in func.ty.par_t.iter().enumerate() { let ty = self.get_type(*ty); if BasicMetadataTypeEnum::try_from(ty).is_ok() { values.insert( - Value::Parameter(i, func), + AnyValue::Parameter(i, func), Some(llvm_func.get_nth_param(values.len() as u32).unwrap()), ); } @@ -143,7 +146,7 @@ impl<'l> CompilationContext<'l> { None } InstructionVariant::GetElementPtr(ptr, idx) => 'val: { - let pointee_ty = self.get_type(match inst.value_ty() { + let pointee_ty = self.get_type(match inst.ty() { Type::Ptr(PtrT { base, .. }) => *base, Type::Ref(RefT { base, .. }) => *base, _ => unreachable!(), @@ -159,6 +162,23 @@ impl<'l> CompilationContext<'l> { Some(ptr.into()) } } + InstructionVariant::GetElementVal( + val, + AnyValue::Constant(AnyConst::Str(fld)), + ) => 'val: { + let Type::Struct(ty) = val.ty() else { + unreachable!() + }; + let idx = match self.fields.get_sync(&(ty, *fld)) { + None => break 'val None, + Some(idx) => *idx, + }; + let Some(val) = self.get_value(&values, val) else { + break 'val None; + }; + let val = val.into_struct_value(); + Some(builder.build_extract_value(val, idx, "").unwrap().into()) + } InstructionVariant::IAdd(lhs, rhs) => { let lhs = self.get_value(&values, lhs).unwrap().into_int_value(); @@ -200,6 +220,11 @@ impl<'l> CompilationContext<'l> { let target = self.get_type(Type::Int(*target)).into_int_type(); Some(builder.build_int_truncate(val, target, "").unwrap().into()) } + InstructionVariant::IntToPtr(v, ty) => { + let val = self.get_value(&values, v).unwrap().into_int_value(); + let ptr = self.get_type(Type::Ptr(ty)).into_pointer_type(); + Some(builder.build_int_to_ptr(val, ptr, "").unwrap().into()) + } InstructionVariant::ICmp(lhs, rhs, cmp) => { let u = !is_signed(lhs.ty()); let cmp = match (cmp, u) { @@ -219,6 +244,23 @@ impl<'l> CompilationContext<'l> { Some(builder.build_int_compare(cmp, lhs, rhs, "").unwrap().into()) } + InstructionVariant::MakeStruct(ty, fields) => { + let ty = self.get_type(Type::Struct(ty)).into_struct_type(); + let mut i = 0; + let mut val = ty.get_undef(); + for field in *fields { + let Some(field) = self.get_value(&values, field) else { + continue; + }; + val = builder + .build_insert_value(val, field, i, "") + .unwrap() + .into_struct_value(); + i += 1; + } + Some(val.into()) + } + InstructionVariant::Call(func, args) => { // TODO This will fail with external assemblies. Fix this. let func = *self.functions.get_sync(&(assembly, *func)).unwrap(); @@ -290,6 +332,7 @@ impl<'l> CompilationContext<'l> { return *ty; } + let mut post_insertion_action = None::>; let llvm_ty = match ty { Type::Void => self.ctx.void_type().into(), Type::I8 | Type::U8 => self.ctx.i8_type().into(), @@ -312,6 +355,7 @@ impl<'l> CompilationContext<'l> { match ret_t { AnyTypeEnum::VoidType(ty) => ty.fn_type(&par_t, false).into(), AnyTypeEnum::IntType(ty) => ty.fn_type(&par_t, false).into(), + AnyTypeEnum::StructType(ty) => ty.fn_type(&par_t, false).into(), _ => todo!("{ret_t:?}"), } } @@ -329,23 +373,45 @@ impl<'l> CompilationContext<'l> { _ => todo!("{ty:#?}"), } } + + Type::Struct(s @ StructT { fields, name, .. }) => { + let ty = self.ctx.opaque_struct_type(name); + post_insertion_action = Some(Box::new(move || { + let mut types = vec![]; + let fields = fields.get().unwrap(); + for Field { ty, name, .. } in fields.values() { + if let Ok(ty) = self.get_type(*ty).try_into() { + self.fields + .insert_sync((s, name), types.len() as _) + .unwrap(); + types.push(ty); + } + } + ty.set_body(&types, false); + })); + ty.into() + } + _ => todo!("{ty:#?}"), }; self.types.entry_sync(ty).or_insert(llvm_ty); + if let Some(post_insertion_action) = post_insertion_action { + post_insertion_action(); + } llvm_ty } #[allow(clippy::mutable_key_type)] fn get_value( &self, - map: &FxHashMap, Option>>, - val: &Value<'l>, + map: &FxHashMap, Option>>, + val: &AnyValue<'l>, ) -> Option> { if let Some(value) = map.get(val) { return *value; } match val { - Value::Constant(val) => match val { + AnyValue::Constant(val) => match val { AnyConst::Int(Int::U8(v)) => { Some(self.ctx.i8_type().const_int(*v as u64, false).into()) } @@ -359,6 +425,7 @@ impl<'l> CompilationContext<'l> { AnyConst::Int(Int::USize(v)) => { Some(self.native_int_ty.const_int(*v as u64, false).into()) } + AnyConst::Array([]) => todo!("{val:?}"), AnyConst::Array(array) => { let ty = self.get_type(array[0].ty()); @@ -367,7 +434,7 @@ impl<'l> CompilationContext<'l> { let mut values = vec![]; for v in *array { let Some(BasicValueEnum::IntValue(v)) = - self.get_value(map, &Value::Constant(*v)) + self.get_value(map, &AnyValue::Constant(*v)) else { unreachable!(); }; @@ -378,6 +445,15 @@ impl<'l> CompilationContext<'l> { _ => todo!("{ty:?}"), } } + + AnyConst::Struct(s_ty, vals) => { + let ty = self.get_type(Type::Struct(s_ty)).into_struct_type(); + let vals: Vec<_> = vals + .iter() + .filter_map(|v| self.get_value(map, &v.as_any_value())) + .collect(); + Some(ty.const_named_struct(&vals).into()) + } _ => todo!("{val:?}"), }, _ => unreachable!("{val:#?}"), diff --git a/compiler/src/error.rs b/compiler/src/error.rs index b2bbf44..a3618f8 100644 --- a/compiler/src/error.rs +++ b/compiler/src/error.rs @@ -14,9 +14,14 @@ pub enum Kind { SymbolNotFound = 0x0200, UninitializedSymbol = 0x0201, NotAType = 0x0202, - NotAFunction = 0x0205, InvalidIntegerType = 0x0203, InvalidType = 0x0204, + NotAFunction = 0x0205, + NotAStruct = 0x0208, + FieldNotFound = 0x0206, + InvalidCast = 0x0207, + + UninitializedField = 0x0300, FunctionCompilationFailed = 0x0301, } diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index 282e5e6..adb3abb 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -5,7 +5,7 @@ use leaf_assembly::{ assembly::{Assembly, AssemblyIdentifier, Context}, functions::Function, types::Type, - values::{AnyConst, Value}, + values::{AnyConst, AnyValue}, }; use leaf_parser::{SourceCode, ast}; use std::{collections::VecDeque, ops::Deref, sync::Arc}; @@ -66,7 +66,7 @@ impl<'l> CompilationContext<'l> { $( scope.insert( literal_substr!($id), - Value::Constant(AnyConst::Type($ty.into())), + AnyValue::Constant(AnyConst::Type($ty.into())), false, ); )* diff --git a/compiler/src/scope.rs b/compiler/src/scope.rs index d127b67..dbf7bb4 100644 --- a/compiler/src/scope.rs +++ b/compiler/src/scope.rs @@ -1,17 +1,24 @@ use crate::{FuncQueue, error::*}; -use arcstr::Substr; +use arcstr::{Substr, literal_substr}; use leaf_assembly::{ assembly::Assembly, functions::{ Function, ir::{Cmp, FunctionBodyBuilder}, }, - types::{Type, derivations::PtrT}, - values::{AnyConst, Int, Value, ValueFlags}, + types::{ + Type, + compound::{Field, FieldMap, StructT}, + derivations::PtrT, + }, + values::{AnyConst, AnyValue, Int, Value, ValueFlags}, }; use leaf_parser::{ SourceCode, - ast::{self, BinaryExpr, BinaryOp, ConstDecl, Expr, Ident, IndexingExpr, NamePattern, While}, + ast::{ + self, AccessExpr, BinaryExpr, BinaryOp, ConstDecl, Expr, Ident, IndexingExpr, NamePattern, + While, + }, }; use std::{ collections::HashMap, @@ -26,7 +33,7 @@ struct ExpressionContext<'l, 'r> { #[derive(Clone)] struct Variable<'l> { - value: Arc>>, + value: Arc>>, mutable: bool, } @@ -46,7 +53,7 @@ impl<'l> Scope<'l> { } } - pub fn insert(&mut self, name: Substr, value: Value<'l>, mutable: bool) { + pub fn insert(&mut self, name: Substr, value: AnyValue<'l>, mutable: bool) { self.values.insert( name, Variable { @@ -122,7 +129,7 @@ impl<'l> Scope<'l> { Type::Void => builder.ret(None).unwrap(), _ => { if let Some(expr) = last_expr.as_mut() - && expr.flags().contains(ValueFlags::LValue) + && expr.is_lvalue() { *expr = builder.load(*expr).unwrap(); } @@ -139,7 +146,7 @@ impl<'l> Scope<'l> { &mut self, expr: &Expr, ctx: &mut ExpressionContext<'l, '_>, - ) -> Result, CompilationError> { + ) -> Result, CompilationError> { match expr { Expr::Ident(Ident(name)) => match self.values.get(name) { None => Err(CompilationError { @@ -167,7 +174,7 @@ impl<'l> Scope<'l> { Expr::Func(func) => self .make_function(func, ctx) - .map(|f| Value::Constant(AnyConst::Function(f))) + .map(|f| AnyValue::Constant(AnyConst::Function(f))) .map_err(|err| CompilationError { kind: Kind::FunctionCompilationFailed, message: "Could not compile function.".to_string(), @@ -189,7 +196,7 @@ impl<'l> Scope<'l> { Some(("0x", value)) => <$ty>::from_str_radix(value, 16), _ => n.text.parse::<$ty>(), } - .map(|v| Value::Constant(AnyConst::Int(Int::$id(v)))) + .map(|v| AnyValue::Constant(AnyConst::Int(Int::$id(v)))) .map_err(|_| CompilationError { kind: Kind::InvalidInteger, message: format!("`{}` is not a valid integer.", n.text), @@ -202,19 +209,19 @@ impl<'l> Scope<'l> { }; } match n.r#type.as_ref().map(|v| v.as_str()) { - None => parse_number!(i128, ISize), + None => parse_number!(i64, ISize), Some("i8") => parse_number!(i8, I8), Some("i16") => parse_number!(i16, I16), Some("i32") => parse_number!(i32, I32), Some("i64") => parse_number!(i64, I64), Some("i128") => parse_number!(i128, I128), - Some("isize") => parse_number!(i128, ISize), + Some("isize") => parse_number!(i64, ISize), Some("u8") => parse_number!(u8, U8), Some("u16") => parse_number!(u16, U16), Some("u32") => parse_number!(u32, U32), Some("u64") => parse_number!(u64, U64), Some("u128") => parse_number!(u128, U128), - Some("usize") => parse_number!(u128, USize), + Some("usize") => parse_number!(u64, USize), Some(ty) => Err(CompilationError { kind: Kind::InvalidIntegerType, message: format!("`{ty}` is not a valid integer type."), @@ -227,6 +234,40 @@ impl<'l> Scope<'l> { } } + Expr::Access(expr) => { + let AccessExpr { + value: value_expr, + field, + } = &**expr; + let value = self.compile_expression(value_expr, ctx)?; + if let Some(value) = value.get_associated_value(&field.0) { + return Ok(value); + } + match value.ty() { + Type::Struct(StructT { fields, .. }) => { + if let Some(fields) = fields.get() { + if let Some(field) = fields.get(field.0.as_str()) { + let builder = ctx.builder.as_mut().unwrap(); + return Ok(builder + .get_element_value(value, field.name.as_any_value()) + .unwrap() + .as_any_value()); + } + } + } + _ => {} + }; + return Err(CompilationError { + kind: Kind::FieldNotFound, + message: format!("Value does not contain field `{}`.", field.0), + location: Location::Range { + file: self.source.clone(), + range: value_expr.range(), + }, + cause: None, + }); + } + Expr::Binary(bin_expr) => { let BinaryExpr { lhs: lhs_expr, @@ -235,50 +276,98 @@ impl<'l> Scope<'l> { } = &**bin_expr; let mut lhs = self.compile_expression(lhs_expr, ctx)?; - if lhs.flags().contains(ValueFlags::LValue) && !matches!(op, BinaryOp::Assign(_)) { + if lhs.is_lvalue() && !matches!(op, BinaryOp::Assign(_)) { lhs = ctx.builder.as_mut().unwrap().load(lhs).unwrap(); } let mut rhs = self.compile_expression(rhs_expr, ctx)?; - if rhs.flags().contains(ValueFlags::LValue) { + if rhs.is_lvalue() { rhs = ctx.builder.as_mut().unwrap().load(rhs).unwrap(); } let builder = ctx.builder.as_mut().unwrap(); + macro_rules! int_bin_ops { + ( + const exact + $([$($ty:ident),*] $op:pat => |$a:ident, $b:ident| $expr:expr,)* + ) => { + match op { + $( + $op => match (lhs, rhs) { + $( + ( + AnyValue::Constant(AnyConst::Int(Int::$ty($a))), + AnyValue::Constant(AnyConst::Int(Int::$ty($b))), + ) => return Ok(AnyValue::Constant(AnyConst::Int(Int::$ty($expr)))), + )* + _ => {} + } + )* + _ => {} + } + }; + ( + const auto + $([$($ty:ident),*] $op:pat => |$a:ident, $b:ident| $expr:expr,)* + ) => { + match op { + $( + $op => match (lhs, rhs) { + $( + ( + AnyValue::Constant(AnyConst::Int(Int::$ty($a))), + AnyValue::Constant(AnyConst::Int(Int::$ty($b))), + ) => return Ok($expr.as_any_value()), + )* + _ => {} + } + )* + _ => {} + } + }; + } + + if lhs.is_const() && rhs.is_const() { + int_bin_ops! { + const exact + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Add(_) => |a, b| a + b, + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Sub(_) => |a, b| a - b, + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Mul(_) => |a, b| a - b, + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Div(_) => |a, b| a - b, + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Mod(_) => |a, b| a - b, + } + int_bin_ops! { + const auto + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Eq(_) => |a, b| a == b, + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Ne(_) => |a, b| a == b, + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Lt(_) => |a, b| a == b, + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Gt(_) => |a, b| a == b, + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Le(_) => |a, b| a == b, + [I8, I16, I32, I64, I128, U8, U16, U32, U64, U128, ISize, USize] BinaryOp::Ge(_) => |a, b| a == b, + } + } + + if match (lhs.ty(), rhs.ty()) { + (Type::Int(a_ty), Type::Int(b_ty)) => a_ty == b_ty, + _ => false, + } { + return Ok(match op { + BinaryOp::Add(_) => builder.add(lhs, rhs).unwrap(), + BinaryOp::Sub(_) => builder.sub(lhs, rhs).unwrap(), + BinaryOp::Mul(_) => builder.mul(lhs, rhs).unwrap(), + BinaryOp::Div(_) => builder.div(lhs, rhs).unwrap(), + BinaryOp::Mod(_) => builder.modulo(lhs, rhs).unwrap(), + BinaryOp::Eq(_) => builder.cmp(lhs, rhs, Cmp::Eq).unwrap(), + BinaryOp::Ne(_) => builder.cmp(lhs, rhs, Cmp::Ne).unwrap(), + BinaryOp::Lt(_) => builder.cmp(lhs, rhs, Cmp::Lt).unwrap(), + BinaryOp::Gt(_) => builder.cmp(lhs, rhs, Cmp::Gt).unwrap(), + BinaryOp::Le(_) => builder.cmp(lhs, rhs, Cmp::Le).unwrap(), + BinaryOp::Ge(_) => builder.cmp(lhs, rhs, Cmp::Ge).unwrap(), + _ => todo!("{lhs:?} {op:?} {rhs:?}"), + }); + } + match (lhs.ty(), rhs.ty(), op) { - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Add(_)) if a_ty == b_ty => { - Ok(builder.add(lhs, rhs).unwrap()) - } - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Sub(_)) if a_ty == b_ty => { - Ok(builder.sub(lhs, rhs).unwrap()) - } - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Mul(_)) if a_ty == b_ty => { - Ok(builder.mul(lhs, rhs).unwrap()) - } - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Div(_)) if a_ty == b_ty => { - Ok(builder.div(lhs, rhs).unwrap()) - } - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Mod(_)) if a_ty == b_ty => { - Ok(builder.modulo(lhs, rhs).unwrap()) - } - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Eq(_)) if a_ty == b_ty => { - Ok(builder.cmp(lhs, rhs, Cmp::Eq).unwrap()) - } - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Ne(_)) if a_ty == b_ty => { - Ok(builder.cmp(lhs, rhs, Cmp::Ne).unwrap()) - } - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Lt(_)) if a_ty == b_ty => { - Ok(builder.cmp(lhs, rhs, Cmp::Lt).unwrap()) - } - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Gt(_)) if a_ty == b_ty => { - Ok(builder.cmp(lhs, rhs, Cmp::Gt).unwrap()) - } - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Le(_)) if a_ty == b_ty => { - Ok(builder.cmp(lhs, rhs, Cmp::Le).unwrap()) - } - (Type::Int(a_ty), Type::Int(b_ty), BinaryOp::Ge(_)) if a_ty == b_ty => { - Ok(builder.cmp(lhs, rhs, Cmp::Ge).unwrap()) - } (Type::Ptr(PtrT { base, .. }), ty, BinaryOp::Assign(_)) => match *base == ty { true => Ok(builder.store(lhs, rhs).unwrap()), false => Err(CompilationError { @@ -294,25 +383,20 @@ impl<'l> Scope<'l> { }), }, (src_ty, Type::Type, BinaryOp::Cast(_)) => { - let Value::Constant(AnyConst::Type(dst_ty)) = rhs else { - return Err(CompilationError { - kind: Kind::NotAType, - message: "Cannot perform cast.".to_string(), - location: Location::Range { - file: self.source.clone(), - range: expr.range(), - }, - cause: Some(Box::new(CompilationError { - kind: Kind::NotAType, - message: "Cast target is not a type.".to_string(), + let dst_ty = + self.assert_ty(rhs, rhs_expr) + .map_err(|err| CompilationError { + kind: Kind::InvalidCast, + message: "Cannot perform cast.".to_string(), location: Location::Range { file: self.source.clone(), - range: rhs_expr.range(), + range: expr.range(), }, - cause: None, - })), - }); - }; + cause: Some(Box::new(err)), + })?; + if src_ty == dst_ty { + return Ok(lhs); + } match (src_ty, dst_ty) { (Type::Int(src_ty), Type::Int(dst_ty)) => { if dst_ty.precision < src_ty.precision { @@ -320,6 +404,9 @@ impl<'l> Scope<'l> { } todo!("{src_ty} as {dst_ty}"); } + (Type::Int(_), Type::Ptr(dst_ty)) => { + return Ok(builder.int_to_ptr(lhs, dst_ty).unwrap()); + } _ => todo!("{src_ty} as {dst_ty}"), } } @@ -356,7 +443,7 @@ impl<'l> Scope<'l> { } builder.jump(cond_block).unwrap(); builder.set_current_block(exit_block); - Ok(last_expr.unwrap_or(Value::Constant(AnyConst::Void))) + Ok(last_expr.unwrap_or(AnyValue::Constant(AnyConst::Void))) } Expr::Call { @@ -364,7 +451,7 @@ impl<'l> Scope<'l> { args: args_exprs, } => { let func = match self.compile_expression(func, ctx)? { - Value::Constant(AnyConst::Function(func)) => func, + AnyValue::Constant(AnyConst::Function(func)) => func, _ => { return Err(CompilationError { kind: Kind::NotAFunction, @@ -380,7 +467,7 @@ impl<'l> Scope<'l> { let mut args = Vec::with_capacity(args_exprs.len()); for expr in args_exprs { let mut arg = self.compile_expression(expr, ctx)?; - if arg.flags().contains(ValueFlags::LValue) { + if arg.is_lvalue() { arg = ctx.builder.as_mut().unwrap().load(arg).unwrap(); } args.push(arg); @@ -390,20 +477,58 @@ impl<'l> Scope<'l> { } Expr::Type(ty_expr) => match &**ty_expr { - ast::Type::Ptr { base, mutable } => match self.compile_expression(base, ctx)? { - Value::Constant(AnyConst::Type(ty)) => { - Ok(AnyConst::Type(Type::Ptr(ty.make_ptr(*mutable))).into()) + ast::Type::Struct(ast::Struct { fields }) => { + let name = match &ctx.decl_names { + Some(NamePattern::Single(func_name)) => func_name.0.as_str(), + _ => "", + }; + let struct_ty = self.assembly.create_struct(name); + let mut scope = self.clone(); + let mut expr_ctx = ExpressionContext { + builder: None, + decl_names: None, + fn_queue: ctx.fn_queue, + }; + let ctx = self.assembly.ctx(); + + scope.insert(literal_substr!("Self"), struct_ty.as_any_value(), false); + let mut field_map = FieldMap::default(); + for ast::Field { + name, + ty: ty_expr, + public, + mutable, + } in fields + { + let ty = scope.compile_expression(ty_expr, &mut expr_ctx)?; + let name = ctx.intern_str(&name.0); + field_map.insert( + name, + Field { + name, + ty: self.assert_ty(ty, ty_expr)?, + public: public.is_some(), + mutable: mutable.is_some(), + }, + ); } - Value::Instruction(inst) if inst.value_flags().contains(ValueFlags::LValue) => { + struct_ty.fields.set(field_map).unwrap(); + Ok(struct_ty.as_any_value()) + } + ast::Type::Ptr { base, mutable } => match self.compile_expression(base, ctx)? { + AnyValue::Constant(AnyConst::Type(ty)) => { + Ok(AnyConst::Type(Type::Ptr(ty.make_ptr(mutable.is_some()))).into()) + } + AnyValue::Instruction(inst) if inst.is_lvalue() => { let Type::Ptr(PtrT { base, mutable: is_mut, .. - }) = inst.value_ty() + }) = inst.ty() else { unreachable!() }; - if *mutable && !*is_mut { + if mutable.is_some() && !*is_mut { return Err(CompilationError { kind: Kind::NotAFunction, message: "Cannot obtain a mutable pointer to an immutable value." @@ -415,27 +540,28 @@ impl<'l> Scope<'l> { cause: None, }); } - let mut flags = inst.value_flags(); + let mut flags = inst.flags(); let builder = ctx.builder.as_mut().unwrap(); flags.remove( ValueFlags::Mutable | ValueFlags::Volatile | ValueFlags::LValue, ); - let ptr = Type::Ptr(base.make_ptr(*mutable)); + let ptr = Type::Ptr(base.make_ptr(mutable.is_some())); unsafe { Ok(builder - .reinterpret(Value::Instruction(inst), ptr, flags) + .reinterpret(AnyValue::Instruction(inst), ptr, flags) .unwrap()) } } v => todo!("{v:?}"), }, + v => todo!("{v:#?}"), }, Expr::List(expr) => { let mut expr = expr.iter(); let mut values = Vec::with_capacity(expr.len()); match expr.next() { - None => return Ok(Value::Constant(AnyConst::Array(&[]))), + None => return Ok(AnyValue::Constant(AnyConst::Array(&[]))), Some(expr) => { let value = self.compile_expression(expr, ctx)?; // TODO Check if it matches the ctx type hint @@ -445,28 +571,14 @@ impl<'l> Scope<'l> { let element_ty = values[0].ty(); for expr in expr { let value = self.compile_expression(expr, ctx)?; - if value.ty() != element_ty { - return Err(CompilationError { - kind: Kind::InvalidType, - message: format!( - "Expected type `{}`, found `{}`", - element_ty, - value.ty() - ), - location: Location::Range { - file: self.source.clone(), - range: expr.range(), - }, - cause: None, - }); - } + self.assert_ty_eq(&value, expr, &element_ty)?; values.push(value); } - if values.iter().all(|v| matches!(v, Value::Constant(_))) { + if values.iter().all(|v| matches!(v, AnyValue::Constant(_))) { let alloc = self.assembly.ctx().alloc(); - return Ok(Value::Constant(AnyConst::Array(alloc.alloc_slice( + return Ok(AnyValue::Constant(AnyConst::Array(alloc.alloc_slice( values.into_iter().map(|v| match v { - Value::Constant(c) => c, + AnyValue::Constant(c) => c, _ => unreachable!(), }), )))); @@ -480,7 +592,7 @@ impl<'l> Scope<'l> { let mut index = self.compile_expression(index, ctx)?; let builder = ctx.builder.as_mut().unwrap(); - if index.flags().contains(ValueFlags::LValue) { + if index.is_lvalue() { index = builder.load(index).unwrap(); } @@ -497,48 +609,111 @@ impl<'l> Scope<'l> { }); } - if value.flags().contains(ValueFlags::LValue) { - let gep = builder.gep(value, index).unwrap(); + if value.is_lvalue() { + let gep = builder.get_element_ptr(value, index).unwrap(); return Ok(gep); } todo!("{:#?}", value.ty()); } + Expr::Struct(ctor) => { + let ty = self.compile_expression(&ctor.r#type, ctx)?; + let AnyValue::Constant(AnyConst::Type(Type::Struct( + struct_ty @ StructT { name, fields, .. }, + ))) = ty + else { + return Err(CompilationError { + kind: Kind::NotAStruct, + message: format!("Expected struct type, got value of type `{}`.", ty.ty()), + location: Location::Range { + file: self.source.clone(), + range: ctor.r#type.range(), + }, + cause: None, + }); + }; + + let mut non_const = false; + let fields = fields.get().unwrap(); + let mut values = Vec::with_capacity(fields.len()); + for Field { + name: fld_name, ty, .. + } in fields.values() + { + let Some(name_value_pair) = ctor.values.get(*fld_name) else { + return Err(CompilationError { + kind: Kind::UninitializedField, + message: format!("Uninitialized field `{fld_name}`."), + location: Location::Range { + file: self.source.clone(), + range: expr.range(), + }, + cause: None, + }); + }; + let value = self.compile_expression(&name_value_pair.value, ctx)?; + self.assert_ty_eq(&value, &name_value_pair.value, ty)?; + non_const |= !value.flags().contains(ValueFlags::Const); + values.push(value); + } + + Ok(match non_const { + true => { + let builder = ctx.builder.as_mut().unwrap(); + let values = self.assembly.ctx().alloc().alloc_slice(values.into_iter()); + builder.make_struct(struct_ty, values).unwrap() + } + false => AnyValue::Constant(AnyConst::Struct( + struct_ty, + self.assembly + .ctx() + .alloc() + .alloc_slice(values.into_iter().map(|v| { + let AnyValue::Constant(c) = v else { + unreachable!() + }; + c + })), + )), + }) + } + _ => todo!("{expr:#?}"), } } fn make_function( &mut self, - ast: &ast::Function, + ast: &Arc, ctx: &mut ExpressionContext<'l, '_>, ) -> Result<&'l Function<'l>, CompilationError> { let ret_ty = match ast.ret.as_ref() { - None => Value::Constant(AnyConst::Type(Type::Void)), + None => AnyValue::Constant(AnyConst::Type(Type::Void)), Some(ty) => self.compile_expression(ty, ctx)?, }; - let ret_ty = self.assert_type(ret_ty)?; + let ast_as_expr = Expr::Func(ast.clone()); + let ret_ty = self.assert_ty(ret_ty, ast.ret.as_ref().unwrap_or(&ast_as_expr))?; let mut par_ty = Vec::with_capacity(ast.args.len()); for arg in &ast.args { - let ty = self.compile_expression(&arg.r#type, ctx)?; - let ty = self.assert_type(ty)?; + let ty = self.compile_expression(&arg.value, ctx)?; + let ty = self.assert_ty(ty, &arg.value)?; par_ty.push(ty); } let fn_ty = ret_ty.make_fn(par_ty); - let func = self.assembly.create_function(fn_ty); - if let Some(NamePattern::Single(func_name)) = &ctx.decl_names { - func.name.set(func.ctx().intern_str(&func_name.0)).unwrap(); + let name = match &ctx.decl_names { + Some(NamePattern::Single(func_name)) => func_name.0.as_str(), + _ => "", }; - + let func = self.assembly.create_function(fn_ty, name); let Some(block) = &ast.block else { return Ok(func); }; let mut scope = self.clone(); for (i, arg) in ast.args.iter().enumerate() { - scope.insert(arg.name.0.clone(), Value::Parameter(i, func), false); + scope.insert(arg.name.0.clone(), AnyValue::Parameter(i, func), false); } ctx.fn_queue.push_back((func, block.clone(), scope)); @@ -552,7 +727,7 @@ impl<'l> Scope<'l> { value: &Expr, mutable: bool, ctx: &mut ExpressionContext<'l, '_>, - ) -> Result, CompilationError> { + ) -> Result, CompilationError> { let mut sub_ctx = ExpressionContext { decl_names: Some(names), builder: ctx.builder.as_deref_mut(), @@ -581,15 +756,45 @@ impl<'l> Scope<'l> { Ok(AnyConst::Void.into()) } - fn assert_type(&self, val: Value<'l>) -> Result, CompilationError> { + fn assert_ty( + &self, + val: AnyValue<'l>, + value_expr: &Expr, + ) -> Result, CompilationError> { match val { - Value::Constant(AnyConst::Type(ty)) => Ok(ty), + AnyValue::Constant(AnyConst::Type(ty)) => Ok(ty), _ => Err(CompilationError { kind: Kind::NotAType, message: "Value is not a type.".to_string(), - location: Location::None, + location: Location::Range { + file: self.source.clone(), + range: value_expr.range(), + }, cause: None, }), } } + + pub fn assert_ty_eq( + &self, + value: &AnyValue<'l>, + value_expr: &Expr, + expected: &Type<'l>, + ) -> Result, CompilationError> { + let value_ty = value.ty(); + match value_ty == *expected { + true => Ok(value_ty), + false => { + return Err(CompilationError { + kind: Kind::InvalidType, + message: format!("Expected value of type `{expected}`, found `{value_ty}`."), + location: Location::Range { + file: self.source.clone(), + range: value_expr.range(), + }, + cause: None, + }); + } + } + } } diff --git a/parser/Cargo.toml b/parser/Cargo.toml index 15c2704..6d4105f 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -6,4 +6,5 @@ edition = "2024" [dependencies] arcstr = "1.2.0" derive_more = { version = "2.1.0", features = ["deref", "debug", "display"] } +indexmap = "2.13.0" peg = "0.8.5" diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 894a282..a37a279 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -1,5 +1,6 @@ use arcstr::Substr; use derive_more::Deref; +use indexmap::IndexMap; use std::ops::Range; use std::sync::Arc; @@ -27,6 +28,7 @@ pub enum Expr { #[debug("{_0:?}")] Binary(Arc), Index(Arc), + Access(Arc), Tuple(Vec), List(Vec), Struct(Arc), @@ -56,6 +58,8 @@ impl Expr { pub fn range(&self) -> Range { match self { Self::Ident(e) => e.range(), + Self::Access(e) => e.range(), + Self::Number(e) => e.text.range(), _ => todo!("{self:?}"), } } @@ -74,6 +78,18 @@ pub struct IndexingExpr { pub index: Expr, } +#[derive(Debug)] +pub struct AccessExpr { + pub value: Expr, + pub field: Ident, +} + +impl AccessExpr { + pub fn range(&self) -> Range { + self.value.range().start..self.field.0.range().end + } +} + #[rustfmt::skip] #[derive(derive_more::Debug)] pub enum BinaryOp { @@ -82,7 +98,6 @@ pub enum BinaryOp { #[debug("{_0}")] Mul(Substr), #[debug("{_0}")] Div(Substr), #[debug("{_0}")] Mod(Substr), - #[debug("{_0}")] Dot(Substr), #[debug("{_0}")] Eq(Substr), #[debug("{_0}")] Ne(Substr), #[debug("{_0}")] Lt(Substr), @@ -96,7 +111,8 @@ pub enum BinaryOp { #[derive(Debug)] pub enum Type { - Ptr { base: Expr, mutable: bool }, + Ptr { base: Expr, mutable: Option }, + Struct(Struct), } #[derive(Debug)] @@ -141,13 +157,26 @@ pub struct Function { #[derive(Debug)] pub struct NameValuePair { pub name: Ident, - pub r#type: Expr, + pub value: Expr, +} + +#[derive(Debug)] +pub struct Struct { + pub fields: Vec, +} + +#[derive(Debug)] +pub struct Field { + pub name: Ident, + pub ty: Expr, + pub public: Option, + pub mutable: Option, } #[derive(Debug)] pub struct StructCtor { pub r#type: Expr, - pub r#values: Vec, + pub r#values: IndexMap, } #[derive(Debug)] diff --git a/parser/src/parser.rs b/parser/src/parser.rs index 7297d72..9e8ce08 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -60,6 +60,9 @@ peg::parser! { rule ident() -> Ident = text:$(['_'|'a'..='z'|'A'..='Z']['_'|'a'..='z'|'A'..='Z'|'0'..='9']*) { Ident(text) } + rule ident2() -> Ident + = text:$("#"? ['_'|'a'..='z'|'A'..='Z']['_'|'a'..='z'|'A'..='Z'|'0'..='9']*) { Ident(text) } + rule string() -> Substr = str:$("\"" char()* "\"") { str } @@ -70,29 +73,32 @@ peg::parser! { // ### EXPRESSIONS #### rule expr() -> Expr = precedence! { - lhs:(@) __ op:$("as") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Cast(op), rhs })) } + lhs:(@) __ op:$("as") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Cast(op), rhs })) } -- lhs:@ __ op:$("=") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Assign(op), rhs })) } - lhs:(@) __ op:$(".") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Dot(op), rhs })) } - lhs:(@) "(" __ args:(expr() ** ("," __)) __ ")" { Expr::Call { func: Arc::new(lhs), args } } - value:(@) "[" __ index:expr() __ "]" { Expr::Index(Arc::new(IndexingExpr { value, index })) } - r#type:(@) _ "{" __ values:name_value_pairs() __ "}" { Expr::Struct(Arc::new(StructCtor { r#type, values })) } + value:@ __ op:$(".") __ field:ident2() { Expr::Access(AccessExpr { value, field }.into()) } + lhs:@ "(" __ args:(expr() ** list_separator()) __ ")" { Expr::Call { func: Arc::new(lhs), args } } + value:@ "[" __ index:expr() __ "]" { Expr::Index(Arc::new(IndexingExpr { value, index })) } + + r#type:@ __ "{" __ values:name_value_pairs() __ "}" { Expr::Struct(Arc::new(StructCtor { + r#type, values: values.into_iter().map(|v| (v.name.0.clone(), v)).collect() + })) } -- - lhs:(@) __ op:$("+") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Add(op), rhs })) } - lhs:(@) __ op:$("-") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Sub(op), rhs })) } + lhs:@ __ op:$("+") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Add(op), rhs })) } + lhs:@ __ op:$("-") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Sub(op), rhs })) } -- - lhs:(@) __ op:$("*") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Mul(op), rhs })) } - lhs:(@) __ op:$("/") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Div(op), rhs })) } - lhs:(@) __ op:$("%") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Mod(op), rhs })) } + lhs:@ __ op:$("*") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Mul(op), rhs })) } + lhs:@ __ op:$("/") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Div(op), rhs })) } + lhs:@ __ op:$("%") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Mod(op), rhs })) } -- - lhs:(@) __ op:$("..") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Range(op), rhs })) } + lhs:@ __ op:$("..") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Range(op), rhs })) } -- - lhs:(@) __ op:$("==") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Eq(op), rhs })) } - lhs:(@) __ op:$("!=") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Ne(op), rhs })) } - lhs:(@) __ op:$("<") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Lt(op), rhs })) } - lhs:(@) __ op:$(">") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Gt(op), rhs })) } - lhs:(@) __ op:$("<=") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Le(op), rhs })) } - lhs:(@) __ op:$(">=") __ rhs:@ { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Ge(op), rhs })) } + lhs:@ __ op:$("==") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Eq(op), rhs })) } + lhs:@ __ op:$("!=") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Ne(op), rhs })) } + lhs:@ __ op:$("<") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Lt(op), rhs })) } + lhs:@ __ op:$(">") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Gt(op), rhs })) } + lhs:@ __ op:$("<=") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Le(op), rhs })) } + lhs:@ __ op:$(">=") __ rhs:expr() { Expr::Binary(Arc::new(BinaryExpr { lhs, op: BinaryOp::Ge(op), rhs })) } -- block:block() { Expr::Block(block)} for_loop:for_loop() { Expr::For(Arc::new(for_loop))} @@ -103,7 +109,8 @@ peg::parser! { "(" __ tuple:(expr() **<2,> ("," __)) __ ")" { Expr::Tuple(tuple) } "[" __ list:(expr() ** ("," __)) __ "]" { Expr::List(list) } "(" __ v:expr() __ ")" { v } - "*" __ m:"mut"? __ v:expr() { Expr::Type(Arc::new(Type::Ptr { base:v, mutable: m.is_some() })) } + "*" __ m:$"mut"? __ v:expr() { Expr::Type(Arc::new(Type::Ptr { base:v, mutable: m })) } + v:struct_t() { Expr::Type(Arc::new(Type::Struct(v))) } v:string() { Expr::String(v) } v:number() { Expr::Number(v) } v:ident() { Expr::Ident(v) } @@ -113,14 +120,32 @@ peg::parser! { = "{" __ exprs:(i:expr() statement_separator() {i})* __ "}" { Block(exprs) } rule func() -> Function - = s:position!() t:$"fn" __ "(" __ args:name_value_pairs() __ ")" __ ret:("->" __ e:expr() {e})? __ block:block()? e:position!() + = s:position!() t:$"fn" __ "(" __ args:name_type_pairs() __ ")" __ ret:("->" __ e:expr() {e})? __ block:block()? e:position!() { Function { args, ret, block: block.map(Arc::new), text: t.parent().substr(s..e), } } + rule name_type_pair() -> NameValuePair + = name:ident() __ ":" __ value:expr() { NameValuePair { name, value } } + + rule name_type_pairs() -> Vec + = v:(name_type_pair() **<1,> list_separator()) list_separator()? { v } + / { vec![] } + rule name_value_pair() -> NameValuePair - = name:ident() __ ":" __ r#type:expr() { NameValuePair { name, r#type } } + = name:ident() __ "=" __ value:expr() { NameValuePair { name, value } } rule name_value_pairs() -> Vec - = v:(name_value_pair() ** ("," __)) { v } + = v:(name_value_pair() **<1,> list_separator()) list_separator()? { v } + / { vec![] } + + rule struct_t() -> Struct + = "struct" __ "{" __ fields:fields() __ "}" { Struct { fields } } + + rule field() -> Field + = public:$"pub"? __ mutable:$"mut"? __ name:ident() __ ":" __ ty:expr() { Field { name, ty, public, mutable } } + + rule fields() -> Vec + = v:(field() **<1,> list_separator()) list_separator()? { v } + / { vec![] } rule import() -> Import = "import" _ expr:expr() { Import(expr) } @@ -156,6 +181,7 @@ peg::parser! { rule _ = quiet! { [' '|'\t']+ } rule __ = quiet! { [' '|'\t'|'\n']* } rule statement_separator() = quiet! { [';'|'\n'] __ } + rule list_separator() = quiet! { [','|'\n'] __ } } }