From 6b24980fc77566674f87b21a4b524632f5262083 Mon Sep 17 00:00:00 2001 From: John Date: Sun, 19 Oct 2025 18:56:32 -0400 Subject: [PATCH] engine: borrow the bset and rset instead of RCing them --- compiler/cl-typeck/src/stage/infer/engine.rs | 65 +++-- compiler/cl-typeck/src/stage/infer/error.rs | 4 + .../cl-typeck/src/stage/infer/inference.rs | 261 ++++++++++-------- 3 files changed, 184 insertions(+), 146 deletions(-) diff --git a/compiler/cl-typeck/src/stage/infer/engine.rs b/compiler/cl-typeck/src/stage/infer/engine.rs index e0a15bc..55b8a19 100644 --- a/compiler/cl-typeck/src/stage/infer/engine.rs +++ b/compiler/cl-typeck/src/stage/infer/engine.rs @@ -1,4 +1,4 @@ -use std::{cell::Cell, collections::HashSet, rc::Rc}; +use std::collections::HashSet; use super::error::InferenceError; use crate::{ @@ -31,19 +31,19 @@ use cl_ast::Sym; - for type T -> R // on a per-case basis! */ -type HandleSet = Rc>>; +type HandleSet<'h> = Option<&'h mut Option>; -pub struct InferenceEngine<'table, 'a> { +pub struct InferenceEngine<'table, 'a, 'b, 'r> { pub(super) table: &'table mut Table<'a>, /// The current working node pub(crate) at: Handle, /// The current breakset - pub(crate) bset: HandleSet, + pub(crate) bset: HandleSet<'b>, /// The current returnset - pub(crate) rset: HandleSet, + pub(crate) rset: HandleSet<'r>, } -impl<'table, 'a> InferenceEngine<'table, 'a> { +impl<'table, 'a, 'b, 'r> InferenceEngine<'table, 'a, 'b, 'r> { /// Infers the type of an object by deferring to [`Inference::infer()`] pub fn infer(&mut self, inferrable: &'a impl Inference<'a>) -> Result { inferrable.infer(self) @@ -56,12 +56,12 @@ impl<'table, 'a> InferenceEngine<'table, 'a> { /// Constructs an [`InferenceEngine`] that borrows the same table as `self`, /// but with a shortened lifetime. - pub fn scoped(&mut self) -> InferenceEngine<'_, 'a> { + pub fn scoped(&mut self) -> InferenceEngine<'_, 'a, '_, '_> { InferenceEngine { at: self.at, table: self.table, - bset: self.bset.clone(), - rset: self.rset.clone(), + bset: self.bset.as_deref_mut(), + rset: self.rset.as_deref_mut(), } } @@ -92,7 +92,7 @@ impl<'table, 'a> InferenceEngine<'table, 'a> { }; match &ret { - &Ok(handle) => println!("=> {}", eng.entry(handle)), + Ok(handle) => println!("=> {}", eng.entry(*handle)), Err(err @ InferenceError::AnnotationEval(_)) => eprintln!("=> ERROR: {err}"), Err(InferenceError::FieldCount(h, want, got)) => { eprintln!("=> ERROR: Field count {want} != {got} in {}", eng.entry(*h)) @@ -101,13 +101,14 @@ impl<'table, 'a> InferenceEngine<'table, 'a> { Err(InferenceError::Mismatch(h1, h2)) => eprintln!( "=> ERROR: Type mismatch {} != {}", eng.entry(*h1), - eng.entry(*h2) + eng.entry(*h2), ), Err(InferenceError::Recursive(h1, h2)) => eprintln!( "=> ERROR: Cycle found in types {}, {}", eng.entry(*h1), - eng.entry(*h2) + eng.entry(*h2), ), + Err(InferenceError::NoBreak | InferenceError::NoReturn) => {} } println!(); @@ -120,35 +121,43 @@ impl<'table, 'a> InferenceEngine<'table, 'a> { } /// Constructs a new InferenceEngine with the - pub fn at(&mut self, at: Handle) -> InferenceEngine<'_, 'a> { + pub fn at(&mut self, at: Handle) -> InferenceEngine<'_, 'a, '_, '_> { InferenceEngine { at, ..self.scoped() } } - pub fn open_bset(&mut self) -> InferenceEngine<'_, 'a> { - InferenceEngine { bset: Default::default(), ..self.scoped() } + pub fn open_bset<'ob>( + &mut self, + bset: &'ob mut Option, + ) -> InferenceEngine<'_, 'a, 'ob, '_> { + InferenceEngine { bset: Some(bset), ..self.scoped() } } - pub fn open_rset(&mut self) -> InferenceEngine<'_, 'a> { - InferenceEngine { rset: Default::default(), ..self.scoped() } + pub fn open_rset<'or>( + &mut self, + rset: &'or mut Option, + ) -> InferenceEngine<'_, 'a, '_, 'or> { + InferenceEngine { rset: Some(rset), ..self.scoped() } } pub fn bset(&mut self, ty: Handle) -> Result<(), InferenceError> { - match self.bset.get() { - Some(bset) => self.unify(ty, bset), - None => { - self.bset.set(Some(ty)); + match self.bset.as_mut() { + Some(&mut &mut Some(bset)) => self.unify(ty, bset), + Some(none) => { + let _ = none.insert(ty); Ok(()) } + None => Err(InferenceError::NoBreak), } } pub fn rset(&mut self, ty: Handle) -> Result<(), InferenceError> { - match self.rset.get() { - Some(rset) => self.unify(ty, rset), - None => { - self.rset.set(Some(ty)); + match self.rset.as_mut() { + Some(&mut &mut Some(rset)) => self.unify(ty, rset), + Some(none) => { + let _ = none.insert(ty); Ok(()) } + None => Err(InferenceError::NoReturn), } } @@ -277,7 +286,7 @@ impl<'table, 'a> InferenceEngine<'table, 'a> { } /// Creates a new locally-scoped InferenceEngine. - pub fn block_scope(&mut self) -> InferenceEngine<'_, 'a> { + pub fn block_scope(&mut self) -> InferenceEngine<'_, 'a, '_, '_> { let scope = self.table.new_entry(self.at, NodeKind::Scope); self.table.add_child(self.at, "".into(), scope); self.at(scope) @@ -459,8 +468,8 @@ impl<'table, 'a> InferenceEngine<'table, 'a> { TypeKind::Adt(Adt::Union(items)) => { items.iter().any(|(_, other)| self.occurs_in(this, *other)) } - TypeKind::Ref(other) => self.occurs_in(this, *other), - TypeKind::Ptr(other) => self.occurs_in(this, *other), + TypeKind::Ref(_) => false, + TypeKind::Ptr(_) => false, TypeKind::Slice(other) => self.occurs_in(this, *other), TypeKind::Array(other, _) => self.occurs_in(this, *other), TypeKind::Tuple(handles) => handles.iter().any(|&other| self.occurs_in(this, other)), diff --git a/compiler/cl-typeck/src/stage/infer/error.rs b/compiler/cl-typeck/src/stage/infer/error.rs index 4ae9fd9..4a6d0bb 100644 --- a/compiler/cl-typeck/src/stage/infer/error.rs +++ b/compiler/cl-typeck/src/stage/infer/error.rs @@ -11,6 +11,8 @@ pub enum InferenceError { NotFound(Path), Mismatch(Handle, Handle), Recursive(Handle, Handle), + NoBreak, + NoReturn, } impl std::error::Error for InferenceError {} @@ -28,6 +30,8 @@ impl fmt::Display for InferenceError { InferenceError::NotFound(p) => write!(f, "Path not visible in scope: {p}"), InferenceError::Mismatch(a, b) => write!(f, "Type mismatch: {a:?} != {b:?}"), InferenceError::Recursive(_, _) => write!(f, "Recursive type!"), + InferenceError::NoBreak => write!(f, "Encountered break outside loop!"), + InferenceError::NoReturn => write!(f, "Encountered return outside function!"), } } } diff --git a/compiler/cl-typeck/src/stage/infer/inference.rs b/compiler/cl-typeck/src/stage/infer/inference.rs index fa59420..97f8503 100644 --- a/compiler/cl-typeck/src/stage/infer/inference.rs +++ b/compiler/cl-typeck/src/stage/infer/inference.rs @@ -18,11 +18,11 @@ type IfResult = Result; pub trait Inference<'a> { /// Performs type inference - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult; + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult; } impl<'a> Inference<'a> for File { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { name: _, items } = self; for item in items { item.infer(e)?; @@ -32,14 +32,14 @@ impl<'a> Inference<'a> for File { } impl<'a> Inference<'a> for Item { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { span: _, attrs: _, vis: _, kind } = self; kind.infer(e) } } impl<'a> Inference<'a> for ItemKind { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { match self { ItemKind::Module(v) => v.infer(e), ItemKind::Alias(v) => v.infer(e), @@ -55,7 +55,7 @@ impl<'a> Inference<'a> for ItemKind { } impl<'a> Inference<'a> for Generics { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { // bind names for name in &self.vars { let ty = e.new_var(); @@ -66,7 +66,7 @@ impl<'a> Inference<'a> for Generics { } impl<'a> Inference<'a> for Module { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { name, file } = self; let Some(file) = file else { return Err(InferenceError::NotFound((*name).into())); @@ -77,7 +77,7 @@ impl<'a> Inference<'a> for Module { } impl<'a> Inference<'a> for Alias { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { name: _, from } = self; // let this = e.by_name(name)?; let alias = if let Some(from) = from { @@ -99,7 +99,7 @@ impl<'a> Inference<'a> for Alias { impl<'a> Inference<'a> for Const { #[allow(unused)] - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { name: _, ty, init } = self; // Same as static let node = e.at; //.by_name(name)?; @@ -116,7 +116,7 @@ impl<'a> Inference<'a> for Const { impl<'a> Inference<'a> for Static { #[allow(unused)] - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Static { mutable, name, ty, init } = self; let node = e.at; //e.by_name(name)?; let ty = e.infer(ty)?; @@ -132,7 +132,7 @@ impl<'a> Inference<'a> for Static { impl<'a> Inference<'a> for Function { #[allow(unused)] - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { name, gens, sign, bind, body } = self; // bind name to signature let node = e.at; // e.by_name(name)?; @@ -149,7 +149,9 @@ impl<'a> Inference<'a> for Function { let arg = scope.by_name(sign.args.as_ref())?; scope.unify(pat, arg); - let mut retscope = scope.open_rset(); + let mut rset = None; + + let mut retscope = scope.open_rset(&mut rset); // infer body let bodty = retscope.infer(body)?; @@ -157,7 +159,8 @@ impl<'a> Inference<'a> for Function { // unify body with rety retscope.unify(bodty, rety)?; // unify rset with rety - if let Some(rset) = retscope.rset.get() { + + if let Some(rset) = rset { scope.unify(rset, rety)?; } Ok(node) @@ -168,7 +171,7 @@ impl<'a> Inference<'a> for Function { // there are no bodies impl<'a> Inference<'a> for Enum { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { name: _, gens, variants } = self; let node = e.at; //e.by_name(name)?; let mut scope = e.at(node); @@ -184,7 +187,7 @@ impl<'a> Inference<'a> for Enum { } impl<'a> Inference<'a> for Variant { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { name, kind: _, body } = self; let node = e.by_name(name)?; @@ -210,7 +213,7 @@ impl<'a> Inference<'a> for Variant { } impl<'a> Inference<'a> for Struct { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { name, gens, kind: _ } = self; let node = e.by_name(name)?; let mut e = e.at(node); @@ -221,7 +224,7 @@ impl<'a> Inference<'a> for Struct { } impl<'a> Inference<'a> for Impl { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { gens, target, body } = self; // TODO: match gens to target gens gens.infer(e)?; @@ -233,7 +236,7 @@ impl<'a> Inference<'a> for Impl { } impl<'a> Inference<'a> for ImplKind { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { match self { ImplKind::Type(ty) => ty.infer(e), ImplKind::Trait { impl_trait: _, for_type } => for_type.infer(e), @@ -242,13 +245,13 @@ impl<'a> Inference<'a> for ImplKind { } impl<'a> Inference<'a> for Ty { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { Ok(e.by_name(self)?) } } impl<'a> Inference<'a> for cl_ast::Stmt { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { span: _, kind, semi } = self; let out = kind.infer(e)?; Ok(match semi { @@ -259,7 +262,7 @@ impl<'a> Inference<'a> for cl_ast::Stmt { } impl<'a> Inference<'a> for cl_ast::StmtKind { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { match self { StmtKind::Empty => Ok(e.empty()), StmtKind::Item(item) => item.infer(e), @@ -269,15 +272,15 @@ impl<'a> Inference<'a> for cl_ast::StmtKind { } impl<'a> Inference<'a> for cl_ast::Expr { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let out = self.kind.infer(e)?; - println!("expr ({self}) -> {}", e.entry(out)); + // println!("expr ({self}) -> {}", e.entry(out)); Ok(out) } } impl<'a> Inference<'a> for cl_ast::ExprKind { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { match self { ExprKind::Empty => Ok(e.empty()), ExprKind::Closure(closure) => closure.infer(e), @@ -314,15 +317,17 @@ impl<'a> Inference<'a> for cl_ast::ExprKind { } impl<'a> Inference<'a> for Closure { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Self { arg, body } = self; let args = arg.infer(e)?; let mut scope = e.block_scope(); - let mut scope = scope.open_rset(); - let rety = scope.infer(body)?; - if let Some(rset) = scope.rset.get() { + let mut rset = None; + let mut retscope = scope.open_rset(&mut rset); + let rety = retscope.infer(body)?; + + if let Some(rset) = rset { e.unify(rety, rset)?; } @@ -331,7 +336,7 @@ impl<'a> Inference<'a> for Closure { } impl<'a> Inference<'a> for Tuple { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Tuple { exprs } = self; exprs .iter() @@ -345,7 +350,7 @@ impl<'a> Inference<'a> for Tuple { } impl<'a> Inference<'a> for Structor { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Structor { to, init } = self; // Evaluate the path in the current context let to = to.infer(e)?; @@ -386,7 +391,7 @@ impl<'a> Inference<'a> for Structor { } impl<'a> Inference<'a> for Array { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Array { values } = self; let out = e.new_inferred(); for value in values { @@ -398,7 +403,7 @@ impl<'a> Inference<'a> for Array { } impl<'a> Inference<'a> for ArrayRep { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let ArrayRep { value, repeat } = self; let ty = value.infer(e)?; let rep = repeat.infer(e)?; @@ -414,7 +419,7 @@ impl<'a> Inference<'a> for ArrayRep { } impl<'a> Inference<'a> for AddrOf { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let AddrOf { mutable: _, expr } = self; // TODO: mut ref let ty = expr.infer(e)?; @@ -423,13 +428,13 @@ impl<'a> Inference<'a> for AddrOf { } impl<'a> Inference<'a> for Quote { - fn infer(&'a self, _e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, _e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { todo!("Quote: {self}") } } impl<'a> Inference<'a> for Literal { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let ty = match self { Literal::Bool(_) => e.bool(), Literal::Char(_) => e.char(), @@ -445,14 +450,14 @@ impl<'a> Inference<'a> for Literal { } impl<'a> Inference<'a> for Group { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Group { expr } = self; expr.infer(e) } } impl<'a> Inference<'a> for Block { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Block { stmts } = self; let mut e = e.block_scope(); let empty = e.empty(); @@ -483,7 +488,7 @@ impl<'a> Inference<'a> for Block { } impl<'a> Inference<'a> for Assign { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Assign { parts } = self; let (head, tail) = parts.as_ref(); // Infer the tail expression @@ -498,7 +503,7 @@ impl<'a> Inference<'a> for Assign { } impl<'a> Inference<'a> for Modify { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Modify { kind: _, parts } = self; let (head, tail) = parts.as_ref(); // Infer the tail expression @@ -513,7 +518,7 @@ impl<'a> Inference<'a> for Modify { } impl<'a> Inference<'a> for Binary { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { use BinaryKind as Bk; let Binary { kind, parts } = self; let (head, tail) = parts.as_ref(); @@ -583,7 +588,7 @@ impl<'a> Inference<'a> for Binary { } impl<'a> Inference<'a> for Unary { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Unary { kind, tail } = self; match kind { UnaryKind::Deref => { @@ -596,18 +601,18 @@ impl<'a> Inference<'a> for Unary { } } UnaryKind::Loop => { - let mut e = e.block_scope(); + let mut scope = e.block_scope(); // Enter a new breakset - let mut e = e.open_bset(); - - // Infer the fail branch - let tail = tail.infer(&mut e)?; - // Unify the pass branch with Empty - let empt = e.empty(); - e.unify(tail, empt)?; + let mut bset = None; + let mut bscope = scope.open_bset(&mut bset); + // Infer the tail + let body = tail.infer(&mut bscope)?; + // Unify the loop body with `empty` + let unit = bscope.empty(); + bscope.unify(unit, body)?; // Return breakset - match e.bset.get() { + match bset { Some(bset) => Ok(bset), None => Ok(e.never()), } @@ -623,7 +628,7 @@ impl<'a> Inference<'a> for Unary { } impl<'a> Inference<'a> for Member { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Member { head, kind } = self; // Infer the head expression let head = head.infer(e)?; @@ -647,10 +652,19 @@ impl<'a> Inference<'a> for Member { Err(InferenceError::NotFound(Path::from(*name))) } } - MemberKind::Struct(name) => match ty.nav(&[PathPart::Ident(*name)]) { - Some(ty) => Ok(ty.id()), - None => Err(InferenceError::NotFound(Path::from(*name))), - }, + MemberKind::Struct(name) => { + if let Some(TypeKind::Adt(Adt::Struct(members))) = ty.ty() { + for member in members { + if member.0 == *name { + return Ok(member.2); + } + } + }; + match ty.nav(&[PathPart::Ident(*name)]) { + Some(ty) => Ok(ty.id()), + None => Err(InferenceError::NotFound(Path::from(*name))), + } + } MemberKind::Tuple(Literal::Int(idx)) => match ty.ty() { Some(TypeKind::Tuple(tys)) => tys .get(*idx as usize) @@ -669,7 +683,7 @@ impl<'a> Inference<'a> for Member { } impl<'a> Inference<'a> for Index { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Index { head, indices } = self; let usize = e.usize(); // Infer the head expression @@ -703,7 +717,7 @@ impl<'a> Inference<'a> for Index { } impl<'a> Inference<'a> for Cast { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Cast { head, ty } = self; // Infer the head expression let _head = head.infer(e)?; @@ -719,14 +733,14 @@ impl<'a> Inference<'a> for Cast { } impl<'a> Inference<'a> for Path { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { e.by_name(self) .map_err(|_| InferenceError::NotFound(self.clone())) } } impl<'a> Inference<'a> for Let { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Let { mutable: _, name, ty, init } = self; let ty = match ty { Some(ty) => ty @@ -752,7 +766,7 @@ impl<'a> Inference<'a> for Let { } impl<'a> Inference<'a> for Match { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Match { scrutinee, arms } = self; // Infer the scrutinee let scrutinee = scrutinee.infer(e)?; @@ -783,7 +797,7 @@ impl<'a> Inference<'a> for Match { impl<'a> Inference<'a> for Pattern { // TODO: This is the wrong way to typeck pattern matching. - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { match self { Pattern::Name(name) => { // Evaluating a pattern creates and enters a new scope. @@ -867,47 +881,61 @@ impl<'a> Inference<'a> for Pattern { } impl<'a> Inference<'a> for While { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let While { cond, pass, fail } = self; - // Infer the condition - let cond = cond.infer(e)?; - // Unify the condition with bool - let bool = e.bool(); - e.unify(bool, cond)?; + let mut bset = None; + + // Open a block scope so the loop condition doesn't escape + { + let mut scope = e.block_scope(); + // Infer the condition + let cond = cond.infer(&mut scope)?; + // Unify the condition with bool + let bool = scope.bool(); + scope.unify(bool, cond)?; + + // Open a breakset for the pass branch + { + let mut body_scope = scope.open_bset(&mut bset); + // Infer the pass branch + let pass = pass.infer(&mut body_scope)?; + // Unify the pass branch with Empty + let empt = body_scope.empty(); + body_scope.unify(empt, pass)?; + } + } // Infer the fail branch let fail = fail.infer(e)?; + // Unify the fail branch with breakset - let mut e = e.open_bset(); - - // Infer the pass branch - let pass = pass.infer(&mut e)?; - // Unify the pass branch with Empty - let empt = e.empty(); - e.unify(pass, empt)?; - - match e.bset.get() { - None => Ok(e.empty()), - Some(bset) => { - e.unify(fail, bset)?; - Ok(fail) - } + if let Some(bset) = bset { + println!("bset: {}", e.entry(bset)); + e.unify(bset, fail)?; } + Ok(fail) } } impl<'a> Inference<'a> for If { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let If { cond, pass, fail } = self; - // Do inference on the condition' - let cond = cond.infer(e)?; - // Unify the condition with bool - let bool = e.bool(); - e.unify(bool, cond)?; - // Do inference on the pass branch - let pass = pass.infer(e)?; + + // Open a block scope so the condition doesn't escape + let pass = { + let mut scope = e.block_scope(); + // Do inference on the condition' + let cond = cond.infer(&mut scope)?; + // Unify the condition with bool + let bool = scope.bool(); + scope.unify(bool, cond)?; + // Do inference on the pass branch + pass.infer(&mut scope)? + }; + // Do inference on the fail branch - let fail = fail.infer(e)?; + let fail = fail.infer(&mut e.block_scope())?; + // Unify pass and fail e.unify(pass, fail)?; // Return the result @@ -916,55 +944,52 @@ impl<'a> Inference<'a> for If { } impl<'a> Inference<'a> for For { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let For { bind, cond, pass, fail } = self; - let mut e = e.block_scope(); + let mut scope = e.block_scope(); - let bind = bind.infer(&mut e)?; + let bind = bind.infer(&mut scope)?; // What does it mean to be iterable? Why, `next()`, of course! - let cond = cond.infer(&mut e)?; - let cond = e.prune(cond); - if let Some((args, rety)) = e.get_fn(cond, "next".into()) { + let cond = cond.infer(&mut scope)?; + let cond = scope.prune(cond); + if let Some((args, rety)) = scope.get_fn(cond, "next".into()) { // Check that the args are correct - let params = vec![e.new_ref(cond)]; - let params = e.new_tuple(params); - e.unify(args, params)?; - e.unify(rety, bind)?; + let params = vec![scope.new_ref(cond)]; + let params = scope.new_tuple(params); + scope.unify(args, params)?; + scope.unify(rety, bind)?; } - // Enter a new breakset - let mut e = e.open_bset(); - - // Infer the fail branch - let fail = fail.infer(&mut e)?; // Open a breakset - let mut e = e.open_bset(); + let mut bset = None; + let mut bscope = scope.open_bset(&mut bset); // Infer the pass branch - let pass = pass.infer(&mut e)?; + let pass = pass.infer(&mut bscope)?; // Unify the pass branch with Empty - let empt = e.empty(); - e.unify(pass, empt)?; + let empt = bscope.empty(); + bscope.unify(pass, empt)?; - // Return breakset - if let Some(bset) = e.bset.get() { + // Infer the fail branch + let fail = fail.infer(&mut e.block_scope())?; + + // Unify the fail branch with the breakset + if let Some(bset) = bset { e.unify(fail, bset)?; - Ok(fail) - } else { - Ok(e.empty()) } + Ok(fail) } } impl<'a> Inference<'a> for Else { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { self.body.infer(e) } } impl<'a> Inference<'a> for Break { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Break { body } = self; // Infer the body of the break let ty = body.infer(e)?; @@ -976,7 +1001,7 @@ impl<'a> Inference<'a> for Break { } impl<'a> Inference<'a> for Return { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { let Return { body } = self; // Infer the body of the return let ty = body.infer(e)?; @@ -988,7 +1013,7 @@ impl<'a> Inference<'a> for Return { } impl<'a, I: Inference<'a>> Inference<'a> for Option { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { match self { Some(expr) => expr.infer(e), None => Ok(e.empty()), @@ -996,7 +1021,7 @@ impl<'a, I: Inference<'a>> Inference<'a> for Option { } } impl<'a, I: Inference<'a>> Inference<'a> for Box { - fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { + fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult { self.as_ref().infer(e) } }