diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index d1c69a2ce6..3ec914f002 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -54,6 +54,23 @@ mediapipe_cc_proto_library( deps = [":rational_factor_resample_calculator_proto"], ) +proto_library( + name = "resample_time_series_calculator_proto", + srcs = ["resample_time_series_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +mediapipe_cc_proto_library( + name = "resample_time_series_calculator_cc_proto", + srcs = ["resample_time_series_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":resample_time_series_calculator_proto"], +) + proto_library( name = "spectrogram_calculator_proto", srcs = ["spectrogram_calculator.proto"], @@ -177,6 +194,33 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "resample_time_series_calculator", + srcs = ["resample_time_series_calculator.cc"], + hdrs = ["resample_time_series_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + ":resample_time_series_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:contract", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/util:time_series_util", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_audio_tools//audio/dsp:resampler_q", + "@eigen_archive//:eigen3", + ], + alwayslink = 1, +) + cc_library( name = "stabilized_log_calculator", srcs = ["stabilized_log_calculator.cc"], @@ -377,3 +421,23 @@ cc_test( "@eigen_archive//:eigen3", ], ) + +cc_test( + name = "resample_time_series_calculator_test", + srcs = ["resample_time_series_calculator_test.cc"], + deps = [ + ":resample_time_series_calculator", + ":resample_time_series_calculator_cc_proto", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/util:time_series_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@com_google_audio_tools//audio/dsp:resampler_q", + "@eigen_archive//:eigen3", + ], +) diff --git a/mediapipe/calculators/audio/resample_time_series_calculator.cc b/mediapipe/calculators/audio/resample_time_series_calculator.cc new file mode 100644 index 0000000000..0ee2b53290 --- /dev/null +++ b/mediapipe/calculators/audio/resample_time_series_calculator.cc @@ -0,0 +1,222 @@ +// Copyright 2025 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/audio/resample_time_series_calculator.h" + +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "audio/dsp/resampler_q.h" +#include "mediapipe/calculators/audio/resample_time_series_calculator.pb.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/time_series_util.h" + +namespace mediapipe { + +namespace { + +void CopyChannelToVector(const mediapipe::Matrix& matrix, int channel, + std::vector* vec) { + vec->resize(matrix.cols()); + Eigen::Map(vec->data(), vec->size()) = matrix.row(channel); +} + +void CopyVectorToChannel(const std::vector& vec, + mediapipe::Matrix* matrix, int channel) { + if (matrix->cols() == 0) { + matrix->resize(matrix->rows(), vec.size()); + } else { + ABSL_CHECK_EQ(vec.size(), matrix->cols()); + } + ABSL_CHECK_LT(channel, matrix->rows()); + matrix->row(channel) = + Eigen::Map(vec.data(), vec.size()); +} + +Timestamp CalculateOutputTimestamp(Timestamp initial_timestamp, + int64_t cumulative_output_samples, + double target_sample_rate) { + ABSL_DCHECK(initial_timestamp != Timestamp::Unstarted()); + return initial_timestamp + ((cumulative_output_samples / target_sample_rate) * + Timestamp::kTimestampUnitsPerSecond); +} + +} // namespace + +// Defines ResampleTimeSeriesCalculator. +absl::Status ResampleTimeSeriesCalculatorImpl::Process(CalculatorContext* cc) { + return ProcessInternal(cc, kInput(cc).Get(), false); +} + +absl::Status ResampleTimeSeriesCalculatorImpl::Close(CalculatorContext* cc) { + if (initial_timestamp_ == Timestamp::Unstarted()) { + return absl::OkStatus(); + } + Matrix empty_input_frame(num_channels_, 0); + return ProcessInternal(cc, empty_input_frame, true); +} + +absl::Status ResampleTimeSeriesCalculatorImpl::Open(CalculatorContext* cc) { + ResampleTimeSeriesCalculatorOptions resample_options; + time_series_util::FillOptionsExtensionOrDie(cc->Options(), &resample_options); + + // Provide target_sample_rate either from static options, or dynamically from + // a side packet, the side packet one will override the options one if + // provided. + if (resample_options.has_target_sample_rate()) { + target_sample_rate_ = resample_options.target_sample_rate(); + } else if (!kSideInputTargetSampleRate(cc).IsEmpty()) { + target_sample_rate_ = kSideInputTargetSampleRate(cc).Get(); + } else { + return tool::StatusInvalid( + "target_sample_rate is not provided in resample_options, nor from a " + "side packet."); + } + + double min_source_sample_rate = target_sample_rate_; + if (resample_options.allow_upsampling()) { + min_source_sample_rate = resample_options.min_source_sample_rate(); + } + + TimeSeriesHeader input_header; + MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid( + kInput(cc).Header(), &input_header)); + + source_sample_rate_ = input_header.sample_rate(); + num_channels_ = input_header.num_channels(); + + if (source_sample_rate_ < min_source_sample_rate) { + return ::absl::FailedPreconditionError( + "Resample() failed because upsampling is disabled or source sample " + "rate is lower than min_source_sample_rate."); + } + + // Don't create resamplers for pass-thru (sample rates are equal). + if (source_sample_rate_ != target_sample_rate_) { + resampler_ = ResamplerFromOptions(source_sample_rate_, target_sample_rate_, + num_channels_, resample_options); + RET_CHECK(resampler_) << "Failed to initialize resampler."; + } + + TimeSeriesHeader* output_header = new TimeSeriesHeader(input_header); + output_header->set_sample_rate(target_sample_rate_); + // The resampler doesn't make guarantees about how many samples will + // be in each packet. + output_header->clear_packet_rate(); + output_header->clear_num_samples(); + + kOutput(cc).SetHeader(mediapipe::api2::FromOldPacket(Adopt(output_header))); + cumulative_output_samples_ = 0; + cumulative_input_samples_ = 0; + initial_timestamp_ = Timestamp::Unstarted(); + check_inconsistent_timestamps_ = + resample_options.check_inconsistent_timestamps(); + return absl::OkStatus(); +} + +absl::Status ResampleTimeSeriesCalculatorImpl::ProcessInternal( + CalculatorContext* cc, const Matrix& input_frame, bool should_flush) { + if (initial_timestamp_ == Timestamp::Unstarted()) { + initial_timestamp_ = kInput(cc).timestamp(); + } + + if (check_inconsistent_timestamps_) { + time_series_util::LogWarningIfTimestampIsInconsistent( + kInput(cc).timestamp(), initial_timestamp_, cumulative_input_samples_, + source_sample_rate_); + } + const Timestamp output_timestamp = CalculateOutputTimestamp( + initial_timestamp_, cumulative_output_samples_, target_sample_rate_); + + cumulative_input_samples_ += input_frame.cols(); + std::unique_ptr output_frame(new Matrix(num_channels_, 0)); + if (resampler_ == nullptr) { + // Sample rates were same for input and output; pass-thru. + *output_frame = input_frame; + } else { + resampler_->Resample(input_frame, output_frame.get(), should_flush); + } + cumulative_output_samples_ += output_frame->cols(); + + if (output_frame->cols() > 0) { + kOutput(cc).Send(*output_frame, output_timestamp); + output_frame.reset(); + } + kOutput(cc).SetNextTimestampBound(CalculateOutputTimestamp( + initial_timestamp_, cumulative_output_samples_, target_sample_rate_)); + + return absl::OkStatus(); +} + +// static +std::unique_ptr +ResampleTimeSeriesCalculatorImpl::ResamplerFromOptions( + double source_sample_rate, double target_sample_rate, int num_channels, + const ResampleTimeSeriesCalculatorOptions& options) { + std::unique_ptr resampler; + switch (options.resampler_type()) { + case ResampleTimeSeriesCalculatorOptions::RESAMPLER_RATIONAL_FACTOR: { + const auto& rational_factor_options = + options.resampler_rational_factor_options(); + + // Read resampler parameters from proto. + audio_dsp::QResamplerParams params; + if (rational_factor_options.has_radius_factor()) { + params.filter_radius_factor = rational_factor_options.radius_factor(); + } else if (rational_factor_options.has_radius()) { + // Convert RationalFactorResampler radius to QResampler radius_factor. + params.filter_radius_factor = + rational_factor_options.radius() * + std::min(1.0, target_sample_rate / source_sample_rate); + } + if (rational_factor_options.has_cutoff_proportion()) { + params.cutoff_proportion = rational_factor_options.cutoff_proportion(); + } else if (rational_factor_options.has_cutoff()) { + // Convert RationalFactorResampler cutoff to QResampler + // cutoff_proportion. + params.cutoff_proportion = + 2 * rational_factor_options.cutoff() / + std::min(source_sample_rate, target_sample_rate); + } + if (rational_factor_options.has_kaiser_beta()) { + params.kaiser_beta = rational_factor_options.kaiser_beta(); + } + // Set large enough so that the resampling factor between common sample + // rates (e.g. 8kHz, 16kHz, 22.05kHz, 32kHz, 44.1kHz, 48kHz) is exact, and + // that any factor is represented with error less than 0.025%. + params.max_denominator = 2000; + + resampler = std::make_unique( + source_sample_rate, target_sample_rate, num_channels, params); + } break; + default: + break; + } + if (resampler != nullptr && !resampler->Valid()) { + resampler.reset(); + } + return resampler; +} +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/resample_time_series_calculator.h b/mediapipe/calculators/audio/resample_time_series_calculator.h new file mode 100644 index 0000000000..317a9aa66c --- /dev/null +++ b/mediapipe/calculators/audio/resample_time_series_calculator.h @@ -0,0 +1,139 @@ +// Copyright 2025 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_AUDIO_RESAMPLE_TIME_SERIES_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_AUDIO_RESAMPLE_TIME_SERIES_CALCULATOR_H_ + +#include +#include + +#include "absl/status/status.h" +#include "audio/dsp/resampler_q.h" +#include "mediapipe/calculators/audio/resample_time_series_calculator.pb.h" +#include "mediapipe/framework/api2/contract.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/util/time_series_util.h" + +namespace mediapipe { + +struct ResampleTimeSeriesCalculator : public mediapipe::api2::NodeIntf { + // Sequence of Matrices, each column describing a particular time frame, each + // row a feature dimension, with TimeSeriesHeader. + static constexpr mediapipe::api2::Input kInput{""}; + static constexpr mediapipe::api2::SideInput::Optional + kSideInputTargetSampleRate{"TARGET_SAMPLE_RATE"}; + // Sequence of Matrices, each column describing a particular time frame, each + // row a feature dimension, with TimeSeriesHeader. + static constexpr mediapipe::api2::Output kOutput{""}; + MEDIAPIPE_NODE_INTERFACE(ResampleTimeSeriesCalculator, kInput, kOutput, + kSideInputTargetSampleRate, + mediapipe::api2::TimestampChange::Arbitrary()); +}; + +// MediaPipe Calculator for resampling a (vector-valued) +// input time series with a uniform sample rate. The output +// stream's sampling rate is specified by target_sample_rate in the +// ResampleTimeSeriesCalculatorOptions. The output time series may have +// a varying number of samples per frame. +class ResampleTimeSeriesCalculatorImpl + : public mediapipe::api2::NodeImpl { + public: + struct TestAccess; + static absl::Status UpdateContract(CalculatorContract* cc) { + return time_series_util::HasOptionsExtension< + ResampleTimeSeriesCalculatorOptions>(cc->Options()); + } + // Returns FAIL if the input stream header is invalid or if the + // resampler cannot be initialized. + absl::Status Open(CalculatorContext* cc) override; + // Resamples a packet of TimeSeries data. Returns FAIL if the + // resampler state becomes inconsistent. + absl::Status Process(CalculatorContext* cc) override; + // Flushes any remaining state. Returns FAIL if the resampler state + // becomes inconsistent. + absl::Status Close(CalculatorContext* cc) override; + + class ResamplerWrapper { + public: + virtual ~ResamplerWrapper() = default; + virtual bool Valid() const = 0; + virtual void Resample(const Matrix& input_frame, Matrix* output_frame, + bool should_flush) = 0; + }; + + // Wrapper for QResampler. + class QResamplerWrapper + : public ResampleTimeSeriesCalculatorImpl::ResamplerWrapper { + public: + QResamplerWrapper(double source_sample_rate, double target_sample_rate, + int num_channels, audio_dsp::QResamplerParams params) + : impl_(source_sample_rate, target_sample_rate, num_channels, params) {} + + bool Valid() const override { return impl_.Valid(); } + + void Resample(const Matrix& input_frame, Matrix* output_frame, + bool should_flush) override { + if (should_flush) { + impl_.Flush(output_frame); + } else { + impl_.ProcessSamples(input_frame, output_frame); + } + } + + private: + audio_dsp::QResampler impl_; + }; + + protected: + // Returns a ResamplerWrapper implementation specified by the + // ResampleTimeSeriesCalculatorOptions proto. Returns null if the options + // specify an invalid resampler. + static std::unique_ptr ResamplerFromOptions( + double source_sample_rate, double target_sample_rate, int num_channels, + const ResampleTimeSeriesCalculatorOptions& options); + + // Does Timestamp bookkeeping and resampling common to Process() and + // Close(). Returns FAIL if the resampler state becomes + // inconsistent. + absl::Status ProcessInternal(CalculatorContext* cc, const Matrix& input_frame, + bool should_flush); + + double source_sample_rate_; + double target_sample_rate_; + int64_t cumulative_input_samples_; + int64_t cumulative_output_samples_; + Timestamp initial_timestamp_; + bool check_inconsistent_timestamps_; + int num_channels_; + std::unique_ptr resampler_; +}; + +// Test-only access to ResampleTimeSeriesCalculator methods. +struct ResampleTimeSeriesCalculatorImpl::TestAccess { + static std::unique_ptr ResamplerFromOptions( + double source_sample_rate, double target_sample_rate, int num_channels, + const ResampleTimeSeriesCalculatorOptions& options) { + return ResampleTimeSeriesCalculatorImpl::ResamplerFromOptions( + source_sample_rate, target_sample_rate, num_channels, options); + } +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_AUDIO_RESAMPLE_TIME_SERIES_CALCULATOR_H_ diff --git a/mediapipe/calculators/audio/resample_time_series_calculator.proto b/mediapipe/calculators/audio/resample_time_series_calculator.proto new file mode 100644 index 0000000000..d053b5fc82 --- /dev/null +++ b/mediapipe/calculators/audio/resample_time_series_calculator.proto @@ -0,0 +1,102 @@ +// Copyright 2025 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message ResampleTimeSeriesCalculatorOptions { + extend CalculatorOptions { + optional ResampleTimeSeriesCalculatorOptions ext = 49296647; + } + + // target_sample_rate is the sample rate, in Hertz, of the output + // stream. Required. Must be greater than 0. + optional double target_sample_rate = 1; + + enum ResamplerType { + UNDEFINED = 0; + + reserved 1; + + // RESAMPLER_RATIONAL_FACTOR selects QResampler which + // replaces RationalFactorResampler. + RESAMPLER_RATIONAL_FACTOR = 2; + } + + optional ResamplerType resampler_type = 2 + [default = RESAMPLER_RATIONAL_FACTOR]; + // Parameters for initializing LibResampleResampler. See LibResampleResampler + // for more details. + message ResamplerLibResampleOptions { + // Whether to use libresample's high-quality resampling mode. + optional bool use_high_quality_resampler = 1 [default = false]; + } + + optional ResamplerLibResampleOptions resampler_libresample_options = 3; + + // Parameters for initializing the RationalFactorResampler. See + // RationalFactorResampler for more details. + message ResamplerRationalFactorOptions { + // Scale factor for the resampling kernel's nonzero support radius. If + // upsampling, the kernel radius is `filter_radius_factor` input samples. If + // downsampling, the kernel radius is `filter_radius_factor` *output* + // samples. Larger radius makes the transition between passband and stopband + // sharper, but proportionally increases computation and memory cost. + // + // The default value 5.0 corresponds to libresample's "low quality" mode + // (which despite the name, is quite good quality). + // + // A value of 17.0 corresponds to libresample's "high quality" mode. + optional double radius_factor = 4 [default = 5.0]; + + // Antialiasing cutoff frequency as a proportion of + // min(input_sample_rate, output_sample_rate) / 2. + // The default is 0.9, meaning the cutoff is at 90% of the input Nyquist + // frequency or the output Nyquist frequency, whichever is smaller. + optional double cutoff_proportion = 5 [default = 0.9]; + + // The Kaiser beta parameter for the kernel window. A larger value implies + // wider transistion band and stronger stopband attenuation. + optional double kaiser_beta = 3 [default = 6.0]; + + // The following fields are an older, alternative parameterization of the + // resampling kernel preserved for backward compatibility. + + // Kernel radius in units of input samples. It is related to radius_factor + // by radius_factor = radius * min(1, output_rate / input_rate). + optional double radius = 1; + + // Anti-aliasing cutoff frequency in Hertz. A reasonable setting is + // 0.45 * min(input_rate, output_rate). It is related to cutoff_proportion + // by cutoff_proportion = 2 * cutoff / min(output_rate, input_rate). + optional double cutoff = 2; + } + + optional ResamplerRationalFactorOptions resampler_rational_factor_options = 4; + + // Set to false to disable checks for jitter in timestamp values. Useful with + // live audio input. + optional bool check_inconsistent_timestamps = 5 [default = true]; + + // Set to false to throw an error if the original audio has lower sample rate + // than `target_sample_rate`. + optional bool allow_upsampling = 6 [default = true]; + + // min_source_sample_rate is the minimum allowed sample rate of the input + // stream, in Hertz. Only used when allow_upsampling is set to true. + optional double min_source_sample_rate = 7 [default = 0.0]; +} diff --git a/mediapipe/calculators/audio/resample_time_series_calculator_test.cc b/mediapipe/calculators/audio/resample_time_series_calculator_test.cc new file mode 100644 index 0000000000..f9b8b1d847 --- /dev/null +++ b/mediapipe/calculators/audio/resample_time_series_calculator_test.cc @@ -0,0 +1,277 @@ +#include "mediapipe/calculators/audio/resample_time_series_calculator.h" + +#include + +#include +#include +#include +#include + +#include "Eigen/Core" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "audio/dsp/resampler_q.h" +#include "mediapipe/calculators/audio/resample_time_series_calculator.pb.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/time_series_test_util.h" + +namespace mediapipe { +namespace { + +const int kInitialTimestampOffsetMilliseconds = 4; + +class ResampleTimeSeriesCalculatorTest + : public TimeSeriesCalculatorTest { + protected: + void SetUp() override { + calculator_name_ = "ResampleTimeSeriesCalculator"; + input_sample_rate_ = 4000.0; + num_input_channels_ = 3; + } + + // Expects two vectors whose lengths are almost the same and whose + // elements are equal (for indices that are present in both). + // + // This is useful because the resampler doesn't make precise + // guarantees about its output size. + void ExpectVectorMostlyFloatEq(absl::Span expected, + absl::Span actual) { + // Lengths should be close, but don't have to be equal. + ASSERT_NEAR(expected.size(), actual.size(), 1); + for (int i = 0; i < std::min(expected.size(), actual.size()); ++i) { + EXPECT_NEAR(expected[i], actual[i], 5e-7f) << " where i=" << i << "."; + } + } + + // Caller takes ownership of the returned value. + Matrix* NewTestFrame(int num_channels, int num_samples, int timestamp) { + return new Matrix(Matrix::Random(num_channels, num_samples)); + } + + // Initializes and runs the test graph. + absl::Status Run(double output_sample_rate) { + options_.set_target_sample_rate(output_sample_rate); + InitializeGraph(); + + FillInputHeader(); + concatenated_input_samples_.resize(num_input_channels_, 0); + num_input_samples_ = 0; + for (int i = 0; i < 5; ++i) { + int packet_size = (i + 1) * 10; + int timestamp = kInitialTimestampOffsetMilliseconds + + num_input_samples_ / input_sample_rate_ * + Timestamp::kTimestampUnitsPerSecond; + Matrix* data_frame = + NewTestFrame(num_input_channels_, packet_size, timestamp); + + // Keep a reference copy of the input. + // + // conservativeResize() is needed here to preserve the existing + // data. Eigen's resize() resizes without preserving data. + concatenated_input_samples_.conservativeResize( + num_input_channels_, num_input_samples_ + packet_size); + concatenated_input_samples_.rightCols(packet_size) = *data_frame; + num_input_samples_ += packet_size; + + AppendInputPacket(data_frame, timestamp); + } + + return RunGraph(); + } + + void CheckOutputLength(double output_sample_rate) { + double factor = output_sample_rate / input_sample_rate_; + + int num_output_samples = 0; + for (const Packet& packet : output().packets) { + num_output_samples += packet.Get().cols(); + } + + // The exact number of expected samples may vary based on the implementation + // of the resampler since the exact value is not an integer. + const double expected_num_output_samples = num_input_samples_ * factor; + EXPECT_LE(ceil(expected_num_output_samples), num_output_samples); + EXPECT_GE(ceil(expected_num_output_samples) + 11, num_output_samples); + } + + // Checks that output timestamps are consistent with the + // output_sample_rate and output packet sizes. + void CheckOutputPacketTimestamps(double output_sample_rate) { + int num_output_samples = 0; + for (const Packet& packet : output().packets) { + const int expected_timestamp = kInitialTimestampOffsetMilliseconds + + num_output_samples / output_sample_rate * + Timestamp::kTimestampUnitsPerSecond; + EXPECT_NEAR(expected_timestamp, packet.Timestamp().Value(), 1); + num_output_samples += packet.Get().cols(); + } + } + + // Checks that output values from the calculator (which resamples + // packet-by-packet) are consistent with resampling the entire + // signal at once. + void CheckOutputValues( + double output_sample_rate, + std::unique_ptr + verification_resampler = nullptr) { + if (!verification_resampler) { + verification_resampler = + ResampleTimeSeriesCalculatorImpl::TestAccess::ResamplerFromOptions( + input_sample_rate_, output_sample_rate, num_input_channels_, + options_); + } + + Matrix expected_resampled; + verification_resampler->Resample(concatenated_input_samples_, + &expected_resampled, false); + Matrix flushed; + verification_resampler->Resample({}, &flushed, true); + expected_resampled.conservativeResize( + num_input_channels_, expected_resampled.cols() + flushed.cols()); + expected_resampled.rightCols(flushed.cols()) = flushed; + + for (int i = 0; i < num_input_channels_; ++i) { + std::vector expected_resampled_i(expected_resampled.row(i).begin(), + expected_resampled.row(i).end()); + std::vector actual_resampled; + for (const Packet& packet : output().packets) { + auto output_frame_row = packet.Get().row(i); + actual_resampled.insert(actual_resampled.end(), + output_frame_row.begin(), + output_frame_row.end()); + } + + ExpectVectorMostlyFloatEq(expected_resampled_i, actual_resampled); + } + } + + void CheckOutputHeaders(double output_sample_rate) { + const TimeSeriesHeader& output_header = + output().header.Get(); + TimeSeriesHeader expected_header; + expected_header.set_sample_rate(output_sample_rate); + expected_header.set_num_channels(num_input_channels_); + EXPECT_THAT(output_header, mediapipe::EqualsProto(expected_header)); + } + + void CheckOutput(double output_sample_rate) { + CheckOutputLength(output_sample_rate); + CheckOutputPacketTimestamps(output_sample_rate); + CheckOutputValues(output_sample_rate); + CheckOutputHeaders(output_sample_rate); + } + + void CheckOutputUnchanged() { + for (int i = 0; i < num_input_channels_; ++i) { + std::vector expected_resampled_data; + for (int j = 0; j < num_input_samples_; ++j) { + expected_resampled_data.push_back(concatenated_input_samples_(i, j)); + } + std::vector actual_resampled_data; + for (const Packet& packet : output().packets) { + Matrix output_frame_row = packet.Get().row(i); + actual_resampled_data.insert( + actual_resampled_data.end(), &output_frame_row(0), + &output_frame_row(0) + output_frame_row.cols()); + } + ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data); + } + } + + Matrix concatenated_input_samples_; +}; + +TEST_F(ResampleTimeSeriesCalculatorTest, Upsample) { + const double kUpsampleRate = input_sample_rate_ * 1.9; + MP_ASSERT_OK(Run(kUpsampleRate)); + CheckOutput(kUpsampleRate); +} + +TEST_F(ResampleTimeSeriesCalculatorTest, Downsample) { + const double kDownsampleRate = input_sample_rate_ / 1.9; + MP_ASSERT_OK(Run(kDownsampleRate)); + CheckOutput(kDownsampleRate); +} + +TEST_F(ResampleTimeSeriesCalculatorTest, UsesRationalFactorResampler) { + options_.set_resampler_type( + ResampleTimeSeriesCalculatorOptions::RESAMPLER_RATIONAL_FACTOR); + // Pick an upsample rate so the resample ratio is 2. + const double kUpsampleRate = input_sample_rate_ * 2; + MP_ASSERT_OK(Run(kUpsampleRate)); + CheckOutput(kUpsampleRate); +} + +TEST_F(ResampleTimeSeriesCalculatorTest, PassthroughIfSampleRateUnchanged) { + const double kUpsampleRate = input_sample_rate_; + MP_ASSERT_OK(Run(kUpsampleRate)); + CheckOutputUnchanged(); +} + +TEST_F(ResampleTimeSeriesCalculatorTest, FailsOnBadTargetRate) { + ASSERT_FALSE(Run(-999.9).ok()); // Invalid output sample rate. +} + +TEST_F(ResampleTimeSeriesCalculatorTest, DoesNotDieOnEmptyInput) { + options_.set_target_sample_rate(input_sample_rate_); + InitializeGraph(); + FillInputHeader(); + MP_ASSERT_OK(RunGraph()); + EXPECT_TRUE(output().packets.empty()); +} + +TEST_F(ResampleTimeSeriesCalculatorTest, CustomQResamplerKernel) { + const float kOutputSampleRate = input_sample_rate_ * 0.7; + const float kRadiusFactor = 11.0; + const float kCutoffProportion = 0.85; + options_.set_resampler_type( + ResampleTimeSeriesCalculatorOptions::RESAMPLER_RATIONAL_FACTOR); + auto resampler_options = options_.mutable_resampler_rational_factor_options(); + resampler_options->set_radius_factor(kRadiusFactor); + resampler_options->set_cutoff_proportion(kCutoffProportion); + MP_ASSERT_OK(Run(kOutputSampleRate)); + + audio_dsp::QResamplerParams params; + params.filter_radius_factor = kRadiusFactor; + params.cutoff_proportion = kCutoffProportion; + CheckOutputValues( + kOutputSampleRate, + std::make_unique( + input_sample_rate_, kOutputSampleRate, num_input_channels_, params)); +} + +TEST_F(ResampleTimeSeriesCalculatorTest, CustomLegacyKernel) { + const float kOutputSampleRate = input_sample_rate_ * 0.7; + const float kRadiusFactor = 11.0; + const float kCutoffProportion = 0.85; + // Convert to equivalent legacy parameters. + const float kRadius = + kRadiusFactor * + std::max(1.0f, input_sample_rate_ / kOutputSampleRate); + const float kCutoff = 0.5f * kCutoffProportion * + std::min(input_sample_rate_, kOutputSampleRate); + + options_.set_resampler_type( + ResampleTimeSeriesCalculatorOptions::RESAMPLER_RATIONAL_FACTOR); + auto resampler_options = options_.mutable_resampler_rational_factor_options(); + resampler_options->set_radius(kRadius); + resampler_options->set_cutoff(kCutoff); + MP_ASSERT_OK(Run(kOutputSampleRate)); + + audio_dsp::QResamplerParams params; + params.filter_radius_factor = kRadiusFactor; + params.cutoff_proportion = kCutoffProportion; + + CheckOutputValues( + kOutputSampleRate, + std::make_unique( + input_sample_rate_, kOutputSampleRate, num_input_channels_, params)); +} + +} // anonymous namespace +} // namespace mediapipe diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index d8cccf0b46..ff5d913e05 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -418,6 +418,7 @@ cc_test( srcs = ["time_series_util_test.cc"], deps = [ ":time_series_util", + "//mediapipe/calculators/audio:resample_time_series_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/util/time_series_util_test.cc b/mediapipe/util/time_series_util_test.cc index e8d47dbc6a..4326693ef8 100644 --- a/mediapipe/util/time_series_util_test.cc +++ b/mediapipe/util/time_series_util_test.cc @@ -15,6 +15,7 @@ #include "mediapipe/util/time_series_util.h" #include "Eigen/Core" +#include "mediapipe/calculators/audio/resample_time_series_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/port/gmock.h"