Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(runtime): add marshalling of value structs #93

Merged
merged 6 commits into from
Mar 7, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat(code_gen): create marshallable wrapper for unmarshallable functions
A function cannot be marshalled when one of its parameters or its return
type are a value struct
Wodann committed Mar 7, 2020
commit 9dc084e443b1723a234e9e1f104df2071f9ee7b3
9 changes: 4 additions & 5 deletions crates/mun_codegen/src/code_gen/symbols.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ use crate::ir::{
};
use crate::type_info::{TypeGroup, TypeInfo};
use crate::values::{BasicValue, GlobalValue};
use crate::IrDatabase;
use crate::{CodeGenParams, IrDatabase};
use hir::Ty;
use inkwell::{
attributes::Attribute,
@@ -262,7 +262,6 @@ fn gen_function_info_array<'a, D: IrDatabase>(
functions: impl Iterator<Item = (&'a hir::Function, &'a FunctionValue)>,
) -> GlobalArrayValue {
let function_infos: Vec<StructValue> = functions
.filter(|(f, _)| f.visibility(db) == hir::Visibility::Public)
.map(|(f, value)| {
// Get the function from the cloned module and modify the linkage of the function.
let value = module
@@ -321,9 +320,9 @@ fn gen_struct_info<D: IrDatabase>(
(0..fields.len()).map(|idx| target_data.offset_of_element(&t, idx as u32).unwrap());
let (field_offsets, _) = gen_u16_array(module, field_offsets);

let field_sizes = fields
.iter()
.map(|field| target_data.get_store_size(&db.type_ir(field.ty(db))));
let field_sizes = fields.iter().map(|field| {
target_data.get_store_size(&db.type_ir(field.ty(db), CodeGenParams { is_extern: false }))
});
let (field_sizes, _) = gen_u16_array(module, field_sizes);

types.struct_info_type.const_named_struct(&[
6 changes: 3 additions & 3 deletions crates/mun_codegen/src/db.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(clippy::type_repetition_in_bounds)]

use crate::{ir::module::ModuleIR, type_info::TypeInfo, Context};
use crate::{ir::module::ModuleIR, type_info::TypeInfo, CodeGenParams, Context};
use inkwell::types::StructType;
use inkwell::{types::AnyTypeEnum, OptimizationLevel};
use mun_target::spec::Target;
@@ -22,9 +22,9 @@ pub trait IrDatabase: hir::HirDatabase {
#[salsa::input]
fn target(&self) -> Target;

/// Given a type, return the corresponding IR type.
/// Given a type and code generation parameters, return the corresponding IR type.
#[salsa::invoke(crate::ir::ty::ir_query)]
fn type_ir(&self, ty: hir::Ty) -> AnyTypeEnum;
fn type_ir(&self, ty: hir::Ty, params: CodeGenParams) -> AnyTypeEnum;

/// Given a struct, return the corresponding IR type.
#[salsa::invoke(crate::ir::ty::struct_ty_query)]
4 changes: 2 additions & 2 deletions crates/mun_codegen/src/ir/adt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//use crate::ir::module::Types;
use crate::ir::try_convert_any_to_basic;
use crate::IrDatabase;
use crate::{CodeGenParams, IrDatabase};
use inkwell::types::{BasicTypeEnum, StructType};

pub(super) fn gen_struct_decl(db: &impl IrDatabase, s: hir::Struct) -> StructType {
@@ -11,7 +11,7 @@ pub(super) fn gen_struct_decl(db: &impl IrDatabase, s: hir::Struct) -> StructTyp
.iter()
.map(|field| {
let field_type = field.ty(db);
try_convert_any_to_basic(db.type_ir(field_type))
try_convert_any_to_basic(db.type_ir(field_type, CodeGenParams { is_extern: false }))
.expect("could not convert field type")
})
.collect();
172 changes: 119 additions & 53 deletions crates/mun_codegen/src/ir/body.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::intrinsics;
use crate::{ir::dispatch_table::DispatchTable, ir::try_convert_any_to_basic, IrDatabase};
use crate::{
ir::dispatch_table::DispatchTable, ir::try_convert_any_to_basic, CodeGenParams, IrDatabase,
};
use hir::{
ArenaId, ArithOp, BinaryOp, Body, CmpOp, Expr, ExprId, HirDisplay, InferenceResult, Literal,
Name, Ordering, Pat, PatId, Path, Resolution, Resolver, Statement, TypeCtor,
};
use inkwell::{
builder::Builder,
module::Module,
values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue},
values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue, StructValue},
AddressSpace, FloatPredicate, IntPredicate,
};
use std::{collections::HashMap, mem, sync::Arc};
@@ -37,6 +39,7 @@ pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> {
dispatch_table: &'b DispatchTable,
active_loop: Option<LoopInfo>,
hir_function: hir::Function,
params: CodeGenParams,
}

impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
@@ -47,6 +50,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
ir_function: FunctionValue,
function_map: &'a HashMap<hir::Function, FunctionValue>,
dispatch_table: &'b DispatchTable,
params: CodeGenParams,
) -> Self {
// Get the type information from the `hir::Function`
let body = hir_function.body(db);
@@ -72,6 +76,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
dispatch_table,
active_loop: None,
hir_function,
params,
}
}

@@ -127,6 +132,50 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
}
}

pub fn gen_fn_wrapper(&mut self) {
let fn_sig = self.hir_function.ty(self.db).callable_sig(self.db).unwrap();
let args: Vec<BasicValueEnum> = fn_sig
.params()
.iter()
.enumerate()
.map(|(idx, ty)| {
let param = self.fn_value.get_nth_param(idx as u32).unwrap();
self.opt_deref_value(ty.clone(), param)
})
.collect();

let ret_value = self
.gen_call(self.hir_function, &args)
.try_as_basic_value()
.left();

let call_return_type = &self.infer[self.body.body_expr()];
if !call_return_type.is_never() {
let fn_ret_type = self
.hir_function
.ty(self.db)
.callable_sig(self.db)
.unwrap()
.ret()
.clone();

if fn_ret_type.is_empty() {
self.builder.build_return(None);
} else if let Some(value) = ret_value {
let ret_value = if let Some(hir_struct) = fn_ret_type.as_struct() {
if hir_struct.data(self.db).memory_kind == hir::StructMemoryKind::Value {
self.gen_struct_alloc_on_heap(hir_struct, value.into_struct_value())
} else {
value
}
} else {
value
};
self.builder.build_return(Some(&ret_value));
}
}
}

/// Generates IR for the specified expression. Dependending on the type of expression an IR
/// value is returned.
fn gen_expr(&mut self, expr: ExprId) -> Option<inkwell::values::BasicValueEnum> {
@@ -152,6 +201,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
// Get the callable definition from the map
match self.infer[*callee].as_callable_def() {
Some(hir::CallableDef::Function(def)) => {
// Get all the arguments
let args: Vec<BasicValueEnum> = args
.iter()
.map(|expr| self.gen_expr(*expr).expect("expected a value"))
.collect();

self.gen_call(def, &args).try_as_basic_value().left()
}
Some(hir::CallableDef::Struct(_)) => Some(self.gen_named_tuple_lit(expr, args)),
@@ -235,37 +290,45 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
hir::StructMemoryKind::Value => struct_lit.into(),
hir::StructMemoryKind::GC => {
// TODO: Root memory in GC
let struct_ir_ty = self.db.struct_ty(hir_struct);
let malloc_fn_ptr = self
.dispatch_table
.gen_intrinsic_lookup(&self.builder, &intrinsics::malloc);
let mem_ptr = self
.builder
.build_call(
malloc_fn_ptr,
&[
struct_ir_ty.size_of().unwrap().into(),
struct_ir_ty.get_alignment().into(),
],
"malloc",
)
.try_as_basic_value()
.left()
.unwrap();
let struct_ptr = self
.builder
.build_bitcast(
mem_ptr,
struct_ir_ty.ptr_type(AddressSpace::Generic),
&hir_struct.name(self.db).to_string(),
)
.into_pointer_value();
self.builder.build_store(struct_ptr, struct_lit);
struct_ptr.into()
self.gen_struct_alloc_on_heap(hir_struct, struct_lit)
}
}
}

fn gen_struct_alloc_on_heap(
&mut self,
hir_struct: hir::Struct,
struct_lit: StructValue,
) -> BasicValueEnum {
let struct_ir_ty = self.db.struct_ty(hir_struct);
let malloc_fn_ptr = self
.dispatch_table
.gen_intrinsic_lookup(&self.builder, &intrinsics::malloc);
let mem_ptr = self
.builder
.build_call(
malloc_fn_ptr,
&[
struct_ir_ty.size_of().unwrap().into(),
struct_ir_ty.get_alignment().into(),
],
"malloc",
)
.try_as_basic_value()
.left()
.unwrap();
let struct_ptr = self
.builder
.build_bitcast(
mem_ptr,
struct_ir_ty.ptr_type(AddressSpace::Generic),
&hir_struct.name(self.db).to_string(),
)
.into_pointer_value();
self.builder.build_store(struct_ptr, struct_lit);
struct_ptr.into()
}

/// Generates IR for a record literal, e.g. `Foo { a: 1.23, b: 4 }`
fn gen_record_lit(
&mut self,
@@ -349,8 +412,11 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
Pat::Bind { name } => {
let builder = self.new_alloca_builder();
let pat_ty = self.infer[pat].clone();
let ty = try_convert_any_to_basic(self.db.type_ir(pat_ty.clone()))
.expect("expected basic type");
let ty = try_convert_any_to_basic(
self.db
.type_ir(pat_ty.clone(), CodeGenParams { is_extern: false }),
)
.expect("expected basic type");
let ptr = builder.build_alloca(ty, &name.to_string());
self.pat_to_local.insert(pat, ptr);
self.pat_to_name.insert(pat, name.to_string());
@@ -394,16 +460,22 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
}

/// Given an expression and the type of the expression, optionally dereference the value.
fn opt_deref_value(&mut self, expr: ExprId, value: BasicValueEnum) -> BasicValueEnum {
match &self.infer[expr] {
fn opt_deref_value(&mut self, ty: hir::Ty, value: BasicValueEnum) -> BasicValueEnum {
match ty {
hir::Ty::Apply(hir::ApplicationTy {
ctor: hir::TypeCtor::Struct(s),
..
}) => match s.data(self.db).memory_kind {
hir::StructMemoryKind::GC => {
self.builder.build_load(value.into_pointer_value(), "deref")
}
hir::StructMemoryKind::Value => value,
hir::StructMemoryKind::Value => {
if self.params.is_extern {
self.builder.build_load(value.into_pointer_value(), "deref")
} else {
value
}
}
},
_ => value,
}
@@ -460,12 +532,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
) -> Option<BasicValueEnum> {
let lhs = self
.gen_expr(lhs_expr)
.map(|value| self.opt_deref_value(lhs_expr, value))
.map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value))
.expect("no lhs value")
.into_float_value();
let rhs = self
.gen_expr(rhs_expr)
.map(|value| self.opt_deref_value(rhs_expr, value))
.map(|value| self.opt_deref_value(self.infer[rhs_expr].clone(), value))
.expect("no rhs value")
.into_float_value();
match op {
@@ -519,12 +591,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
) -> Option<BasicValueEnum> {
let lhs = self
.gen_expr(lhs_expr)
.map(|value| self.opt_deref_value(lhs_expr, value))
.map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value))
.expect("no lhs value")
.into_int_value();
let rhs = self
.gen_expr(rhs_expr)
.map(|value| self.opt_deref_value(lhs_expr, value))
.map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value))
.expect("no rhs value")
.into_int_value();
match op {
@@ -609,19 +681,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
}
}

