resolver: Improve block scoping with a scope guard

This commit is contained in:
John 2024-01-06 14:32:06 -06:00
parent 53f9ec2356
commit 77f7623041

View File

@ -3,9 +3,50 @@
//! This will hopefully become a fully fledged static resolution pass in the future //! This will hopefully become a fully fledged static resolution pass in the future
use std::collections::HashMap; use std::collections::HashMap;
#[allow(unused_imports)]
use crate::ast::preamble::*; use crate::ast::preamble::*;
use scopeguard::Scoped;
pub mod scopeguard {
//! Implements a generic RAII scope-guard
use std::ops::{Deref, DerefMut};
pub trait Scoped: Sized {
fn frame(&mut self) -> Guard<Self> {
Guard::new(self)
}
///
fn enter_scope(&mut self);
fn exit_scope(&mut self);
}
pub struct Guard<'scope, T: Scoped> {
inner: &'scope mut T,
}
impl<'scope, T: Scoped> Guard<'scope, T> {
pub fn new(inner: &'scope mut T) -> Self {
inner.enter_scope();
Self { inner }
}
}
impl<'scope, T: Scoped> Deref for Guard<'scope, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner
}
}
impl<'scope, T: Scoped> DerefMut for Guard<'scope, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner
}
}
impl<'scope, T: Scoped> Drop for Guard<'scope, T> {
fn drop(&mut self) {
self.inner.exit_scope()
}
}
}
/// Prints like [println] if `debug_assertions` are enabled /// Prints like [println] if `debug_assertions` are enabled
macro debugln($($t:tt)*) { macro debugln($($t:tt)*) {
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
@ -309,11 +350,20 @@ pub struct Resolver {
types: Vec<Type>, types: Vec<Type>,
} }
impl Scoped for Resolver {
fn enter_scope(&mut self) {
self.enter_scope();
}
fn exit_scope(&mut self) {
self.exit_scope();
}
}
impl Default for Resolver { impl Default for Resolver {
fn default() -> Self { fn default() -> Self {
let mut new = Self { let mut new = Self {
count: Default::default(), count: Default::default(),
scopes: Default::default(), scopes: vec![Default::default()],
modules: Default::default(), modules: Default::default(),
module: Default::default(), module: Default::default(),
types: Default::default(), types: Default::default(),
@ -419,16 +469,6 @@ impl Resolver {
self.get_mut(name)?.assign(name, ty) self.get_mut(name)?.assign(name, ty)
} }
} }
/// Manages a sub-function lexical scope
/// ```rust,ignore
/// macro scope(self, inner: {...}) -> Result<_, Error>
/// ```
macro scope($self:ident, $inner:tt ) {{
$self.enter_scope();
let scope = (|| $inner)();
$self.exit_scope();
scope
}}
/// Manages a module scope /// Manages a module scope
/// ```rust,ignore /// ```rust,ignore
/// macro module(self, name: &str, inner: {...}) -> Result<_, Error> /// macro module(self, name: &str, inner: {...}) -> Result<_, Error>
@ -508,17 +548,16 @@ impl Resolve for FnDecl {
// create a new lexical scope // create a new lexical scope
let scopes = std::mem::take(&mut resolver.scopes); let scopes = std::mem::take(&mut resolver.scopes);
// type-check the function body // type-check the function body
let out = scope!(resolver, { let out = {
let mut resolver = resolver.frame();
let mut evaluated_args = vec![]; let mut evaluated_args = vec![];
for arg in args { for arg in args {
evaluated_args.push(arg.resolve(resolver)?) evaluated_args.push(arg.resolve(&mut resolver)?)
} }
// TODO: proper typing for return addresses
let fn_decl = Type::Fn { args: evaluated_args.clone(), ret: Box::new(Type::Empty) }; let fn_decl = Type::Fn { args: evaluated_args.clone(), ret: Box::new(Type::Empty) };
resolver.get_mut(name)?.assign(name, &fn_decl)?; resolver.get_mut(name)?.assign(name, &fn_decl)?;
// Enter the new module module!(resolver, name, { body.resolve(&mut resolver) })
module!(resolver, name, { body.resolve(resolver) }) };
});
let _ = std::mem::replace(&mut resolver.scopes, scopes); let _ = std::mem::replace(&mut resolver.scopes, scopes);
out out
} }
@ -531,12 +570,11 @@ impl Resolve for Name {
impl Resolve for Block { impl Resolve for Block {
fn resolve(&mut self, resolver: &mut Resolver) -> TyResult<Type> { fn resolve(&mut self, resolver: &mut Resolver) -> TyResult<Type> {
let Block { let_count: _, statements, expr } = self; let Block { let_count: _, statements, expr } = self;
scope!(resolver, { let mut resolver = resolver.frame();
for stmt in statements { for stmt in statements {
stmt.resolve(resolver)?; stmt.resolve(&mut resolver)?;
} }
expr.resolve(resolver) expr.resolve(&mut resolver)
})
} }
} }
impl Resolve for Expr { impl Resolve for Expr {
@ -774,12 +812,13 @@ impl Resolve for For {
Type::Range(t) => t, Type::Range(t) => t,
got => Err(Error::TypeMismatch { want: Type::Range(Type::Inferred.into()), got })?, got => Err(Error::TypeMismatch { want: Type::Range(Type::Inferred.into()), got })?,
}; };
let body_ty = scope!(resolver, { let body_ty = {
let mut resolver = resolver.frame();
// bind the variable in the loop scope // bind the variable in the loop scope
*index = Some(resolver.insert_scope(name, false)?); *index = Some(resolver.insert_scope(name, false)?);
resolver.get_mut(name)?.assign(name, &ty)?; resolver.get_mut(name)?.assign(name, &ty)?;
body.resolve(resolver) body.resolve(&mut resolver)
})?; }?;
// visit the else block // visit the else block
let else_ty = else_.resolve(resolver)?; let else_ty = else_.resolve(resolver)?;
if body_ty != else_ty { if body_ty != else_ty {