Skip to content

Commit

Permalink
[vulkan] Fix SPIR-V IR references causing leaks (halide#7739)
Browse files Browse the repository at this point in the history
* Remove unnecessary parent refs and owning function/block refs.
Add explicit clear methods for contents structs and destructors.

* Move objects when changing ownership

---------

Co-authored-by: Derek Gerstmann <dgerstmann@adobe.com>
  • Loading branch information
2 people authored and ardier committed Mar 3, 2024
1 parent 1f12fdf commit ce2cef5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 78 deletions.
112 changes: 56 additions & 56 deletions src/SpirvIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,8 @@ SpvInstruction SpvInstruction::make(SpvOp op_code) {
return instance;
}

void SpvInstruction::set_block(SpvBlock block) {
check_defined();
contents->block = std::move(block);
SpvInstruction::~SpvInstruction() {
clear();
}

void SpvInstruction::set_result_id(SpvId result_id) {
Expand Down Expand Up @@ -270,6 +269,10 @@ bool SpvInstruction::is_defined() const {
return contents.defined();
}

void SpvInstruction::clear() {
contents = nullptr;
}

bool SpvInstruction::is_immediate(uint32_t index) const {
check_defined();
return (contents->value_types[index] != SpvOperandId);
Expand All @@ -280,11 +283,6 @@ uint32_t SpvInstruction::length() const {
return (uint32_t)contents->operands.size();
}

SpvBlock SpvInstruction::block() const {
check_defined();
return contents->block;
}

void SpvInstruction::add_data(uint32_t bytes, const void *data, SpvValueType value_type) {
check_defined();

Expand Down Expand Up @@ -346,34 +344,25 @@ void SpvInstruction::encode(SpvBinary &binary) const {

// --

SpvBlock SpvBlock::make(SpvFunction func, SpvId block_id) {
SpvBlock SpvBlock::make(SpvId block_id) {
SpvBlock instance;
instance.contents = SpvBlockContentsPtr(new SpvBlockContents());
instance.contents->parent = std::move(func);
instance.contents->block_id = block_id;
return instance;
}

void SpvBlock::add_instruction(SpvInstruction inst) {
check_defined();
inst.set_block(*this);
contents->instructions.push_back(inst);
}

void SpvBlock::add_variable(SpvInstruction var) {
check_defined();
var.set_block(*this);
contents->variables.push_back(var);
SpvBlock::~SpvBlock() {
clear();
}

void SpvBlock::set_function(SpvFunction func) {
void SpvBlock::add_instruction(SpvInstruction inst) {
check_defined();
contents->parent = std::move(func);
contents->instructions.emplace_back(std::move(inst));
}

SpvFunction SpvBlock::function() const {
void SpvBlock::add_variable(SpvInstruction var) {
check_defined();
return contents->parent;
contents->variables.emplace_back(std::move(var));
}

const SpvBlock::Instructions &SpvBlock::instructions() const {
Expand Down Expand Up @@ -423,6 +412,10 @@ void SpvBlock::check_defined() const {
user_assert(is_defined()) << "An SpvBlock must be defined before accessing its properties\n";
}

void SpvBlock::clear() {
contents = nullptr;
}

void SpvBlock::encode(SpvBinary &binary) const {
check_defined();

Expand Down Expand Up @@ -453,10 +446,18 @@ SpvFunction SpvFunction::make(SpvId func_type_id, SpvId func_id, SpvId return_ty
return instance;
}

SpvFunction::~SpvFunction() {
clear();
}

bool SpvFunction::is_defined() const {
return contents.defined();
}

void SpvFunction::clear() {
contents = nullptr;
}

SpvBlock SpvFunction::create_block(SpvId block_id) {
check_defined();
if (!contents->blocks.empty()) {
Expand All @@ -465,25 +466,25 @@ SpvBlock SpvFunction::create_block(SpvId block_id) {
last_block.add_instruction(SpvFactory::branch(block_id));
}
}
SpvBlock block = SpvBlock::make(*this, block_id);
SpvBlock block = SpvBlock::make(block_id);
contents->blocks.push_back(block);
return block;
}

void SpvFunction::add_block(const SpvBlock &block) {
void SpvFunction::add_block(SpvBlock block) {
check_defined();
if (!contents->blocks.empty()) {
SpvBlock last_block = tail_block();
if (!last_block.is_terminated()) {
last_block.add_instruction(SpvFactory::branch(block.id()));
}
}
contents->blocks.push_back(block);
contents->blocks.emplace_back(std::move(block));
}

void SpvFunction::add_parameter(const SpvInstruction &param) {
void SpvFunction::add_parameter(SpvInstruction param) {
check_defined();
contents->parameters.push_back(param);
contents->parameters.emplace_back(std::move(param));
}

uint32_t SpvFunction::parameter_count() const {
Expand Down Expand Up @@ -535,11 +536,6 @@ SpvPrecision SpvFunction::parameter_precision(uint32_t index) const {
}
}

void SpvFunction::set_module(SpvModule module) {
check_defined();
contents->parent = std::move(module);
}

SpvInstruction SpvFunction::declaration() const {
check_defined();
return contents->declaration;
Expand All @@ -555,11 +551,6 @@ const SpvFunction::Parameters &SpvFunction::parameters() const {
return contents->parameters;
}

SpvModule SpvFunction::module() const {
check_defined();
return contents->parent;
}

SpvId SpvFunction::return_type_id() const {
check_defined();
return contents->return_type_id;
Expand Down Expand Up @@ -608,54 +599,63 @@ SpvModule SpvModule::make(SpvId module_id,
return instance;
}

SpvModule::~SpvModule() {
clear();
}

bool SpvModule::is_defined() const {
return contents.defined();
}

void SpvModule::clear() {
contents = nullptr;
}

void SpvModule::add_debug_string(SpvId result_id, const std::string &string) {
check_defined();
contents->debug_source.push_back(SpvFactory::debug_string(result_id, string));
SpvInstruction inst = SpvFactory::debug_string(result_id, string);
contents->debug_source.emplace_back(std::move(inst));
}

void SpvModule::add_debug_symbol(SpvId id, const std::string &symbol) {
check_defined();
contents->debug_symbols.push_back(SpvFactory::debug_symbol(id, symbol));
SpvInstruction inst = SpvFactory::debug_symbol(id, symbol);
contents->debug_symbols.emplace_back(std::move(inst));
}

void SpvModule::add_annotation(const SpvInstruction &val) {
void SpvModule::add_annotation(SpvInstruction val) {
check_defined();
contents->annotations.push_back(val);
contents->annotations.emplace_back(std::move(val));
}

void SpvModule::add_type(const SpvInstruction &val) {
void SpvModule::add_type(SpvInstruction val) {
check_defined();
contents->types.push_back(val);
contents->types.emplace_back(std::move(val));
}

void SpvModule::add_constant(const SpvInstruction &val) {
void SpvModule::add_constant(SpvInstruction val) {
check_defined();
contents->constants.push_back(val);
contents->constants.emplace_back(std::move(val));
}

void SpvModule::add_global(const SpvInstruction &val) {
void SpvModule::add_global(SpvInstruction val) {
check_defined();
contents->globals.push_back(val);
contents->globals.emplace_back(std::move(val));
}

void SpvModule::add_execution_mode(const SpvInstruction &val) {
void SpvModule::add_execution_mode(SpvInstruction val) {
check_defined();
contents->execution_modes.push_back(val);
contents->execution_modes.emplace_back(std::move(val));
}

void SpvModule::add_instruction(const SpvInstruction &val) {
void SpvModule::add_instruction(SpvInstruction val) {
check_defined();
contents->instructions.push_back(val);
contents->instructions.emplace_back(std::move(val));
}

void SpvModule::add_function(SpvFunction val) {
check_defined();
val.set_module(*this);
contents->functions.emplace_back(val);
contents->functions.emplace_back(std::move(val));
}

void SpvModule::add_entry_point(const std::string &name, SpvInstruction inst) {
Expand Down Expand Up @@ -1297,7 +1297,7 @@ SpvId SpvBuilder::add_function(const std::string &name, SpvId return_type_id, co
func.add_parameter(param_inst);
}
SpvId block_id = make_id(SpvBlockId);
SpvBlock entry_block = SpvBlock::make(func, block_id);
SpvBlock entry_block = SpvBlock::make(block_id);
func.add_block(entry_block);
module.add_function(func);
function_map[func_id] = func;
Expand Down
Loading

0 comments on commit ce2cef5

Please sign in to comment.