From 188907d807fb4a24b73a86773c3b6ee6653483da Mon Sep 17 00:00:00 2001 From: Jeremy Wall Date: Tue, 10 Oct 2023 22:01:53 -0400 Subject: [PATCH] feat: Infer tuple field shapes from DOT operator --- src/ast/mod.rs | 30 +++++++++----- src/ast/test.rs | 2 +- src/ast/typecheck/mod.rs | 83 +++++++++++++++++++++++++++++++------- src/ast/typecheck/test.rs | 84 +++++++++++++++++++++++---------------- 4 files changed, 137 insertions(+), 62 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index c9a94eb..7102c77 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -145,6 +145,12 @@ impl Token { } } +impl<'a> From<&'a Token> for PositionedItem> { + fn from(value: &'a Token) -> Self { + Self::new(value.fragment.clone(), value.pos.clone()) + } +} + impl abortable_parser::Positioned for Token { fn line(&self) -> usize { self.pos.line @@ -219,7 +225,7 @@ macro_rules! make_expr { /// This is usually used as the body of a tuple in the UCG AST. pub type FieldList = Vec<(Token, Expression)>; // Token is expected to be a symbol -pub type TupleShape = Vec<(Token, Shape)>; +pub type TupleShape = Vec<(PositionedItem>, Shape)>; pub type ShapeList = Vec; #[derive(PartialEq, Debug, Clone)] @@ -329,7 +335,7 @@ impl Shape { for (lt, ls) in left_slist.val.iter() { let mut found = false; for (rt, rs) in right_slist.val.iter() { - if lt.fragment == rt.fragment && ls.equivalent(rs, symbol_table) { + if lt.val == rt.val && ls.equivalent(rs, symbol_table) { found = true; } } @@ -347,12 +353,14 @@ impl Shape { let right_args: Vec<&Shape> = dbg!(right_opshape.args.values().collect()); for idx in 0..left_args.len() { let shap = left_args[idx]; - if !shap.equivalent(right_args[idx], symbol_table) - { + if !shap.equivalent(right_args[idx], symbol_table) { return false; } } - if !&left_opshape.ret.equivalent(&right_opshape.ret, symbol_table) { + if !&left_opshape + .ret + .equivalent(&right_opshape.ret, symbol_table) + { return false; } true @@ -402,8 +410,8 @@ impl Shape { fn narrow_tuple_shapes( &self, - left_slist: &PositionedItem>, - right_slist: &PositionedItem>, + left_slist: &PositionedItem>, Shape)>>, + right_slist: &PositionedItem>, Shape)>>, right: &Shape, symbol_table: &mut BTreeMap, Shape>, ) -> Shape { @@ -495,15 +503,15 @@ impl Shape { } fn is_tuple_subset( - mut left_iter: std::slice::Iter<(Token, Shape)>, - right_slist: &PositionedItem>, + mut left_iter: std::slice::Iter<(PositionedItem>, Shape)>, + right_slist: &PositionedItem>, Shape)>>, symbol_table: &mut BTreeMap, Shape>, ) -> bool { return loop { if let Some((lt, ls)) = left_iter.next() { let mut matched = false; for (rt, rs) in right_slist.val.iter() { - if rt.fragment == lt.fragment { + if rt.val == lt.val { if let Shape::TypeErr(_, _) = ls.narrow(rs, symbol_table) { // noop } else { @@ -927,7 +935,7 @@ impl ModuleDef { self.out_expr = Some(Box::new(expr)); } - pub fn derive_shape(&mut self, expr: Expression) { + pub fn derive_shape(&self, symbol_table: &mut BTreeMap, Shape>) -> Shape { todo!() } } diff --git a/src/ast/test.rs b/src/ast/test.rs index 0665368..dbbe471 100644 --- a/src/ast/test.rs +++ b/src/ast/test.rs @@ -57,7 +57,7 @@ fn derive_shape_values() { )), Shape::Tuple(PositionedItem::new( vec![( - Token::new("foo", TokenType::BAREWORD, Position::new(0, 0, 0)), + PositionedItem::new("foo".into(), Position::new(0, 0, 0)), Shape::Int(Position::new(0, 0, 0)), )], Position::new(0, 0, 0), diff --git a/src/ast/typecheck/mod.rs b/src/ast/typecheck/mod.rs index 11ecbb6..f467410 100644 --- a/src/ast/typecheck/mod.rs +++ b/src/ast/typecheck/mod.rs @@ -14,7 +14,6 @@ //! Implements typechecking for the parsed ucg AST. use std::collections::BTreeMap; -use std::default; use std::rc::Rc; use crate::ast::walk::Visitor; @@ -24,8 +23,8 @@ use crate::ast::{ use crate::error::{BuildError, ErrorType}; use super::{ - CastType, CopyDef, FuncDef, ImportShape, ModuleShape, NarrowedShape, NotDef, PositionedItem, - SelectDef, Position, + BinaryExprType, BinaryOpDef, CastType, CopyDef, FuncDef, ImportShape, ModuleShape, + NarrowedShape, NotDef, Position, PositionedItem, SelectDef, }; /// Trait for shape derivation. @@ -70,8 +69,16 @@ impl DeriveShape for FuncDef { impl DeriveShape for SelectDef { fn derive_shape(&self, symbol_table: &mut BTreeMap, Shape>) -> Shape { - let SelectDef { val: _, default: _, tuple, pos: _ } = self; - let mut narrowed_shape = NarrowedShape { pos: self.pos.clone(), types: Vec::with_capacity(tuple.len()) }; + let SelectDef { + val: _, + default: _, + tuple, + pos: _, + } = self; + let mut narrowed_shape = NarrowedShape { + pos: self.pos.clone(), + types: Vec::with_capacity(tuple.len()), + }; for (_, expr) in tuple { let shape = expr.derive_shape(symbol_table); narrowed_shape.merge_in_shape(shape, symbol_table); @@ -178,8 +185,8 @@ fn derive_copy_shape(def: &CopyDef, symbol_table: &mut BTreeMap, Shape>) .map(|(tok, expr)| (tok.fragment.clone(), expr.derive_shape(symbol_table))) .collect::, Shape>>(); // 1. Do our copyable fields have the right names and shapes based on mdef.items. - for (tok, shape) in mdef.items.iter() { - if let Some(s) = arg_fields.get(&tok.fragment) { + for (sym, shape) in mdef.items.iter() { + if let Some(s) = arg_fields.get(&sym.val) { if let Shape::TypeErr(pos, msg) = shape.narrow(s, symbol_table) { return Shape::TypeErr(pos, msg); } @@ -193,7 +200,7 @@ fn derive_copy_shape(def: &CopyDef, symbol_table: &mut BTreeMap, Shape>) base_fields.val.extend( def.fields .iter() - .map(|(tok, expr)| (tok.clone(), expr.derive_shape(symbol_table))), + .map(|(tok, expr)| (tok.into(), expr.derive_shape(symbol_table))), ); Shape::Tuple(base_fields).with_pos(def.pos.clone()) } @@ -206,7 +213,7 @@ fn derive_copy_shape(def: &CopyDef, symbol_table: &mut BTreeMap, Shape>) base_fields.extend( def.fields .iter() - .map(|(tok, expr)| (tok.clone(), expr.derive_shape(symbol_table))), + .map(|(tok, expr)| (tok.into(), expr.derive_shape(symbol_table))), ); Shape::Tuple(PositionedItem::new(base_fields, def.pos.clone())) } @@ -237,7 +244,48 @@ impl DeriveShape for Expression { Expression::Binary(def) => { let left_shape = def.left.derive_shape(symbol_table); let right_shape = def.right.derive_shape(symbol_table); - left_shape.narrow(&right_shape, symbol_table) + // We need to do somethig different if it's a ShapeKind::DOT + if def.kind == BinaryExprType::DOT { + dbg!(&def); + // left_shape can be assumed to be of type tuple. + // If left_shape is not known it can be inferred to be a tuple with right + // shapes symbol as a field name. + if let Shape::Hole(p) = left_shape { + dbg!(&p); + if let Shape::Hole(pi) = right_shape { + dbg!(&pi); + let derived_shape = Shape::Tuple(PositionedItem::new( + // TODO(jeremy): This needs to be a token... + vec![( + pi.into(), + Shape::Narrowed(NarrowedShape { + pos: p.pos.clone(), + types: Vec::new(), + }), + )], + p.pos.clone(), + )); + symbol_table.insert(p.val.clone(), derived_shape); + return Shape::Narrowed(NarrowedShape { + pos: p.pos.clone(), + types: Vec::new(), + }); + } + } else if let Shape::Tuple(fields_pi) = left_shape { + dbg!(&fields_pi); + if let Shape::Hole(pi) = right_shape { + dbg!(&pi); + for (sym, shape) in fields_pi.val { + if pi.val == sym.val { + return shape; + } + } + } + } + Shape::TypeErr(def.pos.clone(), "Invalid Tuple field selector".to_owned()) + } else { + left_shape.narrow(&right_shape, symbol_table) + } } Expression::Copy(def) => derive_copy_shape(def, symbol_table), Expression::Include(def) => derive_include_shape(def), @@ -267,9 +315,7 @@ impl DeriveShape for Value { Shape::Hole(p.clone()) } } - Value::Tuple(flds) => { - derive_field_list_shape(&flds.val, &flds.pos, symbol_table) - } + Value::Tuple(flds) => derive_field_list_shape(&flds.val, &flds.pos, symbol_table), Value::List(flds) => { let mut field_shapes = Vec::new(); for f in &flds.elems { @@ -281,10 +327,17 @@ impl DeriveShape for Value { } } -fn derive_field_list_shape(flds: &Vec<(super::Token, Expression)>, pos: &Position, symbol_table: &mut BTreeMap, Shape>) -> Shape { +fn derive_field_list_shape( + flds: &Vec<(super::Token, Expression)>, + pos: &Position, + symbol_table: &mut BTreeMap, Shape>, +) -> Shape { let mut field_shapes = Vec::new(); for &(ref tok, ref expr) in flds { - field_shapes.push((tok.clone(), expr.derive_shape(symbol_table))); + field_shapes.push(( + PositionedItem::new(tok.fragment.clone(), tok.pos.clone()), + expr.derive_shape(symbol_table), + )); } Shape::Tuple(PositionedItem::new(field_shapes, pos.clone())) } diff --git a/src/ast/typecheck/test.rs b/src/ast/typecheck/test.rs index 7fc6fb1..bce719a 100644 --- a/src/ast/typecheck/test.rs +++ b/src/ast/typecheck/test.rs @@ -1,10 +1,9 @@ use std::convert::Into; -use abortable_parser::SliceIter; +use abortable_parser::{Positioned, SliceIter}; use crate::ast::walk::Walker; use crate::ast::{Position, PositionedItem}; -use crate::ast::{Token, TokenType}; use crate::parse::{expression, parse}; use crate::tokenizer::tokenize; @@ -57,11 +56,7 @@ fn simple_binary_typecheck() { "{foo = 1} + {foo = 1};", Shape::Tuple(PositionedItem::new( vec![( - Token { - typ: TokenType::BAREWORD, - fragment: "foo".into(), - pos: Position::new(1, 2, 1) - }, + PositionedItem::new("foo".into(), Position::new(1, 2, 1)), Shape::Int(Position::new(1, 8, 7)) ),], Position::new(1, 1, 0) @@ -192,6 +187,26 @@ fn infer_symbol_type_test() { Shape::Float(Position::new(0, 0, 0)), ], ), + ( + "bar.foo", + vec![( + bar.clone(), + Shape::Tuple(PositionedItem::new( + vec![( + PositionedItem::new(foo.clone(), Position::new(1, 5, 4)), + Shape::Narrowed(NarrowedShape { + pos: Position::new(0, 0, 0), + types: Vec::new(), + }), + )], + Position::new(0, 0, 0), + )), + )], + vec![Shape::Hole(PositionedItem::new( + bar.clone(), + Position::new(0, 0, 0), + ))], + ), ]; for (expr, sym_list, sym_init_list) in table { infer_symbol_test!(expr, sym_list, sym_init_list) @@ -337,16 +352,15 @@ fn infer_select_shape() { offset: 25 }, val: vec![( - Token { - typ: TokenType::BAREWORD, - fragment: "foo".into(), - pos: Position { + PositionedItem::new( + "foo".into(), + Position { file: None, line: 1, column: 28, offset: 27 } - }, + ), Shape::Int(Position { file: None, line: 1, @@ -363,16 +377,15 @@ fn infer_select_shape() { offset: 47 }, val: vec![( - Token { - typ: TokenType::BAREWORD, - fragment: "bar".into(), - pos: Position { + PositionedItem::new( + "bar".into(), + Position { file: None, line: 1, column: 50, offset: 49 } - }, + ), Shape::Str(Position { file: None, line: 1, @@ -403,16 +416,15 @@ fn infer_select_shape() { }, val: vec![ ( - Token { - typ: TokenType::BAREWORD, - fragment: "foo".into(), - pos: Position { + PositionedItem::new( + "foo".into(), + Position { file: None, line: 1, column: 28, offset: 27 } - }, + ), Shape::Int(Position { file: None, line: 1, @@ -421,16 +433,15 @@ fn infer_select_shape() { }) ), ( - Token { - typ: TokenType::BAREWORD, - fragment: "bar".into(), - pos: Position { + PositionedItem::new( + "bar".into(), + Position { file: None, line: 1, column: 37, offset: 36 } - }, + ), Shape::Str(Position { file: None, line: 1, @@ -448,16 +459,15 @@ fn infer_select_shape() { offset: 60 }, val: vec![( - Token { - typ: TokenType::BAREWORD, - fragment: "bar".into(), - pos: Position { + PositionedItem::new( + "bar".into(), + Position { file: None, line: 1, column: 63, offset: 62 } - }, + ), Shape::Str(Position { file: None, line: 1, @@ -517,7 +527,7 @@ fn parse_expression(expr: &str) -> Option { let token_iter = SliceIter::new(&tokens); let expr = expression(token_iter); if let abortable_parser::Result::Complete(_, expr) = expr { - return Some(expr) + return Some(expr); } None } @@ -527,7 +537,11 @@ fn func_type_equivalence() { let mut symbol_table = BTreeMap::new(); let expr1 = "func(arg1) => arg1 + 1;"; let expr2 = "func(arg2) => arg2 + 1;"; - let shape1 = parse_expression(expr1).unwrap().derive_shape(&mut symbol_table); - let shape2 = parse_expression(expr2).unwrap().derive_shape(&mut symbol_table); + let shape1 = parse_expression(expr1) + .unwrap() + .derive_shape(&mut symbol_table); + let shape2 = parse_expression(expr2) + .unwrap() + .derive_shape(&mut symbol_table); assert!(dbg!(shape1.equivalent(&shape2, &mut symbol_table))); }