Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove parameter-constant arrays from import_ref #4360

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 50 additions & 55 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,6 @@ class ImportRefResolver {
llvm::SmallVector<SemIR::ImportIRInst> indirect_insts = {};
};

// Local information associated with an imported parameter.
struct ParamData {
SemIR::ConstantId type_const_id;
SemIR::ConstantId bind_const_id;
};

// Local information associated with an imported generic.
struct GenericData {
llvm::SmallVector<SemIR::InstId> bindings;
Expand Down Expand Up @@ -469,6 +463,13 @@ class ImportRefResolver {
return GetLocalConstantId(import_ir_.types().GetConstantId(type_id));
}

template <typename Id>
auto GetLocalConstantIdChecked(Id id) {
auto result = GetLocalConstantId(id);
CARBON_CHECK(result.is_valid());
return result;
}

// Gets the local constant values corresponding to an imported inst block.
auto GetLocalInstBlockContents(SemIR::InstBlockId import_block_id)
-> llvm::SmallVector<SemIR::InstId> {
Expand Down Expand Up @@ -643,21 +644,16 @@ class ImportRefResolver {
return specific_id;
}

// Returns the ConstantId for each parameter's type. Adds unresolved constants
// to work_stack_.
auto GetLocalParamConstantIds(SemIR::InstBlockId param_refs_id)
-> llvm::SmallVector<ParamData> {
llvm::SmallVector<ParamData> param_data;
// Adds unresolved constants for each parameter's type to work_stack_.
auto LoadLocalParamConstantIds(SemIR::InstBlockId param_refs_id) -> void {
if (!param_refs_id.is_valid() ||
param_refs_id == SemIR::InstBlockId::Empty) {
return param_data;
return;
}

const auto& param_refs = import_ir_.inst_blocks().Get(param_refs_id);
param_data.reserve(param_refs.size());
for (auto inst_id : param_refs) {
auto type_const_id =
GetLocalConstantId(import_ir_.insts().Get(inst_id).type_id());
GetLocalConstantId(import_ir_.insts().Get(inst_id).type_id());

// If the parameter is a symbolic binding, build the BindSymbolicName
// constant.
Expand All @@ -667,27 +663,33 @@ class ImportRefResolver {
bind_id = addr->inner_id;
bind_inst = import_ir_.insts().Get(bind_id);
}
auto bind_const_id = bind_inst.Is<SemIR::BindSymbolicName>()
? GetLocalConstantId(bind_id)
: SemIR::ConstantId::Invalid;
param_data.push_back(
{.type_const_id = type_const_id, .bind_const_id = bind_const_id});
if (bind_inst.Is<SemIR::BindSymbolicName>()) {
GetLocalConstantId(bind_id);
}
}
return param_data;
}

// Given a param_refs_id and const_ids from GetLocalParamConstantIds, returns
// a version of param_refs_id localized to the current IR.
auto GetLocalParamRefsId(SemIR::InstBlockId param_refs_id,
const llvm::SmallVector<ParamData>& params_data)
// Returns a version of param_refs_id localized to the current IR.
//
// Must only be called after a call to GetLocalParamConstantIds(param_refs_id)
// has completed without adding any new work to work_stack_.
//
// TODO: This is inconsistent with the rest of this class, which expects
// the relevant constants to be explicitly passed in. That makes it
// easier to statically detect when an input isn't loaded, but makes it
// harder to support importing more complex inst structures. We should
// take a holistic look at how to balance those concerns. For example,
// could the same function be used to load the constants and use them, with
// a parameter to select between the two?
auto GetLocalParamRefsId(SemIR::InstBlockId param_refs_id)
-> SemIR::InstBlockId {
if (!param_refs_id.is_valid() ||
param_refs_id == SemIR::InstBlockId::Empty) {
return param_refs_id;
}
const auto& param_refs = import_ir_.inst_blocks().Get(param_refs_id);
llvm::SmallVector<SemIR::InstId> new_param_refs;
for (auto [ref_id, param_data] : llvm::zip(param_refs, params_data)) {
for (auto ref_id : param_refs) {
// Figure out the param structure. This echoes
// Function::GetParamFromParamRefId.
// TODO: Consider a different parameter handling to simplify import logic.
Expand All @@ -712,8 +714,8 @@ class ImportRefResolver {

// Rebuild the param instruction.
auto name_id = GetLocalNameId(param_inst.name_id);
auto type_id =
context_.GetTypeIdForTypeConstant(param_data.type_const_id);
auto type_id = context_.GetTypeIdForTypeConstant(
GetLocalConstantIdChecked(param_inst.type_id));

auto new_param_id = context_.AddInstInNoBlock<SemIR::Param>(
AddImportIRInst(param_id),
Expand All @@ -740,12 +742,12 @@ class ImportRefResolver {
auto new_bind_inst =
context_.insts().GetAs<SemIR::BindSymbolicName>(
context_.constant_values().GetInstId(
param_data.bind_const_id));
GetLocalConstantIdChecked(bind_id)));
new_bind_inst.value_id = new_param_id;
new_param_id = context_.AddInstInNoBlock(AddImportIRInst(bind_id),
new_bind_inst);
context_.constant_values().Set(new_param_id,
param_data.bind_const_id);
GetLocalConstantIdChecked(bind_id));
break;
}
default: {
Expand Down Expand Up @@ -1322,9 +1324,8 @@ class ImportRefResolver {

// Load constants for the definition.
auto parent_scope_id = GetLocalNameScopeId(import_class.parent_scope_id);
auto implicit_param_const_ids =
GetLocalParamConstantIds(import_class.implicit_param_refs_id);
auto param_const_ids = GetLocalParamConstantIds(import_class.param_refs_id);
LoadLocalParamConstantIds(import_class.implicit_param_refs_id);
LoadLocalParamConstantIds(import_class.param_refs_id);
auto generic_data = GetLocalGenericData(import_class.generic_id);
auto self_const_id = GetLocalConstantId(import_class.self_type_id);
auto complete_type_witness_id =
Expand All @@ -1341,10 +1342,9 @@ class ImportRefResolver {

auto& new_class = context_.classes().Get(class_id);
new_class.parent_scope_id = parent_scope_id;
new_class.implicit_param_refs_id = GetLocalParamRefsId(
import_class.implicit_param_refs_id, implicit_param_const_ids);
new_class.param_refs_id =
GetLocalParamRefsId(import_class.param_refs_id, param_const_ids);
new_class.implicit_param_refs_id =
GetLocalParamRefsId(import_class.implicit_param_refs_id);
new_class.param_refs_id = GetLocalParamRefsId(import_class.param_refs_id);
SetGenericData(import_class.generic_id, new_class.generic_id, generic_data);
new_class.self_type_id = context_.GetTypeIdForTypeConstant(self_const_id);

Expand Down Expand Up @@ -1495,10 +1495,8 @@ class ImportRefResolver {
import_ir_.insts().Get(import_function.return_storage_id).type_id());
}
auto parent_scope_id = GetLocalNameScopeId(import_function.parent_scope_id);
auto implicit_param_const_ids =
GetLocalParamConstantIds(import_function.implicit_param_refs_id);
auto param_const_ids =
GetLocalParamConstantIds(import_function.param_refs_id);
LoadLocalParamConstantIds(import_function.implicit_param_refs_id);
LoadLocalParamConstantIds(import_function.param_refs_id);
auto generic_data = GetLocalGenericData(import_function.generic_id);

if (HasNewWork()) {
Expand All @@ -1508,10 +1506,10 @@ class ImportRefResolver {
// Add the function declaration.
auto& new_function = context_.functions().Get(function_id);
new_function.parent_scope_id = parent_scope_id;
new_function.implicit_param_refs_id = GetLocalParamRefsId(
import_function.implicit_param_refs_id, implicit_param_const_ids);
new_function.implicit_param_refs_id =
GetLocalParamRefsId(import_function.implicit_param_refs_id);
new_function.param_refs_id =
GetLocalParamRefsId(import_function.param_refs_id, param_const_ids);
GetLocalParamRefsId(import_function.param_refs_id);
SetGenericData(import_function.generic_id, new_function.generic_id,
generic_data);

Expand Down Expand Up @@ -1654,8 +1652,7 @@ class ImportRefResolver {

// Load constants for the definition.
auto parent_scope_id = GetLocalNameScopeId(import_impl.parent_scope_id);
auto implicit_param_const_ids =
GetLocalParamConstantIds(import_impl.implicit_param_refs_id);
LoadLocalParamConstantIds(import_impl.implicit_param_refs_id);
auto generic_data = GetLocalGenericData(import_impl.generic_id);
auto self_const_id = GetLocalConstantId(import_impl.self_id);
auto constraint_const_id = GetLocalConstantId(import_impl.constraint_id);
Expand All @@ -1666,8 +1663,8 @@ class ImportRefResolver {

auto& new_impl = context_.impls().Get(impl_id);
new_impl.parent_scope_id = parent_scope_id;
new_impl.implicit_param_refs_id = GetLocalParamRefsId(
import_impl.implicit_param_refs_id, implicit_param_const_ids);
new_impl.implicit_param_refs_id =
GetLocalParamRefsId(import_impl.implicit_param_refs_id);
CARBON_CHECK(!import_impl.param_refs_id.is_valid() &&
!new_impl.param_refs_id.is_valid());
SetGenericData(import_impl.generic_id, new_impl.generic_id, generic_data);
Expand Down Expand Up @@ -1811,10 +1808,8 @@ class ImportRefResolver {

auto parent_scope_id =
GetLocalNameScopeId(import_interface.parent_scope_id);
auto implicit_param_const_ids =
GetLocalParamConstantIds(import_interface.implicit_param_refs_id);
auto param_const_ids =
GetLocalParamConstantIds(import_interface.param_refs_id);
LoadLocalParamConstantIds(import_interface.implicit_param_refs_id);
LoadLocalParamConstantIds(import_interface.param_refs_id);
auto generic_data = GetLocalGenericData(import_interface.generic_id);

std::optional<SemIR::InstId> self_param_id;
Expand All @@ -1828,10 +1823,10 @@ class ImportRefResolver {

auto& new_interface = context_.interfaces().Get(interface_id);
new_interface.parent_scope_id = parent_scope_id;
new_interface.implicit_param_refs_id = GetLocalParamRefsId(
import_interface.implicit_param_refs_id, implicit_param_const_ids);
new_interface.implicit_param_refs_id =
GetLocalParamRefsId(import_interface.implicit_param_refs_id);
new_interface.param_refs_id =
GetLocalParamRefsId(import_interface.param_refs_id, param_const_ids);
GetLocalParamRefsId(import_interface.param_refs_id);
SetGenericData(import_interface.generic_id, new_interface.generic_id,
generic_data);

Expand Down
Loading