Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable low_freq high_freq, dithering #664

Merged
merged 4 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cmake/kaldi-native-fbank.cmake
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
function(download_kaldi_native_fbank)
include(FetchContent)

set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.7.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.7.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=e78fd9d481d83d7d6d1be0012752e6531cb614e030558a3491e3c033cb8e0e4e")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also change lines 15-19 in this file.

set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.19.1.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904")

set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
Expand Down
28 changes: 20 additions & 8 deletions sherpa-onnx/csrc/features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,38 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {

po->Register("feat-dim", &feature_dim,
"Feature dimension. Must match the one expected by the model.");

po->Register("low-freq", &low_freq,
"Low cutoff frequency for mel bins");

po->Register("high-freq", &high_freq,
"High cutoff frequency for mel bins "
"(if <= 0, offset from Nyquist)");

po->Register("dither", &dither,
"Dithering constant (0.0 means no dither). "
"By default the audio samples are in range [-1,+1], "
"so 0.00003 is a good value, "
"equivalent to the default 1.0 from kaldi");
}

std::string FeatureExtractorConfig::ToString() const {
std::ostringstream os;

os << "FeatureExtractorConfig(";
os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ")";
os << "feature_dim=" << feature_dim << ", ";
os << "low_freq=" << low_freq << ", ";
os << "high_freq=" << high_freq << ", ";
os << "dither=" << dither << ")";

return os.str();
}

class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
Expand All @@ -50,13 +66,9 @@ class FeatureExtractor::Impl {

opts_.mel_opts.num_bins = config.feature_dim;

// Please see
// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
// and
// https://github.com/k2-fsa/sherpa-onnx/issues/514
opts_.mel_opts.high_freq = -400;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;

opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;

fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
Expand Down
22 changes: 21 additions & 1 deletion sherpa-onnx/csrc/features.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ struct FeatureExtractorConfig {
// Feature dimension
int32_t feature_dim = 80;

// minimal frequency for Mel-filterbank, in Hz
float low_freq = 20.0f;

// maximal frequency of Mel-filterbank
// in Hz; negative value is subtracted from Nyquist freq.:
// i.e. for sampling_rate 16000 / 2 - 400 = 7600Hz
//
// Please see
// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
// and
// https://github.com/k2-fsa/sherpa-onnx/issues/514
float high_freq = -400.0f;

// dithering constant, useful for signals with hard-zeroes in non-speech parts
// this prevents large negative values in log-mel filterbanks
//
// In k2, audio samples are in range [-1..+1], in kaldi the range was
// [-32k..+32k], so the value 0.00003 is equivalent to kaldi default 1.0
//
float dither = 0.0f; // dithering disabled by default

// Set internally by some models, e.g., paraformer sets it to false.
// This parameter is not exposed to users from the commandline
// If true, the feature extractor expects inputs to be normalized to
Expand All @@ -31,7 +52,6 @@ struct FeatureExtractorConfig {
bool snip_edges = false;
float frame_shift_ms = 10.0f; // in milliseconds.
float frame_length_ms = 25.0f; // in milliseconds.
int32_t low_freq = 20;
bool is_librosa = false;
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
std::string window_type = "povey"; // e.g. Hamming window
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
unk_id_ = sym_["<unk>"];
}

model_->SetFeatureDim(config.feat_config.feature_dim);

InitKeywords();

decoder_ = std::make_unique<TransducerKeywordDecoder>(
Expand All @@ -89,6 +91,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
unk_id_ = sym_["<unk>"];
}

model_->SetFeatureDim(config.feat_config.feature_dim);

InitKeywords(mgr);

decoder_ = std::make_unique<TransducerKeywordDecoder>(
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
unk_id_ = sym_["<unk>"];
}

model_->SetFeatureDim(config.feat_config.feature_dim);

if (config.decoding_method == "modified_beam_search") {
if (!config_.hotwords_file.empty()) {
InitHotwords();
Expand Down Expand Up @@ -123,6 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
unk_id_ = sym_["<unk>"];
}

model_->SetFeatureDim(config.feat_config.feature_dim);

if (config.decoding_method == "modified_beam_search") {
#if 0
// TODO(fangjun): Implement it
Expand Down
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/online-transducer-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ class OnlineTransducerModel {
*/
virtual std::vector<Ort::Value> GetEncoderInitStates() = 0;

/** Set feature dim.
*
* This is used in `OnlineZipformer2TransducerModel`,
* to pass `feature_dim` for `GetEncoderInitStates()`.
*
* This has to be called before GetEncoderInitStates(), so the `encoder_embed`
* init state has the correct `embed_dim` of its output.
*/
virtual void SetFeatureDim(int32_t feature_dim) { }

/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
Expand Down
5 changes: 4 additions & 1 deletion sherpa-onnx/csrc/online-zipformer2-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,10 @@ OnlineZipformer2TransducerModel::GetEncoderInitStates() {
}

{
std::array<int64_t, 4> s{1, 128, 3, 19};
SHERPA_ONNX_CHECK_NE(feature_dim_, 0);
int32_t embed_dim = (((feature_dim_ - 1) / 2) - 1) / 2;
std::array<int64_t, 4> s{1, 128, 3, embed_dim};

auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
ans.push_back(std::move(v));
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/online-zipformer2-transducer-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {

std::vector<Ort::Value> GetEncoderInitStates() override;

void SetFeatureDim(int32_t feature_dim) override {
feature_dim_ = feature_dim;
}

std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states,
Ort::Value processed_frames) override;
Expand Down Expand Up @@ -101,6 +105,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {

int32_t context_size_ = 0;
int32_t vocab_size_ = 0;
int32_t feature_dim_ = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change the default value to 80?

};

} // namespace sherpa_onnx
Expand Down
11 changes: 9 additions & 2 deletions sherpa-onnx/python/csrc/features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@ namespace sherpa_onnx {
static void PybindFeatureExtractorConfig(py::module *m) {
using PyClass = FeatureExtractorConfig;
py::class_<PyClass>(*m, "FeatureExtractorConfig")
.def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,
py::arg("feature_dim") = 80)
.def(py::init<int32_t, int32_t, float, float, float>(),
py::arg("sampling_rate") = 16000,
py::arg("feature_dim") = 80,
py::arg("low_freq") = 20.0f,
py::arg("high_freq") = -400.0f,
py::arg("dither") = 0.0f)
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
.def_readwrite("feature_dim", &PyClass::feature_dim)
.def_readwrite("low_freq", &PyClass::low_freq)
.def_readwrite("high_freq", &PyClass::high_freq)
.def_readwrite("dither", &PyClass::high_freq)
.def("__str__", &PyClass::ToString);
}

Expand Down
16 changes: 16 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def from_transducer(
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
low_freq: float = 20.0,
high_freq: float = -400.0,
dither: float = 0.0,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
Expand Down Expand Up @@ -80,6 +83,16 @@ def from_transducer(
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
low_freq:
Low cutoff frequency for mel bins in feature extraction.
high_freq:
High cutoff frequency for mel bins in feature extraction
(if <= 0, offset from Nyquist)
dither:
Dithering constant (0.0 means no dither).
By default the audio samples are in range [-1,+1],
so dithering constant 0.00003 is a good value,
equivalent to the default 1.0 from kaldi
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
Expand Down Expand Up @@ -140,6 +153,9 @@ def from_transducer(
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
low_freq=low_freq,
high_freq=high_freq,
dither=dither,
)

endpoint_config = EndpointConfig(
Expand Down
Loading