Skip to content

Commit

Permalink
[STFT][Op][Ref] Extend STFT op and ref to support 1D signal input (op…
Browse files Browse the repository at this point in the history
…envinotoolkit#27274)

### Details:
- Extend STFT op and ref to support 1D signal input (including CPU, as
currently it uses ref impl)
- Less code is needed to support such case natively than subgraph for
proper Squeeze/Unsqueeze logic
- For 1D signal enablement minor changes in shape_infer and ref impl
have been added
 (all of the other updated files are tests):
   * src/core/reference/src/op/stft.cpp 
   * src/core/shape_inference/include/stft_shape_inference.hpp
  - Xfail for 1D stft case for PT FE has been removed, the test passed
  
  ---------------------------------
  PR for spec update will be provided separately (in progress).
 
### Tickets:
 - 155996
  • Loading branch information
mitruska authored Oct 28, 2024
1 parent 33493b4 commit 664ffc6
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 27 deletions.
5 changes: 3 additions & 2 deletions src/core/reference/src/op/stft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ void stft(const float* signal,
const int64_t frame_size,
const int64_t frame_step,
const bool transpose_frames) {
constexpr size_t signal_axis = 1;
const auto batch_size = signal_shape[0];
const auto is_signal_1D = signal_shape.size() == 1;
const size_t batch_size = is_signal_1D ? 1 : signal_shape[0];
const size_t signal_axis = is_signal_1D ? 0 : 1;
const auto signal_length = signal_shape[signal_axis];
const auto num_frames = static_cast<size_t>((signal_length - frame_size) / frame_step) + 1;
const auto frame_size_dim = static_cast<size_t>(frame_size);
Expand Down
39 changes: 25 additions & 14 deletions src/core/shape_inference/include/stft_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ std::vector<TRShape> shape_infer(const STFT* op,
const auto& frame_size_shape = input_shapes[2];
const auto& frame_step_shape = input_shapes[3];

const auto signal_shape_rank = signal_shape.rank();
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
signal_shape.rank().compatible(2),
"The shape of signal must be 2D [batch, signal_size].");
signal_shape_rank.compatible(1) || signal_shape_rank.compatible(2),
"The shape of signal must be 1D [signal_size] or 2D [batch, signal_size].");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
window_shape.rank().compatible(1),
Expand All @@ -42,29 +43,36 @@ std::vector<TRShape> shape_infer(const STFT* op,
frame_step_shape.rank().compatible(0),
"The shape of frame_step must be a scalar.");

if (signal_shape_rank.is_dynamic()) {
return {signal_shape};
}

const auto frame_size = get_input_const_data_as<TRShape, int64_t>(op, 2, ta);
const auto frame_step = get_input_const_data_as<TRShape, int64_t>(op, 3, ta);

if (signal_shape.rank().is_dynamic()) {
return {signal_shape};
} else if (!frame_size || !frame_step) {
return {TRShape{signal_shape[0], TDim(ov::util::dim::inf_bound), TDim(ov::util::dim::inf_bound), 2}};
const auto is_signal_1D = signal_shape.size() == 1;
if (!frame_size || !frame_step) {
if (is_signal_1D) {
return {TRShape{TDim(ov::util::dim::inf_bound), TDim(ov::util::dim::inf_bound), 2}};
} else {
return {TRShape{signal_shape[0], TDim(ov::util::dim::inf_bound), TDim(ov::util::dim::inf_bound), 2}};
}
}

const auto& frame_size_val = (*frame_size)[0];
const auto& frame_step_val = (*frame_step)[0];

const TDim& signal_dim = is_signal_1D ? signal_shape[0] : signal_shape[1];
const bool is_frame_size_in_range =
0 < frame_size_val &&
(signal_shape[1].is_static() ? static_cast<TDimVal>(frame_size_val) <= signal_shape[1].get_length()
: frame_size_val <= signal_shape[1].get_interval().get_max_val());
0 < frame_size_val && (signal_dim.is_static() ? static_cast<TDimVal>(frame_size_val) <= signal_dim.get_length()
: frame_size_val <= signal_dim.get_interval().get_max_val());
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
is_frame_size_in_range,
"Provided frame size is ",
frame_size_val,
" but must be in range [1, ",
signal_shape[1],
signal_dim,
"].");

NODE_SHAPE_INFER_CHECK(op,
Expand All @@ -84,9 +92,8 @@ std::vector<TRShape> shape_infer(const STFT* op,
frame_size_val,
"].");

const auto& batch_dim = signal_shape[0];
const TDim frame_size_dim = static_cast<TDim>(frame_size_val);
const TDim signal_frame_size_diff = signal_shape[1] - frame_size_dim;
const TDim signal_frame_size_diff = signal_dim - frame_size_dim;
TDim fft_samples_dim = (frame_size_val / 2) + 1;

// Divsion opeartor for static Dimension of PartialShape can return non static dimension and ceil instead of floor
Expand All @@ -95,9 +102,13 @@ std::vector<TRShape> shape_infer(const STFT* op,

std::vector<TRShape> output_shapes;
if (op->get_transpose_frames()) {
output_shapes.emplace_back(TRShape{batch_dim, std::move(fft_samples_dim), std::move(frames_dim), 2});
output_shapes.emplace_back(TRShape{std::move(fft_samples_dim), std::move(frames_dim), 2});
} else {
output_shapes.emplace_back(TRShape{batch_dim, std::move(frames_dim), std::move(fft_samples_dim), 2});
output_shapes.emplace_back(TRShape{std::move(frames_dim), std::move(fft_samples_dim), 2});
}
if (!is_signal_1D) {
const auto& batch_dim = signal_shape[0];
output_shapes[0].insert(output_shapes[0].begin(), batch_dim);
}
return output_shapes;
}
Expand Down
13 changes: 10 additions & 3 deletions src/core/tests/type_prop/stft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ INSTANTIATE_TEST_SUITE_P(
type_prop_stft_shape,
TypePropSTFTTestP,
testing::Values(
std::make_tuple(PartialShape{16}, PartialShape{16}, 16, 16, true, PartialShape{9, 1, 2}),
std::make_tuple(PartialShape{48}, PartialShape{16}, 16, 16, false, PartialShape{3, 9, 2}),
std::make_tuple(PartialShape{56}, PartialShape{7}, 11, 3, false, PartialShape{16, 6, 2}),
std::make_tuple(PartialShape{56}, PartialShape{7}, 11, 3, true, PartialShape{6, 16, 2}),
std::make_tuple(PartialShape{48}, PartialShape{8}, 16, 4, true, PartialShape{9, 9, 2}),
std::make_tuple(PartialShape{{48, 56}}, PartialShape{7}, 11, 3, true, PartialShape{6, {13, 16}, 2}),
std::make_tuple(PartialShape{-1}, PartialShape{7}, 11, 3, true, PartialShape{6, {1, -1}, 2}),
std::make_tuple(PartialShape{1, 16}, PartialShape{16}, 16, 16, true, PartialShape{1, 9, 1, 2}),
std::make_tuple(PartialShape{1, 48}, PartialShape{16}, 16, 16, true, PartialShape{1, 9, 3, 2}),
std::make_tuple(PartialShape{1, 48}, PartialShape{16}, 16, 16, false, PartialShape{1, 3, 9, 2}),
Expand Down Expand Up @@ -139,16 +146,16 @@ TEST_F(TypePropSTFTTest, signal_incompatible_shape) {
const auto frame_size = std::make_shared<Parameter>(element::i64, PartialShape{});
const auto frame_step = std::make_shared<Parameter>(element::i64, PartialShape{});
{
const auto signal = std::make_shared<Parameter>(element::f32, PartialShape{48});
const auto signal = std::make_shared<Parameter>(element::f32, PartialShape{});
OV_EXPECT_THROW(std::ignore = make_op(signal, window, frame_size, frame_step, transform_frames),
NodeValidationFailure,
HasSubstr("The shape of signal must be 2D [batch, signal_size]"));
HasSubstr("The shape of signal must be 1D [signal_size] or 2D [batch, signal_size]"));
}
{
const auto signal = std::make_shared<Parameter>(element::f32, PartialShape{-1, 4, 48});
OV_EXPECT_THROW(std::ignore = make_op(signal, window, frame_size, frame_step, transform_frames),
NodeValidationFailure,
HasSubstr("The shape of signal must be 2D [batch, signal_size]"));
HasSubstr("The shape of signal must be 1D [signal_size] or 2D [batch, signal_size]"));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ const std::vector<ov::element::Type> data_type = {ov::element::f32, ov::element:
const std::vector<ov::element::Type> step_size_type = {ov::element::i32, ov::element::i64};

const std::vector<std::vector<InputShape>> input_shapes = {
{ // Static shapes
{{}, {{128}}}, // 1st input
{{}, {{8}}}, // 2nd input
{{}, {{}}}, // 3rd input
{{}, {{}}} // 4th input
},
{ // Static shapes
{{}, {{1, 128}}}, // 1st input
{{}, {{8}}}, // 2nd input
Expand All @@ -34,6 +40,12 @@ const std::vector<std::vector<InputShape>> input_shapes = {
{{}, {{}}}, // 3rd input
{{}, {{}}} // 4th input
},
{ // Dynamic dims in the first and second input shape
{{-1}, {{128}}}, // 1st input
{{-1}, {{8}}}, // 2nd input
{{}, {{}}}, // 3rd input
{{}, {{}}} // 4th input
},
{ // Dynamic dims in the first and second input shape
{{-1, -1}, {{1, 128}, {2, 226}}}, // 1st input
{{-1}, {{8}, {16}}}, // 2nd input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,26 @@ using testing::HasSubstr;

class STFTShapeInferenceTest : public OpStaticShapeInferenceTest<op::v15::STFT> {};

TEST_F(STFTShapeInferenceTest, all_input_as_params_1D_signal) {
const auto data_type = element::f32;
const auto step_size_type = element::i32;
const auto in_signal = std::make_shared<Parameter>(data_type, ov::PartialShape{-1});
const auto in_window = std::make_shared<Parameter>(data_type, ov::PartialShape{-1});
const auto in_frame_size = std::make_shared<Parameter>(step_size_type, ov::Shape{});
const auto in_frame_step = std::make_shared<Parameter>(step_size_type, ov::Shape{});
const auto op = make_op(in_signal, in_window, in_frame_size, in_frame_step, true);

std::vector<StaticShape> static_input_shapes = {StaticShape{48}, StaticShape{16}, StaticShape{}, StaticShape{}};
int32_t frame_size = 16;
int32_t frame_step = 16;

auto const_data = std::unordered_map<size_t, Tensor>{{2, {element::i32, Shape{}, &frame_size}},
{3, {element::i32, Shape{}, &frame_step}}};
auto acc = make_tensor_accessor(const_data);
auto static_output_shapes = shape_infer(op.get(), static_input_shapes, acc);
ASSERT_EQ(static_output_shapes[0], StaticShape({9, 3, 2}));
}

TEST_F(STFTShapeInferenceTest, all_input_as_params) {
const auto data_type = element::f32;
const auto step_size_type = element::i32;
Expand Down
40 changes: 34 additions & 6 deletions src/plugins/template/tests/functional/op_reference/stft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ std::vector<STFTParams> generateSTFTParams() {
using VT = typename ov::element_type_traits<ET>::value_type;
using INT_T = typename ov::element_type_traits<IT>::value_type;

const ov::Shape signal_48_shape{1, 48};
const ov::Shape signal_48_shape{48};
const ov::Shape signal_1_48_shape{1, 48};
const ov::Shape signal_2_48_shape{2, 48};
const ov::Shape signal_256_shape{1, 256};

Expand All @@ -107,6 +108,16 @@ std::vector<STFTParams> generateSTFTParams() {
-2.43477, 0.11273, 0.37044, 1.35963, 0.50186, -0.84421, 0.00001, 0.54235,
-0.31351, 0.77101, -1.86809, 1.73118, 1.46768, -0.33568, 0.61134, 0.04797});

reference_tests::Tensor signal_1_48(
signal_1_48_shape,
ET,
std::vector<VT>{-0.41676, -0.05627, -2.1362, 1.64027, -1.79344, -0.84175, 0.50288, -1.24529,
-1.05795, -0.90901, 0.55145, 2.29221, 0.04154, -1.11793, 0.53906, -0.59616,
-0.01913, 1.175, -0.74787, 0.00903, -0.87811, -0.15643, 0.25657, -0.98878,
-0.33882, -0.23618, -0.63766, -1.18761, -1.42122, -0.1535, -0.26906, 2.23137,
-2.43477, 0.11273, 0.37044, 1.35963, 0.50186, -0.84421, 0.00001, 0.54235,
-0.31351, 0.77101, -1.86809, 1.73118, 1.46768, -0.33568, 0.61134, 0.04797});

reference_tests::Tensor signal_2_48(
signal_2_48_shape,
ET,
Expand Down Expand Up @@ -255,6 +266,16 @@ std::vector<STFTParams> generateSTFTParams() {
-0.05574, 1.01868, -0.7169, 0.52739, 4.39323, -0.92417, 1.39751, 0.37859, 1.30337,
0., 0.2294, 0., 0.82838, 0., -4.56982, 0., -1.47752, 0.});

reference_tests::Tensor output_9_3_2_transp(
Shape{9, 3, 2},
ET,
std::vector<VT>{-2.52411, 0., -3.6289, 0., 1.1366, 0., 1.99743, 2.45799, 1.84867,
-0.67991, 0.26235, 0.25725, -2.243, -1.74288, 0.39666, 0.60667, -0.73965, -0.24622,
2.91255, -0.82545, 0.03844, 0.45931, -1.29728, -1.50822, -2.56084, 2.24181, -0.92956,
-1.32518, 1.78749, 1.94867, 0.87525, 0.70978, 0.47508, 1.29318, -0.18799, 0.98232,
2.10241, -2.57882, 0.88504, -1.03814, -1.44897, -2.97866, -1.59965, -0.02599, -1.02171,
0.17824, 2.46326, 1.82815, -0.44417, 0., 0.24368, 0., -2.81501, 0.});

reference_tests::Tensor output_1_9_3_2_transp(
Shape{1, 9, 3, 2},
ET,
Expand Down Expand Up @@ -309,6 +330,13 @@ std::vector<STFTParams> generateSTFTParams() {

std::vector<STFTParams> params;
params.emplace_back(signal_48,
hann_window_16,
frame_size_16,
frame_step_16,
transpose_frames_true,
output_9_3_2_transp,
"basic_1D_transp");
params.emplace_back(signal_1_48,
hann_window_16,
frame_size_16,
frame_step_16,
Expand All @@ -329,35 +357,35 @@ std::vector<STFTParams> generateSTFTParams() {
transpose_frames_false,
output_2_3_9_2_no_transp,
"basic_batch_2_no_transp");
params.emplace_back(signal_48,
params.emplace_back(signal_1_48,
hann_window_16,
frame_size_16,
frame_step_4,
transpose_frames_true,
output_1_9_9_2_transp,
"step_1/4_frame_transp");
params.emplace_back(signal_48,
params.emplace_back(signal_1_48,
hann_window_8,
frame_size_16,
frame_step_8,
transpose_frames_true,
output_1_9_5_2_transp,
"win_size_<_frame_size_transp");
params.emplace_back(signal_48,
params.emplace_back(signal_1_48,
hann_window_8,
frame_size_16,
frame_step_4,
transpose_frames_true,
output_1_9_9_2_transp_win_pad,
"step_1/4_frame_&_win_size_<_frame_size_transp");
params.emplace_back(signal_48,
params.emplace_back(signal_1_48,
hann_window_7,
frame_size_11,
frame_step_3,
transpose_frames_true,
output_1_6_13_2_transp,
"odd_sizes_transp");
params.emplace_back(signal_48,
params.emplace_back(signal_1_48,
hann_window_5,
frame_size_9,
frame_step_100,
Expand Down
2 changes: 0 additions & 2 deletions tests/layer_tests/pytorch_tests/test_stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def forward(self, x, window):
def test_stft(self, n_fft, hop_length, window_size, signal_shape, ie_device, precision, ir_version, trace_model):
if ie_device == "GPU":
pytest.xfail(reason="STFT op is not supported on GPU yet")
if signal_shape == (128,):
pytest.xfail(reason="STFT op is doesn't support 1D signal yet, please unsqueeze the input.")
self._test(*self.create_model(n_fft, hop_length, window_size), ie_device, precision,
ir_version, kwargs_to_prepare_input={"win_length": window_size, "signal_shape": signal_shape}, trace_model=trace_model)

Expand Down

0 comments on commit 664ffc6

Please sign in to comment.