Skip to content

Commit

Permalink
[GNA] Use OV thread_local implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
tadamowicz committed Dec 1, 2023
1 parent 9ecebdd commit 8f2d9c1
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 7 deletions.
10 changes: 8 additions & 2 deletions src/plugins/intel_gna/src/backend/gna_limitations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ constexpr uint32_t Limitations::kBytesPerCropElement;
constexpr uint32_t Limitations::kBytesPerConcatElement;
constexpr uint32_t Limitations::kMemoryPageSize;

thread_local std::shared_ptr<Limitations> Limitations::k_instance{nullptr};
InferenceEngine::ThreadLocal<std::shared_ptr<Limitations>> Limitations::kInstance{nullptr};

Limitations::Limitations(const DeviceVersion& target) {
m_use_only_16bit_conv_weights =
Expand All @@ -689,7 +689,13 @@ Limitations::Limitations(const DeviceVersion& target) {
}

void Limitations::init(const DeviceVersion& compile_target) {
k_instance = std::shared_ptr<Limitations>(new Limitations(compile_target));
auto& localInstance = kInstance.local();
localInstance.reset(new Limitations(compile_target));
}

void Limitations::deinit() {
auto& localInstance = kInstance.local();
localInstance.reset();
}

size_t Limitations::get_min_batch_to_fit_in_buffer(InferenceEngine::DataPtr input) {
Expand Down
18 changes: 13 additions & 5 deletions src/plugins/intel_gna/src/backend/gna_limitations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "ngraph/opsets/opset9.hpp"
#include "ops/gna_convolution.hpp"
#include "ops/gna_max_pool.hpp"
#include "threading/ie_thread_local.hpp"

namespace ov {
namespace intel_gna {
Expand Down Expand Up @@ -164,12 +165,17 @@ class AbstractValidator {
class Limitations {
public:
/**
* @brief Create instance of the Limitations class. Due to Limitations being a singleton, multiple instances of the
* plugin with different compilation targets cannot exist at the same time
* @brief Create an instance of the Limitations class. Since Limitations is designed as a singleton, multiple
* instances of the plugin with different compilation targets cannot coexist simultaneously for the same thread.
* @param compile_target GNA compile target
*/
static void init(const target::DeviceVersion& compile_target);

/**
* @brief Delete the instance of the Limitations class for the currently running thread.
*/
static void deinit();

/**
* @brief Returns the instance of Limitations object. Requires an Init call before the first usage
*/
Expand Down Expand Up @@ -309,14 +315,16 @@ class Limitations {
bool m_use_only_16bit_conv_weights = false;
size_t m_mem_alignment = 0;
std::shared_ptr<cnn2d::AbstractValidator> m_cnn_validator;
static thread_local std::shared_ptr<Limitations> k_instance;

static InferenceEngine::ThreadLocal<std::shared_ptr<Limitations>> kInstance;
};

inline std::shared_ptr<Limitations> Limitations::get_instance() {
if (!k_instance) {
auto& instance = kInstance.local();
if (!instance) {
THROW_GNA_EXCEPTION << "Limitations instance is not initialized.\n";
}
return k_instance;
return instance;
}

inline bool Limitations::is_crop_affined_offset(size_t numberOfElements) const {
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gna/src/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1429,4 +1429,6 @@ InferenceEngine::QueryNetworkResult GNAPlugin::QueryNetwork(
GNAPlugin::~GNAPlugin() {
if (gnadevice)
gnadevice->close();

Limitations::deinit();
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class I8QuantisationTest : public GNATest<> {
void SetUp() override {
Limitations::init(target::DeviceVersion::Default);
}

void TearDown() override {
Limitations::deinit();
}
};

// TODO: add test for FC weights after quantization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class I16QuantisationTest : public GNATest<> {
void SetUp() override {
Limitations::init(target::DeviceVersion::Default);
}

void TearDown() override {
Limitations::deinit();
}
};

template <class T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ class GNAcnn2dValidatorTest : public ::testing::TestWithParam<GNACnn2DValidatorT
ASSERT_TRUE(validator);
}

void TearDown() override {
Limitations::deinit();
}

std::shared_ptr<cnn2d::AbstractValidator> validator;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ void RunVariadicSplitSupportedTest(DeviceVersion device_version, std::vector<Var
split_lengths));
ASSERT_TRUE(Limitations::is_split_supported(split, false) == result);
}
Limitations::deinit();
}

TEST(CheckSplitSupported, CheckVariadicSplitSupported_GNA3_5) {
Expand Down Expand Up @@ -108,6 +109,7 @@ void RunSplitSupportedTest(DeviceVersion device_version, std::vector<SplitParame
num_splits);
ASSERT_TRUE(Limitations::is_split_supported(split, false) == result);
}
Limitations::deinit();
}

TEST(CheckSplitSupported, CheckSplitSupported_GNA3_5) {
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gna/tests/unit/gna_memory_alignment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,13 @@ class MemoryAlignmentTest : public ::testing::Test {};
TEST(MemoryAlignmentTest, getMemoryAlignmentBytes_Expect64ByteAlignmentWhenTargetIsGNA3_5) {
Limitations::init(DeviceVersion::GNA3_5);
EXPECT_EQ(Limitations::get_instance()->get_memory_alignment(), 64);
Limitations::deinit();
}

TEST(MemoryAlignmentTest, getMemoryAlignmentBytes_Expect16ByteAlignmentWhenTargetIsGNA3_6) {
Limitations::init(DeviceVersion::GNA3_6);
EXPECT_EQ(Limitations::get_instance()->get_memory_alignment(), 16);
Limitations::deinit();
}

} // namespace testing
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ class Decompose2DConvTestInvalidFixture : public ov::test::TestsCommon,
public ::testing::WithParamInterface<fqDecompose2DConvParams> {
public:
void SetUp() override;
void TearDown() override;

public:
std::shared_ptr<ngraph::Function> function, reference_function;
Expand Down Expand Up @@ -339,12 +340,17 @@ void Decompose2DConvTestInvalidFixture::SetUp() {
conv_params);
}

void Decompose2DConvTestInvalidFixture::TearDown() {
Limitations::deinit();
}

// ---------------------------------------------------------------------------------------------------------------------

class Decompose2DConvTestFixture : public ov::test::TestsCommon,
public ::testing::WithParamInterface<fqDecompose2DConvParams> {
public:
void SetUp() override;
void TearDown() override;

std::shared_ptr<ngraph::Function> get_reference(const bool& fq,
const modelType& model,
Expand Down Expand Up @@ -385,6 +391,10 @@ void Decompose2DConvTestFixture::SetUp() {
reference_function = get_reference(fq, model, input_shape, graph_data, conv_params);
}

void Decompose2DConvTestFixture::TearDown() {
Limitations::deinit();
}

std::shared_ptr<ngraph::Node> ReshapeBiasConst(std::shared_ptr<ngraph::opset7::Add> conv_bias,
const ConvParams& conv_params) {
auto add_const =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class InsertCopyLayerTest : public ov::test::TestsCommon, public ::testing::With
return result.str();
}
void SetUp() override;
void TearDown() override;
virtual void Validate();
virtual void Run();

Expand All @@ -64,6 +65,9 @@ void InsertCopyLayerTest::SetUp() {
std::tie(m_device_ver, m_axis, m_inputs_num) = this->GetParam();
Limitations::init(m_device_ver);
}
void InsertCopyLayerTest::TearDown() {
Limitations::deinit();
}

void InsertCopyLayerTest::Run() {
Validate();
Expand Down Expand Up @@ -212,6 +216,7 @@ class TransformationTestsBase : public ov::test::TestsCommon,

void TearDown() override {
m_func.reset();
Limitations::deinit();
}

void RunPasses(ngraph::pass::Manager& m) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class SplitConvolutionFixture : public ov::test::TestsCommon,
public ::testing::WithParamInterface<std::tuple<DeviceVersion, TestParams>> {
public:
void SetUp() override;
void TearDown() override;

public:
std::shared_ptr<ngraph::Function> function, reference_function;
Expand All @@ -290,6 +291,10 @@ void SplitConvolutionFixture::SetUp() {
reference_function = reference_graph.createFunction();
}

void SplitConvolutionFixture::TearDown() {
Limitations::deinit();
}

void execute_test(std::shared_ptr<ngraph::Function> function,
std::shared_ptr<ngraph::Function> reference_function,
ngraph::pass::Manager& pass_manager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class SplitEltwiseTestSuiteFixture : public ov::test::TestsCommon,
public ::testing::WithParamInterface<EltwiseSplitParams> {
public:
void SetUp() override;
void TearDown() override;

public:
std::shared_ptr<ngraph::Function> function, reference_function;
Expand All @@ -151,6 +152,10 @@ void SplitEltwiseTestSuiteFixture::SetUp() {
reference_function = createFunction(shape, with_const, with_fq, type, true);
}

void SplitEltwiseTestSuiteFixture::TearDown() {
Limitations::deinit();
}

void execute_test(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
ngraph::pass::Manager manager;
manager.register_pass<ov::pass::InitNodeInfo>();
Expand Down

0 comments on commit 8f2d9c1

Please sign in to comment.