diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 6aa56b3..36b87aa 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -271,6 +271,81 @@ value_enum!( Module(ModuleShapeDef), ); +impl Shape { + pub fn merge(&self, compare: &Shape) -> Option { + match (self, compare) { + (Shape::Str(_), Shape::Str(_)) + | (Shape::Symbol(_), Shape::Symbol(_)) + | (Shape::Boolean(_), Shape::Boolean(_)) + | (Shape::Empty(_), Shape::Empty(_)) + | (Shape::Int(_), Shape::Int(_)) + | (Shape::Float(_), Shape::Float(_)) => Some(self.clone()), + (Shape::List(left_slist), Shape::List(right_slist)) => { + // TODO + unimplemented!("Can't merge these yet.") + } + (Shape::Tuple(left_slist), Shape::Tuple(right_slist)) => { + // TODO + unimplemented!("Can't merge these yet.") + } + (Shape::Func(left_opshape), Shape::Func(right_opshape)) => { + // TODO + unimplemented!("Can't merge these yet.") + } + (Shape::Module(left_opshape), Shape::Module(right_opshape)) => { + // TODO + unimplemented!("Can't merge these yet.") + } + _ => None, + } + } + + pub fn type_name(&self) -> &'static str { + match self { + Shape::Str(s) => "str", + Shape::Symbol(s) => "symbol", + Shape::Int(s) => "int", + Shape::Float(s) => "float", + Shape::Boolean(b) => "boolean", + Shape::Empty(p) => "nil", + // TODO(jwall): make these type names account for what they contain. + Shape::List(lst) => "list", + Shape::Tuple(flds) => "tuple", + Shape::Func(_) => "func", + Shape::Module(_) => "module", + } + } + + pub fn pos(&self) -> &Position { + match self { + Shape::Str(s) => &s.pos, + Shape::Symbol(s) => &s.pos, + Shape::Int(s) => &s.pos, + Shape::Float(s) => &s.pos, + Shape::Boolean(b) => &b.pos, + Shape::Empty(p) => p, + Shape::List(lst) => &lst.pos, + Shape::Tuple(flds) => &flds.pos, + Shape::Func(def) => def.ret.pos(), + Shape::Module(def) => def.ret.pos(), + } + } + + pub fn with_pos(self, pos: Position) -> Self { + match self { + Shape::Str(s) => Shape::Str(PositionedItem::new(s.val, pos)), + Shape::Symbol(s) => Shape::Symbol(PositionedItem::new(s.val, pos)), + Shape::Int(s) => Shape::Int(PositionedItem::new(s.val, pos)), + Shape::Float(s) => Shape::Float(PositionedItem::new(s.val, pos)), + Shape::Boolean(b) => Shape::Boolean(PositionedItem::new(b.val, pos)), + Shape::Empty(p) => Shape::Empty(pos), + Shape::List(lst) => Shape::List(PositionedItem::new(lst.val, pos)), + Shape::Tuple(flds) => Shape::Tuple(PositionedItem::new(flds.val, pos)), + Shape::Func(_) | Shape::Module(_) => self.clone(), + } + } +} + impl Value { /// Returns the type name of the Value it is called on as a string. pub fn type_name(&self) -> String { diff --git a/src/ast/rewrite.rs b/src/ast/rewrite.rs index ab0b37a..9fd33c8 100644 --- a/src/ast/rewrite.rs +++ b/src/ast/rewrite.rs @@ -53,5 +53,3 @@ impl Visitor for Rewriter { } } } - -impl Walker for Rewriter {} diff --git a/src/ast/typecheck.rs b/src/ast/typecheck.rs deleted file mode 100644 index 8fc5c39..0000000 --- a/src/ast/typecheck.rs +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2020 Jeremy Wall -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Implements typechecking for the parsed ucg AST. - -use std::collections::BTreeMap; - -use crate::ast::walk::Visitor; -use crate::ast::{Expression, Shape, Statement, Value}; - -use Expression::{ - Binary, Call, Cast, Copy, Debug, Fail, Format, Func, FuncOp, Grouped, Import, Include, Module, - Not, Range, Select, Simple, -}; -use Statement::Let; -use Value::{Boolean, Empty, Float, Int, List, Str, Symbol, Tuple}; - -pub struct Checker { - symbol_table: BTreeMap, -} - -impl Visitor for Checker { - fn visit_import(&mut self, _i: &mut super::ImportDef) { - // noop by default; - } - fn visit_include(&mut self, _i: &mut super::IncludeDef) { - // noop by default; - } - fn visit_fail(&mut self, _f: &mut super::FailDef) { - // noop by default; - } - fn visit_value(&mut self, _val: &mut Value) { - // noop by default - } - fn visit_expression(&mut self, _expr: &mut Expression) { - // noop by default - } - fn visit_statement(&mut self, _stmt: &mut Statement) { - // noop by default - } -} diff --git a/src/ast/typecheck/mod.rs b/src/ast/typecheck/mod.rs new file mode 100644 index 0000000..a2b8717 --- /dev/null +++ b/src/ast/typecheck/mod.rs @@ -0,0 +1,148 @@ +// Copyright 2020 Jeremy Wall +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Implements typechecking for the parsed ucg AST. + +use std::collections::BTreeMap; + +use crate::ast::walk::Visitor; +use crate::ast::{Expression, FailDef, ImportDef, IncludeDef, Position, Shape, Statement, Value}; +use crate::error::{BuildError, ErrorType}; + +use Expression::{ + Binary, Call, Cast, Copy, Debug, Fail, Format, Func, FuncOp, Grouped, Import, Include, Module, + Not, Range, Select, Simple, +}; +use Statement::Let; +use Value::{Boolean, Empty, Float, Int, List, Str, Symbol, Tuple}; + +pub struct Checker { + symbol_table: BTreeMap, + err_stack: Vec, + shape_stack: Vec, +} + +impl Checker { + pub fn new() -> Self { + return Self { + symbol_table: BTreeMap::new(), + err_stack: Vec::new(), + shape_stack: Vec::new(), + }; + } + + pub fn result(mut self) -> Result, BuildError> { + if let Some(err) = self.err_stack.pop() { + Err(err) + } else { + Ok(self.symbol_table) + } + } +} + +impl Visitor for Checker { + fn visit_import(&mut self, _i: &mut ImportDef) { + // noop by default; + } + + fn leave_import(&mut self) { + // noop by default + } + + fn visit_include(&mut self, _i: &mut IncludeDef) { + // noop by default; + } + + fn leave_include(&mut self) { + // noop by default + } + + fn visit_fail(&mut self, _f: &mut FailDef) { + // noop by default; + } + + fn leave_fail(&mut self) { + // noop by default + } + + fn visit_value(&mut self, val: &mut Value) { + // noop by default + // TODO(jwall): Some values can contain expressions. Handle those here. + match val { + Value::Empty(p) => self.shape_stack.push(Shape::Empty(p.clone())), + Value::Boolean(p) => self.shape_stack.push(Shape::Boolean(p.clone())), + Value::Int(p) => self.shape_stack.push(Shape::Int(p.clone())), + Value::Float(p) => self.shape_stack.push(Shape::Float(p.clone())), + Value::Str(p) => self.shape_stack.push(Shape::Str(p.clone())), + // Symbols in a shape are placeholders. They allow a form of genericity + // in the shape. They can be any type and are only refined down. + // by their presence in an expression. + Value::Symbol(p) => self.shape_stack.push(Shape::Symbol(p.clone())), + Value::List(_) => { + // noop + } + Value::Tuple(_) => { + // noop + } + } + } + + fn leave_value(&mut self, _val: &Value) { + // noop by default + } + + fn visit_expression(&mut self, _expr: &mut Expression) { + // noop by default + } + + fn leave_expression(&mut self, expr: &Expression) { + match expr { + Expression::Binary(_) => { + // Collapse the two shapes in the stack into one shape for this expression. + if let Some(right) = self.shape_stack.pop() { + if let Some(left) = self.shape_stack.pop() { + if let Some(shape) = left.merge(&right) { + // Then give them a new position + self.shape_stack.push(shape.with_pos(expr.pos().clone())); + } else { + self.err_stack.push(BuildError::with_pos( + format!( + "Expected {} but got {}", + left.type_name(), + right.type_name() + ), + ErrorType::TypeFail, + right.pos().clone(), + )); + } + } + } + } + _ => { + // TODO + } + } + } + + fn visit_statement(&mut self, _stmt: &mut Statement) { + // noop by default + } + + fn leave_statement(&mut self, stmt: &Statement) { + // noop by default + } +} + +#[cfg(test)] +mod test; diff --git a/src/ast/typecheck/test.rs b/src/ast/typecheck/test.rs new file mode 100644 index 0000000..2368b13 --- /dev/null +++ b/src/ast/typecheck/test.rs @@ -0,0 +1,47 @@ +use std::convert::Into; + +use crate::ast::walk::Walker; +use crate::ast::Position; +use crate::parse; + +use super::Checker; + +#[test] +fn simple_binary_typecheck() { + let mut checker = Checker::new(); + let expr_str = "1 + 1;"; + let mut expr = parse(expr_str.into(), None).unwrap(); + checker.walk_statement_list(expr.iter_mut().collect()); + let result = checker.result(); + assert!(result.is_ok(), "We expect this to typecheck successfully."); + assert!( + result.unwrap().is_empty(), + "We don't expect a symbol table entry." + ); +} + +#[test] +fn simple_binary_typefail() { + let mut checker = Checker::new(); + let expr_str = "1 + true;"; + let mut expr = parse(expr_str.into(), None).unwrap(); + checker.walk_statement_list(expr.iter_mut().collect()); + let result = checker.result(); + assert!(result.is_err(), "We expect this to fail a typecheck."); + let err = result.unwrap_err(); + assert_eq!(err.msg, "Expected int but got boolean"); + assert_eq!(err.pos.unwrap(), Position::new(1, 5, 4)); +} + +#[test] +fn multiple_binary_typefail() { + let mut checker = Checker::new(); + let expr_str = "1 + 1 + true;"; + let mut expr = parse(expr_str.into(), None).unwrap(); + checker.walk_statement_list(expr.iter_mut().collect()); + let result = checker.result(); + assert!(result.is_err(), "We expect this to fail a typecheck."); + let err = result.unwrap_err(); + assert_eq!(err.msg, "Expected int but got boolean"); + assert_eq!(err.pos.unwrap(), Position::new(1, 9, 8)); +} diff --git a/src/ast/walk.rs b/src/ast/walk.rs index f76c0fb..4175741 100644 --- a/src/ast/walk.rs +++ b/src/ast/walk.rs @@ -1,30 +1,53 @@ use crate::ast::*; pub trait Visitor { - // TODO(jwall): Should this have exit versions as well? fn visit_import(&mut self, _i: &mut ImportDef) { // noop by default; } + fn leave_import(&mut self) { + // noop by default + } + fn visit_include(&mut self, _i: &mut IncludeDef) { // noop by default; } + fn leave_include(&mut self) { + // noop by default + } + fn visit_fail(&mut self, _f: &mut FailDef) { // noop by default; } + fn leave_fail(&mut self) { + // noop by default + } + fn visit_value(&mut self, _val: &mut Value) { // noop by default } + fn leave_value(&mut self, _val: &Value) { + // noop by default + } + fn visit_expression(&mut self, _expr: &mut Expression) { // noop by default } + fn leave_expression(&mut self, _expr: &Expression) { + // noop by default + } + fn visit_statement(&mut self, _stmt: &mut Statement) { // noop by default } + + fn leave_statement(&mut self, _stmt: &Statement) { + // noop by default + } } pub trait Walker: Visitor { @@ -53,6 +76,7 @@ pub trait Walker: Visitor { self.walk_expression(expr); } } + self.leave_statement(stmt); } fn walk_fieldset(&mut self, fs: &mut FieldList) { @@ -136,12 +160,15 @@ pub trait Walker: Visitor { Expression::Import(i) => { self.visit_import(i); + self.leave_import(); } Expression::Include(i) => { self.visit_include(i); + self.leave_include(); } Expression::Fail(f) => { self.visit_fail(f); + self.leave_fail(); } Expression::Not(ref mut def) => { self.walk_expression(def.expr.as_mut()); @@ -150,6 +177,7 @@ pub trait Walker: Visitor { self.walk_expression(&mut def.expr); } } + self.leave_expression(expr); } fn walk_value(&mut self, val: &mut Value) { @@ -159,7 +187,10 @@ pub trait Walker: Visitor { | Value::Boolean(_) | Value::Int(_) | Value::Float(_) - | Value::Str(_) => self.visit_value(val), + | Value::Str(_) => { + self.visit_value(val); + self.leave_value(val); + } Value::Tuple(fs) => self.walk_fieldset(&mut fs.val), Value::List(vs) => { for e in &mut vs.elems { @@ -170,6 +201,8 @@ pub trait Walker: Visitor { } } +impl Walker for T where T: Visitor {} + pub struct ChainedWalk { pub visitor_1: Visitor1, pub visitor_2: Visitor2, @@ -197,31 +230,53 @@ where self.visitor_1.visit_import(i); self.visitor_2.visit_import(i); } + fn leave_import(&mut self) { + self.visitor_1.leave_import(); + self.visitor_2.leave_import(); + } + fn visit_include(&mut self, i: &mut IncludeDef) { self.visitor_1.visit_include(i); self.visitor_2.visit_include(i); } + fn leave_include(&mut self) { + self.visitor_1.leave_include(); + self.visitor_2.leave_include(); + } + fn visit_fail(&mut self, f: &mut FailDef) { self.visitor_1.visit_fail(f); self.visitor_2.visit_fail(f); } + fn leave_fail(&mut self) { + self.visitor_1.leave_fail(); + self.visitor_2.leave_fail(); + } + fn visit_value(&mut self, val: &mut Value) { self.visitor_1.visit_value(val); self.visitor_2.visit_value(val); } + fn leave_value(&mut self, val: &Value) { + self.visitor_1.leave_value(val); + self.visitor_2.leave_value(val); + } + fn visit_expression(&mut self, expr: &mut Expression) { self.visitor_1.visit_expression(expr); self.visitor_2.visit_expression(expr); } + fn leave_expression(&mut self, expr: &Expression) { + self.visitor_1.leave_expression(expr); + self.visitor_2.leave_expression(expr); + } + fn visit_statement(&mut self, stmt: &mut Statement) { self.visitor_1.visit_statement(stmt); self.visitor_2.visit_statement(stmt); } -} - -impl Walker for ChainedWalk -where - Visitor1: Visitor, - Visitor2: Visitor, -{ + fn leave_statement(&mut self, stmt: &Statement) { + self.visitor_1.leave_statement(stmt); + self.visitor_2.leave_statement(stmt); + } }