From 93fd656b602d6314b75c5360693478dedb4c6e91 Mon Sep 17 00:00:00 2001 From: peefy Date: Thu, 1 Aug 2024 18:24:12 +0800 Subject: [PATCH] feat: enhance runtime type cast and check for lambda arguments and return values Signed-off-by: peefy --- kclvm/compiler/src/codegen/llvm/node.rs | 54 ++++++++++++++++--- kclvm/evaluator/src/func.rs | 6 ++- kclvm/evaluator/src/node.rs | 24 ++++++--- kclvm/runtime/src/value/val_type.rs | 12 +++++ kclvm/sema/src/resolver/ty_erasure.rs | 29 +++++++--- .../type_annotation_schema_2/main.k | 11 ++++ .../type_annotation_schema_2/stdout.golden | 3 ++ .../type_annotation_schema_3/main.k | 11 ++++ .../type_annotation_schema_3/stdout.golden | 3 ++ 9 files changed, 132 insertions(+), 21 deletions(-) create mode 100644 test/grammar/schema/type_annotation/type_annotation_schema_2/main.k create mode 100644 test/grammar/schema/type_annotation/type_annotation_schema_2/stdout.golden create mode 100644 test/grammar/schema/type_annotation/type_annotation_schema_3/main.k create mode 100644 test/grammar/schema/type_annotation/type_annotation_schema_3/stdout.golden diff --git a/kclvm/compiler/src/codegen/llvm/node.rs b/kclvm/compiler/src/codegen/llvm/node.rs index 24f392f43..e3353a184 100644 --- a/kclvm/compiler/src/codegen/llvm/node.rs +++ b/kclvm/compiler/src/codegen/llvm/node.rs @@ -2179,9 +2179,21 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> { } } self.walk_arguments(&lambda_expr.args, args, kwargs); - let val = self + let mut val = self .walk_stmts(&lambda_expr.body) .expect(kcl_error::COMPILE_ERROR_MSG); + if let Some(ty) = &lambda_expr.return_ty { + let type_annotation = self.native_global_string_value(&ty.node.to_string()); + val = self.build_call( + &ApiFunc::kclvm_convert_collection_value.name(), + &[ + self.current_runtime_ctx_ptr(), + val, + type_annotation, + self.bool_value(false), + ], + ); + } self.builder.build_return(Some(&val)); // Exist the function self.builder.position_at_end(func_before_block); @@ -2731,23 +2743,39 @@ impl<'ctx> LLVMCodeGenContext<'ctx> { kwargs: BasicValueEnum<'ctx>, ) { // Arguments names and defaults - let (arg_names, arg_defaults) = if let Some(args) = &arguments { + let (arg_names, arg_types, arg_defaults) = if let Some(args) = &arguments { let names = &args.node.args; + let types = &args.node.ty_list; let defaults = &args.node.defaults; ( names.iter().map(|identifier| &identifier.node).collect(), + types.iter().collect(), defaults.iter().collect(), ) } else { - (vec![], vec![]) + (vec![], vec![], vec![]) }; // Default parameter values - for (arg_name, value) in arg_names.iter().zip(arg_defaults.iter()) { - let arg_value = if let Some(value) = value { + for ((arg_name, arg_type), value) in + arg_names.iter().zip(&arg_types).zip(arg_defaults.iter()) + { + let mut arg_value = if let Some(value) = value { self.walk_expr(value).expect(kcl_error::COMPILE_ERROR_MSG) } else { self.none_value() }; + if let Some(ty) = arg_type { + let type_annotation = self.native_global_string_value(&ty.node.to_string()); + arg_value = self.build_call( + &ApiFunc::kclvm_convert_collection_value.name(), + &[ + self.current_runtime_ctx_ptr(), + arg_value, + type_annotation, + self.bool_value(false), + ], + ); + } // Arguments are immutable, so we place them in different scopes. self.store_argument_in_current_scope(&arg_name.get_name()); self.walk_identifier_with_ctx(arg_name, &ast::ExprContext::Store, Some(arg_value)) @@ -2756,7 +2784,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> { // for loop in 0..argument_len in LLVM begin let argument_len = self.build_call(&ApiFunc::kclvm_list_len.name(), &[args]); let end_block = self.append_block(""); - for (i, arg_name) in arg_names.iter().enumerate() { + for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types).enumerate() { // Positional arguments let is_in_range = self.builder.build_int_compare( IntPredicate::ULT, @@ -2768,7 +2796,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> { self.builder .build_conditional_branch(is_in_range, next_block, end_block); self.builder.position_at_end(next_block); - let arg_value = self.build_call( + let mut arg_value = self.build_call( &ApiFunc::kclvm_list_get_option.name(), &[ self.current_runtime_ctx_ptr(), @@ -2776,6 +2804,18 @@ impl<'ctx> LLVMCodeGenContext<'ctx> { self.native_int_value(i as i32), ], ); + if let Some(ty) = arg_type { + let type_annotation = self.native_global_string_value(&ty.node.to_string()); + arg_value = self.build_call( + &ApiFunc::kclvm_convert_collection_value.name(), + &[ + self.current_runtime_ctx_ptr(), + arg_value, + type_annotation, + self.bool_value(false), + ], + ); + } self.store_variable(&arg_name.names[0].node, arg_value); } // for loop in 0..argument_len in LLVM end diff --git a/kclvm/evaluator/src/func.rs b/kclvm/evaluator/src/func.rs index 978130765..f6e9ef65e 100644 --- a/kclvm/evaluator/src/func.rs +++ b/kclvm/evaluator/src/func.rs @@ -8,6 +8,7 @@ use kclvm_runtime::ValueRef; use scopeguard::defer; use crate::proxy::Proxy; +use crate::ty::type_pack_and_check; use crate::Evaluator; use crate::{error as kcl_error, EvalContext}; @@ -125,8 +126,11 @@ pub fn func_body( } // Evaluate arguments and keyword arguments and store values to local variables. s.walk_arguments(&ctx.node.args, args, kwargs); - let result = s + let mut result = s .walk_stmts(&ctx.node.body) .expect(kcl_error::RUNTIME_ERROR_MSG); + if let Some(ty) = &ctx.node.return_ty { + result = type_pack_and_check(s, &result, vec![&ty.node.to_string()], false); + } result } diff --git a/kclvm/evaluator/src/node.rs b/kclvm/evaluator/src/node.rs index 70ea3b52b..1e02094ec 100644 --- a/kclvm/evaluator/src/node.rs +++ b/kclvm/evaluator/src/node.rs @@ -1449,23 +1449,31 @@ impl<'ctx> Evaluator<'ctx> { kwargs: &ValueRef, ) { // Arguments names and defaults - let (arg_names, arg_defaults) = if let Some(args) = &arguments { + let (arg_names, arg_types, arg_defaults) = if let Some(args) = &arguments { let names = &args.node.args; + let types = &args.node.ty_list; let defaults = &args.node.defaults; ( names.iter().map(|identifier| &identifier.node).collect(), + types.iter().collect(), defaults.iter().collect(), ) } else { - (vec![], vec![]) + (vec![], vec![], vec![]) }; // Default parameter values - for (arg_name, value) in arg_names.iter().zip(arg_defaults.iter()) { - let arg_value = if let Some(value) = value { + for ((arg_name, arg_type), value) in + arg_names.iter().zip(&arg_types).zip(arg_defaults.iter()) + { + let mut arg_value = if let Some(value) = value { self.walk_expr(value).expect(kcl_error::RUNTIME_ERROR_MSG) } else { self.none_value() }; + if let Some(ty) = arg_type { + arg_value = + type_pack_and_check(self, &arg_value, vec![&ty.node.to_string()], false); + } // Arguments are immutable, so we place them in different scopes. let name = arg_name.get_name(); self.store_argument_in_current_scope(&name); @@ -1477,14 +1485,18 @@ impl<'ctx> Evaluator<'ctx> { } // Positional arguments let argument_len = args.len(); - for (i, arg_name) in arg_names.iter().enumerate() { + for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types).enumerate() { // Positional arguments let is_in_range = i < argument_len; if is_in_range { - let arg_value = match args.list_get_option(i as isize) { + let mut arg_value = match args.list_get_option(i as isize) { Some(v) => v, None => self.undefined_value(), }; + if let Some(ty) = arg_type { + arg_value = + type_pack_and_check(self, &arg_value, vec![&ty.node.to_string()], false); + } self.store_variable(&arg_name.names[0].node, arg_value); } else { break; diff --git a/kclvm/runtime/src/value/val_type.rs b/kclvm/runtime/src/value/val_type.rs index 94ebb14bf..27c3b31b4 100644 --- a/kclvm/runtime/src/value/val_type.rs +++ b/kclvm/runtime/src/value/val_type.rs @@ -413,6 +413,8 @@ pub fn check_type(value: &ValueRef, tpe: &str, strict: bool) -> bool { // if value type is a built-in type e.g. str, int, float, bool if match_builtin_type(value, tpe) { return true; + } else if match_function_type(value, tpe) { + return true; } if value.is_schema() { if strict { @@ -532,6 +534,16 @@ pub fn match_builtin_type(value: &ValueRef, tpe: &str) -> bool { value.type_str() == *tpe || (value.type_str() == BUILTIN_TYPE_INT && tpe == BUILTIN_TYPE_FLOAT) } +/// match_function_type returns the value wether match the given the function type string +#[inline] +pub fn match_function_type(value: &ValueRef, tpe: &str) -> bool { + value.type_str() == *tpe + || (value.type_str() == KCL_TYPE_FUNCTION + && tpe.contains("(") + && tpe.contains(")") + && tpe.contains("->")) +} + /// is_literal_type returns the type string whether is a literal type pub fn is_literal_type(tpe: &str) -> bool { if KCL_NAME_CONSTANTS.contains(&tpe) { diff --git a/kclvm/sema/src/resolver/ty_erasure.rs b/kclvm/sema/src/resolver/ty_erasure.rs index 0d0f269ef..f65d7d867 100644 --- a/kclvm/sema/src/resolver/ty_erasure.rs +++ b/kclvm/sema/src/resolver/ty_erasure.rs @@ -1,5 +1,5 @@ -use kclvm_ast::ast; use kclvm_ast::walker::MutSelfMutWalker; +use kclvm_ast::{ast, walk_if_mut, walk_list_mut}; #[derive(Default)] struct TypeErasureTransformer; @@ -14,14 +14,14 @@ impl<'ctx> MutSelfMutWalker<'ctx> for TypeErasureTransformer { schema_index_signature.node.value_ty.node = FUNCTION.to_string().into(); } } - for item in schema_stmt.body.iter_mut() { - if let kclvm_ast::ast::Stmt::SchemaAttr(attr) = &mut item.node { - self.walk_schema_attr(attr); - } - } + walk_if_mut!(self, walk_arguments, schema_stmt.args); + walk_list_mut!(self, walk_call_expr, schema_stmt.decorators); + walk_list_mut!(self, walk_check_expr, schema_stmt.checks); + walk_list_mut!(self, walk_stmt, schema_stmt.body); } - fn walk_schema_attr(&mut self, schema_attr: &'ctx mut ast::SchemaAttr) { + walk_list_mut!(self, walk_call_expr, schema_attr.decorators); + walk_if_mut!(self, walk_expr, schema_attr.value); if let kclvm_ast::ast::Type::Function(_) = schema_attr.ty.as_ref().node { schema_attr.ty.node = FUNCTION.to_string().into(); } @@ -34,6 +34,7 @@ impl<'ctx> MutSelfMutWalker<'ctx> for TypeErasureTransformer { } } } + self.walk_expr(&mut assign_stmt.value.node); } fn walk_type_alias_stmt(&mut self, type_alias_stmt: &'ctx mut ast::TypeAliasStmt) { if let kclvm_ast::ast::Type::Function(_) = type_alias_stmt.ty.as_ref().node { @@ -46,6 +47,20 @@ impl<'ctx> MutSelfMutWalker<'ctx> for TypeErasureTransformer { ty.node = FUNCTION.to_string().into(); } } + for default in arguments.defaults.iter_mut() { + if let Some(d) = default.as_deref_mut() { + self.walk_expr(&mut d.node) + } + } + } + fn walk_lambda_expr(&mut self, lambda_expr: &'ctx mut ast::LambdaExpr) { + walk_if_mut!(self, walk_arguments, lambda_expr.args); + walk_list_mut!(self, walk_stmt, lambda_expr.body); + if let Some(ty) = lambda_expr.return_ty.as_mut() { + if let kclvm_ast::ast::Type::Function(_) = ty.as_ref().node { + ty.node = FUNCTION.to_string().into(); + } + } } } diff --git a/test/grammar/schema/type_annotation/type_annotation_schema_2/main.k b/test/grammar/schema/type_annotation/type_annotation_schema_2/main.k new file mode 100644 index 000000000..8a577f18a --- /dev/null +++ b/test/grammar/schema/type_annotation/type_annotation_schema_2/main.k @@ -0,0 +1,11 @@ +schema ProviderFamily: + version: str + marketplace: bool = True + +providerFamily = lambda family: ProviderFamily -> ProviderFamily { + family +} + +v = providerFamily({ + version: "1.6.0" +}) diff --git a/test/grammar/schema/type_annotation/type_annotation_schema_2/stdout.golden b/test/grammar/schema/type_annotation/type_annotation_schema_2/stdout.golden new file mode 100644 index 000000000..67d7078bb --- /dev/null +++ b/test/grammar/schema/type_annotation/type_annotation_schema_2/stdout.golden @@ -0,0 +1,3 @@ +v: + version: '1.6.0' + marketplace: true diff --git a/test/grammar/schema/type_annotation/type_annotation_schema_3/main.k b/test/grammar/schema/type_annotation/type_annotation_schema_3/main.k new file mode 100644 index 000000000..544f1cb23 --- /dev/null +++ b/test/grammar/schema/type_annotation/type_annotation_schema_3/main.k @@ -0,0 +1,11 @@ +schema ProviderFamily: + version: str + marketplace: bool = True + +providerFamily = lambda -> ProviderFamily { + { + version: "1.6.0" + } +} + +v = providerFamily() diff --git a/test/grammar/schema/type_annotation/type_annotation_schema_3/stdout.golden b/test/grammar/schema/type_annotation/type_annotation_schema_3/stdout.golden new file mode 100644 index 000000000..67d7078bb --- /dev/null +++ b/test/grammar/schema/type_annotation/type_annotation_schema_3/stdout.golden @@ -0,0 +1,3 @@ +v: + version: '1.6.0' + marketplace: true