diff --git a/scripts/dotnet/OfflineModelConfig.cs b/scripts/dotnet/OfflineModelConfig.cs index 58b24dbbe..2dc2347c1 100644 --- a/scripts/dotnet/OfflineModelConfig.cs +++ b/scripts/dotnet/OfflineModelConfig.cs @@ -23,6 +23,8 @@ public OfflineModelConfig() Debug = 0; Provider = "cpu"; ModelType = ""; + ModelingUnit = "cjkchar"; + BpeVocab = ""; } public OfflineTransducerModelConfig Transducer; public OfflineParaformerModelConfig Paraformer; @@ -42,5 +44,11 @@ public OfflineModelConfig() [MarshalAs(UnmanagedType.LPStr)] public string ModelType; + + [MarshalAs(UnmanagedType.LPStr)] + public string ModelingUnit; + + [MarshalAs(UnmanagedType.LPStr)] + public string BpeVocab; } } diff --git a/scripts/dotnet/OnlineModelConfig.cs b/scripts/dotnet/OnlineModelConfig.cs index 1471959d8..dcba23cf8 100644 --- a/scripts/dotnet/OnlineModelConfig.cs +++ b/scripts/dotnet/OnlineModelConfig.cs @@ -23,6 +23,8 @@ public OnlineModelConfig() Provider = "cpu"; Debug = 0; ModelType = ""; + ModelingUnit = "cjkchar"; + BpeVocab = ""; } public OnlineTransducerModelConfig Transducer; @@ -43,5 +45,11 @@ public OnlineModelConfig() [MarshalAs(UnmanagedType.LPStr)] public string ModelType; + + [MarshalAs(UnmanagedType.LPStr)] + public string ModelingUnit; + + [MarshalAs(UnmanagedType.LPStr)] + public string BpeVocab; } } diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index af60d959f..e89787da9 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -87,6 +87,8 @@ type OnlineModelConfig struct { Provider string // Optional. Valid values are: cpu, cuda, coreml Debug int // 1 to show model meta information while loading it. ModelType string // Optional. You can specify it for faster model initialization + ModelingUnit string // Optional. cjkchar, bpe, cjkchar+bpe + BpeVocab string // Optional. } // Configuration for the feature extractor @@ -187,6 +189,12 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { c.model_config.model_type = C.CString(config.ModelConfig.ModelType) defer C.free(unsafe.Pointer(c.model_config.model_type)) + c.model_config.modeling_unit = C.CString(config.ModelConfig.ModelingUnit) + defer C.free(unsafe.Pointer(c.model_config.modeling_unit)) + + c.model_config.bpe_vocab = C.CString(config.ModelConfig.BpeVocab) + defer C.free(unsafe.Pointer(c.model_config.bpe_vocab)) + c.decoding_method = C.CString(config.DecodingMethod) defer C.free(unsafe.Pointer(c.decoding_method)) @@ -372,6 +380,9 @@ type OfflineModelConfig struct { // Optional. Specify it for faster model initialization. ModelType string + + ModelingUnit string // Optional. cjkchar, bpe, cjkchar+bpe + BpeVocab string // Optional. } // Configuration for the offline/non-streaming recognizer. @@ -460,6 +471,12 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { c.model_config.model_type = C.CString(config.ModelConfig.ModelType) defer C.free(unsafe.Pointer(c.model_config.model_type)) + c.model_config.modeling_unit = C.CString(config.ModelConfig.ModelingUnit) + defer C.free(unsafe.Pointer(c.model_config.modeling_unit)) + + c.model_config.bpe_vocab = C.CString(config.ModelConfig.BpeVocab) + defer C.free(unsafe.Pointer(c.model_config.bpe_vocab)) + c.lm_config.model = C.CString(config.LmConfig.Model) defer C.free(unsafe.Pointer(c.lm_config.model)) diff --git a/scripts/node-addon-api/src/non-streaming-asr.cc b/scripts/node-addon-api/src/non-streaming-asr.cc index a1749a47e..d101c7eb6 100644 --- a/scripts/node-addon-api/src/non-streaming-asr.cc +++ b/scripts/node-addon-api/src/non-streaming-asr.cc @@ -126,6 +126,8 @@ static SherpaOnnxOfflineModelConfig GetOfflineModelConfig(Napi::Object obj) { SHERPA_ONNX_ASSIGN_ATTR_STR(provider, provider); SHERPA_ONNX_ASSIGN_ATTR_STR(model_type, modelType); + SHERPA_ONNX_ASSIGN_ATTR_STR(modeling_unit, modelingUnit); + SHERPA_ONNX_ASSIGN_ATTR_STR(bpe_vocab, bpeVocab); return c; } @@ -232,6 +234,14 @@ CreateOfflineRecognizerWrapper(const Napi::CallbackInfo &info) { delete[] c.model_config.model_type; } + if (c.model_config.modeling_unit) { + delete[] c.model_config.modeling_unit; + } + + if (c.model_config.bpe_vocab) { + delete[] c.model_config.bpe_vocab; + } + if (c.lm_config.model) { delete[] c.lm_config.model; } diff --git a/scripts/node-addon-api/src/streaming-asr.cc b/scripts/node-addon-api/src/streaming-asr.cc index fec4a46fc..59312a230 100644 --- a/scripts/node-addon-api/src/streaming-asr.cc +++ b/scripts/node-addon-api/src/streaming-asr.cc @@ -118,6 +118,8 @@ SherpaOnnxOnlineModelConfig GetOnlineModelConfig(Napi::Object obj) { } SHERPA_ONNX_ASSIGN_ATTR_STR(model_type, modelType); + SHERPA_ONNX_ASSIGN_ATTR_STR(modeling_unit, modelingUnit); + SHERPA_ONNX_ASSIGN_ATTR_STR(bpe_vocab, bpeVocab); return c; } @@ -228,6 +230,14 @@ static Napi::External CreateOnlineRecognizerWrapper( delete[] c.model_config.model_type; } + if (c.model_config.modeling_unit) { + delete[] c.model_config.modeling_unit; + } + + if (c.model_config.bpe_vocab) { + delete[] c.model_config.bpe_vocab; + } + if (c.decoding_method) { delete[] c.decoding_method; } diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index f39d5ebee..0c8d22f3d 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -88,7 +88,9 @@ func sherpaOnnxOnlineModelConfig( numThreads: Int = 1, provider: String = "cpu", debug: Int = 0, - modelType: String = "" + modelType: String = "", + modelingUnit: String = "cjkchar", + bpeVocab: String = "" ) -> SherpaOnnxOnlineModelConfig { return SherpaOnnxOnlineModelConfig( transducer: transducer, @@ -98,7 +100,9 @@ func sherpaOnnxOnlineModelConfig( num_threads: Int32(numThreads), provider: toCPointer(provider), debug: Int32(debug), - model_type: toCPointer(modelType) + model_type: toCPointer(modelType), + modeling_unit: toCPointer(modelingUnit), + bpeVocab: toCPointer(bpeVocab) ) } @@ -354,7 +358,9 @@ func sherpaOnnxOfflineModelConfig( numThreads: Int = 1, provider: String = "cpu", debug: Int = 0, - modelType: String = "" + modelType: String = "", + modelingUnit: String = "cjkchar", + bpeVocab: String = "" ) -> SherpaOnnxOfflineModelConfig { return SherpaOnnxOfflineModelConfig( transducer: transducer, @@ -366,7 +372,9 @@ func sherpaOnnxOfflineModelConfig( num_threads: Int32(numThreads), debug: Int32(debug), provider: toCPointer(provider), - model_type: toCPointer(modelType) + model_type: toCPointer(modelType), + modeling_unit: toCPointer(modelingUnit), + bpeVocab: toCPointer(bpeVocab) ) } diff --git a/wasm/asr/sherpa-onnx-asr.js b/wasm/asr/sherpa-onnx-asr.js index d68b22e20..c77794a68 100644 --- a/wasm/asr/sherpa-onnx-asr.js +++ b/wasm/asr/sherpa-onnx-asr.js @@ -137,7 +137,7 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { const ctc = initSherpaOnnxOnlineZipformer2CtcModelConfig( config.zipformer2Ctc, Module); - const len = transducer.len + paraformer.len + ctc.len + 5 * 4; + const len = transducer.len + paraformer.len + ctc.len + 7 * 4; const ptr = Module._malloc(len); let offset = 0; @@ -153,7 +153,11 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { const tokensLen = Module.lengthBytesUTF8(config.tokens) + 1; const providerLen = Module.lengthBytesUTF8(config.provider) + 1; const modelTypeLen = Module.lengthBytesUTF8(config.modelType) + 1; - const bufferLen = tokensLen + providerLen + modelTypeLen; + const modelingUnitLen = Module.lengthBytesUTF8(config.modelingUnit || '') + 1; + const bpeVocabLen = Module.lengthBytesUTF8(config.bpeVocab || '') + 1; + + const bufferLen = + tokensLen + providerLen + modelTypeLen + modelingUnitLen + bpeVocabLen; const buffer = Module._malloc(bufferLen); offset = 0; @@ -164,6 +168,14 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { offset += providerLen; Module.stringToUTF8(config.modelType, buffer + offset, modelTypeLen); + offset += modelTypeLen; + + Module.stringToUTF8( + config.modelingUnit || '', buffer + offset, modelingUnitLen); + offset += modelingUnitLen; + + Module.stringToUTF8(config.bpeVocab || '', buffer + offset, bpeVocabLen); + offset += bpeVocabLen; offset = transducer.len + paraformer.len + ctc.len; Module.setValue(ptr + offset, buffer, 'i8*'); // tokens @@ -182,6 +194,17 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { ptr + offset, buffer + tokensLen + providerLen, 'i8*'); // modelType offset += 4; + Module.setValue( + ptr + offset, buffer + tokensLen + providerLen + modelTypeLen, + 'i8*'); // modelingUnit + offset += 4; + + Module.setValue( + ptr + offset, + buffer + tokensLen + providerLen + modelTypeLen + modelingUnitLen, + 'i8*'); // bpeVocab + offset += 4; + return { buffer: buffer, ptr: ptr, len: len, transducer: transducer, paraformer: paraformer, ctc: ctc @@ -317,6 +340,8 @@ function createOnlineRecognizer(Module, myConfig) { provider: 'cpu', debug: 1, modelType: '', + modelingUnit: 'cjkchar', + bpeVocab: '', }; const featureConfig = { @@ -504,7 +529,7 @@ function initSherpaOnnxOfflineModelConfig(config, Module) { const tdnn = initSherpaOnnxOfflineTdnnModelConfig(config.tdnn, Module); const len = transducer.len + paraformer.len + nemoCtc.len + whisper.len + - tdnn.len + 5 * 4; + tdnn.len + 7 * 4; const ptr = Module._malloc(len); let offset = 0; @@ -526,7 +551,11 @@ function initSherpaOnnxOfflineModelConfig(config, Module) { const tokensLen = Module.lengthBytesUTF8(config.tokens) + 1; const providerLen = Module.lengthBytesUTF8(config.provider) + 1; const modelTypeLen = Module.lengthBytesUTF8(config.modelType) + 1; - const bufferLen = tokensLen + providerLen + modelTypeLen; + const modelingUnitLen = Module.lengthBytesUTF8(config.modelingUnit || '') + 1; + const bpeVocabLen = Module.lengthBytesUTF8(config.bpeVocab || '') + 1; + + const bufferLen = + tokensLen + providerLen + modelTypeLen + modelingUnitLen + bpeVocabLen; const buffer = Module._malloc(bufferLen); offset = 0; @@ -537,6 +566,14 @@ function initSherpaOnnxOfflineModelConfig(config, Module) { offset += providerLen; Module.stringToUTF8(config.modelType, buffer + offset, modelTypeLen); + offset += modelTypeLen; + + Module.stringToUTF8( + config.modelingUnit || '', buffer + offset, modelingUnitLen); + offset += modelingUnitLen; + + Module.stringToUTF8(config.bpeVocab || '', buffer + offset, bpeVocabLen); + offset += bpeVocabLen; offset = transducer.len + paraformer.len + nemoCtc.len + whisper.len + tdnn.len; @@ -556,6 +593,17 @@ function initSherpaOnnxOfflineModelConfig(config, Module) { ptr + offset, buffer + tokensLen + providerLen, 'i8*'); // modelType offset += 4; + Module.setValue( + ptr + offset, buffer + tokensLen + providerLen + modelTypeLen, + 'i8*'); // modelingUnit + offset += 4; + + Module.setValue( + ptr + offset, + buffer + tokensLen + providerLen + modelTypeLen + modelingUnitLen, + 'i8*'); // bpeVocab + offset += 4; + return { buffer: buffer, ptr: ptr, len: len, transducer: transducer, paraformer: paraformer, nemoCtc: nemoCtc, whisper: whisper, tdnn: tdnn diff --git a/wasm/asr/sherpa-onnx-wasm-main-asr.cc b/wasm/asr/sherpa-onnx-wasm-main-asr.cc index 70d13f1c4..de0cf1430 100644 --- a/wasm/asr/sherpa-onnx-wasm-main-asr.cc +++ b/wasm/asr/sherpa-onnx-wasm-main-asr.cc @@ -19,7 +19,7 @@ static_assert(sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) == 1 * 4, ""); static_assert(sizeof(SherpaOnnxOnlineModelConfig) == sizeof(SherpaOnnxOnlineTransducerModelConfig) + sizeof(SherpaOnnxOnlineParaformerModelConfig) + - sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 5 * 4, + sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 7 * 4, ""); static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, ""); static_assert(sizeof(SherpaOnnxOnlineCtcFstDecoderConfig) == 2 * 4, ""); @@ -52,6 +52,8 @@ void MyPrint(SherpaOnnxOnlineRecognizerConfig *config) { fprintf(stdout, "provider: %s\n", model_config->provider); fprintf(stdout, "debug: %d\n", model_config->debug); fprintf(stdout, "model type: %s\n", model_config->model_type); + fprintf(stdout, "modeling unit: %s\n", model_config->modeling_unit); + fprintf(stdout, "bpe vocab: %s\n", model_config->bpe_vocab); fprintf(stdout, "----------feat config----------\n"); fprintf(stdout, "sample rate: %d\n", feat->sample_rate); diff --git a/wasm/nodejs/sherpa-onnx-wasm-nodejs.cc b/wasm/nodejs/sherpa-onnx-wasm-nodejs.cc index 539699cc4..ceb5a2442 100644 --- a/wasm/nodejs/sherpa-onnx-wasm-nodejs.cc +++ b/wasm/nodejs/sherpa-onnx-wasm-nodejs.cc @@ -23,7 +23,7 @@ static_assert(sizeof(SherpaOnnxOfflineModelConfig) == sizeof(SherpaOnnxOfflineParaformerModelConfig) + sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) + sizeof(SherpaOnnxOfflineWhisperModelConfig) + - sizeof(SherpaOnnxOfflineTdnnModelConfig) + 5 * 4, + sizeof(SherpaOnnxOfflineTdnnModelConfig) + 7 * 4, ""); static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, ""); static_assert(sizeof(SherpaOnnxOfflineRecognizerConfig) == @@ -90,6 +90,8 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) { fprintf(stdout, "provider: %s\n", model_config->provider); fprintf(stdout, "debug: %d\n", model_config->debug); fprintf(stdout, "model type: %s\n", model_config->model_type); + fprintf(stdout, "modeling unit: %s\n", model_config->modeling_unit); + fprintf(stdout, "bpe vocab: %s\n", model_config->bpe_vocab); fprintf(stdout, "----------feat config----------\n"); fprintf(stdout, "sample rate: %d\n", feat->sample_rate);