Ifs and Phis
This commit is contained in:
@@ -20,6 +20,7 @@ pub enum Kind {
|
||||
NotAStruct = 0x0208,
|
||||
FieldNotFound = 0x0206,
|
||||
InvalidCast = 0x0207,
|
||||
CannotDereference = 0x0209,
|
||||
|
||||
UninitializedField = 0x0300,
|
||||
|
||||
|
||||
@@ -10,7 +10,10 @@ use leaf_backend_llvm::{
|
||||
};
|
||||
use leaf_compiler::CompilationContext;
|
||||
use leaf_parser::SourceCode;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
let alloc = SyncArenaAllocator::default();
|
||||
@@ -54,10 +57,7 @@ fn main() {
|
||||
module.print_to_stderr();
|
||||
module.verify().unwrap();
|
||||
|
||||
let asm = target_machine
|
||||
.write_to_memory_buffer(&module, FileType::Assembly)
|
||||
target_machine
|
||||
.write_to_file(&module, FileType::Assembly, Path::new("out.asm"))
|
||||
.unwrap();
|
||||
let asm = std::str::from_utf8(asm.as_slice()).unwrap();
|
||||
eprintln!("{asm}");
|
||||
std::fs::write("out.asm", asm).unwrap();
|
||||
}
|
||||
|
||||
+238
-40
@@ -4,20 +4,20 @@ use leaf_assembly::{
|
||||
assembly::Assembly,
|
||||
functions::{
|
||||
Function,
|
||||
ir::{Cmp, FunctionBodyBuilder},
|
||||
ir::{Cmp, FunctionBodyBuilder, PhiValue},
|
||||
},
|
||||
types::{
|
||||
Type,
|
||||
IntT, Type,
|
||||
compound::{Field, FieldMap, StructT},
|
||||
derivations::PtrT,
|
||||
derivations::{PtrT, RefT},
|
||||
},
|
||||
values::{AnyConst, AnyValue, Int, Value, ValueFlags},
|
||||
};
|
||||
use leaf_parser::{
|
||||
SourceCode,
|
||||
ast::{
|
||||
self, AccessExpr, BinaryExpr, BinaryOp, ConstDecl, Expr, Ident, IndexingExpr, NamePattern,
|
||||
While,
|
||||
self, AccessExpr, BinaryExpr, BinaryOp, Block, ConstDecl, Else, Expr, Ident, If,
|
||||
IndexingExpr, NamePattern, While,
|
||||
},
|
||||
};
|
||||
use std::{
|
||||
@@ -109,31 +109,30 @@ impl<'l> Scope<'l> {
|
||||
pub fn compile_function(
|
||||
&mut self,
|
||||
func: &'l Function<'l>,
|
||||
block: &ast::Block,
|
||||
block: &Arc<ast::Block>,
|
||||
fn_queue: &mut FuncQueue<'l>,
|
||||
) -> Result<(), CompilationError> {
|
||||
let mut builder = func.create_body().unwrap();
|
||||
let mut ctx = ExpressionContext {
|
||||
builder: Some(&mut builder),
|
||||
decl_names: None,
|
||||
fn_queue: fn_queue,
|
||||
};
|
||||
|
||||
let mut last_expr = None;
|
||||
for expr in &block.0 {
|
||||
last_expr = Some(self.compile_expression(expr, &mut ctx)?);
|
||||
}
|
||||
let mut ret = self.compile_block(
|
||||
block,
|
||||
&mut ExpressionContext {
|
||||
builder: Some(&mut builder),
|
||||
decl_names: None,
|
||||
fn_queue: fn_queue,
|
||||
},
|
||||
)?;
|
||||
|
||||
if !builder.current_block().has_termination() {
|
||||
match func.ty.ret_t {
|
||||
Type::Void => builder.ret(None).unwrap(),
|
||||
Type::Void => {
|
||||
builder.ret(None).unwrap();
|
||||
}
|
||||
_ => {
|
||||
if let Some(expr) = last_expr.as_mut()
|
||||
&& expr.is_lvalue()
|
||||
{
|
||||
*expr = builder.load(*expr).unwrap();
|
||||
if ret.is_lvalue() {
|
||||
ret = builder.load(ret).unwrap();
|
||||
}
|
||||
builder.ret(last_expr).unwrap()
|
||||
self.assert_ty_eq(&ret, &Expr::Block(block.clone()), &func.ty.ret_t)?;
|
||||
builder.ret(Some(ret)).unwrap();
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -255,6 +254,28 @@ impl<'l> Scope<'l> {
|
||||
}
|
||||
}
|
||||
}
|
||||
Type::Ptr(PtrT {
|
||||
base: Type::Struct(StructT { fields, .. }),
|
||||
..
|
||||
}) => {
|
||||
if let Some(fields) = fields.get() {
|
||||
if let Some(field) = fields.get(field.0.as_str()) {
|
||||
let builder = ctx.builder.as_mut().unwrap();
|
||||
let inst = builder
|
||||
.get_element_ptr(value, field.name.as_any_value())
|
||||
.unwrap()
|
||||
.as_any_value();
|
||||
|
||||
unsafe {
|
||||
if let AnyValue::Instruction(inst) = inst {
|
||||
inst.edit_flags(|f| f | ValueFlags::LValue);
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(inst);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
return Err(CompilationError {
|
||||
@@ -347,7 +368,9 @@ impl<'l> Scope<'l> {
|
||||
}
|
||||
}
|
||||
|
||||
if match (lhs.ty(), rhs.ty()) {
|
||||
let (lhs_ty, rhs_ty) = (lhs.ty(), rhs.ty());
|
||||
|
||||
if match (lhs_ty, rhs_ty) {
|
||||
(Type::Int(a_ty), Type::Int(b_ty)) => a_ty == b_ty,
|
||||
_ => false,
|
||||
} {
|
||||
@@ -367,7 +390,31 @@ impl<'l> Scope<'l> {
|
||||
});
|
||||
}
|
||||
|
||||
match (lhs.ty(), rhs.ty(), op) {
|
||||
if match (lhs_ty, rhs_ty) {
|
||||
(Type::Ptr(a_ty), Type::Ptr(b_ty)) => a_ty == b_ty,
|
||||
_ => false,
|
||||
} {
|
||||
let lhs = builder.ptr_to_int(lhs, IntT::USIZE).unwrap();
|
||||
let rhs = builder.ptr_to_int(rhs, IntT::USIZE).unwrap();
|
||||
return Ok(match op {
|
||||
BinaryOp::Eq(_) => builder.cmp(lhs, rhs, Cmp::Eq).unwrap(),
|
||||
BinaryOp::Ne(_) => builder.cmp(lhs, rhs, Cmp::Ne).unwrap(),
|
||||
BinaryOp::Lt(_) => builder.cmp(lhs, rhs, Cmp::Lt).unwrap(),
|
||||
BinaryOp::Gt(_) => builder.cmp(lhs, rhs, Cmp::Gt).unwrap(),
|
||||
BinaryOp::Le(_) => builder.cmp(lhs, rhs, Cmp::Le).unwrap(),
|
||||
BinaryOp::Ge(_) => builder.cmp(lhs, rhs, Cmp::Ge).unwrap(),
|
||||
_ => todo!("{lhs:?} {op:?} {rhs:?}"),
|
||||
});
|
||||
}
|
||||
|
||||
match (lhs_ty, rhs_ty, op) {
|
||||
(Type::Ptr(ptr @ PtrT { base, .. }), Type::USIZE, BinaryOp::Add(_)) => {
|
||||
let mut value = builder.ptr_to_int(lhs, IntT::USIZE).unwrap();
|
||||
let add = builder.mul(rhs, AnyConst::SizeOf(*base).into()).unwrap();
|
||||
value = builder.add(value, add).unwrap();
|
||||
value = builder.int_to_ptr(value, ptr).unwrap();
|
||||
Ok(value)
|
||||
}
|
||||
(Type::Ptr(PtrT { base, .. }), ty, BinaryOp::Assign(_)) => match *base == ty {
|
||||
true => Ok(builder.store(lhs, rhs).unwrap()),
|
||||
false => Err(CompilationError {
|
||||
@@ -407,6 +454,9 @@ impl<'l> Scope<'l> {
|
||||
(Type::Int(_), Type::Ptr(dst_ty)) => {
|
||||
return Ok(builder.int_to_ptr(lhs, dst_ty).unwrap());
|
||||
}
|
||||
(Type::Ptr(_), dst_ty @ Type::Ptr(_)) => unsafe {
|
||||
return Ok(builder.reinterpret(lhs, dst_ty, lhs.flags()).unwrap());
|
||||
},
|
||||
_ => todo!("{src_ty} as {dst_ty}"),
|
||||
}
|
||||
}
|
||||
@@ -414,8 +464,10 @@ impl<'l> Scope<'l> {
|
||||
}
|
||||
}
|
||||
|
||||
Expr::If(expr) => self.compile_if(expr, ctx),
|
||||
|
||||
Expr::While(expr) => {
|
||||
let While { value, block } = &**expr;
|
||||
let While { cond, block } = &**expr;
|
||||
|
||||
let mut builder = ctx.builder.as_mut().unwrap();
|
||||
let cond_block = builder.create_block();
|
||||
@@ -424,26 +476,18 @@ impl<'l> Scope<'l> {
|
||||
|
||||
builder.jump(cond_block).unwrap();
|
||||
builder.set_current_block(cond_block);
|
||||
let condition = self.compile_expression(value, ctx)?;
|
||||
let condition = self.compile_expression(cond, ctx)?;
|
||||
|
||||
builder = ctx.builder.as_mut().unwrap();
|
||||
builder.branch(condition, exec_block, exit_block).unwrap();
|
||||
builder.set_current_block(exec_block);
|
||||
|
||||
let mut scope = self.clone();
|
||||
let mut ctx = ExpressionContext {
|
||||
builder: Some(builder),
|
||||
decl_names: None,
|
||||
fn_queue: ctx.fn_queue,
|
||||
};
|
||||
let ret = self.compile_block(block, ctx)?;
|
||||
builder = ctx.builder.as_mut().unwrap();
|
||||
|
||||
let mut last_expr = None;
|
||||
for expr in &block.0 {
|
||||
last_expr = Some(scope.compile_expression(expr, &mut ctx)?);
|
||||
}
|
||||
builder.jump(cond_block).unwrap();
|
||||
builder.set_current_block(exit_block);
|
||||
Ok(last_expr.unwrap_or(AnyValue::Constant(AnyConst::Void)))
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
Expr::Call {
|
||||
@@ -586,9 +630,63 @@ impl<'l> Scope<'l> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
Expr::Deref(expr) => {
|
||||
let value = self.compile_expression(expr, ctx)?;
|
||||
let builder = ctx.builder.as_mut().unwrap();
|
||||
let ty = value.ty();
|
||||
|
||||
match value.is_lvalue() {
|
||||
false => unsafe {
|
||||
if !matches!(ty, Type::Ptr(_)) {
|
||||
return Err(CompilationError {
|
||||
kind: Kind::CannotDereference,
|
||||
message: format!("Cannot dereference a value of type `{ty}`."),
|
||||
location: Location::Range {
|
||||
file: self.source.clone(),
|
||||
range: expr.range(),
|
||||
},
|
||||
cause: None,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(builder
|
||||
.reinterpret(value, value.ty(), value.flags() | ValueFlags::LValue)
|
||||
.unwrap())
|
||||
},
|
||||
true => unsafe {
|
||||
if !matches!(
|
||||
ty,
|
||||
Type::Ptr(PtrT {
|
||||
base: Type::Ptr(_),
|
||||
..
|
||||
})
|
||||
) {
|
||||
let Type::Ptr(PtrT { base, .. }) = ty else {
|
||||
unreachable!()
|
||||
};
|
||||
return Err(CompilationError {
|
||||
kind: Kind::CannotDereference,
|
||||
message: format!("Cannot dereference a value of type `{base}`."),
|
||||
location: Location::Range {
|
||||
file: self.source.clone(),
|
||||
range: expr.range(),
|
||||
},
|
||||
cause: None,
|
||||
});
|
||||
}
|
||||
let AnyValue::Instruction(value) = builder.load(value).unwrap() else {
|
||||
unreachable!()
|
||||
};
|
||||
value.edit_flags(|v| v | ValueFlags::LValue);
|
||||
Ok(AnyValue::Instruction(value))
|
||||
},
|
||||
_ => unimplemented!("{}", value.is_lvalue()),
|
||||
}
|
||||
}
|
||||
|
||||
Expr::Index(expr) => {
|
||||
let IndexingExpr { value, index } = &**expr;
|
||||
let value = self.compile_expression(value, ctx)?;
|
||||
let mut value = self.compile_expression(value, ctx)?;
|
||||
let mut index = self.compile_expression(index, ctx)?;
|
||||
let builder = ctx.builder.as_mut().unwrap();
|
||||
|
||||
@@ -609,6 +707,21 @@ impl<'l> Scope<'l> {
|
||||
});
|
||||
}
|
||||
|
||||
// TODO This is probably wrong, make it better.
|
||||
while value.is_lvalue()
|
||||
&& !matches!(
|
||||
value.ty(),
|
||||
Type::Ptr(PtrT {
|
||||
base: Type::Array(_),
|
||||
..
|
||||
}) | Type::Ref(RefT {
|
||||
base: Type::Array(_),
|
||||
..
|
||||
})
|
||||
) {
|
||||
value = builder.load(value).unwrap();
|
||||
}
|
||||
|
||||
if value.is_lvalue() {
|
||||
let gep = builder.get_element_ptr(value, index).unwrap();
|
||||
return Ok(gep);
|
||||
@@ -620,7 +733,7 @@ impl<'l> Scope<'l> {
|
||||
Expr::Struct(ctor) => {
|
||||
let ty = self.compile_expression(&ctor.r#type, ctx)?;
|
||||
let AnyValue::Constant(AnyConst::Type(Type::Struct(
|
||||
struct_ty @ StructT { name, fields, .. },
|
||||
struct_ty @ StructT { fields, .. },
|
||||
))) = ty
|
||||
else {
|
||||
return Err(CompilationError {
|
||||
@@ -734,8 +847,11 @@ impl<'l> Scope<'l> {
|
||||
fn_queue: ctx.fn_queue,
|
||||
};
|
||||
let mut value = self.compile_expression(value, &mut sub_ctx)?;
|
||||
let builder = sub_ctx.builder.unwrap();
|
||||
if value.is_lvalue() {
|
||||
value = builder.load(value).unwrap();
|
||||
}
|
||||
if mutable {
|
||||
let builder = sub_ctx.builder.unwrap();
|
||||
let variable = builder.stack_alloc(value.ty()).unwrap();
|
||||
builder.store(variable, value).unwrap();
|
||||
value = variable;
|
||||
@@ -756,6 +872,88 @@ impl<'l> Scope<'l> {
|
||||
Ok(AnyConst::Void.into())
|
||||
}
|
||||
|
||||
fn compile_block(
|
||||
&mut self,
|
||||
expr: &Block,
|
||||
ctx: &mut ExpressionContext<'l, '_>,
|
||||
) -> Result<AnyValue<'l>, CompilationError> {
|
||||
let mut scope = self.clone();
|
||||
let builder = ctx.builder.as_mut().unwrap();
|
||||
let mut ctx = ExpressionContext {
|
||||
builder: Some(builder),
|
||||
decl_names: None,
|
||||
fn_queue: ctx.fn_queue,
|
||||
};
|
||||
|
||||
let mut last_expr = None;
|
||||
for expr in &expr.0 {
|
||||
last_expr = Some(scope.compile_expression(expr, &mut ctx)?);
|
||||
}
|
||||
Ok(last_expr.unwrap_or(AnyValue::Constant(AnyConst::Void)))
|
||||
}
|
||||
|
||||
fn compile_if(
|
||||
&mut self,
|
||||
expr: &If,
|
||||
ctx: &mut ExpressionContext<'l, '_>,
|
||||
) -> Result<AnyValue<'l>, CompilationError> {
|
||||
let If { cond, block, else_ } = expr;
|
||||
let condition = self.compile_expression(cond, ctx)?;
|
||||
self.assert_ty_eq(&condition, cond, &Type::Bool)?;
|
||||
|
||||
let builder = ctx.builder.as_mut().unwrap();
|
||||
let then_block = builder.create_block();
|
||||
let else_block = builder.create_block();
|
||||
|
||||
builder.branch(condition, then_block, else_block).unwrap();
|
||||
builder.set_current_block(then_block);
|
||||
let then_val = self.compile_block(block, ctx)?;
|
||||
let builder = ctx.builder.as_mut().unwrap();
|
||||
|
||||
let else_ = match else_ {
|
||||
None => {
|
||||
builder.jump(else_block).unwrap();
|
||||
builder.set_current_block(else_block);
|
||||
return Ok(then_val);
|
||||
}
|
||||
Some(else_) => else_,
|
||||
};
|
||||
|
||||
let continue_block = builder.create_block();
|
||||
builder.jump(continue_block).unwrap();
|
||||
builder.set_current_block(else_block);
|
||||
|
||||
let else_val = match &**else_ {
|
||||
Else::Block(block) => self.compile_block(block, ctx)?,
|
||||
Else::If(if_) => self.compile_if(if_, ctx)?,
|
||||
};
|
||||
|
||||
let builder = ctx.builder.as_mut().unwrap();
|
||||
builder.jump(continue_block).unwrap();
|
||||
builder.set_current_block(continue_block);
|
||||
|
||||
if then_val.ty() != else_val.ty() {
|
||||
todo!()
|
||||
}
|
||||
|
||||
Ok(builder
|
||||
.phi(
|
||||
[
|
||||
PhiValue {
|
||||
value: then_val,
|
||||
block: then_block.id,
|
||||
},
|
||||
PhiValue {
|
||||
value: else_val,
|
||||
block: else_block.id,
|
||||
},
|
||||
]
|
||||
.into_iter(),
|
||||
)
|
||||
.unwrap()
|
||||
.into())
|
||||
}
|
||||
|
||||
fn assert_ty(
|
||||
&self,
|
||||
val: AnyValue<'l>,
|
||||
|
||||
Reference in New Issue
Block a user