Ifs and Phis

This commit is contained in:
Mia
2026-03-07 16:24:32 +01:00
parent fb84e09391
commit 168a12b4fc
11 changed files with 612 additions and 150 deletions
+1
View File
@@ -20,6 +20,7 @@ pub enum Kind {
NotAStruct = 0x0208,
FieldNotFound = 0x0206,
InvalidCast = 0x0207,
CannotDereference = 0x0209,
UninitializedField = 0x0300,
+6 -6
View File
@@ -10,7 +10,10 @@ use leaf_backend_llvm::{
};
use leaf_compiler::CompilationContext;
use leaf_parser::SourceCode;
use std::{path::PathBuf, sync::Arc};
use std::{
path::{Path, PathBuf},
sync::Arc,
};
fn main() {
let alloc = SyncArenaAllocator::default();
@@ -54,10 +57,7 @@ fn main() {
module.print_to_stderr();
module.verify().unwrap();
let asm = target_machine
.write_to_memory_buffer(&module, FileType::Assembly)
target_machine
.write_to_file(&module, FileType::Assembly, Path::new("out.asm"))
.unwrap();
let asm = std::str::from_utf8(asm.as_slice()).unwrap();
eprintln!("{asm}");
std::fs::write("out.asm", asm).unwrap();
}
+238 -40
View File
@@ -4,20 +4,20 @@ use leaf_assembly::{
assembly::Assembly,
functions::{
Function,
ir::{Cmp, FunctionBodyBuilder},
ir::{Cmp, FunctionBodyBuilder, PhiValue},
},
types::{
Type,
IntT, Type,
compound::{Field, FieldMap, StructT},
derivations::PtrT,
derivations::{PtrT, RefT},
},
values::{AnyConst, AnyValue, Int, Value, ValueFlags},
};
use leaf_parser::{
SourceCode,
ast::{
self, AccessExpr, BinaryExpr, BinaryOp, ConstDecl, Expr, Ident, IndexingExpr, NamePattern,
While,
self, AccessExpr, BinaryExpr, BinaryOp, Block, ConstDecl, Else, Expr, Ident, If,
IndexingExpr, NamePattern, While,
},
};
use std::{
@@ -109,31 +109,30 @@ impl<'l> Scope<'l> {
pub fn compile_function(
&mut self,
func: &'l Function<'l>,
block: &ast::Block,
block: &Arc<ast::Block>,
fn_queue: &mut FuncQueue<'l>,
) -> Result<(), CompilationError> {
let mut builder = func.create_body().unwrap();
let mut ctx = ExpressionContext {
builder: Some(&mut builder),
decl_names: None,
fn_queue: fn_queue,
};
let mut last_expr = None;
for expr in &block.0 {
last_expr = Some(self.compile_expression(expr, &mut ctx)?);
}
let mut ret = self.compile_block(
block,
&mut ExpressionContext {
builder: Some(&mut builder),
decl_names: None,
fn_queue: fn_queue,
},
)?;
if !builder.current_block().has_termination() {
match func.ty.ret_t {
Type::Void => builder.ret(None).unwrap(),
Type::Void => {
builder.ret(None).unwrap();
}
_ => {
if let Some(expr) = last_expr.as_mut()
&& expr.is_lvalue()
{
*expr = builder.load(*expr).unwrap();
if ret.is_lvalue() {
ret = builder.load(ret).unwrap();
}
builder.ret(last_expr).unwrap()
self.assert_ty_eq(&ret, &Expr::Block(block.clone()), &func.ty.ret_t)?;
builder.ret(Some(ret)).unwrap();
}
};
}
@@ -255,6 +254,28 @@ impl<'l> Scope<'l> {
}
}
}
Type::Ptr(PtrT {
base: Type::Struct(StructT { fields, .. }),
..
}) => {
if let Some(fields) = fields.get() {
if let Some(field) = fields.get(field.0.as_str()) {
let builder = ctx.builder.as_mut().unwrap();
let inst = builder
.get_element_ptr(value, field.name.as_any_value())
.unwrap()
.as_any_value();
unsafe {
if let AnyValue::Instruction(inst) = inst {
inst.edit_flags(|f| f | ValueFlags::LValue);
}
}
return Ok(inst);
}
}
}
_ => {}
};
return Err(CompilationError {
@@ -347,7 +368,9 @@ impl<'l> Scope<'l> {
}
}
if match (lhs.ty(), rhs.ty()) {
let (lhs_ty, rhs_ty) = (lhs.ty(), rhs.ty());
if match (lhs_ty, rhs_ty) {
(Type::Int(a_ty), Type::Int(b_ty)) => a_ty == b_ty,
_ => false,
} {
@@ -367,7 +390,31 @@ impl<'l> Scope<'l> {
});
}
match (lhs.ty(), rhs.ty(), op) {
if match (lhs_ty, rhs_ty) {
(Type::Ptr(a_ty), Type::Ptr(b_ty)) => a_ty == b_ty,
_ => false,
} {
let lhs = builder.ptr_to_int(lhs, IntT::USIZE).unwrap();
let rhs = builder.ptr_to_int(rhs, IntT::USIZE).unwrap();
return Ok(match op {
BinaryOp::Eq(_) => builder.cmp(lhs, rhs, Cmp::Eq).unwrap(),
BinaryOp::Ne(_) => builder.cmp(lhs, rhs, Cmp::Ne).unwrap(),
BinaryOp::Lt(_) => builder.cmp(lhs, rhs, Cmp::Lt).unwrap(),
BinaryOp::Gt(_) => builder.cmp(lhs, rhs, Cmp::Gt).unwrap(),
BinaryOp::Le(_) => builder.cmp(lhs, rhs, Cmp::Le).unwrap(),
BinaryOp::Ge(_) => builder.cmp(lhs, rhs, Cmp::Ge).unwrap(),
_ => todo!("{lhs:?} {op:?} {rhs:?}"),
});
}
match (lhs_ty, rhs_ty, op) {
(Type::Ptr(ptr @ PtrT { base, .. }), Type::USIZE, BinaryOp::Add(_)) => {
let mut value = builder.ptr_to_int(lhs, IntT::USIZE).unwrap();
let add = builder.mul(rhs, AnyConst::SizeOf(*base).into()).unwrap();
value = builder.add(value, add).unwrap();
value = builder.int_to_ptr(value, ptr).unwrap();
Ok(value)
}
(Type::Ptr(PtrT { base, .. }), ty, BinaryOp::Assign(_)) => match *base == ty {
true => Ok(builder.store(lhs, rhs).unwrap()),
false => Err(CompilationError {
@@ -407,6 +454,9 @@ impl<'l> Scope<'l> {
(Type::Int(_), Type::Ptr(dst_ty)) => {
return Ok(builder.int_to_ptr(lhs, dst_ty).unwrap());
}
(Type::Ptr(_), dst_ty @ Type::Ptr(_)) => unsafe {
return Ok(builder.reinterpret(lhs, dst_ty, lhs.flags()).unwrap());
},
_ => todo!("{src_ty} as {dst_ty}"),
}
}
@@ -414,8 +464,10 @@ impl<'l> Scope<'l> {
}
}
Expr::If(expr) => self.compile_if(expr, ctx),
Expr::While(expr) => {
let While { value, block } = &**expr;
let While { cond, block } = &**expr;
let mut builder = ctx.builder.as_mut().unwrap();
let cond_block = builder.create_block();
@@ -424,26 +476,18 @@ impl<'l> Scope<'l> {
builder.jump(cond_block).unwrap();
builder.set_current_block(cond_block);
let condition = self.compile_expression(value, ctx)?;
let condition = self.compile_expression(cond, ctx)?;
builder = ctx.builder.as_mut().unwrap();
builder.branch(condition, exec_block, exit_block).unwrap();
builder.set_current_block(exec_block);
let mut scope = self.clone();
let mut ctx = ExpressionContext {
builder: Some(builder),
decl_names: None,
fn_queue: ctx.fn_queue,
};
let ret = self.compile_block(block, ctx)?;
builder = ctx.builder.as_mut().unwrap();
let mut last_expr = None;
for expr in &block.0 {
last_expr = Some(scope.compile_expression(expr, &mut ctx)?);
}
builder.jump(cond_block).unwrap();
builder.set_current_block(exit_block);
Ok(last_expr.unwrap_or(AnyValue::Constant(AnyConst::Void)))
Ok(ret)
}
Expr::Call {
@@ -586,9 +630,63 @@ impl<'l> Scope<'l> {
todo!()
}
Expr::Deref(expr) => {
let value = self.compile_expression(expr, ctx)?;
let builder = ctx.builder.as_mut().unwrap();
let ty = value.ty();
match value.is_lvalue() {
false => unsafe {
if !matches!(ty, Type::Ptr(_)) {
return Err(CompilationError {
kind: Kind::CannotDereference,
message: format!("Cannot dereference a value of type `{ty}`."),
location: Location::Range {
file: self.source.clone(),
range: expr.range(),
},
cause: None,
});
}
Ok(builder
.reinterpret(value, value.ty(), value.flags() | ValueFlags::LValue)
.unwrap())
},
true => unsafe {
if !matches!(
ty,
Type::Ptr(PtrT {
base: Type::Ptr(_),
..
})
) {
let Type::Ptr(PtrT { base, .. }) = ty else {
unreachable!()
};
return Err(CompilationError {
kind: Kind::CannotDereference,
message: format!("Cannot dereference a value of type `{base}`."),
location: Location::Range {
file: self.source.clone(),
range: expr.range(),
},
cause: None,
});
}
let AnyValue::Instruction(value) = builder.load(value).unwrap() else {
unreachable!()
};
value.edit_flags(|v| v | ValueFlags::LValue);
Ok(AnyValue::Instruction(value))
},
_ => unimplemented!("{}", value.is_lvalue()),
}
}
Expr::Index(expr) => {
let IndexingExpr { value, index } = &**expr;
let value = self.compile_expression(value, ctx)?;
let mut value = self.compile_expression(value, ctx)?;
let mut index = self.compile_expression(index, ctx)?;
let builder = ctx.builder.as_mut().unwrap();
@@ -609,6 +707,21 @@ impl<'l> Scope<'l> {
});
}
// TODO This is probably wrong, make it better.
while value.is_lvalue()
&& !matches!(
value.ty(),
Type::Ptr(PtrT {
base: Type::Array(_),
..
}) | Type::Ref(RefT {
base: Type::Array(_),
..
})
) {
value = builder.load(value).unwrap();
}
if value.is_lvalue() {
let gep = builder.get_element_ptr(value, index).unwrap();
return Ok(gep);
@@ -620,7 +733,7 @@ impl<'l> Scope<'l> {
Expr::Struct(ctor) => {
let ty = self.compile_expression(&ctor.r#type, ctx)?;
let AnyValue::Constant(AnyConst::Type(Type::Struct(
struct_ty @ StructT { name, fields, .. },
struct_ty @ StructT { fields, .. },
))) = ty
else {
return Err(CompilationError {
@@ -734,8 +847,11 @@ impl<'l> Scope<'l> {
fn_queue: ctx.fn_queue,
};
let mut value = self.compile_expression(value, &mut sub_ctx)?;
let builder = sub_ctx.builder.unwrap();
if value.is_lvalue() {
value = builder.load(value).unwrap();
}
if mutable {
let builder = sub_ctx.builder.unwrap();
let variable = builder.stack_alloc(value.ty()).unwrap();
builder.store(variable, value).unwrap();
value = variable;
@@ -756,6 +872,88 @@ impl<'l> Scope<'l> {
Ok(AnyConst::Void.into())
}
fn compile_block(
&mut self,
expr: &Block,
ctx: &mut ExpressionContext<'l, '_>,
) -> Result<AnyValue<'l>, CompilationError> {
let mut scope = self.clone();
let builder = ctx.builder.as_mut().unwrap();
let mut ctx = ExpressionContext {
builder: Some(builder),
decl_names: None,
fn_queue: ctx.fn_queue,
};
let mut last_expr = None;
for expr in &expr.0 {
last_expr = Some(scope.compile_expression(expr, &mut ctx)?);
}
Ok(last_expr.unwrap_or(AnyValue::Constant(AnyConst::Void)))
}
fn compile_if(
&mut self,
expr: &If,
ctx: &mut ExpressionContext<'l, '_>,
) -> Result<AnyValue<'l>, CompilationError> {
let If { cond, block, else_ } = expr;
let condition = self.compile_expression(cond, ctx)?;
self.assert_ty_eq(&condition, cond, &Type::Bool)?;
let builder = ctx.builder.as_mut().unwrap();
let then_block = builder.create_block();
let else_block = builder.create_block();
builder.branch(condition, then_block, else_block).unwrap();
builder.set_current_block(then_block);
let then_val = self.compile_block(block, ctx)?;
let builder = ctx.builder.as_mut().unwrap();
let else_ = match else_ {
None => {
builder.jump(else_block).unwrap();
builder.set_current_block(else_block);
return Ok(then_val);
}
Some(else_) => else_,
};
let continue_block = builder.create_block();
builder.jump(continue_block).unwrap();
builder.set_current_block(else_block);
let else_val = match &**else_ {
Else::Block(block) => self.compile_block(block, ctx)?,
Else::If(if_) => self.compile_if(if_, ctx)?,
};
let builder = ctx.builder.as_mut().unwrap();
builder.jump(continue_block).unwrap();
builder.set_current_block(continue_block);
if then_val.ty() != else_val.ty() {
todo!()
}
Ok(builder
.phi(
[
PhiValue {
value: then_val,
block: then_block.id,
},
PhiValue {
value: else_val,
block: else_block.id,
},
]
.into_iter(),
)
.unwrap()
.into())
}
fn assert_ty(
&self,
val: AnyValue<'l>,