Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TensorFlow Decoder, NodeContext, and additional API for integration with OVTF #37

Merged
merged 9 commits into from
Sep 27, 2021
41 changes: 41 additions & 0 deletions ngraph/frontend/tensorflow/include/tensorflow_frontend/decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <ngraph/variant.hpp>

namespace ngraph {
namespace frontend {

class DecoderBase {
rkazants marked this conversation as resolved.
Show resolved Hide resolved
public:
/// \brief Get attribute value by name and requested type
///
/// \param name Attribute name
/// \param type_info Attribute type information
/// \return Shared pointer to appropriate value if it exists, 'nullptr' otherwise
virtual std::shared_ptr<Variant> get_attribute(const std::string& name, const VariantTypeInfo& type_info) const = 0;

/// \brief Get a number of inputs
virtual size_t get_input_size() const = 0;

/// \brief Get a producer name and its output port index
///
/// \param input_port_idx Input port index by which data is consumed
/// \param producer_name A producer name
/// \return producer_output_port_index Output port index from which data is generated
virtual void get_input_node(const size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index) const = 0;

/// \brief Get operation type
virtual std::string get_op_type() const = 0;

/// \brief Get node name
virtual std::string get_op_name() const = 0;
};

} // namespace frontend
} // namespace ngraph
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <frontend_manager/frontend_exceptions.hpp>
#include <ngraph/node.hpp>

#include "node_context.hpp"

using namespace ngraph::frontend::tensorflow::detail;

