cl-typeck: Get some serious type inference going!

This commit is contained in:
John 2025-07-18 05:30:23 -04:00
parent 74220d3bff
commit 8732cca3f9
11 changed files with 680 additions and 215 deletions

View File

@ -72,6 +72,7 @@ fn main_menu(prj: &mut Table) -> Result<(), RlError> {
match line {
"c" | "code" => enter_code(prj)?,
"clear" => clear()?,
"dump" => dump(prj)?,
"d" | "desugar" => live_desugar()?,
"e" | "exit" => return Ok(Response::Break),
"f" | "file" => import_files(prj)?,
@ -254,7 +255,7 @@ fn infer_all(table: &mut Table) -> Result<(), Box<dyn Error>> {
}
e => eprint!("{e}"),
}
eprintln!(" in {} ({id})", id.to_entry(table))
eprintln!(" in {id}\n({})\n", id.to_entry(table).source().unwrap())
}
println!("...Inferred!");
@ -404,6 +405,30 @@ fn inline_modules(code: cl_ast::File, path: impl AsRef<path::Path>) -> cl_ast::F
}
}
fn dump(table: &Table) -> Result<(), Box<dyn Error>> {
fn dump_recursive(
name: cl_ast::Sym,
entry: Entry,
depth: usize,
to_file: &mut std::fs::File,
) -> std::io::Result<()> {
use std::io::Write;
write!(to_file, "{:w$}{name}: {entry}", "", w = depth)?;
if let Some(children) = entry.children() {
writeln!(to_file, " {{")?;
for (name, child) in children {
dump_recursive(*name, entry.with_id(*child), depth + 2, to_file)?;
}
write!(to_file, "{:w$}}}", "", w = depth)?;
}
writeln!(to_file)
}
let mut file = std::fs::File::create("typeck-table.ron")?;
dump_recursive("root".into(), table.root_entry(), 0, &mut file)?;
Ok(())
}
fn clear() -> Result<(), Box<dyn Error>> {
println!("\x1b[H\x1b[2J");
banner();

View File

@ -20,6 +20,7 @@ use crate::{
type_kind::TypeKind,
};
mod debug;
mod display;
impl Handle {
@ -31,24 +32,65 @@ impl Handle {
}
}
#[derive(Debug)]
pub struct Entry<'t, 'a> {
table: &'t Table<'a>,
id: Handle,
}
macro_rules! impl_entry_ {
() => {
pub const fn id(&self) -> Handle {
self.id
}
pub const fn inner(&'t self) -> &'t Table<'a> {
self.table
}
pub fn kind(&self) -> Option<&NodeKind> {
self.table.kind(self.id)
}
pub const fn root(&self) -> Handle {
self.table.root()
}
pub fn children(&self) -> Option<&HashMap<Sym, Handle>> {
self.table.children(self.id)
}
pub fn imports(&self) -> Option<&HashMap<Sym, Handle>> {
self.table.imports(self.id)
}
pub fn bodies(&self) -> Option<&'a Expr> {
self.table.body(self.id)
}
pub fn span(&self) -> Option<&Span> {
self.table.span(self.id)
}
pub fn meta(&self) -> Option<&[Meta]> {
self.table.meta(self.id)
}
pub fn source(&self) -> Option<&Source<'a>> {
self.table.source(self.id)
}
pub fn name(&self) -> Option<Sym> {
self.table.name(self.id)
}
};
}
impl<'t, 'a> Entry<'t, 'a> {
pub const fn new(table: &'t Table<'a>, id: Handle) -> Self {
Self { table, id }
}
pub const fn id(&self) -> Handle {
self.id
}
pub fn inner(&self) -> &'t Table<'a> {
self.table
}
impl_entry_!();
pub const fn with_id(&self, id: Handle) -> Entry<'t, 'a> {
Self { table: self.table, id }
@ -58,46 +100,14 @@ impl<'t, 'a> Entry<'t, 'a> {
Some(Entry { id: self.table.nav(self.id, path)?, table: self.table })
}
pub const fn root(&self) -> Handle {
self.table.root()
}
pub fn kind(&self) -> Option<&'t NodeKind> {
self.table.kind(self.id)
}
pub fn parent(&self) -> Option<Entry<'t, 'a>> {
Some(Entry { id: *self.table.parent(self.id)?, ..*self })
}
pub fn children(&self) -> Option<&'t HashMap<Sym, Handle>> {
self.table.children(self.id)
}
pub fn imports(&self) -> Option<&'t HashMap<Sym, Handle>> {
self.table.imports(self.id)
}
pub fn bodies(&self) -> Option<&'a Expr> {
self.table.body(self.id)
}
pub fn ty(&self) -> Option<&'t TypeKind> {
self.table.ty(self.id)
}
pub fn span(&self) -> Option<&'t Span> {
self.table.span(self.id)
}
pub fn meta(&self) -> Option<&'a [Meta]> {
self.table.meta(self.id)
}
pub fn source(&self) -> Option<&'t Source<'a>> {
self.table.source(self.id)
}
pub fn impl_target(&self) -> Option<Entry<'_, 'a>> {
Some(Entry { id: self.table.impl_target(self.id)?, ..*self })
}
@ -105,10 +115,6 @@ impl<'t, 'a> Entry<'t, 'a> {
pub fn selfty(&self) -> Option<Entry<'_, 'a>> {
Some(Entry { id: self.table.selfty(self.id)?, ..*self })
}
pub fn name(&self) -> Option<Sym> {
self.table.name(self.id)
}
}
#[derive(Debug)]
@ -122,12 +128,18 @@ impl<'t, 'a> EntryMut<'t, 'a> {
Self { table, id }
}
pub fn as_ref(&self) -> Entry<'_, 'a> {
Entry { table: self.table, id: self.id }
impl_entry_!();
pub fn ty(&self) -> Option<&TypeKind> {
self.table.ty(self.id)
}
pub const fn id(&self) -> Handle {
self.id
pub fn inner_mut(&mut self) -> &mut Table<'a> {
self.table
}
pub fn as_ref(&self) -> Entry<'_, 'a> {
Entry { table: self.table, id: self.id }
}
/// Evaluates a [TypeExpression] in this entry's context
@ -182,6 +194,10 @@ impl<'t, 'a> EntryMut<'t, 'a> {
self.table.set_impl_target(self.id, target)
}
pub fn mark_unchecked(&mut self) {
self.table.mark_unchecked(self.id)
}
pub fn mark_use_item(&mut self) {
self.table.mark_use_item(self.id)
}

View File

@ -0,0 +1,33 @@
//! [std::fmt::Debug] implementation for [Entry]
use super::Entry;
impl std::fmt::Debug for Entry<'_, '_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// virtual fields
let mut ds = f.debug_struct("Entry");
if let Some(name) = self.name() {
ds.field("name", &name.to_ref());
}
ds.field("kind", &self.kind());
if let Some(ty) = self.ty() {
ds.field("type", ty);
}
if let Some(meta) = self.meta() {
ds.field("meta", &meta);
}
if let Some(body) = self.bodies() {
ds.field("body", body);
}
if let Some(children) = self.children() {
ds.field("children", children);
}
if let Some(imports) = self.imports() {
ds.field("imports", imports);
}
// if let Some(source) = self.source() {
// ds.field("source", source);
// }
ds.field("implements", &self.impl_target()).finish()
}
}