// TODO: Implement me!
fn should_use_dispatch_table(&self) -> bool {
true
// FIXME: When we use the dispatch table, generated wrappers have infinite recursion
!self.params.is_extern
}

/// Generates IR for a function call.
fn gen_call(&mut self, function: hir::Function, args: &[ExprId]) -> CallSiteValue {
// Get all the arguments
let args: Vec<BasicValueEnum> = args
.iter()
.map(|expr| self.gen_expr(*expr).expect("expected a value"))
.collect();

fn gen_call(&mut self, function: hir::Function, args: &[BasicValueEnum]) -> CallSiteValue {
if self.should_use_dispatch_table() {
let ptr_value =
self.dispatch_table
@@ -649,7 +715,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
// Generate IR for the condition
let condition_ir = self
.gen_expr(condition)
.map(|value| self.opt_deref_value(condition, value))?
.map(|value| self.opt_deref_value(self.infer[condition].clone(), value))?
.into_int_value();

// Generate the code blocks to branch to
@@ -787,7 +853,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
self.builder.position_at_end(&cond_block);
let condition_ir = self
.gen_expr(condition_expr)
.map(|value| self.opt_deref_value(condition_expr, value));
.map(|value| self.opt_deref_value(self.infer[condition_expr].clone(), value));
if let Some(condition_ir) = condition_ir {
self.builder.build_conditional_branch(
condition_ir.into_int_value(),
@@ -844,11 +910,11 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
}

fn gen_field(&mut self, _expr: ExprId, receiver_expr: ExprId, name: &Name) -> PointerValue {
let receiver_ty = &self.infer[receiver_expr]
let hir_struct = self.infer[receiver_expr]
.as_struct()
.expect("expected a struct");

let field_idx = receiver_ty
let field_idx = hir_struct
.field(self.db, name)
.expect("expected a struct field")
.id()
@@ -857,13 +923,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {

let receiver_ptr = self.gen_place_expr(receiver_expr);
let receiver_ptr = self
.opt_deref_value(receiver_expr, receiver_ptr.into())
.opt_deref_value(self.infer[receiver_expr].clone(), receiver_ptr.into())
.into_pointer_value();
unsafe {
self.builder.build_struct_gep(
receiver_ptr,
field_idx,
&format!("{}.{}", receiver_ty.name(self.db), name),
&format!("{}.{}", hir_struct.name(self.db), name),
)
}
}
12 changes: 10 additions & 2 deletions crates/mun_codegen/src/ir/dispatch_table.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::intrinsics;
use crate::values::FunctionValue;
use crate::IrDatabase;
use crate::{CodeGenParams, IrDatabase};
use inkwell::module::Module;
use inkwell::types::{BasicTypeEnum, FunctionType};
use inkwell::values::{BasicValueEnum, PointerValue};
@@ -225,7 +225,10 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> {
let name = function.name(self.db).to_string();
let hir_type = function.ty(self.db);
let sig = hir_type.callable_sig(self.db).unwrap();
let ir_type = self.db.type_ir(hir_type).into_function_type();
let ir_type = self
.db
.type_ir(hir_type, CodeGenParams { is_extern: false })
.into_function_type();
let arg_types = sig
.params()
.iter()
@@ -282,6 +285,11 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> {
self.collect_expr(body.body_expr(), body, infer);
}

/// Collect the call expression from the body of a wrapper for the specified function.
pub fn collect_wrapper_body(&mut self, _function: hir::Function) {
self.collect_intrinsic(&intrinsics::malloc)
}

/// This creates the final DispatchTable with all *called* functions from within the module
/// # Parameters
/// * **functions**: Mapping of *defined* Mun functions to their respective IR values.
31 changes: 29 additions & 2 deletions crates/mun_codegen/src/ir/function.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::ir::body::BodyIrGenerator;
use crate::ir::dispatch_table::DispatchTable;
use crate::values::FunctionValue;
use crate::{IrDatabase, Module, OptimizationLevel};
use crate::{CodeGenParams, IrDatabase, Module, OptimizationLevel};
use inkwell::passes::{PassManager, PassManagerBuilder};
use inkwell::types::AnyTypeEnum;

@@ -30,9 +30,10 @@ pub(crate) fn gen_signature(
db: &impl IrDatabase,
f: hir::Function,
module: &Module,
params: CodeGenParams,
) -> FunctionValue {
let name = f.name(db).to_string();
if let AnyTypeEnum::FunctionType(ty) = db.type_ir(f.ty(db)) {
if let AnyTypeEnum::FunctionType(ty) = db.type_ir(f.ty(db), params) {
module.add_function(&name, ty, None)
} else {
panic!("not a function type")
@@ -55,9 +56,35 @@ pub(crate) fn gen_body<'a, 'b, D: IrDatabase>(
llvm_function,
llvm_functions,
dispatch_table,
CodeGenParams { is_extern: false },
);

code_gen.gen_fn_body();

llvm_function
}

/// Generates the body of a wrapper around `hir::Function` for its associated
/// `FunctionValue`
pub(crate) fn gen_wrapper_body<'a, 'b, D: IrDatabase>(
db: &'a D,
hir_function: hir::Function,
llvm_function: FunctionValue,
module: &'a Module,
llvm_functions: &'a HashMap<hir::Function, FunctionValue>,
dispatch_table: &'b DispatchTable,
) -> FunctionValue {
let mut code_gen = BodyIrGenerator::new(
db,
module,
hir_function,
llvm_function,
llvm_functions,
dispatch_table,
CodeGenParams { is_extern: true },
);

code_gen.gen_fn_wrapper();

llvm_function
}
48 changes: 45 additions & 3 deletions crates/mun_codegen/src/ir/module.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ use super::adt;
use crate::ir::dispatch_table::{DispatchTable, DispatchTableBuilder};
use crate::ir::function;
use crate::type_info::TypeInfo;
use crate::IrDatabase;
use crate::{CodeGenParams, IrDatabase};
use hir::{FileId, ModuleDef};
use inkwell::{module::Module, values::FunctionValue};
use std::collections::{HashMap, HashSet};
@@ -47,6 +47,7 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc<ModuleIR> {

// Generate all the function signatures
let mut functions = HashMap::new();
let mut wrappers = HashMap::new();
let mut dispatch_table_builder = DispatchTableBuilder::new(db, &llvm_module);
for def in db.module_data(file_id).definitions() {
// TODO: Remove once we have more ModuleDef variants
@@ -65,13 +66,31 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc<ModuleIR> {
}

// Construct the function signature
let fun = function::gen_signature(db, *f, &llvm_module);
let fun = function::gen_signature(
db,
*f,
&llvm_module,
CodeGenParams { is_extern: false },
);
functions.insert(*f, fun);

// Add calls to the dispatch table
let body = f.body(db);
let infer = f.infer(db);
dispatch_table_builder.collect_body(&body, &infer);

if f.data(db).visibility() != hir::Visibility::Private && !fn_sig.marshallable(db) {
let wrapper_fun = function::gen_signature(
db,
*f,
&llvm_module,
CodeGenParams { is_extern: true },
);
wrappers.insert(*f, wrapper_fun);

// Add calls from the function's wrapper to the dispatch table
dispatch_table_builder.collect_wrapper_body(*f);
}
}
_ => {}
}
@@ -94,6 +113,18 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc<ModuleIR> {
fn_pass_manager.run_on(llvm_function);
}

for (hir_function, llvm_function) in wrappers.iter() {
function::gen_wrapper_body(
db,
*hir_function,
*llvm_function,
&llvm_module,
&functions,
&dispatch_table,
);
fn_pass_manager.run_on(llvm_function);
}

// Dispatch entries can include previously unchecked intrinsics
for entry in dispatch_table.entries().iter() {
// Collect argument types
@@ -106,10 +137,21 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc<ModuleIR> {
}
}

// Filter private methods
let mut api: HashMap<hir::Function, FunctionValue> = functions
.into_iter()
.filter(|(f, _)| f.visibility(db) != hir::Visibility::Private)
.collect();

// Replace non-marshallable functions with their marshallable wrappers
for (hir_function, llvm_function) in wrappers {
api.insert(hir_function, llvm_function);
}

Arc::new(ModuleIR {
file_id,
llvm_module,
functions,
functions: api,
types,
dispatch_table,
})
26 changes: 17 additions & 9 deletions crates/mun_codegen/src/ir/ty.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use super::try_convert_any_to_basic;
use crate::{
type_info::{TypeGroup, TypeInfo},
IrDatabase,
CodeGenParams, IrDatabase,
};
use hir::{ApplicationTy, CallableDef, Ty, TypeCtor};
use inkwell::types::{AnyTypeEnum, BasicType, BasicTypeEnum, StructType};
use inkwell::AddressSpace;

/// Given a mun type, construct an LLVM IR type
pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum {
pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty, params: CodeGenParams) -> AnyTypeEnum {
let context = db.context();
match ty {
Ty::Empty => AnyTypeEnum::StructType(context.struct_type(&[], false)),
@@ -18,17 +18,19 @@ pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum {
TypeCtor::Bool => AnyTypeEnum::IntType(context.bool_type()),
TypeCtor::FnDef(def @ CallableDef::Function(_)) => {
let ty = db.callable_sig(def);
let params: Vec<BasicTypeEnum> = ty
let param_tys: Vec<BasicTypeEnum> = ty
.params()
.iter()
.map(|p| try_convert_any_to_basic(db.type_ir(p.clone())).unwrap())
.map(|p| {
try_convert_any_to_basic(db.type_ir(p.clone(), params.clone())).unwrap()
})
.collect();

let fn_type = match ty.ret() {
Ty::Empty => context.void_type().fn_type(&params, false),
ty => try_convert_any_to_basic(db.type_ir(ty.clone()))
Ty::Empty => context.void_type().fn_type(&param_tys, false),
ty => try_convert_any_to_basic(db.type_ir(ty.clone(), params))
.expect("could not convert return value")
.fn_type(&params, false),
.fn_type(&param_tys, false),
};

AnyTypeEnum::FunctionType(fn_type)
@@ -37,7 +39,13 @@ pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum {
let struct_ty = db.struct_ty(s);
match s.data(db).memory_kind {
hir::StructMemoryKind::GC => struct_ty.ptr_type(AddressSpace::Generic).into(),
hir::StructMemoryKind::Value => struct_ty.into(),
hir::StructMemoryKind::Value => {
if params.is_extern {
struct_ty.ptr_type(AddressSpace::Generic).into()
} else {
struct_ty.into()
}
}
}
}
_ => unreachable!(),
@@ -51,7 +59,7 @@ pub fn struct_ty_query(db: &impl IrDatabase, s: hir::Struct) -> StructType {
let name = s.name(db).to_string();
for field in s.fields(db).iter() {
// Ensure that salsa's cached value incorporates the struct fields
let _field_type_ir = db.type_ir(field.ty(db));
let _field_type_ir = db.type_ir(field.ty(db), CodeGenParams { is_extern: false });
}

db.context().opaque_struct_type(&name)
7 changes: 7 additions & 0 deletions crates/mun_codegen/src/lib.rs
Original file line number Diff line number Diff line change
@@ -19,3 +19,10 @@ pub use crate::{
code_gen::write_module_shared_object,
db::{IrDatabase, IrDatabaseStorage},
};

#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct CodeGenParams {
/// Whether generated code should support extern function calls.
/// This allows function parameters with `struct(value)` types to be marshalled.
is_extern: bool,
}
13 changes: 12 additions & 1 deletion crates/mun_hir/src/ty.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ mod op;
use crate::display::{HirDisplay, HirFormatter};
use crate::ty::infer::TypeVarId;
use crate::ty::lower::fn_sig_for_struct_constructor;
use crate::{HirDatabase, Struct};
use crate::{HirDatabase, Struct, StructMemoryKind};
pub(crate) use infer::infer_query;
pub use infer::InferenceResult;
pub(crate) use lower::{callable_item_sig, fn_sig_for_fn, type_for_def, CallableDef, TypableDef};
@@ -172,6 +172,17 @@ impl FnSig {
pub fn ret(&self) -> &Ty {
&self.params_and_return[self.params_and_return.len() - 1]
}

pub fn marshallable(&self, db: &impl HirDatabase) -> bool {
for ty in self.params_and_return.iter() {
if let Some(s) = ty.as_struct() {
if s.data(db).memory_kind == StructMemoryKind::Value {
return false;
}
}
}
true
}
}

impl HirDisplay for Ty {