namespace ngraph {
namespace frontend {
namespace tf {

class NodeContext;

class OpValidationFailureTF : public OpValidationFailure {
public:
OpValidationFailureTF(const CheckLocInfo& check_loc_info, const NodeContext& node, const std::string& explanation)
: OpValidationFailure(check_loc_info, get_error_msg_prefix_tf(node), explanation) {}

private:
static std::string get_error_msg_prefix_tf(const NodeContext& node) {
std::stringstream ss;
ss << "While validating node '" << node.get_op_type() << '\'';
return ss.str();
}
static std::string get_error_msg_prefix_tf(const NodeContext& node);
};
} // namespace tf
} // namespace frontend
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <tensorflow_frontend/decoder.hpp>

namespace ngraph {
namespace frontend {
/// Abstract representation for an input model graph that gives nodes in topologically sorted order
class GraphIterator {
public:
/// \brief Get a number of operation nodes in the graph
virtual size_t size() const = 0;

/// \brief Set iterator to the start position
virtual void reset() = 0;

/// \brief Move to the next node in the graph
virtual void next() = 0;

/// \brief Returns true if iterator goes out of the range of available nodes
virtual bool is_end() const = 0;

/// \brief Return a pointer to a decoder of the current node
virtual std::shared_ptr<::ngraph::frontend::DecoderBase> get_decoder() const = 0;
};

} // namespace frontend
} // namespace ngraph

This file was deleted.

25 changes: 5 additions & 20 deletions ngraph/frontend/tensorflow/include/tensorflow_frontend/place.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,14 @@
#pragma once

#include <frontend_manager/frontend.hpp>

namespace tensorflow {
class OpDef;
class TensorProto;
} // namespace tensorflow
#include <tensorflow_frontend/decoder.hpp>

namespace ngraph {
namespace frontend {

class TensorPlaceTF;
class OpPlaceTF;

namespace tensorflow {
namespace detail {
class TFNodeDecoder;
} // namespace detail
} // namespace tensorflow

class PlaceTF : public Place {
public:
PlaceTF(const InputModel& input_model, const std::vector<std::string>& names)
Expand Down Expand Up @@ -98,12 +89,7 @@ class OutPortPlaceTF : public PlaceTF {

class OpPlaceTF : public PlaceTF {
public:
OpPlaceTF(const InputModel& input_model,
std::shared_ptr<ngraph::frontend::tensorflow::detail::TFNodeDecoder> op_def,
const std::vector<std::string>& names);

OpPlaceTF(const InputModel& input_model,
std::shared_ptr<ngraph::frontend::tensorflow::detail::TFNodeDecoder> op_def);
OpPlaceTF(const InputModel& input_model, std::shared_ptr<DecoderBase> op_decoder);

void add_in_port(const std::shared_ptr<InPortPlaceTF>& input, const std::string& name);
void add_out_port(const std::shared_ptr<OutPortPlaceTF>& output, int idx);
Expand All @@ -112,7 +98,7 @@ class OpPlaceTF : public PlaceTF {
const std::vector<std::shared_ptr<OutPortPlaceTF>>& get_output_ports() const;
const std::map<std::string, std::vector<std::shared_ptr<InPortPlaceTF>>>& get_input_ports() const;
std::shared_ptr<InPortPlaceTF> get_input_port_tf(const std::string& inputName, int inputPortIndex) const;
std::shared_ptr<ngraph::frontend::tensorflow::detail::TFNodeDecoder> get_desc() const;
std::shared_ptr<DecoderBase> get_decoder() const;

// External API methods
std::vector<Place::Ptr> get_consuming_ports() const override;
Expand Down Expand Up @@ -148,7 +134,7 @@ class OpPlaceTF : public PlaceTF {
// Ptr get_target_tensor(int outputPortIndex) const override;

private:
std::shared_ptr<ngraph::frontend::tensorflow::detail::TFNodeDecoder> m_op_def;
std::shared_ptr<DecoderBase> m_op_decoder;
std::map<std::string, std::vector<std::shared_ptr<InPortPlaceTF>>> m_input_ports;
std::vector<std::shared_ptr<OutPortPlaceTF>> m_output_ports;
};
Expand Down Expand Up @@ -185,7 +171,6 @@ class TensorPlaceTF : public PlaceTF {
bool is_equal_data(Ptr another) const override;

private:
// const ::tensorflow::TensorProto& m_tensor;
rkazants marked this conversation as resolved.
Show resolved Hide resolved
PartialShape m_pshape;
element::Type m_type;

Expand Down
122 changes: 122 additions & 0 deletions ngraph/frontend/tensorflow/src/decoder_proto.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "decoder_proto.hpp"

#include "node_context.hpp"

namespace ngraph {
namespace frontend {
std::map<::tensorflow::DataType, ngraph::element::Type> TYPE_MAP{
{::tensorflow::DataType::DT_BOOL, ngraph::element::boolean},
{::tensorflow::DataType::DT_INT16, ngraph::element::i16},
{::tensorflow::DataType::DT_INT32, ngraph::element::i32},
{::tensorflow::DataType::DT_INT64, ngraph::element::i64},
{::tensorflow::DataType::DT_HALF, ngraph::element::f16},
{::tensorflow::DataType::DT_FLOAT, ngraph::element::f32},
{::tensorflow::DataType::DT_DOUBLE, ngraph::element::f64},
{::tensorflow::DataType::DT_UINT8, ngraph::element::u8},
{::tensorflow::DataType::DT_INT8, ngraph::element::i8},
{::tensorflow::DataType::DT_BFLOAT16, ngraph::element::bf16}};

std::shared_ptr<Variant> DecoderTFProto::get_attribute(const std::string& name,
const VariantTypeInfo& type_info) const {
auto attrs = decode_attribute_helper(name);
if (attrs.empty()) {
return nullptr;
}

if (type_info == VariantWrapper<std::string>::type_info) {
return std::make_shared<VariantWrapper<std::string>>(attrs[0].s());
} else if (type_info == VariantWrapper<int64_t>::type_info) {
return std::make_shared<VariantWrapper<int64_t>>(attrs[0].i());
} else if (type_info == VariantWrapper<std::vector<int64_t>>::type_info) {
std::vector<int64_t> longs;
longs.reserve(attrs[0].list().i_size());
for (size_t idx = 0; idx < attrs[0].list().i_size(); ++idx) {
longs.push_back(attrs[0].list().i(idx));
}
return std::make_shared<VariantWrapper<std::vector<int64_t>>>(longs);
} else if (type_info == VariantWrapper<int32_t>::type_info) {
return std::make_shared<VariantWrapper<int32_t>>(static_cast<int32_t>(attrs[0].i()));
} else if (type_info == VariantWrapper<std::vector<int32_t>>::type_info) {
std::vector<int32_t> ints;
ints.reserve(attrs[0].list().i_size());
for (size_t idx = 0; idx < attrs[0].list().i_size(); ++idx) {
ints.push_back(static_cast<int32_t>(attrs[0].list().i(idx)));
}
return std::make_shared<VariantWrapper<std::vector<int32_t>>>(ints);
} else if (type_info == VariantWrapper<float>::type_info) {
return std::make_shared<VariantWrapper<float>>(attrs[0].f());
} else if (type_info == VariantWrapper<std::vector<float>>::type_info) {
std::vector<float> floats;
floats.reserve(attrs[0].list().i_size());
for (size_t idx = 0; idx < attrs[0].list().i_size(); ++idx) {
floats.push_back(attrs[0].list().f(idx));
}
return std::make_shared<VariantWrapper<std::vector<float>>>(floats);
} else if (type_info == VariantWrapper<ngraph::element::Type>::type_info) {
auto data_type = attrs[0].type();
return std::make_shared<VariantWrapper<ngraph::element::Type>>(TYPE_MAP[data_type]);
} else if (type_info == VariantWrapper<bool>::type_info) {
return std::make_shared<VariantWrapper<bool>>(attrs[0].b());
} else if (type_info == VariantWrapper<::tensorflow::DataType>::type_info) {
return std::make_shared<VariantWrapper<::tensorflow::DataType>>(attrs[0].type());
} else if (type_info == VariantWrapper<::tensorflow::TensorProto>::type_info) {
return std::make_shared<VariantWrapper<::tensorflow::TensorProto>>(attrs[0].tensor());
} else if (type_info == VariantWrapper<::ngraph::PartialShape>::type_info) {
std::vector<ngraph::Dimension> dims;
auto tf_shape = attrs[0].shape();
for (int i = 0; i < tf_shape.dim_size(); i++) {
dims.push_back(tf_shape.dim(i).size());
}
auto pshape = ngraph::PartialShape(dims);
return std::make_shared<VariantWrapper<::ngraph::PartialShape>>(pshape);
}

// type is not supported by decoder
rkazants marked this conversation as resolved.
Show resolved Hide resolved
return nullptr;
}

size_t DecoderTFProto::get_input_size() const {
return m_node_def->input_size();
}

void DecoderTFProto::get_input_node(const size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index) const {
// TODO: handle body graph nodes with a couple of columns
std::string producer_port_name = m_node_def->input(input_port_idx);
auto delim_pos = producer_port_name.find(':');
rkazants marked this conversation as resolved.
Show resolved Hide resolved
if (delim_pos != std::string::npos) {
producer_name = producer_port_name.substr(0, delim_pos);
producer_output_port_index = std::stoi(producer_port_name.substr(delim_pos));
return;
}
producer_name = producer_port_name;
producer_output_port_index = 0;
}

std::string DecoderTFProto::get_op_type() const {
return m_node_def->op();
}

std::string DecoderTFProto::get_op_name() const {
return m_node_def->name();
}

std::vector<::tensorflow::AttrValue> DecoderTFProto::decode_attribute_helper(const std::string& name) const {
auto attr_map = m_node_def->attr();
FRONT_END_GENERAL_CHECK(attr_map.contains(name),
"An error occurred while parsing the ",
name,
" attribute of ",
this->get_op_type(),
"node");
auto value = m_node_def->attr().at(name);
return {value};
}

} // namespace frontend
} // namespace ngraph
Loading