engine: borrow the bset and rset instead of RCing them

This commit is contained in:
John 2025-10-19 18:56:32 -04:00
parent 55324af358
commit 6b24980fc7
3 changed files with 184 additions and 146 deletions

View File

@ -1,4 +1,4 @@
use std::{cell::Cell, collections::HashSet, rc::Rc}; use std::collections::HashSet;
use super::error::InferenceError; use super::error::InferenceError;
use crate::{ use crate::{
@ -31,19 +31,19 @@ use cl_ast::Sym;
- for<T, R> type T -> R // on a per-case basis! - for<T, R> type T -> R // on a per-case basis!
*/ */
type HandleSet = Rc<Cell<Option<Handle>>>; type HandleSet<'h> = Option<&'h mut Option<Handle>>;
pub struct InferenceEngine<'table, 'a> { pub struct InferenceEngine<'table, 'a, 'b, 'r> {
pub(super) table: &'table mut Table<'a>, pub(super) table: &'table mut Table<'a>,
/// The current working node /// The current working node
pub(crate) at: Handle, pub(crate) at: Handle,
/// The current breakset /// The current breakset
pub(crate) bset: HandleSet, pub(crate) bset: HandleSet<'b>,
/// The current returnset /// The current returnset
pub(crate) rset: HandleSet, pub(crate) rset: HandleSet<'r>,
} }
impl<'table, 'a> InferenceEngine<'table, 'a> { impl<'table, 'a, 'b, 'r> InferenceEngine<'table, 'a, 'b, 'r> {
/// Infers the type of an object by deferring to [`Inference::infer()`] /// Infers the type of an object by deferring to [`Inference::infer()`]
pub fn infer(&mut self, inferrable: &'a impl Inference<'a>) -> Result<Handle, InferenceError> { pub fn infer(&mut self, inferrable: &'a impl Inference<'a>) -> Result<Handle, InferenceError> {
inferrable.infer(self) inferrable.infer(self)
@ -56,12 +56,12 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
/// Constructs an [`InferenceEngine`] that borrows the same table as `self`, /// Constructs an [`InferenceEngine`] that borrows the same table as `self`,
/// but with a shortened lifetime. /// but with a shortened lifetime.
pub fn scoped(&mut self) -> InferenceEngine<'_, 'a> { pub fn scoped(&mut self) -> InferenceEngine<'_, 'a, '_, '_> {
InferenceEngine { InferenceEngine {
at: self.at, at: self.at,
table: self.table, table: self.table,
bset: self.bset.clone(), bset: self.bset.as_deref_mut(),
rset: self.rset.clone(), rset: self.rset.as_deref_mut(),
} }
} }
@ -92,7 +92,7 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
}; };
match &ret { match &ret {
&Ok(handle) => println!("=> {}", eng.entry(handle)), Ok(handle) => println!("=> {}", eng.entry(*handle)),
Err(err @ InferenceError::AnnotationEval(_)) => eprintln!("=> ERROR: {err}"), Err(err @ InferenceError::AnnotationEval(_)) => eprintln!("=> ERROR: {err}"),
Err(InferenceError::FieldCount(h, want, got)) => { Err(InferenceError::FieldCount(h, want, got)) => {
eprintln!("=> ERROR: Field count {want} != {got} in {}", eng.entry(*h)) eprintln!("=> ERROR: Field count {want} != {got} in {}", eng.entry(*h))
@ -101,13 +101,14 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
Err(InferenceError::Mismatch(h1, h2)) => eprintln!( Err(InferenceError::Mismatch(h1, h2)) => eprintln!(
"=> ERROR: Type mismatch {} != {}", "=> ERROR: Type mismatch {} != {}",
eng.entry(*h1), eng.entry(*h1),
eng.entry(*h2) eng.entry(*h2),
), ),
Err(InferenceError::Recursive(h1, h2)) => eprintln!( Err(InferenceError::Recursive(h1, h2)) => eprintln!(
"=> ERROR: Cycle found in types {}, {}", "=> ERROR: Cycle found in types {}, {}",
eng.entry(*h1), eng.entry(*h1),
eng.entry(*h2) eng.entry(*h2),
), ),
Err(InferenceError::NoBreak | InferenceError::NoReturn) => {}
} }
println!(); println!();
@ -120,35 +121,43 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
} }
/// Constructs a new InferenceEngine with the /// Constructs a new InferenceEngine with the
pub fn at(&mut self, at: Handle) -> InferenceEngine<'_, 'a> { pub fn at(&mut self, at: Handle) -> InferenceEngine<'_, 'a, '_, '_> {
InferenceEngine { at, ..self.scoped() } InferenceEngine { at, ..self.scoped() }
} }
pub fn open_bset(&mut self) -> InferenceEngine<'_, 'a> { pub fn open_bset<'ob>(
InferenceEngine { bset: Default::default(), ..self.scoped() } &mut self,
bset: &'ob mut Option<Handle>,
) -> InferenceEngine<'_, 'a, 'ob, '_> {
InferenceEngine { bset: Some(bset), ..self.scoped() }
} }
pub fn open_rset(&mut self) -> InferenceEngine<'_, 'a> { pub fn open_rset<'or>(
InferenceEngine { rset: Default::default(), ..self.scoped() } &mut self,
rset: &'or mut Option<Handle>,
) -> InferenceEngine<'_, 'a, '_, 'or> {
InferenceEngine { rset: Some(rset), ..self.scoped() }
} }
pub fn bset(&mut self, ty: Handle) -> Result<(), InferenceError> { pub fn bset(&mut self, ty: Handle) -> Result<(), InferenceError> {
match self.bset.get() { match self.bset.as_mut() {
Some(bset) => self.unify(ty, bset), Some(&mut &mut Some(bset)) => self.unify(ty, bset),
None => { Some(none) => {
self.bset.set(Some(ty)); let _ = none.insert(ty);
Ok(()) Ok(())
} }
None => Err(InferenceError::NoBreak),
} }
} }
pub fn rset(&mut self, ty: Handle) -> Result<(), InferenceError> { pub fn rset(&mut self, ty: Handle) -> Result<(), InferenceError> {
match self.rset.get() { match self.rset.as_mut() {
Some(rset) => self.unify(ty, rset), Some(&mut &mut Some(rset)) => self.unify(ty, rset),
None => { Some(none) => {
self.rset.set(Some(ty)); let _ = none.insert(ty);
Ok(()) Ok(())
} }
None => Err(InferenceError::NoReturn),
} }
} }
@ -277,7 +286,7 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
} }
/// Creates a new locally-scoped InferenceEngine. /// Creates a new locally-scoped InferenceEngine.
pub fn block_scope(&mut self) -> InferenceEngine<'_, 'a> { pub fn block_scope(&mut self) -> InferenceEngine<'_, 'a, '_, '_> {
let scope = self.table.new_entry(self.at, NodeKind::Scope); let scope = self.table.new_entry(self.at, NodeKind::Scope);
self.table.add_child(self.at, "".into(), scope); self.table.add_child(self.at, "".into(), scope);
self.at(scope) self.at(scope)
@ -459,8 +468,8 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
TypeKind::Adt(Adt::Union(items)) => { TypeKind::Adt(Adt::Union(items)) => {
items.iter().any(|(_, other)| self.occurs_in(this, *other)) items.iter().any(|(_, other)| self.occurs_in(this, *other))
} }
TypeKind::Ref(other) => self.occurs_in(this, *other), TypeKind::Ref(_) => false,
TypeKind::Ptr(other) => self.occurs_in(this, *other), TypeKind::Ptr(_) => false,
TypeKind::Slice(other) => self.occurs_in(this, *other), TypeKind::Slice(other) => self.occurs_in(this, *other),
TypeKind::Array(other, _) => self.occurs_in(this, *other), TypeKind::Array(other, _) => self.occurs_in(this, *other),
TypeKind::Tuple(handles) => handles.iter().any(|&other| self.occurs_in(this, other)), TypeKind::Tuple(handles) => handles.iter().any(|&other| self.occurs_in(this, other)),

View File

@ -11,6 +11,8 @@ pub enum InferenceError {
NotFound(Path), NotFound(Path),
Mismatch(Handle, Handle), Mismatch(Handle, Handle),
Recursive(Handle, Handle), Recursive(Handle, Handle),
NoBreak,
NoReturn,
} }
impl std::error::Error for InferenceError {} impl std::error::Error for InferenceError {}
@ -28,6 +30,8 @@ impl fmt::Display for InferenceError {
InferenceError::NotFound(p) => write!(f, "Path not visible in scope: {p}"), InferenceError::NotFound(p) => write!(f, "Path not visible in scope: {p}"),
InferenceError::Mismatch(a, b) => write!(f, "Type mismatch: {a:?} != {b:?}"), InferenceError::Mismatch(a, b) => write!(f, "Type mismatch: {a:?} != {b:?}"),
InferenceError::Recursive(_, _) => write!(f, "Recursive type!"), InferenceError::Recursive(_, _) => write!(f, "Recursive type!"),
InferenceError::NoBreak => write!(f, "Encountered break outside loop!"),
InferenceError::NoReturn => write!(f, "Encountered return outside function!"),
} }
} }
} }