View File

@ -60,7 +60,14 @@ impl fmt::Display for Entry<'_, '_> {
TypeKind::Module => write!(f, "module?"),
}
} else {
write!(f, "{kind}")
match kind {
NodeKind::Type
| NodeKind::Const
| NodeKind::Static
| NodeKind::Temporary
| NodeKind::Let => write!(f, "WARNING: NO TYPE ASSIGNED FOR {}", self.id),
_ => write!(f, "{kind}"),
}
}
}
}

View File

@ -1,6 +1,7 @@
//! Categorizes an entry in a table according to its embedded type information
#![allow(unused)]
use crate::{
entry::EntryMut,
handle::Handle,
source::Source,
table::{NodeKind, Table},
@ -11,39 +12,37 @@ use cl_ast::*;
/// Ensures a type entry exists for the provided handle in the table
pub fn categorize(table: &mut Table, node: Handle) -> CatResult<()> {
if let Some(meta) = table.meta(node) {
for meta @ Meta { name, kind } in meta {
if let ("lang", MetaKind::Equals(Literal::String(s))) = (&**name, kind) {
let kind =
TypeKind::Primitive(s.parse().map_err(|_| Error::BadMeta(meta.clone()))?);
table.set_ty(node, kind);
return Ok(());
}
}
}
let Some(source) = table.source(node) else {
return Ok(());
};
match source {
Source::Root => Ok(()),
Source::Module(_) => Ok(()),
Source::Alias(a) => cat_alias(table, node, a),
Source::Enum(e) => cat_enum(table, node, e),
Source::Variant(v) => cat_variant(table, node, v),
Source::Struct(s) => cat_struct(table, node, s),
Source::Const(c) => cat_const(table, node, c),
Source::Static(s) => cat_static(table, node, s),
Source::Function(f) => cat_function(table, node, f),
Source::Local(l) => cat_local(table, node, l),
Source::Impl(i) => cat_impl(table, node, i),
Source::Use(_) => Ok(()),
Source::Ty(ty) => ty
.evaluate(table, node)
.map_err(|e| Error::TypeEval(e, " while categorizing a type"))
.map(drop),
Source::Alias(a) => cat_alias(table, node, a)?,
Source::Enum(e) => cat_enum(table, node, e)?,
Source::Variant(v) => cat_variant(table, node, v)?,
Source::Struct(s) => cat_struct(table, node, s)?,
Source::Const(c) => cat_const(table, node, c)?,
Source::Static(s) => cat_static(table, node, s)?,
Source::Function(f) => cat_function(table, node, f)?,
Source::Local(l) => cat_local(table, node, l)?,
Source::Impl(i) => cat_impl(table, node, i)?,
_ => {}
}
if let Some(meta) = table.meta(node) {
for meta @ Meta { name, kind } in meta {
if let ("lang", MetaKind::Equals(Literal::String(s))) = (&**name, kind) {
if let Ok(prim) = s.parse() {
table.set_ty(node, TypeKind::Primitive(prim));
} else {
table.mark_lang_item(s.into(), node);
continue;
}
return Ok(());
}
}
}
Ok(())
}
fn parent(table: &Table, node: Handle) -> Handle {
@ -108,7 +107,6 @@ fn cat_enum<'a>(_table: &mut Table<'a>, _node: Handle, e: &'a Enum) -> CatResult
fn cat_variant<'a>(table: &mut Table<'a>, node: Handle, v: &'a Variant) -> CatResult<()> {
let Variant { name, kind, body } = v;
let parent = table.parent(node).copied().unwrap_or(table.root());
table.add_child(parent, *name, node);
match (kind, body) {
(StructKind::Empty, None) => {
table.set_ty(node, TypeKind::Adt(Adt::UnitStruct));
@ -195,7 +193,7 @@ fn cat_local(table: &mut Table, node: Handle, l: &Let) -> CatResult<()> {
fn cat_impl(table: &mut Table, node: Handle, i: &Impl) -> CatResult<()> {
let parent = parent(table, node);
let Impl { target, body: _ } = i;
let Impl { gens, target, body: _ } = i;
let target = match target {
ImplKind::Type(t) => t.evaluate(table, parent),
ImplKind::Trait { impl_trait: _, for_type: t } => t.evaluate(table, parent),

View File

@ -15,9 +15,9 @@ pub fn impl_one(table: &mut Table, node: Handle) -> Result<(), Handle> {
let Some(target) = table.impl_target(node) else {
Err(node)?
};
let Table { children, imports, .. } = table;
if let Some(children) = children.get(&node) {
imports.entry(target).or_default().extend(children);
if let Some(children) = table.children.get_mut(&node) {
let children = children.clone();
table.children.entry(target).or_default().extend(children);
}
Ok(())
}

View File

@ -1,7 +1,10 @@
use std::{cell::Cell, collections::HashSet, rc::Rc};
use super::error::InferenceError;
use crate::{
entry::Entry,
handle::Handle,
source::Source,
stage::infer::inference::Inference,
table::{NodeKind, Table},
type_expression::TypeExpression,
@ -28,14 +31,16 @@ use cl_ast::Sym;
- for<T, R> type T -> R // on a per-case basis!
*/
type HandleSet = Rc<Cell<Option<Handle>>>;
pub struct InferenceEngine<'table, 'a> {
pub(super) table: &'table mut Table<'a>,
/// The current working node
pub(crate) at: Handle,
/// The current breakset
pub(crate) bset: Handle,
pub(crate) bset: HandleSet,
/// The current returnset
pub(crate) rset: Handle,
pub(crate) rset: HandleSet,
}
impl<'table, 'a> InferenceEngine<'table, 'a> {
@ -46,48 +51,69 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
/// Constructs a new [`InferenceEngine`], scoped around a [`Handle`] in a [`Table`].
pub fn new(table: &'table mut Table<'a>, at: Handle) -> Self {
let never = table.anon_type(TypeKind::Never);
Self { at, table, bset: never, rset: never }
Self { at, table, bset: Default::default(), rset: Default::default() }
}
/// Constructs an [`InferenceEngine`] that borrows the same table as `self`,
/// but with a shortened lifetime.
pub fn scoped(&mut self) -> InferenceEngine<'_, 'a> {
InferenceEngine { at: self.at, table: self.table, bset: self.bset, rset: self.rset }
InferenceEngine {
at: self.at,
table: self.table,
bset: self.bset.clone(),
rset: self.rset.clone(),
}
}
pub fn infer_all(&mut self) -> Vec<(Handle, InferenceError)> {
let iter = self.table.handle_iter();
let queue = std::mem::take(&mut self.table.unchecked);
let mut res = Vec::new();
for handle in iter {
for handle in queue {
let mut eng = self.at(handle);
// TODO: use sources instead of bodies, and infer the type globally
let Some(body) = eng.table.body(handle) else {
let Some(source) = eng.table.source(handle) else {
eprintln!("No source found for {handle}");
continue;
};
eprintln!("Evaluating body {body}");
match body.infer(&mut eng) {
Ok(ty) => println!("=> {}", eng.table.entry(ty)),
Err(e) => {
match &e {
&InferenceError::Mismatch(a, b) => {
eprintln!(
"=> Mismatched types: {}, {}",
eng.table.entry(a),
eng.table.entry(b)
);
}
&InferenceError::Recursive(a, b) => {
eprintln!(
"=> Recursive types: {}, {}",
eng.table.entry(a),
eng.table.entry(b)
);
}
e => eprintln!("=> {e}"),
}
res.push((handle, e))
println!("Inferring {source}");
let ret = match source {
Source::Module(v) => v.infer(&mut eng),
Source::Alias(v) => v.infer(&mut eng),
Source::Enum(v) => v.infer(&mut eng),
Source::Variant(v) => v.infer(&mut eng),
Source::Struct(v) => v.infer(&mut eng),
Source::Const(v) => v.infer(&mut eng),
Source::Static(v) => v.infer(&mut eng),
Source::Function(v) => v.infer(&mut eng),
Source::Local(v) => v.infer(&mut eng),
Source::Impl(v) => v.infer(&mut eng),
_ => Ok(eng.empty()),
};
match &ret {
&Ok(handle) => println!("=> {}", eng.entry(handle)),
Err(err @ InferenceError::AnnotationEval(_)) => eprintln!("=> ERROR: {err}"),
Err(InferenceError::FieldCount(h, want, got)) => {
eprintln!("=> ERROR: Field count {want} != {got} in {}", eng.entry(*h))
}
Err(err @ InferenceError::NotFound(_)) => eprintln!("=> ERROR: {err}"),
Err(InferenceError::Mismatch(h1, h2)) => eprintln!(
"=> ERROR: Type mismatch {} != {}",
eng.entry(*h1),
eng.entry(*h2)
),
Err(InferenceError::Recursive(h1, h2)) => eprintln!(
"=> ERROR: Cycle found in types {}, {}",
eng.entry(*h1),
eng.entry(*h2)
),
}
println!();
if let Err(err) = ret {
res.push((handle, err));
eng.table.mark_unchecked(handle);
}
}
res
@ -99,11 +125,31 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
}
pub fn open_bset(&mut self) -> InferenceEngine<'_, 'a> {
InferenceEngine { bset: self.new_var(), ..self.scoped() }
InferenceEngine { bset: Default::default(), ..self.scoped() }
}
pub fn open_rset(&mut self) -> InferenceEngine<'_, 'a> {
InferenceEngine { rset: self.new_var(), ..self.scoped() }
InferenceEngine { rset: Default::default(), ..self.scoped() }
}
pub fn bset(&mut self, ty: Handle) -> Result<(), InferenceError> {
match self.bset.get() {
Some(bset) => self.unify(ty, bset),
None => {
self.bset.set(Some(ty));
Ok(())
}
}
}
pub fn rset(&mut self, ty: Handle) -> Result<(), InferenceError> {
match self.rset.get() {
Some(rset) => self.unify(ty, rset),
None => {
self.rset.set(Some(ty));
Ok(())
}
}
}
/// Constructs an [Entry] out of a [Handle], for ease of use
@ -111,12 +157,6 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
self.table.entry(of)
}
#[deprecated = "Use dedicated methods instead."]
pub fn from_type_kind(&mut self, kind: TypeKind) -> Handle {
// TODO: preserve type heirarchy (for, i.e., reference types)
self.table.anon_type(kind)
}
pub fn by_name<Out, N: TypeExpression<Out>>(
&mut self,
name: &N,
@ -129,6 +169,10 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
self.table.type_variable()
}
pub fn new_inferred(&mut self) -> Handle {
self.table.inferred_type()
}
/// Creates a variable that is a new instance of another [Type](Handle)
pub fn new_inst(&mut self, of: Handle) -> Handle {
self.table.anon_type(TypeKind::Instance(of))
@ -217,7 +261,7 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
/// Creates a new inferred-integer literal
pub fn integer_literal(&mut self) -> Handle {
let h = self.table.new_entry(self.at, NodeKind::Local);
let h = self.table.new_entry(self.at, NodeKind::Temporary);
self.table
.set_ty(h, TypeKind::Primitive(Primitive::Integer));
h
@ -225,20 +269,22 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
/// Creates a new inferred-float literal
pub fn float_literal(&mut self) -> Handle {
let h = self.table.new_entry(self.at, NodeKind::Local);
let h = self.table.new_entry(self.at, NodeKind::Temporary);
self.table.set_ty(h, TypeKind::Primitive(Primitive::Float));
h
}
/// Enters a new scope
pub fn local_scope(&mut self) {
let scope = self.table.new_entry(self.at, NodeKind::Local);
pub fn local_scope(&mut self, name: Sym) {
let scope = self.table.new_entry(self.at, NodeKind::Scope);
self.table.add_child(self.at, name, scope);
self.at = scope;
}
/// Creates a new locally-scoped InferenceEngine.
pub fn block_scope(&mut self) -> InferenceEngine<'_, 'a> {
let scope = self.table.new_entry(self.at, NodeKind::Local);
let scope = self.table.new_entry(self.at, NodeKind::Scope);
self.table.add_child(self.at, "".into(), scope);
self.at(scope)
}
@ -264,27 +310,44 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
/// Checks whether there are any unbound type variables in this type
pub fn is_generic(&self, ty: Handle) -> bool {
let entry = self.table.entry(ty);
let Some(ty) = entry.ty() else {
return false;
};
match ty {
TypeKind::Inferred => false,
TypeKind::Variable => true,
&TypeKind::Array(h, _) => self.is_generic(h),
&TypeKind::Instance(h) => self.is_generic(h),
TypeKind::Primitive(_) => false,
TypeKind::Adt(Adt::Enum(tys)) => tys.iter().any(|(_, ty)| self.is_generic(*ty)),
TypeKind::Adt(Adt::Struct(tys)) => tys.iter().any(|&(_, _, ty)| self.is_generic(ty)),
TypeKind::Adt(Adt::TupleStruct(tys)) => tys.iter().any(|&(_, ty)| self.is_generic(ty)),
TypeKind::Adt(Adt::UnitStruct) => false,
TypeKind::Adt(Adt::Union(tys)) => tys.iter().any(|&(_, ty)| self.is_generic(ty)),
&TypeKind::Ref(h) => self.is_generic(h),
&TypeKind::Slice(h) => self.is_generic(h),
TypeKind::Tuple(handles) => handles.iter().any(|&ty| self.is_generic(ty)),
&TypeKind::FnSig { args, rety } => self.is_generic(args) || self.is_generic(rety),
TypeKind::Empty | TypeKind::Never | TypeKind::Module => false,
fn is_generic_rec(this: &InferenceEngine, ty: Handle, seen: &mut HashSet<Handle>) -> bool {
if !seen.insert(ty) {
return false;
}
let entry = this.table.entry(ty);
let Some(ty) = entry.ty() else {
return false;
};
match ty {
TypeKind::Inferred => false,
TypeKind::Variable => true,
&TypeKind::Array(ty, _) => is_generic_rec(this, ty, seen),
&TypeKind::Instance(ty) => is_generic_rec(this, ty, seen),
TypeKind::Primitive(_) => false,
TypeKind::Adt(Adt::Enum(tys)) => {
tys.iter().any(|&(_, ty)| is_generic_rec(this, ty, seen))
}
TypeKind::Adt(Adt::Struct(tys)) => {
tys.iter().any(|&(_, _, ty)| is_generic_rec(this, ty, seen))
}
TypeKind::Adt(Adt::TupleStruct(tys)) => {
tys.iter().any(|&(_, ty)| is_generic_rec(this, ty, seen))
}
TypeKind::Adt(Adt::UnitStruct) => false,
TypeKind::Adt(Adt::Union(tys)) => {
tys.iter().any(|&(_, ty)| is_generic_rec(this, ty, seen))
}
&TypeKind::Ref(ty) => is_generic_rec(this, ty, seen),
&TypeKind::Ptr(ty) => is_generic_rec(this, ty, seen),
&TypeKind::Slice(ty) => is_generic_rec(this, ty, seen),
TypeKind::Tuple(tys) => tys.iter().any(|&ty| is_generic_rec(this, ty, seen)),
&TypeKind::FnSig { args, rety } => {
is_generic_rec(this, args, seen) || is_generic_rec(this, rety, seen)
}
TypeKind::Empty | TypeKind::Never | TypeKind::Module => false,
}
}
is_generic_rec(self, ty, &mut HashSet::new())
}
/// Makes a deep copy of a type expression.
@ -298,8 +361,10 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
let Some(ty) = entry.ty().cloned() else {
return ty;
};
// TODO: Parent the deep clone into a new "monomorphs" branch of tree
match ty {
TypeKind::Variable => self.new_var(),
TypeKind::Variable => self.new_inferred(),
TypeKind::Array(h, s) => {
let ty = self.deep_clone(h);
self.table.anon_type(TypeKind::Array(ty, s))
@ -396,6 +461,7 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
items.iter().any(|(_, other)| self.occurs_in(this, *other))
}
TypeKind::Ref(other) => self.occurs_in(this, *other),
TypeKind::Ptr(other) => self.occurs_in(this, *other),
TypeKind::Slice(other) => self.occurs_in(this, *other),
TypeKind::Array(other, _) => self.occurs_in(this, *other),
TypeKind::Tuple(handles) => handles.iter().any(|&other| self.occurs_in(this, other)),
@ -415,6 +481,9 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
/// Unifies two types
pub fn unify(&mut self, this: Handle, other: Handle) -> Result<(), InferenceError> {
let (ah, bh) = (self.prune(this), self.prune(other));
if ah == bh {
return Ok(());
}
let (a, b) = (self.table.entry(ah), self.table.entry(bh));
let (Some(a), Some(b)) = (a.ty(), b.ty()) else {
return Err(InferenceError::Mismatch(ah, bh));
@ -427,10 +496,7 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
}
(_, TypeKind::Inferred) => self.unify(bh, ah),
(TypeKind::Variable, _) => {
self.set_instance(ah, bh);
Ok(())
}
(TypeKind::Variable, _) => Err(InferenceError::Mismatch(ah, bh)),
(TypeKind::Instance(a), TypeKind::Instance(b)) if !self.occurs_in(*a, *b) => {
self.set_instance(*a, *b);
Ok(())
@ -467,6 +533,26 @@ impl<'table, 'a> InferenceEngine<'table, 'a> {
}
Ok(())
}
(TypeKind::Adt(Adt::Enum(en)), TypeKind::Adt(_)) => {
#[allow(unused)]
let Some(other_parent) = self.table.parent(bh) else {
Err(InferenceError::Mismatch(ah, bh))?
};
if ah != *other_parent {
Err(InferenceError::Mismatch(ah, *other_parent))?
}
#[allow(unused)]
for (sym, handle) in en {
let handle = self.def_usage(*handle);
if handle == bh {
return Ok(());
}
}
Err(InferenceError::Mismatch(ah, bh))
}
(TypeKind::Adt(Adt::Struct(ia)), TypeKind::Adt(Adt::Struct(ib)))
if ia.len() == ib.len() =>
{

View File

@ -31,3 +31,9 @@ impl fmt::Display for InferenceError {
}
}
}
impl From<crate::type_expression::Error> for InferenceError {
fn from(value: crate::type_expression::Error) -> Self {
Self::AnnotationEval(value)
}
}

View File

@ -7,7 +7,6 @@ use std::iter;
use super::{engine::InferenceEngine, error::InferenceError};
use crate::{
handle::Handle,
table::NodeKind,
type_expression::TypeExpression,
type_kind::{Adt, TypeKind},
};
@ -22,9 +21,227 @@ pub trait Inference<'a> {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult;
}
impl<'a> Inference<'a> for File {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Self { name: _, items } = self;
for item in items {
item.infer(e)?;
}
Ok(e.empty())
}
}
impl<'a> Inference<'a> for Item {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Self { span: _, attrs: _, vis: _, kind } = self;
kind.infer(e)
}
}
impl<'a> Inference<'a> for ItemKind {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
match self {
ItemKind::Module(v) => v.infer(e),
ItemKind::Alias(v) => v.infer(e),
ItemKind::Enum(v) => v.infer(e),
ItemKind::Struct(v) => v.infer(e),
ItemKind::Const(v) => v.infer(e),
ItemKind::Static(v) => v.infer(e),
ItemKind::Function(v) => v.infer(e),
ItemKind::Impl(v) => v.infer(e),
ItemKind::Use(_v) => Ok(e.empty()),
}
}
}
impl<'a> Inference<'a> for Generics {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
// bind names
for name in &self.vars {
let ty = e.new_var();
e.table.add_child(e.at, *name, ty);
}
Ok(e.empty())
}
}
impl<'a> Inference<'a> for Module {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Self { name, file } = self;
let Some(file) = file else {
return Err(InferenceError::NotFound((*name).into()));
};
let module = e.by_name(name)?;
e.at(module).infer(file)
}
}
impl<'a> Inference<'a> for Alias {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
Ok(e.empty())
}
}
impl<'a> Inference<'a> for Const {
#[allow(unused)]
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Self { name, ty, init } = self;
// Same as static
let node = e.by_name(name)?;
let ty = e.infer(ty)?;
let mut scope = e.at(node);
// infer body
let body = scope.infer(init)?;
// unify with ty
e.unify(body, ty)?;
Ok(node)
}
}
impl<'a> Inference<'a> for Static {
#[allow(unused)]
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Static { mutable, name, ty, init } = self;
let node = e.by_name(name)?;
let ty = e.infer(ty)?;
let mut scope = e.at(node);
// infer body
let body = scope.infer(init)?;
// unify with ty
e.unify(body, ty)?;
Ok(node)
}
}
impl<'a> Inference<'a> for Function {
#[allow(unused)]
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Self { name, gens, sign, bind, body } = self;
// bind name to signature
let node = e.by_name(name)?;
let node = e.deep_clone(node);
let fnty = e.by_name(sign)?;
e.unify(node, fnty)?;
// bind gens to new variables at function scope
let mut scope = e.at(node);
scope.infer(gens)?;
// bind binds to args
let pat = scope.infer(bind)?;
let arg = scope.by_name(sign.args.as_ref())?;
scope.unify(pat, arg);
let mut retscope = scope.open_rset();
// infer body
let bodty = retscope.infer(body)?;
let rety = sign.rety.infer(&mut retscope)?;
// unify body with rety
retscope.unify(bodty, rety)?;
// unify rset with rety
if let Some(rset) = retscope.rset.get() {
scope.unify(rset, rety)?;
}
Ok(node)
}
}
// TODO: do we need type inference/checking in struct definitions?
// there are no bodies
impl<'a> Inference<'a> for Enum {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Self { name, gens, variants } = self;
let node = e.by_name(name)?;
let mut scope = e.at(node);
scope.infer(gens)?;
for variant in variants {
let var_ty = scope.infer(variant)?;
scope.unify(var_ty, node)?;
}
Ok(node)
}
}
impl<'a> Inference<'a> for Variant {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Self { name: _, kind: _, body } = self;
let ty = e.new_inferred();
// TODO: evaluate kind
if let Some(body) = body {
let value = body.infer(e)?;
e.unify(ty, value)?;
}
Ok(ty)
}
}
impl<'a> Inference<'a> for Struct {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
Ok(e.new_inferred())
}
}
impl<'a> Inference<'a> for Impl {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Self { gens: _, target, body } = self;
// TODO: match gens to target gens
// gens.infer(e)?;
let instance = target.infer(e)?;
let instance = e.def_usage(instance);
let mut scope = e.at(instance);
scope.infer(body)
}
}
impl<'a> Inference<'a> for ImplKind {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
match self {
ImplKind::Type(ty) => ty.infer(e),
ImplKind::Trait { impl_trait: _, for_type } => for_type.infer(e),
}
}
}
impl<'a> Inference<'a> for Ty {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
Ok(e.by_name(self)?)
}
}
impl<'a> Inference<'a> for cl_ast::Stmt {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Self { span: _, kind, semi } = self;
let out = kind.infer(e)?;
Ok(match semi {
Semi::Terminated => e.empty(),
Semi::Unterminated => out,
})
}
}
impl<'a> Inference<'a> for cl_ast::StmtKind {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
match self {
StmtKind::Empty => Ok(e.empty()),
StmtKind::Item(item) => item.infer(e),
StmtKind::Expr(expr) => expr.infer(e),
}
}
}
impl<'a> Inference<'a> for cl_ast::Expr {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
self.kind.infer(e)
let out = self.kind.infer(e)?;
println!("expr ({self}) -> {}", e.entry(out));
Ok(out)
}
}
@ -32,7 +249,7 @@ impl<'a> Inference<'a> for cl_ast::ExprKind {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
match self {
ExprKind::Empty => Ok(e.empty()),
ExprKind::Closure(_) => todo!("Infer the type of a closure"),
ExprKind::Closure(closure) => closure.infer(e),
ExprKind::Tuple(tuple) => tuple.infer(e),
ExprKind::Structor(structor) => structor.infer(e),
ExprKind::Array(array) => array.infer(e),
@ -65,6 +282,23 @@ impl<'a> Inference<'a> for cl_ast::ExprKind {
}
}
impl<'a> Inference<'a> for Closure {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Self { arg, body } = self;
let args = arg.infer(e)?;
let mut scope = e.block_scope();
let mut scope = scope.open_rset();
let rety = scope.infer(body)?;
if let Some(rset) = scope.rset.get() {
e.unify(rety, rset)?;
}
Ok(e.table.anon_type(TypeKind::FnSig { args, rety }))
}
}
impl<'a> Inference<'a> for Tuple {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Tuple { exprs } = self;
@ -123,7 +357,7 @@ impl<'a> Inference<'a> for Structor {
impl<'a> Inference<'a> for Array {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let Array { values } = self;
let out = e.new_var();
let out = e.new_inferred();
for value in values {
let ty = value.infer(e)?;
e.unify(out, ty)?;
@ -136,7 +370,15 @@ impl<'a> Inference<'a> for ArrayRep {
fn infer(&'a self, e: &mut InferenceEngine<'_, 'a>) -> IfResult {
let ArrayRep { value, repeat } = self;
let ty = value.infer(e)?;
Ok(e.new_array(ty, *repeat))
let rep = repeat.infer(e)?;
let usize_ty = e.usize();
e.unify(rep, usize_ty)?;
match &repeat.kind {
ExprKind::Literal(Literal::Int(repeat)) => Ok(e.new_array(ty, *repeat as usize)),
_ => {
todo!("TODO: constant folding before type checking?");
}
}
}
}
@ -196,14 +438,13 @@ impl<'a> Inference<'a> for Block {
_ => {}
}
}
match (&ret.kind, &ret.semi) {
(StmtKind::Expr(expr), Semi::Terminated) => {
expr.infer(&mut e)?;
}
(StmtKind::Expr(expr), Semi::Unterminated) => {
return expr.infer(&mut e);
}
_ => {}
let out = if let StmtKind::Expr(expr) = &ret.kind {
expr.infer(&mut e)?
} else {
empty
};
if Semi::Unterminated == ret.semi {
return Ok(out);
}
}
Ok(empty)
@ -315,10 +556,11 @@ impl<'a> Inference<'a> for Unary {
match kind {
UnaryKind::Deref => {
let tail = tail.infer(e)?;
let tail = e.def_usage(tail);
// TODO: get the base type
match e.entry(tail).ty() {
Some(&TypeKind::Ref(h)) => Ok(h),
other => todo!("Deref {other:?}"),
_ => todo!("Deref {}", e.entry(tail)),
}
}
UnaryKind::Loop => {
@ -333,7 +575,10 @@ impl<'a> Inference<'a> for Unary {
e.unify(tail, empt)?;
// Return breakset
Ok(e.bset)
match e.bset.get() {
Some(bset) => Ok(bset),
None => Ok(e.never()),
}
}
_op => {
// Infer the tail expression
@ -455,7 +700,7 @@ impl<'a> Inference<'a> for Let {
Some(ty) => ty
.evaluate(e.table, e.at)
.map_err(InferenceError::AnnotationEval)?,
None => e.new_var(),
None => e.new_inferred(),
};
// Infer the initializer
if let Some(init) = init {
@ -465,8 +710,6 @@ impl<'a> Inference<'a> for Let {
}
// Deep copy the ty, if it exists
let ty = e.deep_clone(ty);
// Enter a local scope (modifies the current scope)
e.local_scope();
// Infer the pattern
let patty = name.infer(e)?;
// Unify the pattern and the ty
@ -513,11 +756,9 @@ impl<'a> Inference<'a> for Pattern {
Pattern::Name(name) => {
// Evaluating a pattern creates and enters a new scope.
// Surely this will cause zero problems.
let node = e.table.new_entry(e.at, NodeKind::Local);
e.table.set_ty(node, TypeKind::Variable);
e.table.add_child(e.at, *name, node);
e.at = node;
Ok(node)
e.local_scope(*name);
e.table.set_ty(e.at, TypeKind::Inferred);
Ok(e.at)
}
Pattern::Path(path) => {
// Evaluating a path pattern puts type constraints on the scrutinee
@ -525,8 +766,12 @@ impl<'a> Inference<'a> for Pattern {
.map_err(|_| InferenceError::NotFound(path.clone()))
}
Pattern::Literal(literal) => literal.infer(e),
Pattern::Rest(Some(pat)) => pat.infer(e), // <-- glaring soundness holes
Pattern::Rest(_) => todo!("Fix glaring soundness holes in pattern"),
Pattern::Rest(Some(pat)) => {
eprintln!("TODO: Rest patterns in tuples?");
let ty = pat.infer(e)?;
Ok(e.new_slice(ty))
}
Pattern::Rest(_) => Ok(e.new_inferred()),
Pattern::Ref(_, pattern) => {
let ty = pattern.infer(e)?;
Ok(e.new_ref(ty))
@ -561,12 +806,30 @@ impl<'a> Inference<'a> for Pattern {
Ok(e.new_slice(ty))
}
[] => {
let ty = e.new_var();
let ty = e.new_inferred();
Ok(e.new_slice(ty))
}
},
Pattern::Struct(_path, _items) => todo!("Struct patterns"),
Pattern::TupleStruct(_path, _patterns) => todo!("Tuple struct patterns"),
Pattern::Struct(_path, _items) => {
eprintln!("TODO: struct patterns: {self}");
Ok(e.empty())
}
Pattern::TupleStruct(path, patterns) => {
eprintln!("TODO: tuple struct patterns: {self}");
let struc = e.by_name(path)?;
let Some(TypeKind::Adt(Adt::TupleStruct(ts))) = e.entry(struc).ty() else {
Err(InferenceError::Mismatch(struc, e.never()))?
};
let ts: Vec<_> = ts.iter().map(|(_v, h)| *h).collect();
let tys = patterns
.iter()
.map(|pat| pat.infer(e))
.collect::<Result<Vec<Handle>, InferenceError>>()?;
let ts = e.new_tuple(ts);
let tup = e.new_tuple(tys);
e.unify(ts, tup)?;
Ok(struc)
}
}
}
}
@ -583,7 +846,7 @@ impl<'a> Inference<'a> for While {
// Infer the fail branch
let fail = fail.infer(e)?;
// Unify the fail branch with breakset
let mut e = InferenceEngine { bset: fail, ..e.scoped() };
let mut e = e.open_bset();
// Infer the pass branch
let pass = pass.infer(&mut e)?;
@ -591,8 +854,13 @@ impl<'a> Inference<'a> for While {
let empt = e.empty();
e.unify(pass, empt)?;
// Return breakset
Ok(e.bset)
match e.bset.get() {
None => Ok(e.empty()),
Some(bset) => {
e.unify(fail, bset)?;
Ok(fail)
}
}
}
}
@ -638,9 +906,8 @@ impl<'a> Inference<'a> for For {
// Infer the fail branch
let fail = fail.infer(&mut e)?;
// Unify the fail branch with breakset
let mut e = InferenceEngine { bset: fail, ..e.scoped() };
e.bset = fail;
// Open a breakset
let mut e = e.open_bset();
// Infer the pass branch
let pass = pass.infer(&mut e)?;
@ -649,7 +916,12 @@ impl<'a> Inference<'a> for For {
e.unify(pass, empt)?;
// Return breakset
Ok(e.bset)
if let Some(bset) = e.bset.get() {
e.unify(fail, bset)?;
Ok(fail)
} else {
Ok(e.empty())
}
}
}
@ -665,7 +937,7 @@ impl<'a> Inference<'a> for Break {
// Infer the body of the break
let ty = body.infer(e)?;
// Unify it with the breakset of the loop
e.unify(ty, e.bset)?;
e.bset(ty)?;
// Return never
Ok(e.never())
}
@ -677,7 +949,7 @@ impl<'a> Inference<'a> for Return {
// Infer the body of the return
let ty = body.infer(e)?;
// Unify it with the return-set of the function
e.unify(ty, e.rset)?;
e.rset(ty)?;
// Return never
Ok(e.never())
}

View File

@ -94,27 +94,16 @@ impl<'a> Visit<'a> for Populator<'_, 'a> {
}
fn visit_static(&mut self, s: &'a cl_ast::Static) {
let cl_ast::Static { mutable, name, ty, init } = s;
let cl_ast::Static { name, init, .. } = s;
self.inner.set_source(Source::Static(s));
self.inner.set_body(init);
self.set_name(*name);
self.visit(mutable);
self.visit(ty);
self.visit(init);
}
fn visit_module(&mut self, m: &'a cl_ast::Module) {
let cl_ast::Module { name, file } = m;
self.inner.set_source(Source::Module(m));
self.set_name(*name);
self.visit(file);
s.children(self);
}
fn visit_function(&mut self, f: &'a cl_ast::Function) {
let cl_ast::Function { name, gens, sign, bind, body } = f;
// TODO: populate generics?
self.inner.set_source(Source::Function(f));
self.set_name(*name);
@ -128,6 +117,14 @@ impl<'a> Visit<'a> for Populator<'_, 'a> {
}
}
fn visit_module(&mut self, m: &'a cl_ast::Module) {
let cl_ast::Module { name, file } = m;
self.inner.set_source(Source::Module(m));
self.set_name(*name);
self.visit(file);
}
fn visit_struct(&mut self, s: &'a cl_ast::Struct) {
let cl_ast::Struct { name, gens, kind } = s;
self.inner.set_source(Source::Struct(s));
@ -143,12 +140,14 @@ impl<'a> Visit<'a> for Populator<'_, 'a> {
self.set_name(*name);
self.visit(gens);
self.visit(variants);
let mut children = Vec::new();
for variant in variants.iter() {
let mut entry = self.new_entry(NodeKind::Type);
variant.visit_in(&mut entry);
children.push((variant.name, entry.inner.id()));
let child = entry.inner.id();
children.push((variant.name, child));
self.inner.add_child(variant.name, child);
}
self.inner
.set_ty(TypeKind::Adt(crate::type_kind::Adt::Enum(children)));
@ -156,23 +155,27 @@ impl<'a> Visit<'a> for Populator<'_, 'a> {
fn visit_variant(&mut self, value: &'a cl_ast::Variant) {
let cl_ast::Variant { name, kind, body } = value;
let mut entry = self.new_entry(NodeKind::Type);
entry.inner.set_source(Source::Variant(value));
entry.visit(kind);
self.inner.set_source(Source::Variant(value));
self.set_name(*name);
self.visit(kind);
if let Some(body) = body {
entry.inner.set_body(body);
self.inner.set_body(body);
}
let child = entry.inner.id();
self.inner.add_child(*name, child);
}
fn visit_impl(&mut self, i: &'a cl_ast::Impl) {
let cl_ast::Impl { target, body } = i;
let cl_ast::Impl { gens, target: _, body } = i;
self.inner.set_source(Source::Impl(i));
self.inner.mark_impl_item();
self.visit(target);
// We don't know if target is generic yet -- that's checked later.
for generic in &gens.vars {
let mut entry = self.new_entry(NodeKind::Type);
entry.inner.set_ty(TypeKind::Inferred);
let child = entry.inner.id();
self.inner.add_child(*generic, child);
}
self.visit(body);
}

View File

@ -57,8 +57,10 @@ pub struct Table<'a> {
sources: HashMap<Handle, Source<'a>>,
impl_targets: HashMap<Handle, Handle>,
anon_types: HashMap<TypeKind, Handle>,
lang_items: HashMap<Sym, Handle>,
// --- Queues for algorithms ---
pub(crate) unchecked: Vec<Handle>,
pub(crate) impls: Vec<Handle>,
pub(crate) uses: Vec<Handle>,
}
@ -84,6 +86,8 @@ impl<'a> Table<'a> {
sources: HashMap::new(),
impl_targets: HashMap::new(),
anon_types: HashMap::new(),
lang_items: HashMap::new(),
unchecked: Vec::new(),
impls: Vec::new(),
uses: Vec::new(),
}
@ -111,6 +115,10 @@ impl<'a> Table<'a> {
self.imports.entry(parent).or_default().insert(name, import)
}
pub fn mark_unchecked(&mut self, item: Handle) {
self.unchecked.push(item);
}
pub fn mark_use_item(&mut self, item: Handle) {
let parent = self.parents[item];
self.use_items.entry(parent).or_default().push(item);
@ -121,6 +129,10 @@ impl<'a> Table<'a> {
self.impls.push(item);
}
pub fn mark_lang_item(&mut self, name: Sym, item: Handle) {
self.lang_items.insert(name, item);
}
pub fn handle_iter(&self) -> impl Iterator<Item = Handle> + use<> {
self.kinds.keys()
}
@ -209,7 +221,12 @@ impl<'a> Table<'a> {
self.impl_targets.get(&node).copied()
}
pub fn reparent(&mut self, node: Handle, parent: Handle) -> Handle {
self.parents.replace(node, parent)
}
pub fn set_body(&mut self, node: Handle, body: &'a Expr) -> Option<&'a Expr> {
self.mark_unchecked(node);
self.bodies.insert(node, body)
}
@ -311,7 +328,8 @@ pub enum NodeKind {
Static,
Function,
Temporary,
Local,
Let,
Scope,
Impl,
Use,
}
@ -329,7 +347,8 @@ mod display {
NodeKind::Static => write!(f, "static"),
NodeKind::Function => write!(f, "fn"),
NodeKind::Temporary => write!(f, "temp"),
NodeKind::Local => write!(f, "local"),
NodeKind::Let => write!(f, "let"),
NodeKind::Scope => write!(f, "scope"),
NodeKind::Use => write!(f, "use"),
NodeKind::Impl => write!(f, "impl"),
}