From ce928b7bd22ee581a34b5fe2f03655dd9949d156 Mon Sep 17 00:00:00 2001 From: Jeremy Wall Date: Thu, 26 Oct 2023 20:02:21 -0400 Subject: [PATCH] feat: let statement inference and also module defs --- flake.nix | 9 ++--- src/ast/mod.rs | 4 -- src/ast/test.rs | 16 ++++---- src/ast/typecheck/mod.rs | 63 ++++++++++++++++++++++++++++++-- src/ast/typecheck/simple_mod.ucg | 3 ++ src/ast/typecheck/test.rs | 48 +++++++++++++++++++++++- 6 files changed, 121 insertions(+), 22 deletions(-) create mode 100644 src/ast/typecheck/simple_mod.ucg diff --git a/flake.nix b/flake.nix index e715988..176aa29 100644 --- a/flake.nix +++ b/flake.nix @@ -10,12 +10,12 @@ }; naersk.url = "github:nix-community/naersk"; flake-compat = { - url = github:edolstra/flake-compat; + url = "github:edolstra/flake-compat"; flake = false; }; }; - outputs = {self, nixpkgs, flake-utils, rust-overlay, naersk, flake-compat}: + outputs = {nixpkgs, flake-utils, rust-overlay, naersk, ...}: flake-utils.lib.eachDefaultSystem (system: let overlays = [ rust-overlay.overlays.default ]; @@ -29,8 +29,7 @@ rustc = rust-bin; cargo = rust-bin; }; - ucg = with pkgs; - naersk-lib.buildPackage rec { + ucg = naersk-lib.buildPackage rec { pname = "ucg"; version = "0.7.3"; src = ./.; @@ -47,4 +46,4 @@ program = "${ucg}/bin/ucg"; }; }); -} \ No newline at end of file +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 7102c77..e86fe4d 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -934,10 +934,6 @@ impl ModuleDef { pub fn set_out_expr(&mut self, expr: Expression) { self.out_expr = Some(Box::new(expr)); } - - pub fn derive_shape(&self, symbol_table: &mut BTreeMap, Shape>) -> Shape { - todo!() - } } /// RangeDef defines a range with optional step. diff --git a/src/ast/test.rs b/src/ast/test.rs index dbbe471..5e6ad3b 100644 --- a/src/ast/test.rs +++ b/src/ast/test.rs @@ -86,10 +86,10 @@ fn derive_shape_values() { #[test] fn derive_shape_expressions() { let expr_cases = vec![ - ("3;", Shape::Int(Position::new(0, 0, 0))), - ("(3);", Shape::Int(Position::new(0, 0, 0))), - ("\"foo {}\" % (1);", Shape::Str(Position::new(0, 0, 0))), - ("not true;", Shape::Boolean(Position::new(1, 0, 0))), + ("3;", Shape::Int(Position::new(1, 1, 0))), + ("(3);", Shape::Int(Position::new(1, 2, 1))), + ("\"foo {}\" % (1);", Shape::Str(Position::new(1, 1, 0))), + ("not true;", Shape::Boolean(Position::new(1, 1, 0))), ( "0:1;", Shape::List(NarrowedShape::new_with_pos( @@ -97,10 +97,10 @@ fn derive_shape_expressions() { Position::new(1, 1, 0), )), ), - ("int(\"1\");", Shape::Int(Position::new(0, 0, 0))), - ("float(1);", Shape::Float(Position::new(0, 0, 0))), - ("str(1);", Shape::Str(Position::new(0, 0, 0))), - ("bool(\"true\");", Shape::Boolean(Position::new(0, 0, 0))), + ("int(\"1\");", Shape::Int(Position::new(1, 1, 0))), + ("float(1);", Shape::Float(Position::new(1, 1, 0))), + ("str(1);", Shape::Str(Position::new(1, 1, 0))), + ("bool(\"true\");", Shape::Boolean(Position::new(1, 1, 0))), ("1 + 1;", Shape::Int(Position::new(1, 1, 0))), ]; diff --git a/src/ast/typecheck/mod.rs b/src/ast/typecheck/mod.rs index f467410..5107cc0 100644 --- a/src/ast/typecheck/mod.rs +++ b/src/ast/typecheck/mod.rs @@ -13,17 +13,17 @@ // limitations under the License. //! Implements typechecking for the parsed ucg AST. -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use std::rc::Rc; -use crate::ast::walk::Visitor; +use crate::ast::walk::{Visitor, Walker}; use crate::ast::{ Expression, FailDef, FuncShapeDef, ImportDef, IncludeDef, Shape, Statement, Value, }; use crate::error::{BuildError, ErrorType}; use super::{ - BinaryExprType, BinaryOpDef, CastType, CopyDef, FuncDef, ImportShape, ModuleShape, + BinaryExprType, BinaryOpDef, CastType, CopyDef, FuncDef, ImportShape, ModuleDef, ModuleShape, NarrowedShape, NotDef, Position, PositionedItem, SelectDef, }; @@ -67,6 +67,48 @@ impl DeriveShape for FuncDef { } } +impl DeriveShape for ModuleDef { + fn derive_shape(&self, symbol_table: &mut BTreeMap, Shape>) -> Shape { + let sym_table: BTreeMap, Shape> = self + .arg_set + .iter() + .map(|(tok, expr)| (tok.fragment.clone(), expr.derive_shape(symbol_table))) + .collect(); + let sym_positions: BTreeSet>> = + self.arg_set.iter().map(|(tok, _)| tok.into()).collect(); + let mut checker = Checker::new().with_symbol_table(sym_table); + checker.walk_statement_list(self.statements.clone().iter_mut().collect()); + if let Some(mut expr) = self.out_expr.clone() { + checker.walk_expression(&mut expr); + } else { + // TODO(jwall): We need to construct a tuple from the let statements here. + } + let ret = Box::new( + checker + .pop_shape() + .expect("There should always be a return type here"), + ); + let mut items = Vec::new(); + let sym_table = checker + .result() + .expect("There should aways be a symbol_table here"); + for pos_key in sym_positions { + let key = pos_key.val.clone(); + items.push(( + pos_key, + sym_table + .get(&key) + .expect("This should always have a valid shape") + .clone(), + )); + } + Shape::Module(ModuleShape { + items, + ret, + }) + } +} + impl DeriveShape for SelectDef { fn derive_shape(&self, symbol_table: &mut BTreeMap, Shape>) -> Shape { let SelectDef { @@ -455,7 +497,20 @@ impl Visitor for Checker { } fn visit_statement(&mut self, _stmt: &mut Statement) { - // noop by default + if let Statement::Let(def) = _stmt { + let name = def.name.fragment.clone(); + let shape = def.value.derive_shape(&mut self.symbol_table); + if let Shape::TypeErr(pos, msg) = &shape { + self.err_stack.push(BuildError::with_pos( + msg.clone(), + ErrorType::TypeFail, + pos.clone(), + )); + } else { + self.symbol_table.insert(name.clone(), shape.clone()); + self.shape_stack.push(shape); + } + } } fn leave_statement(&mut self, stmt: &Statement) { diff --git a/src/ast/typecheck/simple_mod.ucg b/src/ast/typecheck/simple_mod.ucg new file mode 100644 index 0000000..4b88ffb --- /dev/null +++ b/src/ast/typecheck/simple_mod.ucg @@ -0,0 +1,3 @@ +module{ +} => (1) { +}; diff --git a/src/ast/typecheck/test.rs b/src/ast/typecheck/test.rs index bce719a..ecfaabe 100644 --- a/src/ast/typecheck/test.rs +++ b/src/ast/typecheck/test.rs @@ -1,6 +1,6 @@ use std::convert::Into; -use abortable_parser::{Positioned, SliceIter}; +use abortable_parser::SliceIter; use crate::ast::walk::Walker; use crate::ast::{Position, PositionedItem}; @@ -36,6 +36,17 @@ macro_rules! assert_type_success { assert!(maybe_shape.is_some(), "We got a shape out of it"); assert_eq!(maybe_shape.unwrap(), $shape); }}; + ($e:expr, $shape:expr, $sym_table:expr, $expected_sym:expr) => {{ + let mut checker = Checker::new().with_symbol_table($sym_table); + let mut expr = parse($e.into(), None).unwrap(); + checker.walk_statement_list(expr.iter_mut().collect()); + let maybe_shape = checker.pop_shape(); + assert_eq!(checker.symbol_table[$expected_sym], $shape); + let result = checker.result(); + assert!(result.is_ok(), "We expect this to typecheck successfully."); + assert!(maybe_shape.is_some(), "We got a shape out of it"); + assert_eq!(maybe_shape.unwrap(), $shape); + }}; } #[test] @@ -545,3 +556,38 @@ fn func_type_equivalence() { .derive_shape(&mut symbol_table); assert!(dbg!(shape1.equivalent(&shape2, &mut symbol_table))); } + +#[test] +fn let_stmt_inference() { + let int_stmt = "let foo = 1;"; + assert_type_success!( + int_stmt, + Shape::Int(Position::new(1, 11, 10)), + BTreeMap::new(), + "foo".into() + ); + let float_stmt = "let foo = 1.0;"; + assert_type_success!( + float_stmt, + Shape::Float(Position::new(1, 11, 10)), + BTreeMap::new(), + "foo".into() + ); +} + +#[test] +fn test_module_inference() { + let simple_mod_stmt = include_str!("simple_mod.ucg"); + assert_type_success!( + simple_mod_stmt, + Shape::Module(ModuleShape { + items: vec![], + ret: Box::new(Shape::Int(Position { + file: None, + line: 2, + column: 7, + offset: 14 + })) + }) + ) +}