View File

@ -18,11 +18,11 @@ type IfResult = Result<Handle, InferenceError>;
pub trait Inference<'a> { pub trait Inference<'a> {
/// Performs type inference /// Performs type inference
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult; fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult;
} }
impl<'a> Inference<'a> for File { impl<'a> Inference<'a> for File {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { name: _, items } = self; let Self { name: _, items } = self;
for item in items { for item in items {
item.infer(e)?; item.infer(e)?;
@ -32,14 +32,14 @@ impl<'a> Inference<'a> for File {
} }
impl<'a> Inference<'a> for Item { impl<'a> Inference<'a> for Item {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { span: _, attrs: _, vis: _, kind } = self; let Self { span: _, attrs: _, vis: _, kind } = self;
kind.infer(e) kind.infer(e)
} }
} }
impl<'a> Inference<'a> for ItemKind { impl<'a> Inference<'a> for ItemKind {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
match self { match self {
ItemKind::Module(v) => v.infer(e), ItemKind::Module(v) => v.infer(e),
ItemKind::Alias(v) => v.infer(e), ItemKind::Alias(v) => v.infer(e),
@ -55,7 +55,7 @@ impl<'a> Inference<'a> for ItemKind {
} }
impl<'a> Inference<'a> for Generics { impl<'a> Inference<'a> for Generics {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
// bind names // bind names
for name in &self.vars { for name in &self.vars {
let ty = e.new_var(); let ty = e.new_var();
@ -66,7 +66,7 @@ impl<'a> Inference<'a> for Generics {
} }
impl<'a> Inference<'a> for Module { impl<'a> Inference<'a> for Module {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { name, file } = self; let Self { name, file } = self;
let Some(file) = file else { let Some(file) = file else {
return Err(InferenceError::NotFound((*name).into())); return Err(InferenceError::NotFound((*name).into()));
@ -77,7 +77,7 @@ impl<'a> Inference<'a> for Module {
} }
impl<'a> Inference<'a> for Alias { impl<'a> Inference<'a> for Alias {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { name: _, from } = self; let Self { name: _, from } = self;
// let this = e.by_name(name)?; // let this = e.by_name(name)?;
let alias = if let Some(from) = from { let alias = if let Some(from) = from {
@ -99,7 +99,7 @@ impl<'a> Inference<'a> for Alias {
impl<'a> Inference<'a> for Const { impl<'a> Inference<'a> for Const {
#[allow(unused)] #[allow(unused)]
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { name: _, ty, init } = self; let Self { name: _, ty, init } = self;
// Same as static // Same as static
let node = e.at; //.by_name(name)?; let node = e.at; //.by_name(name)?;
@ -116,7 +116,7 @@ impl<'a> Inference<'a> for Const {
impl<'a> Inference<'a> for Static { impl<'a> Inference<'a> for Static {
#[allow(unused)] #[allow(unused)]
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Static { mutable, name, ty, init } = self; let Static { mutable, name, ty, init } = self;
let node = e.at; //e.by_name(name)?; let node = e.at; //e.by_name(name)?;
let ty = e.infer(ty)?; let ty = e.infer(ty)?;
@ -132,7 +132,7 @@ impl<'a> Inference<'a> for Static {
impl<'a> Inference<'a> for Function { impl<'a> Inference<'a> for Function {
#[allow(unused)] #[allow(unused)]
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { name, gens, sign, bind, body } = self; let Self { name, gens, sign, bind, body } = self;
// bind name to signature // bind name to signature
let node = e.at; // e.by_name(name)?; let node = e.at; // e.by_name(name)?;
@ -149,7 +149,9 @@ impl<'a> Inference<'a> for Function {
let arg = scope.by_name(sign.args.as_ref())?; let arg = scope.by_name(sign.args.as_ref())?;
scope.unify(pat, arg); scope.unify(pat, arg);
let mut retscope = scope.open_rset(); let mut rset = None;
let mut retscope = scope.open_rset(&mut rset);
// infer body // infer body
let bodty = retscope.infer(body)?; let bodty = retscope.infer(body)?;
@ -157,7 +159,8 @@ impl<'a> Inference<'a> for Function {
// unify body with rety // unify body with rety
retscope.unify(bodty, rety)?; retscope.unify(bodty, rety)?;
// unify rset with rety // unify rset with rety
if let Some(rset) = retscope.rset.get() {
if let Some(rset) = rset {
scope.unify(rset, rety)?; scope.unify(rset, rety)?;
} }
Ok(node) Ok(node)
@ -168,7 +171,7 @@ impl<'a> Inference<'a> for Function {
// there are no bodies // there are no bodies
impl<'a> Inference<'a> for Enum { impl<'a> Inference<'a> for Enum {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { name: _, gens, variants } = self; let Self { name: _, gens, variants } = self;
let node = e.at; //e.by_name(name)?; let node = e.at; //e.by_name(name)?;
let mut scope = e.at(node); let mut scope = e.at(node);
@ -184,7 +187,7 @@ impl<'a> Inference<'a> for Enum {
} }
impl<'a> Inference<'a> for Variant { impl<'a> Inference<'a> for Variant {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { name, kind: _, body } = self; let Self { name, kind: _, body } = self;
let node = e.by_name(name)?; let node = e.by_name(name)?;
@ -210,7 +213,7 @@ impl<'a> Inference<'a> for Variant {
} }
impl<'a> Inference<'a> for Struct { impl<'a> Inference<'a> for Struct {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { name, gens, kind: _ } = self; let Self { name, gens, kind: _ } = self;
let node = e.by_name(name)?; let node = e.by_name(name)?;
let mut e = e.at(node); let mut e = e.at(node);
@ -221,7 +224,7 @@ impl<'a> Inference<'a> for Struct {
} }
impl<'a> Inference<'a> for Impl { impl<'a> Inference<'a> for Impl {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { gens, target, body } = self; let Self { gens, target, body } = self;
// TODO: match gens to target gens // TODO: match gens to target gens
gens.infer(e)?; gens.infer(e)?;
@ -233,7 +236,7 @@ impl<'a> Inference<'a> for Impl {
} }
impl<'a> Inference<'a> for ImplKind { impl<'a> Inference<'a> for ImplKind {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
match self { match self {
ImplKind::Type(ty) => ty.infer(e), ImplKind::Type(ty) => ty.infer(e),
ImplKind::Trait { impl_trait: _, for_type } => for_type.infer(e), ImplKind::Trait { impl_trait: _, for_type } => for_type.infer(e),
@ -242,13 +245,13 @@ impl<'a> Inference<'a> for ImplKind {
} }
impl<'a> Inference<'a> for Ty { impl<'a> Inference<'a> for Ty {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
Ok(e.by_name(self)?) Ok(e.by_name(self)?)
} }
} }
impl<'a> Inference<'a> for cl_ast::Stmt { impl<'a> Inference<'a> for cl_ast::Stmt {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { span: _, kind, semi } = self; let Self { span: _, kind, semi } = self;
let out = kind.infer(e)?; let out = kind.infer(e)?;
Ok(match semi { Ok(match semi {
@ -259,7 +262,7 @@ impl<'a> Inference<'a> for cl_ast::Stmt {
} }
impl<'a> Inference<'a> for cl_ast::StmtKind { impl<'a> Inference<'a> for cl_ast::StmtKind {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
match self { match self {
StmtKind::Empty => Ok(e.empty()), StmtKind::Empty => Ok(e.empty()),
StmtKind::Item(item) => item.infer(e), StmtKind::Item(item) => item.infer(e),
@ -269,15 +272,15 @@ impl<'a> Inference<'a> for cl_ast::StmtKind {
} }
impl<'a> Inference<'a> for cl_ast::Expr { impl<'a> Inference<'a> for cl_ast::Expr {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let out = self.kind.infer(e)?; let out = self.kind.infer(e)?;
println!("expr ({self}) -> {}", e.entry(out)); // println!("expr ({self}) -> {}", e.entry(out));
Ok(out) Ok(out)
} }
} }
impl<'a> Inference<'a> for cl_ast::ExprKind { impl<'a> Inference<'a> for cl_ast::ExprKind {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
match self { match self {
ExprKind::Empty => Ok(e.empty()), ExprKind::Empty => Ok(e.empty()),
ExprKind::Closure(closure) => closure.infer(e), ExprKind::Closure(closure) => closure.infer(e),
@ -314,15 +317,17 @@ impl<'a> Inference<'a> for cl_ast::ExprKind {
} }
impl<'a> Inference<'a> for Closure { impl<'a> Inference<'a> for Closure {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Self { arg, body } = self; let Self { arg, body } = self;
let args = arg.infer(e)?; let args = arg.infer(e)?;
let mut scope = e.block_scope(); let mut scope = e.block_scope();
let mut scope = scope.open_rset();
let rety = scope.infer(body)?;
if let Some(rset) = scope.rset.get() { let mut rset = None;
let mut retscope = scope.open_rset(&mut rset);
let rety = retscope.infer(body)?;
if let Some(rset) = rset {
e.unify(rety, rset)?; e.unify(rety, rset)?;
} }
@ -331,7 +336,7 @@ impl<'a> Inference<'a> for Closure {
} }
impl<'a> Inference<'a> for Tuple { impl<'a> Inference<'a> for Tuple {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Tuple { exprs } = self; let Tuple { exprs } = self;
exprs exprs
.iter() .iter()
@ -345,7 +350,7 @@ impl<'a> Inference<'a> for Tuple {
} }
impl<'a> Inference<'a> for Structor { impl<'a> Inference<'a> for Structor {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Structor { to, init } = self; let Structor { to, init } = self;
// Evaluate the path in the current context // Evaluate the path in the current context
let to = to.infer(e)?; let to = to.infer(e)?;
@ -386,7 +391,7 @@ impl<'a> Inference<'a> for Structor {
} }
impl<'a> Inference<'a> for Array { impl<'a> Inference<'a> for Array {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Array { values } = self; let Array { values } = self;
let out = e.new_inferred(); let out = e.new_inferred();
for value in values { for value in values {
@ -398,7 +403,7 @@ impl<'a> Inference<'a> for Array {
} }
impl<'a> Inference<'a> for ArrayRep { impl<'a> Inference<'a> for ArrayRep {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let ArrayRep { value, repeat } = self; let ArrayRep { value, repeat } = self;
let ty = value.infer(e)?; let ty = value.infer(e)?;
let rep = repeat.infer(e)?; let rep = repeat.infer(e)?;
@ -414,7 +419,7 @@ impl<'a> Inference<'a> for ArrayRep {
} }
impl<'a> Inference<'a> for AddrOf { impl<'a> Inference<'a> for AddrOf {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let AddrOf { mutable: _, expr } = self; let AddrOf { mutable: _, expr } = self;
// TODO: mut ref // TODO: mut ref
let ty = expr.infer(e)?; let ty = expr.infer(e)?;
@ -423,13 +428,13 @@ impl<'a> Inference<'a> for AddrOf {
} }
impl<'a> Inference<'a> for Quote { impl<'a> Inference<'a> for Quote {
fn infer(&'a self, _e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, _e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
todo!("Quote: {self}") todo!("Quote: {self}")
} }
} }
impl<'a> Inference<'a> for Literal { impl<'a> Inference<'a> for Literal {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let ty = match self { let ty = match self {
Literal::Bool(_) => e.bool(), Literal::Bool(_) => e.bool(),
Literal::Char(_) => e.char(), Literal::Char(_) => e.char(),
@ -445,14 +450,14 @@ impl<'a> Inference<'a> for Literal {
} }
impl<'a> Inference<'a> for Group { impl<'a> Inference<'a> for Group {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Group { expr } = self; let Group { expr } = self;
expr.infer(e) expr.infer(e)
} }
} }
impl<'a> Inference<'a> for Block { impl<'a> Inference<'a> for Block {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Block { stmts } = self; let Block { stmts } = self;
let mut e = e.block_scope(); let mut e = e.block_scope();
let empty = e.empty(); let empty = e.empty();
@ -483,7 +488,7 @@ impl<'a> Inference<'a> for Block {
} }
impl<'a> Inference<'a> for Assign { impl<'a> Inference<'a> for Assign {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Assign { parts } = self; let Assign { parts } = self;
let (head, tail) = parts.as_ref(); let (head, tail) = parts.as_ref();
// Infer the tail expression // Infer the tail expression
@ -498,7 +503,7 @@ impl<'a> Inference<'a> for Assign {
} }
impl<'a> Inference<'a> for Modify { impl<'a> Inference<'a> for Modify {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Modify { kind: _, parts } = self; let Modify { kind: _, parts } = self;
let (head, tail) = parts.as_ref(); let (head, tail) = parts.as_ref();
// Infer the tail expression // Infer the tail expression
@ -513,7 +518,7 @@ impl<'a> Inference<'a> for Modify {
} }
impl<'a> Inference<'a> for Binary { impl<'a> Inference<'a> for Binary {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
use BinaryKind as Bk; use BinaryKind as Bk;
let Binary { kind, parts } = self; let Binary { kind, parts } = self;
let (head, tail) = parts.as_ref(); let (head, tail) = parts.as_ref();
@ -583,7 +588,7 @@ impl<'a> Inference<'a> for Binary {
} }
impl<'a> Inference<'a> for Unary { impl<'a> Inference<'a> for Unary {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Unary { kind, tail } = self; let Unary { kind, tail } = self;
match kind { match kind {
UnaryKind::Deref => { UnaryKind::Deref => {
@ -596,18 +601,18 @@ impl<'a> Inference<'a> for Unary {
} }
} }
UnaryKind::Loop => { UnaryKind::Loop => {
let mut e = e.block_scope(); let mut scope = e.block_scope();
// Enter a new breakset // Enter a new breakset
let mut e = e.open_bset(); let mut bset = None;
let mut bscope = scope.open_bset(&mut bset);
// Infer the fail branch // Infer the tail
let tail = tail.infer(&mut e)?; let body = tail.infer(&mut bscope)?;
// Unify the pass branch with Empty // Unify the loop body with `empty`
let empt = e.empty(); let unit = bscope.empty();
e.unify(tail, empt)?; bscope.unify(unit, body)?;
// Return breakset // Return breakset
match e.bset.get() { match bset {
Some(bset) => Ok(bset), Some(bset) => Ok(bset),
None => Ok(e.never()), None => Ok(e.never()),
} }
@ -623,7 +628,7 @@ impl<'a> Inference<'a> for Unary {
} }
impl<'a> Inference<'a> for Member { impl<'a> Inference<'a> for Member {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Member { head, kind } = self; let Member { head, kind } = self;
// Infer the head expression // Infer the head expression
let head = head.infer(e)?; let head = head.infer(e)?;
@ -647,10 +652,19 @@ impl<'a> Inference<'a> for Member {
Err(InferenceError::NotFound(Path::from(*name))) Err(InferenceError::NotFound(Path::from(*name)))
} }
} }
MemberKind::Struct(name) => match ty.nav(&[PathPart::Ident(*name)]) { MemberKind::Struct(name) => {
if let Some(TypeKind::Adt(Adt::Struct(members))) = ty.ty() {
for member in members {
if member.0 == *name {
return Ok(member.2);
}
}
};
match ty.nav(&[PathPart::Ident(*name)]) {
Some(ty) => Ok(ty.id()), Some(ty) => Ok(ty.id()),
None => Err(InferenceError::NotFound(Path::from(*name))), None => Err(InferenceError::NotFound(Path::from(*name))),
}, }
}
MemberKind::Tuple(Literal::Int(idx)) => match ty.ty() { MemberKind::Tuple(Literal::Int(idx)) => match ty.ty() {
Some(TypeKind::Tuple(tys)) => tys Some(TypeKind::Tuple(tys)) => tys
.get(*idx as usize) .get(*idx as usize)
@ -669,7 +683,7 @@ impl<'a> Inference<'a> for Member {
} }
impl<'a> Inference<'a> for Index { impl<'a> Inference<'a> for Index {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Index { head, indices } = self; let Index { head, indices } = self;
let usize = e.usize(); let usize = e.usize();
// Infer the head expression // Infer the head expression
@ -703,7 +717,7 @@ impl<'a> Inference<'a> for Index {
} }
impl<'a> Inference<'a> for Cast { impl<'a> Inference<'a> for Cast {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Cast { head, ty } = self; let Cast { head, ty } = self;
// Infer the head expression // Infer the head expression
let _head = head.infer(e)?; let _head = head.infer(e)?;
@ -719,14 +733,14 @@ impl<'a> Inference<'a> for Cast {
} }
impl<'a> Inference<'a> for Path { impl<'a> Inference<'a> for Path {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
e.by_name(self) e.by_name(self)
.map_err(|_| InferenceError::NotFound(self.clone())) .map_err(|_| InferenceError::NotFound(self.clone()))
} }
} }
impl<'a> Inference<'a> for Let { impl<'a> Inference<'a> for Let {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Let { mutable: _, name, ty, init } = self; let Let { mutable: _, name, ty, init } = self;
let ty = match ty { let ty = match ty {
Some(ty) => ty Some(ty) => ty
@ -752,7 +766,7 @@ impl<'a> Inference<'a> for Let {
} }
impl<'a> Inference<'a> for Match { impl<'a> Inference<'a> for Match {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Match { scrutinee, arms } = self; let Match { scrutinee, arms } = self;
// Infer the scrutinee // Infer the scrutinee
let scrutinee = scrutinee.infer(e)?; let scrutinee = scrutinee.infer(e)?;
@ -783,7 +797,7 @@ impl<'a> Inference<'a> for Match {
impl<'a> Inference<'a> for Pattern { impl<'a> Inference<'a> for Pattern {
// TODO: This is the wrong way to typeck pattern matching. // TODO: This is the wrong way to typeck pattern matching.
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
match self { match self {
Pattern::Name(name) => { Pattern::Name(name) => {
// Evaluating a pattern creates and enters a new scope. // Evaluating a pattern creates and enters a new scope.
@ -867,47 +881,61 @@ impl<'a> Inference<'a> for Pattern {
} }
impl<'a> Inference<'a> for While { impl<'a> Inference<'a> for While {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let While { cond, pass, fail } = self; let While { cond, pass, fail } = self;
let mut bset = None;
// Open a block scope so the loop condition doesn't escape
{
let mut scope = e.block_scope();
// Infer the condition // Infer the condition
let cond = cond.infer(e)?; let cond = cond.infer(&mut scope)?;
// Unify the condition with bool // Unify the condition with bool
let bool = e.bool(); let bool = scope.bool();
e.unify(bool, cond)?; scope.unify(bool, cond)?;
// Open a breakset for the pass branch
{
let mut body_scope = scope.open_bset(&mut bset);
// Infer the pass branch
let pass = pass.infer(&mut body_scope)?;
// Unify the pass branch with Empty
let empt = body_scope.empty();
body_scope.unify(empt, pass)?;
}
}
// Infer the fail branch // Infer the fail branch
let fail = fail.infer(e)?; let fail = fail.infer(e)?;
// Unify the fail branch with breakset // Unify the fail branch with breakset
let mut e = e.open_bset(); if let Some(bset) = bset {
println!("bset: {}", e.entry(bset));
// Infer the pass branch e.unify(bset, fail)?;
let pass = pass.infer(&mut e)?; }
// Unify the pass branch with Empty
let empt = e.empty();
e.unify(pass, empt)?;
match e.bset.get() {
None => Ok(e.empty()),
Some(bset) => {
e.unify(fail, bset)?;
Ok(fail) Ok(fail)
} }
}
}
} }
impl<'a> Inference<'a> for If { impl<'a> Inference<'a> for If {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let If { cond, pass, fail } = self; let If { cond, pass, fail } = self;
// Open a block scope so the condition doesn't escape
let pass = {
let mut scope = e.block_scope();
// Do inference on the condition' // Do inference on the condition'
let cond = cond.infer(e)?; let cond = cond.infer(&mut scope)?;
// Unify the condition with bool // Unify the condition with bool
let bool = e.bool(); let bool = scope.bool();
e.unify(bool, cond)?; scope.unify(bool, cond)?;
// Do inference on the pass branch // Do inference on the pass branch
let pass = pass.infer(e)?; pass.infer(&mut scope)?
};
// Do inference on the fail branch // Do inference on the fail branch
let fail = fail.infer(e)?; let fail = fail.infer(&mut e.block_scope())?;
// Unify pass and fail // Unify pass and fail
e.unify(pass, fail)?; e.unify(pass, fail)?;
// Return the result // Return the result
@ -916,55 +944,52 @@ impl<'a> Inference<'a> for If {
} }
impl<'a> Inference<'a> for For { impl<'a> Inference<'a> for For {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let For { bind, cond, pass, fail } = self; let For { bind, cond, pass, fail } = self;
let mut e = e.block_scope(); let mut scope = e.block_scope();
let bind = bind.infer(&mut e)?; let bind = bind.infer(&mut scope)?;
// What does it mean to be iterable? Why, `next()`, of course! // What does it mean to be iterable? Why, `next()`, of course!
let cond = cond.infer(&mut e)?; let cond = cond.infer(&mut scope)?;
let cond = e.prune(cond); let cond = scope.prune(cond);
if let Some((args, rety)) = e.get_fn(cond, "next".into()) { if let Some((args, rety)) = scope.get_fn(cond, "next".into()) {
// Check that the args are correct // Check that the args are correct
let params = vec![e.new_ref(cond)]; let params = vec![scope.new_ref(cond)];
let params = e.new_tuple(params); let params = scope.new_tuple(params);
e.unify(args, params)?; scope.unify(args, params)?;
e.unify(rety, bind)?; scope.unify(rety, bind)?;
} }
// Enter a new breakset
let mut e = e.open_bset();
// Infer the fail branch
let fail = fail.infer(&mut e)?;
// Open a breakset // Open a breakset
let mut e = e.open_bset(); let mut bset = None;
let mut bscope = scope.open_bset(&mut bset);
// Infer the pass branch // Infer the pass branch
let pass = pass.infer(&mut e)?; let pass = pass.infer(&mut bscope)?;
// Unify the pass branch with Empty // Unify the pass branch with Empty
let empt = e.empty(); let empt = bscope.empty();
e.unify(pass, empt)?; bscope.unify(pass, empt)?;
// Return breakset // Infer the fail branch
if let Some(bset) = e.bset.get() { let fail = fail.infer(&mut e.block_scope())?;
// Unify the fail branch with the breakset
if let Some(bset) = bset {
e.unify(fail, bset)?; e.unify(fail, bset)?;
Ok(fail)
} else {
Ok(e.empty())
} }
Ok(fail)
} }
} }
impl<'a> Inference<'a> for Else { impl<'a> Inference<'a> for Else {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
self.body.infer(e) self.body.infer(e)
} }
} }
impl<'a> Inference<'a> for Break { impl<'a> Inference<'a> for Break {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Break { body } = self; let Break { body } = self;
// Infer the body of the break // Infer the body of the break
let ty = body.infer(e)?; let ty = body.infer(e)?;
@ -976,7 +1001,7 @@ impl<'a> Inference<'a> for Break {
} }
impl<'a> Inference<'a> for Return { impl<'a> Inference<'a> for Return {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
let Return { body } = self; let Return { body } = self;
// Infer the body of the return // Infer the body of the return
let ty = body.infer(e)?; let ty = body.infer(e)?;
@ -988,7 +1013,7 @@ impl<'a> Inference<'a> for Return {
} }
impl<'a, I: Inference<'a>> Inference<'a> for Option<I> { impl<'a, I: Inference<'a>> Inference<'a> for Option<I> {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
match self { match self {
Some(expr) => expr.infer(e), Some(expr) => expr.infer(e),
None => Ok(e.empty()), None => Ok(e.empty()),
@ -996,7 +1021,7 @@ impl<'a, I: Inference<'a>> Inference<'a> for Option<I> {
} }
} }
impl<'a, I: Inference<'a>> Inference<'a> for Box<I> { impl<'a, I: Inference<'a>> Inference<'a> for Box<I> {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult { fn infer(&'a self, e: &mut InferenceEngine<'_, 'a, '_, '_>) -> IfResult {
self.as_ref().infer(e) self.as_ref().infer(e)
} }
} }