Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enhance runtime type cast and check for lambda arguments and return values #1529

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions kclvm/compiler/src/codegen/llvm/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2179,9 +2179,21 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
}
}
self.walk_arguments(&lambda_expr.args, args, kwargs);
let val = self
let mut val = self
.walk_stmts(&lambda_expr.body)
.expect(kcl_error::COMPILE_ERROR_MSG);
if let Some(ty) = &lambda_expr.return_ty {
let type_annotation = self.native_global_string_value(&ty.node.to_string());
val = self.build_call(
&ApiFunc::kclvm_convert_collection_value.name(),
&[
self.current_runtime_ctx_ptr(),
val,
type_annotation,
self.bool_value(false),
],
);
}
self.builder.build_return(Some(&val));
// Exist the function
self.builder.position_at_end(func_before_block);
Expand Down Expand Up @@ -2731,23 +2743,39 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
kwargs: BasicValueEnum<'ctx>,
) {
// Arguments names and defaults
let (arg_names, arg_defaults) = if let Some(args) = &arguments {
let (arg_names, arg_types, arg_defaults) = if let Some(args) = &arguments {
let names = &args.node.args;
let types = &args.node.ty_list;
let defaults = &args.node.defaults;
(
names.iter().map(|identifier| &identifier.node).collect(),
types.iter().collect(),
defaults.iter().collect(),
)
} else {
(vec![], vec![])
(vec![], vec![], vec![])
};
// Default parameter values
for (arg_name, value) in arg_names.iter().zip(arg_defaults.iter()) {
let arg_value = if let Some(value) = value {
for ((arg_name, arg_type), value) in
arg_names.iter().zip(&arg_types).zip(arg_defaults.iter())
{
let mut arg_value = if let Some(value) = value {
self.walk_expr(value).expect(kcl_error::COMPILE_ERROR_MSG)
} else {
self.none_value()
};
if let Some(ty) = arg_type {
let type_annotation = self.native_global_string_value(&ty.node.to_string());
arg_value = self.build_call(
&ApiFunc::kclvm_convert_collection_value.name(),
&[
self.current_runtime_ctx_ptr(),
arg_value,
type_annotation,
self.bool_value(false),
],
);
}
// Arguments are immutable, so we place them in different scopes.
self.store_argument_in_current_scope(&arg_name.get_name());
self.walk_identifier_with_ctx(arg_name, &ast::ExprContext::Store, Some(arg_value))
Expand All @@ -2756,7 +2784,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
// for loop in 0..argument_len in LLVM begin
let argument_len = self.build_call(&ApiFunc::kclvm_list_len.name(), &[args]);
let end_block = self.append_block("");
for (i, arg_name) in arg_names.iter().enumerate() {
for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types).enumerate() {
// Positional arguments
let is_in_range = self.builder.build_int_compare(
IntPredicate::ULT,
Expand All @@ -2768,14 +2796,26 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
self.builder
.build_conditional_branch(is_in_range, next_block, end_block);
self.builder.position_at_end(next_block);
let arg_value = self.build_call(
let mut arg_value = self.build_call(
&ApiFunc::kclvm_list_get_option.name(),
&[
self.current_runtime_ctx_ptr(),
args,
self.native_int_value(i as i32),
],
);
if let Some(ty) = arg_type {
let type_annotation = self.native_global_string_value(&ty.node.to_string());
arg_value = self.build_call(
&ApiFunc::kclvm_convert_collection_value.name(),
&[
self.current_runtime_ctx_ptr(),
arg_value,
type_annotation,
self.bool_value(false),
],
);
}
self.store_variable(&arg_name.names[0].node, arg_value);
}
// for loop in 0..argument_len in LLVM end
Expand Down
6 changes: 5 additions & 1 deletion kclvm/evaluator/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use kclvm_runtime::ValueRef;
use scopeguard::defer;

use crate::proxy::Proxy;
use crate::ty::type_pack_and_check;
use crate::Evaluator;
use crate::{error as kcl_error, EvalContext};

Expand Down Expand Up @@ -125,8 +126,11 @@ pub fn func_body(
}
// Evaluate arguments and keyword arguments and store values to local variables.
s.walk_arguments(&ctx.node.args, args, kwargs);
let result = s
let mut result = s
.walk_stmts(&ctx.node.body)
.expect(kcl_error::RUNTIME_ERROR_MSG);
if let Some(ty) = &ctx.node.return_ty {
result = type_pack_and_check(s, &result, vec![&ty.node.to_string()], false);
}
result
}
24 changes: 18 additions & 6 deletions kclvm/evaluator/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1449,23 +1449,31 @@ impl<'ctx> Evaluator<'ctx> {
kwargs: &ValueRef,
) {
// Arguments names and defaults
let (arg_names, arg_defaults) = if let Some(args) = &arguments {
let (arg_names, arg_types, arg_defaults) = if let Some(args) = &arguments {
let names = &args.node.args;
let types = &args.node.ty_list;
let defaults = &args.node.defaults;
(
names.iter().map(|identifier| &identifier.node).collect(),
types.iter().collect(),
defaults.iter().collect(),
)
} else {
(vec![], vec![])
(vec![], vec![], vec![])
};
// Default parameter values
for (arg_name, value) in arg_names.iter().zip(arg_defaults.iter()) {
let arg_value = if let Some(value) = value {
for ((arg_name, arg_type), value) in
arg_names.iter().zip(&arg_types).zip(arg_defaults.iter())
{
let mut arg_value = if let Some(value) = value {
self.walk_expr(value).expect(kcl_error::RUNTIME_ERROR_MSG)
} else {
self.none_value()
};
if let Some(ty) = arg_type {
arg_value =
type_pack_and_check(self, &arg_value, vec![&ty.node.to_string()], false);
}
// Arguments are immutable, so we place them in different scopes.
let name = arg_name.get_name();
self.store_argument_in_current_scope(&name);
Expand All @@ -1477,14 +1485,18 @@ impl<'ctx> Evaluator<'ctx> {
}
// Positional arguments
let argument_len = args.len();
for (i, arg_name) in arg_names.iter().enumerate() {
for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types).enumerate() {
// Positional arguments
let is_in_range = i < argument_len;
if is_in_range {
let arg_value = match args.list_get_option(i as isize) {
let mut arg_value = match args.list_get_option(i as isize) {
Some(v) => v,
None => self.undefined_value(),
};
if let Some(ty) = arg_type {
arg_value =
type_pack_and_check(self, &arg_value, vec![&ty.node.to_string()], false);
}
self.store_variable(&arg_name.names[0].node, arg_value);
} else {
break;
Expand Down
12 changes: 12 additions & 0 deletions kclvm/runtime/src/value/val_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ pub fn check_type(value: &ValueRef, tpe: &str, strict: bool) -> bool {
// if value type is a built-in type e.g. str, int, float, bool
if match_builtin_type(value, tpe) {
return true;
} else if match_function_type(value, tpe) {
return true;
}
if value.is_schema() {
if strict {
Expand Down Expand Up @@ -532,6 +534,16 @@ pub fn match_builtin_type(value: &ValueRef, tpe: &str) -> bool {
value.type_str() == *tpe || (value.type_str() == BUILTIN_TYPE_INT && tpe == BUILTIN_TYPE_FLOAT)
}

/// match_function_type returns the value wether match the given the function type string
#[inline]
pub fn match_function_type(value: &ValueRef, tpe: &str) -> bool {
value.type_str() == *tpe
|| (value.type_str() == KCL_TYPE_FUNCTION
&& tpe.contains("(")
&& tpe.contains(")")
&& tpe.contains("->"))
}

/// is_literal_type returns the type string whether is a literal type
pub fn is_literal_type(tpe: &str) -> bool {
if KCL_NAME_CONSTANTS.contains(&tpe) {
Expand Down
29 changes: 22 additions & 7 deletions kclvm/sema/src/resolver/ty_erasure.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use kclvm_ast::ast;
use kclvm_ast::walker::MutSelfMutWalker;
use kclvm_ast::{ast, walk_if_mut, walk_list_mut};

#[derive(Default)]
struct TypeErasureTransformer;
Expand All @@ -14,14 +14,14 @@ impl<'ctx> MutSelfMutWalker<'ctx> for TypeErasureTransformer {
schema_index_signature.node.value_ty.node = FUNCTION.to_string().into();
}
}
for item in schema_stmt.body.iter_mut() {
if let kclvm_ast::ast::Stmt::SchemaAttr(attr) = &mut item.node {
self.walk_schema_attr(attr);
}
}
walk_if_mut!(self, walk_arguments, schema_stmt.args);
walk_list_mut!(self, walk_call_expr, schema_stmt.decorators);
walk_list_mut!(self, walk_check_expr, schema_stmt.checks);
walk_list_mut!(self, walk_stmt, schema_stmt.body);
}

fn walk_schema_attr(&mut self, schema_attr: &'ctx mut ast::SchemaAttr) {
walk_list_mut!(self, walk_call_expr, schema_attr.decorators);
walk_if_mut!(self, walk_expr, schema_attr.value);
if let kclvm_ast::ast::Type::Function(_) = schema_attr.ty.as_ref().node {
schema_attr.ty.node = FUNCTION.to_string().into();
}
Expand All @@ -34,6 +34,7 @@ impl<'ctx> MutSelfMutWalker<'ctx> for TypeErasureTransformer {
}
}
}
self.walk_expr(&mut assign_stmt.value.node);
}
fn walk_type_alias_stmt(&mut self, type_alias_stmt: &'ctx mut ast::TypeAliasStmt) {
if let kclvm_ast::ast::Type::Function(_) = type_alias_stmt.ty.as_ref().node {
Expand All @@ -46,6 +47,20 @@ impl<'ctx> MutSelfMutWalker<'ctx> for TypeErasureTransformer {
ty.node = FUNCTION.to_string().into();
}
}
for default in arguments.defaults.iter_mut() {
if let Some(d) = default.as_deref_mut() {
self.walk_expr(&mut d.node)
}
}
}
fn walk_lambda_expr(&mut self, lambda_expr: &'ctx mut ast::LambdaExpr) {
walk_if_mut!(self, walk_arguments, lambda_expr.args);
walk_list_mut!(self, walk_stmt, lambda_expr.body);
if let Some(ty) = lambda_expr.return_ty.as_mut() {
if let kclvm_ast::ast::Type::Function(_) = ty.as_ref().node {
ty.node = FUNCTION.to_string().into();
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
schema ProviderFamily:
version: str
marketplace: bool = True

providerFamily = lambda family: ProviderFamily -> ProviderFamily {
family
}

v = providerFamily({
version: "1.6.0"
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
v:
version: '1.6.0'
marketplace: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
schema ProviderFamily:
version: str
marketplace: bool = True

providerFamily = lambda -> ProviderFamily {
{
version: "1.6.0"
}
}

v = providerFamily()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
v:
version: '1.6.0'
marketplace: true
Loading