diff --git a/src/ast/mod.rs b/src/ast/mod.rs index af05769..f72340d 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -224,7 +224,7 @@ pub type ShapeList = Vec; #[derive(PartialEq, Debug, Clone)] pub struct FuncShapeDef { - args: Vec, + args: BTreeMap, Shape>, ret: Box, } @@ -278,7 +278,6 @@ impl NarrowedShape { /// Shapes represent the types that UCG values or expressions can have. #[derive(PartialEq, Debug, Clone)] pub enum Shape { - Empty(Position), Boolean(PositionedItem), Int(PositionedItem), Float(PositionedItem), @@ -294,20 +293,28 @@ pub enum Shape { } impl Shape { - pub fn narrow(&self, right: &Shape) -> Self { + pub fn narrow(&self, right: &Shape, symbol_table: &mut BTreeMap, Shape>) -> Self { + dbg!((self, right)); match (self, right) { (Shape::Str(_), Shape::Str(_)) | (Shape::Boolean(_), Shape::Boolean(_)) - | (Shape::Empty(_), Shape::Empty(_)) | (Shape::Int(_), Shape::Int(_)) | (Shape::Float(_), Shape::Float(_)) => self.clone(), - (Shape::Hole(_), other) | (other, Shape::Hole(_)) => other.clone(), + (Shape::Hole(sym), other) | (other, Shape::Hole(sym)) => { + if symbol_table.contains_key(&sym.val) { + symbol_table.insert(sym.val.clone(), other.clone().with_pos(sym.pos.clone())); + } else { + // TODO(jwall): Is this an error? + todo!(); + } + other.clone() + }, (Shape::Narrowed(left_slist), Shape::Narrowed(right_slist)) | (Shape::List(left_slist), Shape::List(right_slist)) => { - self.narrow_list_shapes(left_slist, right_slist, right) + self.narrow_list_shapes(left_slist, right_slist, right, symbol_table) } (Shape::Tuple(left_slist), Shape::Tuple(right_slist)) => { - self.narrow_tuple_shapes(left_slist, right_slist, right) + self.narrow_tuple_shapes(left_slist, right_slist, right, symbol_table) } (Shape::Func(left_opshape), Shape::Func(right_opshape)) => { todo!(); @@ -331,12 +338,13 @@ impl Shape { left_slist: &PositionedItem>, right_slist: &PositionedItem>, right: &Shape, + symbol_table: &mut BTreeMap, Shape>, ) -> Shape { let left_iter = left_slist.val.iter(); let right_iter = right_slist.val.iter(); - if is_tuple_subset(left_iter, right_slist) { + if is_tuple_subset(left_iter, right_slist, symbol_table) { self.clone() - } else if is_tuple_subset(right_iter, left_slist) { + } else if is_tuple_subset(right_iter, left_slist, symbol_table) { right.clone() } else { Shape::TypeErr(right.pos().clone(), "Incompatible Tuple Shapes".to_owned()) @@ -348,12 +356,13 @@ impl Shape { left_slist: &NarrowedShape, right_slist: &NarrowedShape, right: &Shape, + symbol_table: &mut BTreeMap, Shape>, ) -> Shape { let left_iter = left_slist.types.iter(); let right_iter = right_slist.types.iter(); - if is_list_subset(left_iter, right_slist) { + if is_list_subset(left_iter, right_slist, symbol_table) { self.clone() - } else if is_list_subset(right_iter, left_slist) { + } else if is_list_subset(right_iter, left_slist, symbol_table) { right.clone() } else { Shape::TypeErr(right.pos().clone(), "Incompatible List Shapes".to_owned()) @@ -366,7 +375,6 @@ impl Shape { Shape::Int(s) => "int", Shape::Float(s) => "float", Shape::Boolean(b) => "boolean", - Shape::Empty(p) => "nil", // TODO(jwall): make these type names account for what they contain. Shape::List(lst) => "list", Shape::Tuple(flds) => "tuple", @@ -385,7 +393,6 @@ impl Shape { Shape::Int(s) => &s.pos, Shape::Float(s) => &s.pos, Shape::Boolean(b) => &b.pos, - Shape::Empty(p) => p, Shape::List(lst) => &lst.pos, Shape::Tuple(flds) => &flds.pos, Shape::Func(def) => def.ret.pos(), @@ -404,7 +411,6 @@ impl Shape { Shape::Int(s) => Shape::Int(PositionedItem::new(s.val, pos)), Shape::Float(s) => Shape::Float(PositionedItem::new(s.val, pos)), Shape::Boolean(b) => Shape::Boolean(PositionedItem::new(b.val, pos)), - Shape::Empty(p) => Shape::Empty(pos), Shape::List(lst) => Shape::List(NarrowedShape::new_with_pos(lst.types, pos)), Shape::Tuple(flds) => Shape::Tuple(PositionedItem::new(flds.val, pos)), Shape::Func(_) | Shape::Module(_) => self.clone(), @@ -424,13 +430,14 @@ impl Shape { fn is_tuple_subset( mut left_iter: std::slice::Iter<(Token, Shape)>, right_slist: &PositionedItem>, + symbol_table: &mut BTreeMap, Shape>, ) -> bool { return loop { if let Some((lt, ls)) = left_iter.next() { let mut matched = false; for (rt, rs) in right_slist.val.iter() { if rt.fragment == lt.fragment { - if let Shape::TypeErr(_, _) = ls.narrow(rs) { + if let Shape::TypeErr(_, _) = ls.narrow(rs, symbol_table) { // noop } else { matched = true; @@ -448,7 +455,11 @@ fn is_tuple_subset( }; } -fn is_list_subset(mut right_iter: std::slice::Iter, left_slist: &NarrowedShape) -> bool { +fn is_list_subset( + mut right_iter: std::slice::Iter, + left_slist: &NarrowedShape, + symbol_table: &mut BTreeMap, Shape>, +) -> bool { let right_subset = loop { let mut matches = false; let ls = if let Some(ls) = right_iter.next() { @@ -457,7 +468,7 @@ fn is_list_subset(mut right_iter: std::slice::Iter, left_slist: &Narrowed break true; }; for rs in left_slist.types.iter() { - let s = ls.narrow(rs); + let s = ls.narrow(rs, symbol_table); if let Shape::TypeErr(_, _) = s { // noop } else { diff --git a/src/ast/test.rs b/src/ast/test.rs index 9bb97a1..d529c3d 100644 --- a/src/ast/test.rs +++ b/src/ast/test.rs @@ -24,10 +24,6 @@ use crate::tokenizer::tokenize; #[test] fn derive_shape_values() { let value_cases = vec![ - ( - Value::Empty(Position::new(0, 0, 0)), - Shape::Empty(Position::new(0, 0, 0)), - ), ( Value::Boolean(PositionedItem::new(false, Position::new(0, 1, 2))), Shape::Boolean(PositionedItem::new(false, Position::new(0, 1, 2))), diff --git a/src/ast/typecheck/mod.rs b/src/ast/typecheck/mod.rs index 0fa7ab9..bb5975b 100644 --- a/src/ast/typecheck/mod.rs +++ b/src/ast/typecheck/mod.rs @@ -17,7 +17,9 @@ use std::collections::BTreeMap; use std::rc::Rc; use crate::ast::walk::Visitor; -use crate::ast::{Expression, FailDef, ImportDef, IncludeDef, Shape, Statement, Value}; +use crate::ast::{ + Expression, FailDef, FuncShapeDef, ImportDef, IncludeDef, Shape, Statement, Value, +}; use crate::error::{BuildError, ErrorType}; use super::{ @@ -31,16 +33,20 @@ pub trait DeriveShape { } impl DeriveShape for FuncDef { - fn derive_shape(&self, symbol_table: &mut BTreeMap, Shape>) -> Shape { + fn derive_shape(&self, _symbol_table: &mut BTreeMap, Shape>) -> Shape { // 1. First set up our symbols. - let _table = self + let mut table = self .argdefs .iter() .map(|sym| (sym.val.clone(), Shape::Hole(sym.clone()))) .collect::, Shape>>(); // 2.Then determine the shapes of those symbols in our expression. + let _shape = self.fields.derive_shape(&mut table); // 3. Finally determine what the return shape can be. - todo!(); + Shape::Func(FuncShapeDef { + args: table, + ret: todo!(), + }) } } @@ -87,8 +93,7 @@ fn derive_copy_shape(def: &CopyDef, symbol_table: &mut BTreeMap, Shape>) match &base_shape { // TODO(jwall): Should we allow a stack of these? Shape::TypeErr(_, _) => base_shape, - Shape::Empty(_) - | Shape::Boolean(_) + Shape::Boolean(_) | Shape::Int(_) | Shape::Float(_) | Shape::Str(_) @@ -103,7 +108,10 @@ fn derive_copy_shape(def: &CopyDef, symbol_table: &mut BTreeMap, Shape>) Shape::Tuple(PositionedItem::new(vec![], pi.pos.clone())), Shape::Module(ModuleShape { items: vec![], - ret: Box::new(Shape::Empty(pi.pos.clone())), + ret: Box::new(Shape::Narrowed(NarrowedShape { + pos: pi.pos.clone(), + types: vec![], + })), }), Shape::Import(ImportShape::Unresolved(pi.clone())), ], @@ -142,7 +150,7 @@ fn derive_copy_shape(def: &CopyDef, symbol_table: &mut BTreeMap, Shape>) // 1. Do our copyable fields have the right names and shapes based on mdef.items. for (tok, shape) in mdef.items.iter() { if let Some(s) = arg_fields.get(&tok.fragment) { - if let Shape::TypeErr(pos, msg) = shape.narrow(s) { + if let Shape::TypeErr(pos, msg) = shape.narrow(s, symbol_table) { return Shape::TypeErr(pos, msg); } } @@ -199,7 +207,7 @@ impl DeriveShape for Expression { Expression::Binary(def) => { let left_shape = def.left.derive_shape(symbol_table); let right_shape = def.right.derive_shape(symbol_table); - left_shape.narrow(&right_shape) + left_shape.narrow(&right_shape, symbol_table) } Expression::Copy(def) => derive_copy_shape(def, symbol_table), Expression::Include(def) => derive_include_shape(def), @@ -217,7 +225,7 @@ impl DeriveShape for Expression { impl DeriveShape for Value { fn derive_shape(&self, symbol_table: &mut BTreeMap, Shape>) -> Shape { match self { - Value::Empty(p) => Shape::Empty(p.clone()), + Value::Empty(p) => Shape::Narrowed(NarrowedShape::new_with_pos(vec![], p.clone())), Value::Boolean(p) => Shape::Boolean(p.clone()), Value::Int(p) => Shape::Int(p.clone()), Value::Float(p) => Shape::Float(p.clone()), @@ -310,7 +318,12 @@ impl Visitor for Checker { fn visit_value(&mut self, val: &mut Value) { match val { - Value::Empty(p) => self.shape_stack.push(Shape::Empty(p.clone())), + Value::Empty(p) => self + .shape_stack + .push(Shape::Narrowed(NarrowedShape::new_with_pos( + vec![], + p.clone(), + ))), Value::Boolean(p) => self.shape_stack.push(Shape::Boolean(p.clone())), Value::Int(p) => self.shape_stack.push(Shape::Int(p.clone())), Value::Float(p) => self.shape_stack.push(Shape::Float(p.clone())), diff --git a/src/ast/typecheck/test.rs b/src/ast/typecheck/test.rs index 2b1fb9d..24ca475 100644 --- a/src/ast/typecheck/test.rs +++ b/src/ast/typecheck/test.rs @@ -1,9 +1,12 @@ use std::convert::Into; -use crate::ast::{Token, TokenType}; +use abortable_parser::SliceIter; + use crate::ast::walk::Walker; use crate::ast::{Position, PositionedItem}; -use crate::parse; +use crate::ast::{Token, TokenType}; +use crate::parse::{expression, parse}; +use crate::tokenizer::tokenize; use super::*; @@ -63,10 +66,14 @@ fn simple_binary_typecheck() { assert_type_success!( "{foo = 1} + {foo = 1};", Shape::Tuple(PositionedItem::new( - vec![ - (Token { typ: TokenType::BAREWORD, fragment: "foo".into(), pos: Position::new(1, 2, 1)}, - Shape::Int(PositionedItem::new_with_pos(1, Position::new(1, 8, 7)))), - ], + vec![( + Token { + typ: TokenType::BAREWORD, + fragment: "foo".into(), + pos: Position::new(1, 2, 1) + }, + Shape::Int(PositionedItem::new_with_pos(1, Position::new(1, 8, 7))) + ),], Position::new(1, 1, 0) )) ); @@ -145,3 +152,25 @@ fn multiple_binary_typefail() { Position::new(1, 9, 8) ); } + +#[test] +fn infer_symbol_type_test() { + // foo should be determined to be int + let expr = "1 + foo".into(); + let symbol: Rc = "foo".into(); + let mut checker = Checker::new(); + checker + .symbol_table + .insert(symbol.clone(), Shape::Hole(PositionedItem::new(symbol.clone(), Position::new(0, 0, 0)))); + let tokens = tokenize(expr, None).unwrap(); + let token_iter = SliceIter::new(&tokens); + let expr = expression(token_iter); + if let abortable_parser::Result::Complete(_, mut expr) = expr { + checker.walk_expression(&mut expr); + dbg!(&checker.symbol_table); + assert_eq!( + checker.symbol_table[&symbol], + Shape::Int(PositionedItem::new(1, Position::new(0, 0, 0))) + ); + } +}