Skip to content

Commit

Permalink
FP8 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gkrivor committed Jun 27, 2024
1 parent 5a50c44 commit b75b243
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
12 changes: 6 additions & 6 deletions src/frontends/onnx/frontend/src/core/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,8 @@ class Tensor {
bfloat16 = TensorProto_DataType::TensorProto_DataType_BFLOAT16,
complex64 = TensorProto_DataType::TensorProto_DataType_COMPLEX64,
complex128 = TensorProto_DataType::TensorProto_DataType_COMPLEX128,
float8e4m3fn = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN,
float8e4m3fnuz = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ,
float8e5m2 = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2,
float8e5m2fnuz = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ,
float8e4m3fn = TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN,
float8e5m2 = TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2,
};

Tensor() = delete;
Expand Down Expand Up @@ -170,7 +168,8 @@ class Tensor {
default:
ONNX_UNSUPPORTED_DATA_TYPE(
m_tensor_proto->data_type(),
"BOOL, BFLOAT16, FLOAT, FLOAT16, DOUBLE, INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64");
"BOOL, BFLOAT16, FLOAT8E4M3FN, FLOAT8E5M2, FLOAT, FLOAT16, DOUBLE, INT8, INT16, INT32, INT64, "
"UINT8, UINT16, UINT32, UINT64");
}
}

Expand Down Expand Up @@ -216,7 +215,8 @@ class Tensor {
default:
ONNX_UNSUPPORTED_DATA_TYPE(
m_tensor_proto->data_type(),
"BOOL, BFLOAT16, FLOAT, FLOAT16, DOUBLE, INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64");
"BOOL, BFLOAT16, FLOAT8E4M3FN, FLOAT8E5M2, FLOAT, FLOAT16, DOUBLE, INT8, INT16, INT32, INT64, "
"UINT8, UINT16, UINT32, UINT64");
}
}

Expand Down
13 changes: 6 additions & 7 deletions src/frontends/onnx/frontend/src/utils/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <onnx/onnx_pb.h> // onnx types

#include "core/tensor.hpp"
#include "onnx_framework_node.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/op/add.hpp"
Expand Down Expand Up @@ -60,16 +61,14 @@ const ov::element::Type& get_ov_element_type(int64_t onnx_type) {
return ov::element::dynamic;
case TensorProto_DataType::TensorProto_DataType_BFLOAT16:
return ov::element::bf16;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
case TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN:
return ov::element::f8e4m3;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ:
return ov::element::f8e4m3;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
return ov::element::f8e5m2;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ:
case TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2:
return ov::element::f8e5m2;
}
OPENVINO_THROW("unsupported element type");
ONNX_UNSUPPORTED_DATA_TYPE(onnx_type,
"BOOL, BFLOAT16, FLOAT8E4M3FN, FLOAT8E5M2, FLOAT, FLOAT16, DOUBLE, INT8, INT16, "
"INT32, INT64, UINT8, UINT16, UINT32, UINT64, STRING, UNDEFINED");
}

std::shared_ptr<ov::Node> get_monotonic_range_along_node_rank(const ov::Output<ov::Node>& value,
Expand Down
6 changes: 6 additions & 0 deletions src/frontends/onnx/onnx_common/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ size_t get_onnx_data_size(int32_t onnx_type) {
return 2;
case TensorProto_DataType_FLOAT:
return sizeof(float);
case TensorProto_DataType_FLOAT8E4M3FN:
return sizeof(int8_t);
case TensorProto_DataType_FLOAT8E5M2:
return sizeof(int8_t);
case TensorProto_DataType_INT8:
return sizeof(int8_t);
case TensorProto_DataType_INT16:
Expand All @@ -49,6 +53,8 @@ size_t get_onnx_data_size(int32_t onnx_type) {
}
const std::map<ov::element::Type_t, TensorProto_DataType> OV_2_ONNX_TYPES = {
{ov::element::Type_t::bf16, TensorProto_DataType::TensorProto_DataType_BFLOAT16},
{ov::element::Type_t::f8e4m3, TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN},
{ov::element::Type_t::f8e5m2, TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2},
{ov::element::Type_t::f16, TensorProto_DataType::TensorProto_DataType_FLOAT16},
{ov::element::Type_t::f32, TensorProto_DataType::TensorProto_DataType_FLOAT},
{ov::element::Type_t::f64, TensorProto_DataType::TensorProto_DataType_DOUBLE},
Expand Down

0 comments on commit b75b243

Please sign in to comment.