Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integer data type #7

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use ratatui::{
use tui_textarea::{Input, TextArea};

use crate::{
interpreter::{Interpreter, Stmt, Value},
interpreter::{Interpreter, Stmt},
parse::{Expr, Parser},
token::Tokenizer,
};
Expand Down Expand Up @@ -75,9 +75,9 @@ impl<'ta> App<'ta> {
Stmt::Assign(name, expr) => {
self.interpreter.define(
name,
Value::Num(self.interpreter.interpret_expr(&expr).unwrap_or_else(
|_| panic!("RC file: {} not found", &self.rc_file.display()),
)),
self.interpreter.interpret_expr(&expr).unwrap_or_else(|_| {
panic!("RC file: {} not found", &self.rc_file.display())
}),
);
}
_ => {}
Expand Down
184 changes: 149 additions & 35 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
error::Error,
f64::consts::{E, PI},
fmt::Display,
ops::Neg,
ops::{Add, Div, Mul, Neg, Rem, Sub},
};

use crate::{
Expand All @@ -25,16 +25,18 @@ pub enum Stmt {
Undef(Vec<String>),
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub enum Value {
Fn(Function),
Num(f64),
Float(f64),
Int(i64),
}

impl Value {
pub fn to_input(&self, name: &str) -> String {
match self {
Self::Num(num) => format!("let {} = {}", name, num),
Self::Float(float) => format!("let {} = {}", name, float),
Self::Int(int) => format!("let {} = {}", name, int),
Self::Fn(func) => format!(
"fn {}({}) {}",
name,
Expand All @@ -43,9 +45,27 @@ impl Value {
),
}
}

fn abs(&mut self) -> Self {
match self {
Value::Int(int) => Value::Int(int.abs()),
Value::Float(float) => Value::Float(float.abs()),
_ => unreachable!(),
}
}

fn pow(&self, exponent: Value) -> Value {
match (self, exponent) {
(Value::Int(lhs), Value::Int(rhs)) => Value::Int(lhs.pow(rhs as u32)),
(Value::Float(lhs), Value::Float(rhs)) => Value::Float(lhs.powf(rhs)),
(Value::Int(lhs), Value::Float(rhs)) => Value::Float((*lhs as f64).powf(rhs)),
(Value::Float(lhs), Value::Int(rhs)) => Value::Float(lhs.powi(rhs as i32)),
_ => panic!("Cannot sub non numeric types"),
}
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct Function {
closure: HashMap<String, Value>,
parameters: Vec<String>,
Expand All @@ -71,7 +91,7 @@ impl Function {
}
}

fn call(&self, args: Vec<Value>) -> Result<f64, InterpretError> {
fn call(&self, args: Vec<Value>) -> Result<Value, InterpretError> {
let mut interpreter = Interpreter::with_env(self.closure.clone());
args.into_iter()
.enumerate()
Expand All @@ -87,16 +107,16 @@ impl Interpreter {
}
}

pub fn interpret(&mut self, stmt: Stmt) -> Result<Option<f64>, InterpretError> {
pub fn interpret(&mut self, stmt: Stmt) -> Result<Option<Value>, InterpretError> {
match stmt {
Stmt::Assign(name, expr) => {
let val = self.interpret_expr(&expr)?;
self.env.insert(name, Value::Num(val));
self.env.insert(name, val.clone()); // Some way to remove this clone?
Ok(Some(val))
}
Stmt::Expr(expr) => {
let ans = self.interpret_expr(&expr)?;
self.env.insert("ans".to_string(), Value::Num(ans));
self.env.insert("ans".to_string(), ans.clone()); // Some way to remove this clone?
Ok(Some(ans))
}
Stmt::Fn(name, params, body) => {
Expand All @@ -117,8 +137,8 @@ impl Interpreter {

fn default_env() -> HashMap<String, Value> {
HashMap::from_iter([
("pi".to_string(), Value::Num(PI)),
("e".to_string(), Value::Num(E)),
("pi".to_string(), Value::Float(PI)),
("e".to_string(), Value::Float(E)),
])
}

Expand All @@ -133,9 +153,10 @@ impl Interpreter {
);
}

pub fn interpret_expr(&self, expr: &Expr) -> Result<f64, InterpretError> {
pub fn interpret_expr(&self, expr: &Expr) -> Result<Value, InterpretError> {
match expr {
Expr::Num(num) => Ok(*num),
Expr::Float(float) => Ok(Value::Float(*float)),
Expr::Int(int) => Ok(Value::Int(*int)),
Expr::Binary(left, operator, right) => {
let left = self.interpret_expr(left)?;
let right = self.interpret_expr(right)?;
Expand All @@ -152,9 +173,11 @@ impl Interpreter {
Expr::Abs(expr) => Ok(self.interpret_expr(expr)?.abs()),
Expr::Grouping(expr) => self.interpret_expr(expr),
Expr::Negative(expr) => Ok(self.interpret_expr(expr)?.neg()),
Expr::Exponent(base, exponent) => Ok(self
.interpret_expr(base)?
.powf(self.interpret_expr(exponent)?)),
Expr::Exponent(base, exponent) => {
let base = self.interpret_expr(base)?;
let exponent = self.interpret_expr(exponent)?;
Ok(base.pow(exponent))
}
Expr::Call(name, args) => {
if let Some(Value::Fn(func)) = self.env.get(name) {
if args.len() != func.arity {
Expand All @@ -166,7 +189,7 @@ impl Interpreter {
} else {
let mut vals = vec![];
for arg in args.iter() {
vals.push(Value::Num(self.interpret_expr(arg)?));
vals.push(self.interpret_expr(arg)?);
}
func.call(vals)
}
Expand All @@ -178,14 +201,18 @@ impl Interpreter {
if let Some(val) = self.env.get(var) {
match val {
Value::Fn(_) => Err(InterpretError::UnInvokedFunction(var.clone())),
Value::Num(num) => Ok(*num),
Value::Float(_) | Value::Int(_) => Ok(val.to_owned()),
}
} else {
Err(InterpretError::UnknownVariable(var.clone()))
}
}
Expr::Func(func, arg) => {
let arg = self.interpret_expr(arg)?;
let arg = match self.interpret_expr(arg)? {
Value::Float(float) => float,
Value::Int(int) => int as f64,
_ => unreachable!(),
};
let val = match func {
Func::Sin => arg.sin(),
Func::Sinh => arg.sinh(),
Expand Down Expand Up @@ -215,12 +242,16 @@ impl Interpreter {
Func::Fract => arg.fract(),
Func::Recip => arg.recip(),
};
Ok(val)
Ok(Value::Float(val))
}
}
.map(|n| {
if (n.round() - n).abs() < 1e-10 {
n.round()
if let Value::Float(n) = n {
if (n.round() - n).abs() < 1e-10 {
Value::Float(n.round())
} else {
Value::Float(n)
}
} else {
n
}
Expand Down Expand Up @@ -263,12 +294,95 @@ impl Display for InterpretError {
impl Display for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Num(num) => inner_write(num, f),
Self::Float(float) => inner_write(float, f),
Self::Int(int) => inner_write(int, f),
Self::Fn(func) => inner_write(func, f),
}
}
}

impl Add for Value {
type Output = Value;

fn add(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Value::Int(lhs), Value::Int(rhs)) => Value::Int(lhs + rhs),
(Value::Float(lhs), Value::Float(rhs)) => Value::Float(lhs + rhs),
(Value::Int(lhs), Value::Float(rhs)) => Value::Float(lhs as f64 + rhs),
(Value::Float(lhs), Value::Int(rhs)) => Value::Float(lhs + rhs as f64),
_ => panic!("Cannot add non numeric types"),
}
}
}

impl Sub for Value {
type Output = Value;

fn sub(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Value::Int(lhs), Value::Int(rhs)) => Value::Int(lhs - rhs),
(Value::Float(lhs), Value::Float(rhs)) => Value::Float(lhs - rhs),
(Value::Int(lhs), Value::Float(rhs)) => Value::Float(lhs as f64 - rhs),
(Value::Float(lhs), Value::Int(rhs)) => Value::Float(lhs - rhs as f64),
_ => panic!("Cannot sub non numeric types"),
}
}
}

impl Mul for Value {
type Output = Value;

fn mul(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Value::Int(lhs), Value::Int(rhs)) => Value::Int(lhs * rhs),
(Value::Float(lhs), Value::Float(rhs)) => Value::Float(lhs * rhs),
(Value::Int(lhs), Value::Float(rhs)) => Value::Float(lhs as f64 * rhs),
(Value::Float(lhs), Value::Int(rhs)) => Value::Float(lhs * rhs as f64),
_ => panic!("Cannot mul non numeric types"),
}
}
}

impl Div for Value {
type Output = Value;

fn div(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Value::Int(lhs), Value::Int(rhs)) => Value::Float(lhs as f64 / rhs as f64),
(Value::Float(lhs), Value::Float(rhs)) => Value::Float(lhs / rhs),
(Value::Int(lhs), Value::Float(rhs)) => Value::Float(lhs as f64 / rhs),
(Value::Float(lhs), Value::Int(rhs)) => Value::Float(lhs / rhs as f64),
_ => panic!("Cannot divide non numeric types"),
}
}
}

impl Neg for Value {
type Output = Value;

fn neg(self) -> Self::Output {
match &self {
Value::Int(int) => Value::Int(-*int),
Value::Float(float) => Value::Float(-*float),
_ => panic!("Cannot negate non numeric types"),
}
}
}

impl Rem for Value {
type Output = Value;

fn rem(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Value::Int(lhs), Value::Int(rhs)) => Value::Float(lhs as f64 % rhs as f64),
(Value::Float(lhs), Value::Float(rhs)) => Value::Float(lhs % rhs),
(Value::Int(lhs), Value::Float(rhs)) => Value::Float(lhs as f64 % rhs),
(Value::Float(lhs), Value::Int(rhs)) => Value::Float(lhs % rhs as f64),
_ => panic!("Cannot calculate rem of non numeric types"),
}
}
}

impl Display for Function {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}) {}", self.parameters.join(", "), self.body)
Expand All @@ -279,15 +393,15 @@ impl Display for Function {
mod tests {
use super::*;

fn check(expr: Expr, expected: Result<f64, InterpretError>) {
fn check(expr: Expr, expected: Result<Value, InterpretError>) {
let interpreter = Interpreter::new();
let res = interpreter.interpret_expr(&expr);
assert_eq!(res, expected);
}

fn check_with_vars(
expr: Expr,
expected: Result<f64, InterpretError>,
expected: Result<Value, InterpretError>,
env: HashMap<String, Value>,
) {
let interpreter = Interpreter::with_env(env);
Expand All @@ -299,32 +413,32 @@ mod tests {
fn simple_add() {
check(
Expr::Binary(
Box::new(Expr::Num(1.1)),
Box::new(Expr::Float(1.1)),
Token::Plus,
Box::new(Expr::Num(5.0)),
Box::new(Expr::Float(5.0)),
),
Ok(6.1),
Ok(Value::Float(6.1)),
);
}

#[test]
fn simple_variable() {
check_with_vars(
Expr::Binary(
Box::new(Expr::Num(12.0)),
Box::new(Expr::Float(12.0)),
Token::Mult,
Box::new(Expr::Var("foo".to_string())),
),
Ok(144.0),
HashMap::from_iter([("foo".to_string(), Value::Num(12.0))]),
Ok(Value::Float(144.0)),
HashMap::from_iter([("foo".to_string(), Value::Float(12.0))]),
);
}

#[test]
fn function_with_closure() {
check_with_vars(
Expr::Call("foo".to_string(), vec![Expr::Var("bar".to_string())]),
Ok(27.0),
Ok(Value::Float(27.0)),
HashMap::from_iter([
(
"foo".to_string(),
Expand All @@ -333,15 +447,15 @@ mod tests {
Expr::Binary(
Box::new(Expr::Exponent(
Box::new(Expr::Var("x".to_string())),
Box::new(Expr::Num(2.0)),
Box::new(Expr::Float(2.0)),
)),
Token::Plus,
Box::new(Expr::Var("bar".to_string())),
),
HashMap::from_iter([("bar".to_string(), Value::Num(2.0))]),
HashMap::from_iter([("bar".to_string(), Value::Float(2.0))]),
)),
),
("bar".to_string(), Value::Num(5.0)), // The function uses the closure value
("bar".to_string(), Value::Int(5)), // The function uses the closure value
]),
);
}
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use interpreter::Interpreter;
use interpreter::{Interpreter, Value};
use parse::Parser;
use ratatui::{prelude::CrosstermBackend, Terminal};
use std::{error::Error, io};
Expand Down Expand Up @@ -44,15 +44,15 @@ pub fn tui() -> Result<(), Box<dyn Error>> {
Ok(())
}

pub fn eval(input: &str) -> Result<f64, Box<dyn Error>> {
pub fn eval(input: &str) -> Result<Value, Box<dyn Error>> {
let mut tokenizer = Tokenizer::new(input.chars().peekable()).peekable();
let current = tokenizer.next().ok_or("Expected expression")?;
let stmt = Parser::new(tokenizer, current).parse()?;
let res = Interpreter::new().interpret(stmt)?;
if let Some(ans) = res {
Ok(ans)
} else {
Ok(1_f64)
Ok(Value::Float(1_f64))
}
}

Expand Down
Loading