From 2aababbbe19b6108ac55a2ff9f3fd97b8f4984e8 Mon Sep 17 00:00:00 2001 From: Mia Date: Sat, 7 Mar 2026 19:16:29 +0100 Subject: [PATCH] Added fairly limited type inference --- backends/llvm/src/lib.rs | 3 ++ compiler/src/error.rs | 1 + compiler/src/scope.rs | 94 +++++++++++++++++++++++++++++++++------- parser/src/ast.rs | 9 +++- parser/src/parser.rs | 13 +++++- 5 files changed, 102 insertions(+), 18 deletions(-) diff --git a/backends/llvm/src/lib.rs b/backends/llvm/src/lib.rs index 87f2dfc..eceded1 100644 --- a/backends/llvm/src/lib.rs +++ b/backends/llvm/src/lib.rs @@ -492,6 +492,9 @@ impl<'l> CompilationContext<'l> { AnyConst::Int(Int::USize(v)) => { Some(self.native_int_ty.const_int(*v as u64, false).into()) } + AnyConst::Int(Int::ISize(v)) => { + Some(self.native_int_ty.const_int(*v as u64, true).into()) + } AnyConst::Array([]) => todo!("{val:?}"), AnyConst::Array(array) => { diff --git a/compiler/src/error.rs b/compiler/src/error.rs index 0bcac6c..1871524 100644 --- a/compiler/src/error.rs +++ b/compiler/src/error.rs @@ -21,6 +21,7 @@ pub enum Kind { FieldNotFound = 0x0206, InvalidCast = 0x0207, CannotDereference = 0x0209, + CannotInferType = 0x020A, UninitializedField = 0x0300, diff --git a/compiler/src/scope.rs b/compiler/src/scope.rs index ccc2c68..58cbf28 100644 --- a/compiler/src/scope.rs +++ b/compiler/src/scope.rs @@ -27,10 +27,22 @@ use std::{ struct ExpressionContext<'l, 'r> { decl_names: Option<&'r NamePattern>, + type_hint: Option>, builder: Option<&'r mut FunctionBodyBuilder<'l>>, fn_queue: &'r mut FuncQueue<'l>, } +impl<'l, 'r, 'a> ExpressionContext<'l, 'r> { + pub fn with_type_hit(&'a mut self, type_hint: Option>) -> ExpressionContext<'l, 'a> { + ExpressionContext { + decl_names: self.decl_names, + type_hint, + builder: self.builder.as_mut().map(|b| &mut **b), + fn_queue: self.fn_queue, + } + } +} + #[derive(Clone)] struct Variable<'l> { value: Arc>>, @@ -88,6 +100,7 @@ impl<'l> Scope<'l> { &mut ExpressionContext { decl_names: Some(&val.names), builder: None, + type_hint: None, fn_queue, }, )?; @@ -118,6 +131,7 @@ impl<'l> Scope<'l> { &mut ExpressionContext { builder: Some(&mut builder), decl_names: None, + type_hint: Some(func.ty.ret_t), fn_queue: fn_queue, }, )?; @@ -208,7 +222,20 @@ impl<'l> Scope<'l> { }; } match n.r#type.as_ref().map(|v| v.as_str()) { + None if ctx.type_hint == Some(Type::I8) => parse_number!(i8, I8), + None if ctx.type_hint == Some(Type::I16) => parse_number!(i16, I16), + None if ctx.type_hint == Some(Type::I32) => parse_number!(i32, I32), + None if ctx.type_hint == Some(Type::I64) => parse_number!(i64, I64), + None if ctx.type_hint == Some(Type::I128) => parse_number!(i128, I128), + None if ctx.type_hint == Some(Type::ISIZE) => parse_number!(i64, ISize), + None if ctx.type_hint == Some(Type::U8) => parse_number!(u8, U8), + None if ctx.type_hint == Some(Type::U16) => parse_number!(u16, U16), + None if ctx.type_hint == Some(Type::U32) => parse_number!(u32, U32), + None if ctx.type_hint == Some(Type::U64) => parse_number!(u64, U64), + None if ctx.type_hint == Some(Type::U128) => parse_number!(u128, U128), + None if ctx.type_hint == Some(Type::USIZE) => parse_number!(u64, USize), None => parse_number!(i64, ISize), + Some("i8") => parse_number!(i8, I8), Some("i16") => parse_number!(i16, I16), Some("i32") => parse_number!(i32, I32), @@ -297,11 +324,20 @@ impl<'l> Scope<'l> { } = &**bin_expr; let mut lhs = self.compile_expression(lhs_expr, ctx)?; - if lhs.is_lvalue() && !matches!(op, BinaryOp::Assign(_)) { - lhs = ctx.builder.as_mut().unwrap().load(lhs).unwrap(); - } + let type_hint = if lhs.is_lvalue() { + let Type::Ptr(PtrT { base, .. }) = lhs.ty() else { + unreachable!(); + }; + if !matches!(op, BinaryOp::Assign(_)) { + lhs = ctx.builder.as_mut().unwrap().load(lhs).unwrap(); + } + *base + } else { + lhs.ty() + }; - let mut rhs = self.compile_expression(rhs_expr, ctx)?; + let mut rhs = + self.compile_expression(rhs_expr, &mut ctx.with_type_hit(Some(type_hint)))?; if rhs.is_lvalue() { rhs = ctx.builder.as_mut().unwrap().load(rhs).unwrap(); } @@ -508,9 +544,11 @@ impl<'l> Scope<'l> { }); } }; + let mut arg_ty = func.ty.par_t.iter().cloned(); let mut args = Vec::with_capacity(args_exprs.len()); for expr in args_exprs { - let mut arg = self.compile_expression(expr, ctx)?; + let mut arg = + self.compile_expression(expr, &mut ctx.with_type_hit(arg_ty.next()))?; if arg.is_lvalue() { arg = ctx.builder.as_mut().unwrap().load(arg).unwrap(); } @@ -531,6 +569,7 @@ impl<'l> Scope<'l> { let mut expr_ctx = ExpressionContext { builder: None, decl_names: None, + type_hint: Some(Type::Type), fn_queue: ctx.fn_queue, }; let ctx = self.assembly.ctx(); @@ -731,7 +770,23 @@ impl<'l> Scope<'l> { } Expr::Struct(ctor) => { - let ty = self.compile_expression(&ctor.r#type, ctx)?; + let ty = match &ctor.r#type { + Some(ty) => self.compile_expression(ty, ctx)?, + None => match ctx.type_hint { + Some(ty) => AnyValue::Constant(AnyConst::Type(ty)), + None => { + return Err(CompilationError { + kind: Kind::CannotInferType, + message: "Type cannot be inferred.".into(), + location: Location::Range { + file: self.source.clone(), + range: ctor.range(), + }, + cause: None, + }); + } + }, + }; let AnyValue::Constant(AnyConst::Type(Type::Struct( struct_ty @ StructT { fields, .. }, ))) = ty @@ -741,7 +796,10 @@ impl<'l> Scope<'l> { message: format!("Expected struct type, got value of type `{}`.", ty.ty()), location: Location::Range { file: self.source.clone(), - range: ctor.r#type.range(), + range: match &ctor.r#type { + None => ctor.range(), + Some(ty) => ty.range(), + }, }, cause: None, }); @@ -765,7 +823,8 @@ impl<'l> Scope<'l> { cause: None, }); }; - let value = self.compile_expression(&name_value_pair.value, ctx)?; + let mut ctx = ctx.with_type_hit(Some(*ty)); + let value = self.compile_expression(&name_value_pair.value, &mut ctx)?; self.assert_ty_eq(&value, &name_value_pair.value, ty)?; non_const |= !value.flags().contains(ValueFlags::Const); values.push(value); @@ -844,6 +903,7 @@ impl<'l> Scope<'l> { let mut sub_ctx = ExpressionContext { decl_names: Some(names), builder: ctx.builder.as_deref_mut(), + type_hint: None, //TODO fn_queue: ctx.fn_queue, }; let mut value = self.compile_expression(value, &mut sub_ctx)?; @@ -874,19 +934,23 @@ impl<'l> Scope<'l> { fn compile_block( &mut self, - expr: &Block, + block: &Block, ctx: &mut ExpressionContext<'l, '_>, ) -> Result, CompilationError> { let mut scope = self.clone(); let builder = ctx.builder.as_mut().unwrap(); - let mut ctx = ExpressionContext { - builder: Some(builder), - decl_names: None, - fn_queue: ctx.fn_queue, - }; let mut last_expr = None; - for expr in &expr.0 { + for (i, expr) in block.0.iter().enumerate() { + let mut ctx = ExpressionContext { + builder: Some(builder), + decl_names: None, + type_hint: match i == block.0.len() - 1 { + true => ctx.type_hint, + false => None, + }, + fn_queue: ctx.fn_queue, + }; last_expr = Some(scope.compile_expression(expr, &mut ctx)?); } Ok(last_expr.unwrap_or(AnyValue::Constant(AnyConst::Void))) diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 6606fae..0a27d68 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -178,8 +178,15 @@ pub struct Field { #[derive(Debug)] pub struct StructCtor { - pub r#type: Expr, + pub r#type: Option, pub r#values: IndexMap, + pub(crate) range: Range, +} + +impl StructCtor { + pub fn range(&self) -> Range { + self.range.clone() + } } #[derive(Debug)] diff --git a/parser/src/parser.rs b/parser/src/parser.rs index 740dbb9..e004aad 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -81,9 +81,18 @@ peg::parser! { lhs:@ "(" __ args:(expr() ** list_separator()) __ ")" { Expr::Call { func: lhs.into(), args } } value:@ "[" __ index:expr() __ "]" { Expr::Index(IndexingExpr { value, index }.into()) } - r#type:@ __ "#{" __ values:name_value_pairs() __ "}" { Expr::Struct( + ty:@ __ "#{" __ values:name_value_pairs() __ "}" e:position!() { Expr::Struct( StructCtor { - r#type, values: values.into_iter().map(|v| (v.name.0.clone(), v)).collect() + range: ty.range().start..e, + r#type: Some(ty), + values: values.into_iter().map(|v| (v.name.0.clone(), v)).collect(), + }.into() + ) } + s:position!() "#{" __ values:name_value_pairs() __ "}" e:position!() { Expr::Struct( + StructCtor { + r#type: None, + values: values.into_iter().map(|v| (v.name.0.clone(), v)).collect(), + range: s..e, }.into() ) } --