From 3a6a646e556af8feb040df0e16c5231fc7f654bb Mon Sep 17 00:00:00 2001 From: Jeremy Wall Date: Mon, 2 Oct 2023 19:59:47 -0400 Subject: [PATCH] feat: tests and code to infer select expression shapes --- src/ast/mod.rs | 72 ++++++++- src/ast/typecheck/mod.rs | 30 +++- src/ast/typecheck/test.rs | 307 +++++++++++++++++++++++++++++++++++--- 3 files changed, 380 insertions(+), 29 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 08a9034..4d8847c 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -272,6 +272,15 @@ impl NarrowedShape { self.pos = pos; self } + + pub fn merge_in_shape(&mut self, shape: Shape, symbol_table: &mut BTreeMap, Shape>) { + for s in self.types.iter() { + if s.equivalent(&shape, symbol_table) { + return; + } + } + self.types.push(shape) + } } // TODO(jwall): Display implementations for shapes. @@ -293,8 +302,69 @@ pub enum Shape { } impl Shape { + pub fn equivalent(&self, right: &Shape, symbol_table: &BTreeMap, Shape>) -> bool { + match (self, right) { + (Shape::Str(_), Shape::Str(_)) + | (Shape::Boolean(_), Shape::Boolean(_)) + | (Shape::Int(_), Shape::Int(_)) + | (Shape::Hole(_), Shape::Hole(_)) + | (Shape::Float(_), Shape::Float(_)) => true, + (Shape::Narrowed(left_slist), Shape::Narrowed(right_slist)) + | (Shape::List(left_slist), Shape::List(right_slist)) => { + for ls in left_slist.types.iter() { + let mut found = false; + for rs in right_slist.types.iter() { + if ls.equivalent(rs, symbol_table) { + found = true; + break; + } + } + if !found { + return false; + } + } + true + } + (Shape::Tuple(left_slist), Shape::Tuple(right_slist)) => { + for (lt, ls) in left_slist.val.iter() { + let mut found = false; + for (rt, rs) in right_slist.val.iter() { + if lt.fragment == rt.fragment && ls.equivalent(rs, symbol_table) { + found = true; + } + } + if !found { + return false; + } + } + true + } + (Shape::Func(left_opshape), Shape::Func(right_opshape)) => { + if left_opshape.args.len() != right_opshape.args.len() { + return false; + } + let left_args: Vec<&Shape> = dbg!(left_opshape.args.values().collect()); + let right_args: Vec<&Shape> = dbg!(right_opshape.args.values().collect()); + for idx in 0..left_args.len() { + let shap = left_args[idx]; + if !shap.equivalent(right_args[idx], symbol_table) + { + return false; + } + } + if !&left_opshape.ret.equivalent(&right_opshape.ret, symbol_table) { + return false; + } + true + } + (Shape::Module(left_opshape), Shape::Module(right_opshape)) => { + todo!(); + } + _ => false, + } + } + pub fn narrow(&self, right: &Shape, symbol_table: &mut BTreeMap, Shape>) -> Self { - dbg!((self, right)); match (self, right) { (Shape::Str(_), Shape::Str(_)) | (Shape::Boolean(_), Shape::Boolean(_)) diff --git a/src/ast/typecheck/mod.rs b/src/ast/typecheck/mod.rs index ea6434b..13555dc 100644 --- a/src/ast/typecheck/mod.rs +++ b/src/ast/typecheck/mod.rs @@ -14,6 +14,7 @@ //! Implements typechecking for the parsed ucg AST. use std::collections::BTreeMap; +use std::default; use std::rc::Rc; use crate::ast::walk::Visitor; @@ -24,6 +25,7 @@ use crate::error::{BuildError, ErrorType}; use super::{ CastType, CopyDef, FuncDef, ImportShape, ModuleShape, NarrowedShape, NotDef, PositionedItem, + SelectDef, Position, }; /// Trait for shape derivation. @@ -66,6 +68,18 @@ impl DeriveShape for FuncDef { } } +impl DeriveShape for SelectDef { + fn derive_shape(&self, symbol_table: &mut BTreeMap, Shape>) -> Shape { + let SelectDef { val: _, default: _, tuple, pos: _ } = self; + let mut narrowed_shape = NarrowedShape { pos: self.pos.clone(), types: Vec::with_capacity(tuple.len()) }; + for (_, expr) in tuple { + let shape = expr.derive_shape(symbol_table); + narrowed_shape.merge_in_shape(shape, symbol_table); + } + Shape::Narrowed(narrowed_shape) + } +} + fn derive_include_shape( IncludeDef { pos, @@ -229,7 +243,7 @@ impl DeriveShape for Expression { Expression::Include(def) => derive_include_shape(def), Expression::Call(_) => todo!(), Expression::Func(def) => def.derive_shape(symbol_table), - Expression::Select(_) => todo!(), + Expression::Select(def) => def.derive_shape(symbol_table), Expression::FuncOp(_) => todo!(), Expression::Module(_) => todo!(), Expression::Fail(_) => todo!(), @@ -254,11 +268,7 @@ impl DeriveShape for Value { } } Value::Tuple(flds) => { - let mut field_shapes = Vec::new(); - for &(ref tok, ref expr) in &flds.val { - field_shapes.push((tok.clone(), expr.derive_shape(symbol_table))); - } - Shape::Tuple(PositionedItem::new(field_shapes, flds.pos.clone())) + derive_field_list_shape(&flds.val, &flds.pos, symbol_table) } Value::List(flds) => { let mut field_shapes = Vec::new(); @@ -271,6 +281,14 @@ impl DeriveShape for Value { } } +fn derive_field_list_shape(flds: &Vec<(super::Token, Expression)>, pos: &Position, symbol_table: &mut BTreeMap, Shape>) -> Shape { + let mut field_shapes = Vec::new(); + for &(ref tok, ref expr) in flds { + field_shapes.push((tok.clone(), expr.derive_shape(symbol_table))); + } + Shape::Tuple(PositionedItem::new(field_shapes, pos.clone())) +} + pub struct Checker { symbol_table: BTreeMap, Shape>, err_stack: Vec, diff --git a/src/ast/typecheck/test.rs b/src/ast/typecheck/test.rs index f4a3a70..7fc6fb1 100644 --- a/src/ast/typecheck/test.rs +++ b/src/ast/typecheck/test.rs @@ -15,7 +15,7 @@ macro_rules! assert_type_fail { let mut checker = Checker::new(); let mut expr = parse($e.into(), None).unwrap(); checker.walk_statement_list(expr.iter_mut().collect()); - let result = dbg!(checker.result()); + let result = checker.result(); assert!(result.is_err(), "We expect this to fail a typecheck."); let err = result.unwrap_err(); assert_eq!(err.msg, $msg); @@ -155,7 +155,6 @@ macro_rules! infer_symbol_test { let expr = expression(token_iter); if let abortable_parser::Result::Complete(_, mut expr) = expr { checker.walk_expression(&mut expr); - dbg!(&checker.symbol_table); for (sym, shape) in $sym_list { assert_eq!( checker.symbol_table[&sym], @@ -228,20 +227,20 @@ fn infer_func_type_test() { symbol_table.insert( bar.clone(), Shape::Int(Position { - file: None, - line: 1, - column: 20, - offset: 19, + file: None, + line: 1, + column: 20, + offset: 19, }), ); let mut args = BTreeMap::new(); args.insert( foo.clone(), Shape::Int(Position { - file: None, - line: 1, - column: 6, - offset: 5, + file: None, + line: 1, + column: 6, + offset: 5, }), ); assert_type_success!( @@ -249,10 +248,10 @@ fn infer_func_type_test() { Shape::Func(FuncShapeDef { args: args, ret: Shape::Int(Position { - file: None, - line: 1, - column: 1, - offset: 0, + file: None, + line: 1, + column: 1, + offset: 0, }) .into() }), @@ -260,11 +259,275 @@ fn infer_func_type_test() { ); } -//#[test] -//fn infer_select_shape() { -// assert_type_success!( -// r#"select () => { true = "foo", false = 1 }"#, -// Shape::Narrowed(NarrowedShape { pos: Position { file: None, line: 0, column: 0, offset: 0 }, types: vec![ -// Shape::Str(PositionedItem { pos: , val: () }) -// ] })) -//} +#[test] +fn infer_select_shape() { + assert_type_success!( + r#"select (foo) => { true = "foo", false = 1 };"#, + Shape::Narrowed(NarrowedShape { + pos: Position { + file: None, + line: 1, + column: 1, + offset: 0 + }, + types: vec![ + Shape::Str(Position::new(1, 26, 25)), + Shape::Int(Position::new(1, 41, 40)), + ] + }) + ); + assert_type_success!( + r#"select (foo) => { true = "foo", false = { } };"#, + Shape::Narrowed(NarrowedShape { + pos: Position { + file: None, + line: 1, + column: 1, + offset: 0 + }, + types: vec![ + Shape::Str(Position::new(1, 26, 25)), + Shape::Tuple(PositionedItem { + pos: Position { + file: None, + line: 1, + column: 41, + offset: 40 + }, + val: vec![], + }), + ], + }) + ); + assert_type_success!( + r#"select (foo) => { true = { }, false = { } };"#, + Shape::Narrowed(NarrowedShape { + pos: Position { + file: None, + line: 1, + column: 1, + offset: 0 + }, + types: vec![Shape::Tuple(PositionedItem { + pos: Position { + file: None, + line: 1, + column: 26, + offset: 25 + }, + val: vec![], + }),], + }) + ); + assert_type_success!( + r#"select (foo) => { true = { foo = 1, }, false = { bar = "foo" } };"#, + Shape::Narrowed(NarrowedShape { + pos: Position { + file: None, + line: 1, + column: 1, + offset: 0 + }, + types: vec![ + Shape::Tuple(PositionedItem { + pos: Position { + file: None, + line: 1, + column: 26, + offset: 25 + }, + val: vec![( + Token { + typ: TokenType::BAREWORD, + fragment: "foo".into(), + pos: Position { + file: None, + line: 1, + column: 28, + offset: 27 + } + }, + Shape::Int(Position { + file: None, + line: 1, + column: 34, + offset: 33 + }) + )] + }), + Shape::Tuple(PositionedItem { + pos: Position { + file: None, + line: 1, + column: 48, + offset: 47 + }, + val: vec![( + Token { + typ: TokenType::BAREWORD, + fragment: "bar".into(), + pos: Position { + file: None, + line: 1, + column: 50, + offset: 49 + } + }, + Shape::Str(Position { + file: None, + line: 1, + column: 56, + offset: 55 + }) + )] + }) + ] + }) + ); + assert_type_success!( + r#"select (foo) => { true = { foo = 1, bar = "quux" }, false = { bar = "foo" } };"#, + Shape::Narrowed(NarrowedShape { + pos: Position { + file: None, + line: 1, + column: 1, + offset: 0 + }, + types: vec![ + Shape::Tuple(PositionedItem { + pos: Position { + file: None, + line: 1, + column: 26, + offset: 25 + }, + val: vec![ + ( + Token { + typ: TokenType::BAREWORD, + fragment: "foo".into(), + pos: Position { + file: None, + line: 1, + column: 28, + offset: 27 + } + }, + Shape::Int(Position { + file: None, + line: 1, + column: 34, + offset: 33 + }) + ), + ( + Token { + typ: TokenType::BAREWORD, + fragment: "bar".into(), + pos: Position { + file: None, + line: 1, + column: 37, + offset: 36 + } + }, + Shape::Str(Position { + file: None, + line: 1, + column: 43, + offset: 42 + }) + ), + ] + }), + Shape::Tuple(PositionedItem { + pos: Position { + file: None, + line: 1, + column: 61, + offset: 60 + }, + val: vec![( + Token { + typ: TokenType::BAREWORD, + fragment: "bar".into(), + pos: Position { + file: None, + line: 1, + column: 63, + offset: 62 + } + }, + Shape::Str(Position { + file: None, + line: 1, + column: 69, + offset: 68 + }) + )] + }) + ] + }) + ); + assert_type_success!( + r#"select (foo) => { true = [ "quux" ], false = [ 1 ] };"#, + Shape::Narrowed(NarrowedShape { + pos: Position { + file: None, + line: 1, + column: 1, + offset: 0 + }, + types: vec![ + Shape::List(NarrowedShape { + pos: Position { + file: None, + line: 1, + column: 26, + offset: 25 + }, + types: vec![Shape::Str(Position { + file: None, + line: 1, + column: 28, + offset: 27 + })] + }), + Shape::List(NarrowedShape { + pos: Position { + file: None, + line: 1, + column: 46, + offset: 45 + }, + types: vec![Shape::Int(Position { + file: None, + line: 1, + column: 48, + offset: 47 + })] + }) + ] + }) + ); +} + +fn parse_expression(expr: &str) -> Option { + let tokens = tokenize(expr.into(), None).unwrap(); + let token_iter = SliceIter::new(&tokens); + let expr = expression(token_iter); + if let abortable_parser::Result::Complete(_, expr) = expr { + return Some(expr) + } + None +} + +#[test] +fn func_type_equivalence() { + let mut symbol_table = BTreeMap::new(); + let expr1 = "func(arg1) => arg1 + 1;"; + let expr2 = "func(arg2) => arg2 + 1;"; + let shape1 = parse_expression(expr1).unwrap().derive_shape(&mut symbol_table); + let shape2 = parse_expression(expr2).unwrap().derive_shape(&mut symbol_table); + assert!(dbg!(shape1.equivalent(&shape2, &mut symbol_table))); +}