Fangjun Kuang
Committed by GitHub

Add C++ runtime for SenseVoice models (#1148)

正在显示 34 个修改的文件 包含 1160 行增加39 行删除
... ... @@ -15,7 +15,30 @@ echo "PATH: $PATH"
which $EXE
if false; then
log "------------------------------------------------------------"
log "Run SenseVoice models"
log "------------------------------------------------------------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
repo=sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17
for m in model.onnx model.int8.onnx; do
for w in zh en yue ja ko; do
for use_itn in 0 1; do
echo "$m $w $use_itn"
time $EXE \
--tokens=$repo/tokens.txt \
--sense-voice-model=$repo/$m \
--sense-voice-use-itn=$use_itn \
$repo/test_wavs/$w.wav
done
done
done
rm -rf $repo
if true; then
# It has problems with onnxruntime 1.18
log "------------------------------------------------------------"
log "Run Wenet models"
... ...
... ... @@ -10,6 +10,18 @@ log() {
export GIT_CLONE_PROTECTION_ACTIVE=false
log "test offline SenseVoice CTC"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
name=$(basename $url)
repo=$(basename -s .tar.bz2 $name)
curl -SL -O $url
tar xvf $name
rm $name
ls -lh $repo
python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py
rm -rf $repo
log "test offline TeleSpeech CTC"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
name=$(basename $url)
... ...
... ... @@ -73,7 +73,7 @@ jobs:
echo "pwd: $PWD"
ls -lh ../scripts/sense-voice
rm -rf ./
rm -rf ./*
cp -v ../scripts/sense-voice/*.onnx .
cp -v ../scripts/sense-voice/tokens.txt .
... ...
... ... @@ -111,3 +111,4 @@ sherpa-onnx-telespeech-ctc-*
*.fst
.ccache
lib*.a
sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17
... ...
## 1.10.17
* Support SenseVoice CTC models.
## 1.10.16
* Support zh-en TTS model from MeloTTS.
... ...
... ... @@ -11,7 +11,7 @@ project(sherpa-onnx)
# ./nodejs-addon-examples
# ./dart-api-examples/
# ./CHANGELOG.md
set(SHERPA_ONNX_VERSION "1.10.16")
set(SHERPA_ONNX_VERSION "1.10.17")
# Disable warning about
#
... ...
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming SenseVoice CTC model from
https://github.com/FunAudioLLM/SenseVoice
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
For instance,
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"
tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"
test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav"
# test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/en.wav"
# test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ja.wav"
# test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ko.wav"
# test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/yue.wav"
if not Path(model).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_sense_voice(
model=model,
tokens=tokens,
use_itn=True,
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print(wave_filename)
print(stream.result)
if __name__ == "__main__":
main()
... ...
... ... @@ -162,7 +162,9 @@ def main():
"neg_mean": neg_mean,
"inv_stddev": inv_stddev,
"model_type": "sense_voice_ctc",
"version": "1",
# version 1: Use QInt8
# version 2: Use QUInt8
"version": "2",
"model_author": "iic",
"maintainer": "k2-fsa",
"vocab_size": vocab_size,
... ... @@ -185,7 +187,10 @@ def main():
model_input=filename,
model_output=filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
# Note that we have to use QUInt8 here.
#
# When QInt8 is used, C++ onnxruntime produces incorrect results
weight_type=QuantType.QUInt8,
)
... ...
... ... @@ -310,6 +310,7 @@ struct SherpaOnnxOfflineStream {
static sherpa_onnx::OfflineRecognizerConfig convertConfig(
const SherpaOnnxOfflineRecognizerConfig *config);
SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
const SherpaOnnxOfflineRecognizerConfig *config) {
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
... ... @@ -391,6 +392,15 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
recognizer_config.model_config.telespeech_ctc =
SHERPA_ONNX_OR(config->model_config.telespeech_ctc, "");
recognizer_config.model_config.sense_voice.model =
SHERPA_ONNX_OR(config->model_config.sense_voice.model, "");
recognizer_config.model_config.sense_voice.language =
SHERPA_ONNX_OR(config->model_config.sense_voice.language, "");
recognizer_config.model_config.sense_voice.use_itn =
config->model_config.sense_voice.use_itn;
recognizer_config.lm_config.model =
SHERPA_ONNX_OR(config->lm_config.model, "");
recognizer_config.lm_config.scale =
... ...
... ... @@ -379,6 +379,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineLMConfig {
float scale;
} SherpaOnnxOfflineLMConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSenseVoiceModelConfig {
const char *model;
const char *language;
int32_t use_itn;
} SherpaOnnxOfflineSenseVoiceModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
SherpaOnnxOfflineTransducerModelConfig transducer;
SherpaOnnxOfflineParaformerModelConfig paraformer;
... ... @@ -398,6 +404,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
const char *modeling_unit;
const char *bpe_vocab;
const char *telespeech_ctc;
SherpaOnnxOfflineSenseVoiceModelConfig sense_voice;
} SherpaOnnxOfflineModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {
... ...
... ... @@ -36,6 +36,8 @@ set(sources
offline-recognizer-impl.cc
offline-recognizer.cc
offline-rnn-lm.cc
offline-sense-voice-model-config.cc
offline-sense-voice-model.cc
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
... ...
// sherpa-onnx/csrc/offline-ct-transformer-model-meta_data.h
// sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
... ...
... ... @@ -93,6 +93,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
const OfflineModelConfig &config) {
// TODO(fangjun): Refactor it. We don't need to use model_type here
ModelType model_type = ModelType::kUnknown;
std::string filename;
... ... @@ -148,6 +149,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
AAssetManager *mgr, const OfflineModelConfig &config) {
// TODO(fangjun): Refactor it. We don't need to use model_type here
ModelType model_type = ModelType::kUnknown;
std::string filename;
... ...
... ... @@ -18,6 +18,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
tdnn.Register(po);
zipformer_ctc.Register(po);
wenet_ctc.Register(po);
sense_voice.Register(po);
po->Register("telespeech-ctc", &telespeech_ctc,
"Path to model.onnx for telespeech ctc");
... ... @@ -94,15 +95,21 @@ bool OfflineModelConfig::Validate() const {
return wenet_ctc.Validate();
}
if (!sense_voice.model.empty()) {
return sense_voice.Validate();
}
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
telespeech_ctc.c_str());
return false;
} else {
return true;
}
return transducer.Validate();
if (!transducer.encoder_filename.empty()) {
return transducer.Validate();
}
return true;
}
std::string OfflineModelConfig::ToString() const {
... ... @@ -116,6 +123,7 @@ std::string OfflineModelConfig::ToString() const {
os << "tdnn=" << tdnn.ToString() << ", ";
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "sense_voice=" << sense_voice.ToString() << ", ";
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
... ...
... ... @@ -8,6 +8,7 @@
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
... ... @@ -24,6 +25,7 @@ struct OfflineModelConfig {
OfflineTdnnModelConfig tdnn;
OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc;
OfflineSenseVoiceModelConfig sense_voice;
std::string telespeech_ctc;
std::string tokens;
... ... @@ -53,6 +55,7 @@ struct OfflineModelConfig {
const OfflineTdnnModelConfig &tdnn,
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const OfflineSenseVoiceModelConfig &sense_voice,
const std::string &telespeech_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type,
... ... @@ -65,6 +68,7 @@ struct OfflineModelConfig {
tdnn(tdnn),
zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc),
sense_voice(sense_voice),
telespeech_ctc(telespeech_ctc),
tokens(tokens),
num_threads(num_threads),
... ...
... ... @@ -212,10 +212,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig GetConfig() const override {
return config_;
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
private:
// Decode a single stream.
... ...
... ... @@ -21,6 +21,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
... ... @@ -31,6 +32,28 @@ namespace sherpa_onnx {
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
const OfflineRecognizerConfig &config) {
if (!config.model_config.sense_voice.model.empty()) {
return std::make_unique<OfflineRecognizerSenseVoiceImpl>(config);
}
if (!config.model_config.paraformer.model.empty()) {
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
}
if (!config.model_config.nemo_ctc.model.empty() ||
!config.model_config.zipformer_ctc.model.empty() ||
!config.model_config.tdnn.model.empty() ||
!config.model_config.wenet_ctc.model.empty()) {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}
if (!config.model_config.whisper.encoder.empty()) {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
}
// TODO(fangjun): Refactor it. We only need to use model type for the
// following models:
// 1. transducer and nemo_transducer
if (!config.model_config.model_type.empty()) {
const auto &model_type = config.model_config.model_type;
if (model_type == "transducer") {
... ... @@ -180,6 +203,28 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
AAssetManager *mgr, const OfflineRecognizerConfig &config) {
if (!config.model_config.sense_voice.model.empty()) {
return std::make_unique<OfflineRecognizerSenseVoiceImpl>(mgr, config);
}
if (!config.model_config.paraformer.model.empty()) {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
}
if (!config.model_config.nemo_ctc.model.empty() ||
!config.model_config.zipformer_ctc.model.empty() ||
!config.model_config.tdnn.model.empty() ||
!config.model_type.wenet_ctc.model.empty()) {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
if (!config.model_config.whisper.encoder.empty()) {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
}
// TODO(fangjun): Refactor it. We only need to use model type for the
// following models:
// 1. transducer and nemo_transducer
if (!config.model_config.model_type.empty()) {
const auto &model_type = config.model_config.model_type;
if (model_type == "transducer") {
... ...
... ... @@ -102,9 +102,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
exit(-1);
}
// Paraformer models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
InitFeatConfig();
}
#if __ANDROID_API__ >= 9
... ... @@ -124,9 +122,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
exit(-1);
}
// Paraformer models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
InitFeatConfig();
}
#endif
... ... @@ -211,11 +207,18 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig GetConfig() const override {
return config_;
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
private:
void InitFeatConfig() {
// Paraformer models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
config_.feat_config.window_type = "hamming";
config_.feat_config.high_freq = 0;
config_.feat_config.snip_edges = true;
}
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
int32_t lfr_window_size = model_->LfrWindowSize();
int32_t lfr_window_shift = model_->LfrWindowShift();
... ...
// sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-sense-voice-model.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace sherpa_onnx {
static OfflineRecognitionResult ConvertSenseVoiceResult(
const OfflineCtcDecoderResult &src, const SymbolTable &sym_table,
int32_t frame_shift_ms, int32_t subsampling_factor) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.timestamps.size());
std::string text;
for (int32_t i = 4; i < src.tokens.size(); ++i) {
auto sym = sym_table[src.tokens[i]];
text.append(sym);
r.tokens.push_back(std::move(sym));
}
r.text = std::move(text);
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (int32_t i = 4; i < src.timestamps.size(); ++i) {
float time = frame_shift_s * (src.timestamps[i] - 4);
r.timestamps.push_back(time);
}
r.words = std::move(src.words);
return r;
}
class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerSenseVoiceImpl(
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineSenseVoiceModel>(config.model_config)) {
const auto &meta_data = model_->GetModelMetadata();
if (config.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineCtcGreedySearchDecoder>(meta_data.blank_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
exit(-1);
}
InitFeatConfig();
}
#if __ANDROID_API__ >= 9
OfflineRecognizerSenseVoiceImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineSenseVoiceModel>(mgr,
config.model_config)) {
const auto &meta_data = model_->GetModelMetadata();
if (config.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineCtcGreedySearchDecoder>(meta_data.blank_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
exit(-1);
}
InitFeatConfig();
}
#endif
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
if (n == 1) {
DecodeOneStream(ss[0]);
return;
}
const auto &meta_data = model_->GetModelMetadata();
// 1. Apply LFR
// 2. Apply CMVN
//
// Please refer to
// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf
// for what LFR means
//
// "Lower Frame Rate Neural Network Acoustic Models"
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<Ort::Value> features;
features.reserve(n);
int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size;
std::vector<std::vector<float>> features_vec(n);
std::vector<int32_t> features_length_vec(n);
for (int32_t i = 0; i != n; ++i) {
std::vector<float> f = ss[i]->GetFrames();
f = ApplyLFR(f);
ApplyCMVN(&f);
int32_t num_frames = f.size() / feat_dim;
features_vec[i] = std::move(f);
features_length_vec[i] = num_frames;
std::array<int64_t, 2> shape = {num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(
memory_info, features_vec[i].data(), features_vec[i].size(),
shape.data(), shape.size());
features.push_back(std::move(x));
}
std::vector<const Ort::Value *> features_pointer(n);
for (int32_t i = 0; i != n; ++i) {
features_pointer[i] = &features[i];
}
std::array<int64_t, 1> features_length_shape = {n};
Ort::Value x_length = Ort::Value::CreateTensor(
memory_info, features_length_vec.data(), n,
features_length_shape.data(), features_length_shape.size());
// Caution(fangjun): We cannot pad it with log(eps),
// i.e., -23.025850929940457f
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
int32_t language = 0;
if (config_.model_config.sense_voice.language.empty()) {
language = 0;
} else if (meta_data.lang2id.count(
config_.model_config.sense_voice.language)) {
language =
meta_data.lang2id.at(config_.model_config.sense_voice.language);
} else {
SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.",
config_.model_config.sense_voice.language.c_str());
}
std::vector<int32_t> language_array(n);
std::fill(language_array.begin(), language_array.end(), language);
std::vector<int32_t> text_norm_array(n);
std::fill(text_norm_array.begin(), text_norm_array.end(),
config_.model_config.sense_voice.use_itn
? meta_data.with_itn_id
: meta_data.without_itn_id);
Ort::Value language_tensor = Ort::Value::CreateTensor(
memory_info, language_array.data(), n, features_length_shape.data(),
features_length_shape.size());
Ort::Value text_norm_tensor = Ort::Value::CreateTensor(
memory_info, text_norm_array.data(), n, features_length_shape.data(),
features_length_shape.size());
Ort::Value logits{nullptr};
try {
logits = model_->Forward(std::move(x), std::move(x_length),
std::move(language_tensor),
std::move(text_norm_tensor));
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result",
ex.what());
return;
}
// decoder_->Decode() requires that logits_length is of dtype int64
std::vector<int64_t> features_length_vec_64;
features_length_vec_64.reserve(n);
for (auto i : features_length_vec) {
i += 4;
features_length_vec_64.push_back(i);
}
Ort::Value logits_length = Ort::Value::CreateTensor(
memory_info, features_length_vec_64.data(), n,
features_length_shape.data(), features_length_shape.size());
auto results =
decoder_->Decode(std::move(logits), std::move(logits_length));
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = meta_data.window_shift;
for (int32_t i = 0; i != n; ++i) {
auto r = ConvertSenseVoiceResult(results[i], symbol_table_,
frame_shift_ms, subsampling_factor);
r.text = ApplyInverseTextNormalization(std::move(r.text));
ss[i]->SetResult(r);
}
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
private:
void DecodeOneStream(OfflineStream *s) const {
const auto &meta_data = model_->GetModelMetadata();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size;
std::vector<float> f = s->GetFrames();
f = ApplyLFR(f);
ApplyCMVN(&f);
int32_t num_frames = f.size() / feat_dim;
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
int64_t scale_shape = 1;
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &num_frames, 1, &scale_shape, 1);
int32_t language = 0;
if (config_.model_config.sense_voice.language.empty()) {
language = 0;
} else if (meta_data.lang2id.count(
config_.model_config.sense_voice.language)) {
language =
meta_data.lang2id.at(config_.model_config.sense_voice.language);
} else {
SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.",
config_.model_config.sense_voice.language.c_str());
}
int32_t text_norm = config_.model_config.sense_voice.use_itn
? meta_data.with_itn_id
: meta_data.without_itn_id;
Ort::Value language_tensor =
Ort::Value::CreateTensor(memory_info, &language, 1, &scale_shape, 1);
Ort::Value text_norm_tensor =
Ort::Value::CreateTensor(memory_info, &text_norm, 1, &scale_shape, 1);
Ort::Value logits{nullptr};
try {
logits = model_->Forward(std::move(x), std::move(x_length),
std::move(language_tensor),
std::move(text_norm_tensor));
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result",
ex.what());
return;
}
int64_t new_num_frames = num_frames + 4;
Ort::Value logits_length = Ort::Value::CreateTensor(
memory_info, &new_num_frames, 1, &scale_shape, 1);
auto results =
decoder_->Decode(std::move(logits), std::move(logits_length));
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = meta_data.window_shift;
auto r = ConvertSenseVoiceResult(results[0], symbol_table_, frame_shift_ms,
subsampling_factor);
r.text = ApplyInverseTextNormalization(std::move(r.text));
s->SetResult(r);
}
void InitFeatConfig() {
const auto &meta_data = model_->GetModelMetadata();
config_.feat_config.normalize_samples = meta_data.normalize_samples;
config_.feat_config.window_type = "hamming";
config_.feat_config.high_freq = 0;
config_.feat_config.snip_edges = true;
}
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
const auto &meta_data = model_->GetModelMetadata();
int32_t lfr_window_size = meta_data.window_size;
int32_t lfr_window_shift = meta_data.window_shift;
int32_t in_feat_dim = config_.feat_config.feature_dim;
int32_t in_num_frames = in.size() / in_feat_dim;
int32_t out_num_frames =
(in_num_frames - lfr_window_size) / lfr_window_shift + 1;
int32_t out_feat_dim = in_feat_dim * lfr_window_size;
std::vector<float> out(out_num_frames * out_feat_dim);
const float *p_in = in.data();
float *p_out = out.data();
for (int32_t i = 0; i != out_num_frames; ++i) {
std::copy(p_in, p_in + out_feat_dim, p_out);
p_out += out_feat_dim;
p_in += lfr_window_shift * in_feat_dim;
}
return out;
}
void ApplyCMVN(std::vector<float> *v) const {
const auto &meta_data = model_->GetModelMetadata();
const std::vector<float> &neg_mean = meta_data.neg_mean;
const std::vector<float> &inv_stddev = meta_data.inv_stddev;
int32_t dim = neg_mean.size();
int32_t num_frames = v->size() / dim;
float *p = v->data();
for (int32_t i = 0; i != num_frames; ++i) {
for (int32_t k = 0; k != dim; ++k) {
p[k] = (p[k] + neg_mean[k]) * inv_stddev[k];
}
p += dim;
}
}
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineSenseVoiceModel> model_;
std::unique_ptr<OfflineCtcDecoder> decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-sense-voice-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineSenseVoiceModelConfig::Register(ParseOptions *po) {
po->Register("sense-voice-model", &model,
"Path to model.onnx of SenseVoice.");
po->Register(
"sense-voice-language", &language,
"Valid values: auto, zh, en, ja, ko, yue. If left empty, auto is used");
po->Register(
"sense-voice-use-itn", &use_itn,
"True to enable inverse text normalization. False to disable it.");
}
bool OfflineSenseVoiceModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("SenseVoice model '%s' does not exist", model.c_str());
return false;
}
if (!language.empty()) {
if (language != "auto" && language != "zh" && language != "en" &&
language != "ja" && language != "ko" && language != "yue") {
SHERPA_ONNX_LOGE(
"Invalid sense-voice-language: '%s'. Valid values are: auto, zh, en, "
"ja, ko, yue. Or you can leave it empty to use 'auto'",
language.c_str());
return false;
}
}
return true;
}
std::string OfflineSenseVoiceModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineSenseVoiceModelConfig(";
os << "model=\"" << model << "\", ";
os << "language=\"" << language << "\", ";
os << "use_itn=" << (use_itn ? "True" : "False") << ")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-sense-voice-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineSenseVoiceModelConfig {
std::string model;
// "" or "auto" to let the model recognize the language
// valid values:
// zh, en, ja, ko, yue, auto
std::string language = "auto";
// true to use inverse text normalization
// false to not use inverse text normalization
bool use_itn = false;
OfflineSenseVoiceModelConfig() = default;
explicit OfflineSenseVoiceModelConfig(const std::string &model,
const std::string &language,
bool use_itn)
: model(model), language(language), use_itn(use_itn) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_
#include <string>
#include <unordered_map>
#include <vector>
namespace sherpa_onnx {
struct OfflineSenseVoiceModelMetaData {
// ID for using inverse text normalization
int32_t with_itn_id;
// ID for not using inverse text normalization
int32_t without_itn_id;
int32_t window_size; // lfr_m
int32_t window_shift; // lfr_n
int32_t vocab_size;
int32_t subsampling_factor = 1;
// Usually 0 for SenseVoice models.
// 0 means samples are scaled to [-32768, 32767] before are sent to the
// feature extractor
int32_t normalize_samples = 0;
int32_t blank_id = 0;
// possible values:
// zh, en, ja, ko, yue, auto
// where
// zh is Chinese (Mandarin)
// en is English
// ja is Japanese
// ko is Korean
// yue is Cantonese
// auto is to let the model recognize the language
std::unordered_map<std::string, int32_t> lang2id;
std::vector<float> neg_mean;
std::vector<float> inv_stddev;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_
... ...
// sherpa-onnx/csrc/offline-sense-voice-model.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-sense-voice-model.h"
#include <algorithm>
#include <string>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineSenseVoiceModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.sense_voice.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.sense_voice.model);
Init(buf.data(), buf.size());
}
#endif
Ort::Value Forward(Ort::Value features, Ort::Value features_length,
Ort::Value language, Ort::Value text_norm) {
std::array<Ort::Value, 4> inputs = {
std::move(features),
std::move(features_length),
std::move(language),
std::move(text_norm),
};
auto ans =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(ans[0]);
}
const OfflineSenseVoiceModelMetaData &GetModelMetadata() const {
return meta_data_;
}
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(meta_data_.vocab_size, "vocab_size");
SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "lfr_window_size");
SHERPA_ONNX_READ_META_DATA(meta_data_.window_shift, "lfr_window_shift");
SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_samples,
"normalize_samples");
SHERPA_ONNX_READ_META_DATA(meta_data_.with_itn_id, "with_itn");
SHERPA_ONNX_READ_META_DATA(meta_data_.without_itn_id, "without_itn");
int32_t lang_auto = 0;
int32_t lang_zh = 0;
int32_t lang_en = 0;
int32_t lang_ja = 0;
int32_t lang_ko = 0;
int32_t lang_yue = 0;
SHERPA_ONNX_READ_META_DATA(lang_auto, "lang_auto");
SHERPA_ONNX_READ_META_DATA(lang_zh, "lang_zh");
SHERPA_ONNX_READ_META_DATA(lang_en, "lang_en");
SHERPA_ONNX_READ_META_DATA(lang_ja, "lang_ja");
SHERPA_ONNX_READ_META_DATA(lang_ko, "lang_ko");
SHERPA_ONNX_READ_META_DATA(lang_yue, "lang_yue");
meta_data_.lang2id = {
{"auto", lang_auto}, {"zh", lang_zh}, {"ja", lang_ja},
{"ko", lang_ko}, {"yue", lang_yue},
};
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.neg_mean, "neg_mean");
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev, "inv_stddev");
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
OfflineSenseVoiceModelMetaData meta_data_;
};
OfflineSenseVoiceModel::OfflineSenseVoiceModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineSenseVoiceModel::OfflineSenseVoiceModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineSenseVoiceModel::~OfflineSenseVoiceModel() = default;
Ort::Value OfflineSenseVoiceModel::Forward(Ort::Value features,
Ort::Value features_length,
Ort::Value language,
Ort::Value text_norm) const {
return impl_->Forward(std::move(features), std::move(features_length),
std::move(language), std::move(text_norm));
}
const OfflineSenseVoiceModelMetaData &OfflineSenseVoiceModel::GetModelMetadata()
const {
return impl_->GetModelMetadata();
}
OrtAllocator *OfflineSenseVoiceModel::Allocator() const {
return impl_->Allocator();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-sense-voice-model.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_
#include <memory>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h"
namespace sherpa_onnx {
class OfflineSenseVoiceModel {
public:
explicit OfflineSenseVoiceModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineSenseVoiceModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineSenseVoiceModel();
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int32_t.
* @param language A 1-D tensor of shape (N,) with dtype int32_t
* @param text_norm A 1-D tensor of shape (N,) with dtype int32_t
*
* @return Return logits of shape (N, T, C) with dtype float
*
* Note: The subsampling factor is 1 for SenseVoice, so there is
* no need to output logits_length.
*/
Ort::Value Forward(Ort::Value features, Ort::Value features_length,
Ort::Value language, Ort::Value text_norm) const;
const OfflineSenseVoiceModelMetaData &GetModelMetadata() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_
... ...
... ... @@ -6,6 +6,8 @@
#include <algorithm>
#include <fstream>
#include <functional>
#include <numeric>
#include <sstream>
#include <string>
... ... @@ -153,23 +155,60 @@ Ort::Value View(Ort::Value *v) {
}
}
float ComputeSum(const Ort::Value *v, int32_t n /*= -1*/) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
auto size = static_cast<int32_t>(std::accumulate(
shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
if (n != -1 && n < size && n > 0) {
size = n;
}
const float *p = v->GetTensorData<float>();
return std::accumulate(p, p + size, 1.0f);
}
float ComputeMean(const Ort::Value *v, int32_t n /*= -1*/) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
auto size = static_cast<int32_t>(std::accumulate(
shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
if (n != -1 && n < size && n > 0) {
size = n;
}
auto sum = ComputeSum(v, n);
return sum / size;
}
void PrintShape(const Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
std::ostringstream os;
for (auto i : shape) {
os << i << ", ";
}
os << "\n";
fprintf(stderr, "%s", os.str().c_str());
}
template <typename T /*= float*/>
void Print1D(Ort::Value *v) {
void Print1D(const Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const T *d = v->GetTensorData<T>();
std::ostringstream os;
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
os << *d << " ";
os << d[i] << " ";
}
os << "\n";
fprintf(stderr, "%s\n", os.str().c_str());
}
template void Print1D<int64_t>(Ort::Value *v);
template void Print1D<float>(Ort::Value *v);
template void Print1D<int64_t>(const Ort::Value *v);
template void Print1D<int32_t>(const Ort::Value *v);
template void Print1D<float>(const Ort::Value *v);
template <typename T /*= float*/>
void Print2D(Ort::Value *v) {
void Print2D(const Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const T *d = v->GetTensorData<T>();
... ... @@ -183,10 +222,10 @@ void Print2D(Ort::Value *v) {
fprintf(stderr, "%s\n", os.str().c_str());
}
template void Print2D<int64_t>(Ort::Value *v);
template void Print2D<float>(Ort::Value *v);
template void Print2D<int64_t>(const Ort::Value *v);
template void Print2D<float>(const Ort::Value *v);
void Print3D(Ort::Value *v) {
void Print3D(const Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
... ... @@ -202,7 +241,7 @@ void Print3D(Ort::Value *v) {
fprintf(stderr, "\n");
}
void Print4D(Ort::Value *v) {
void Print4D(const Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
... ...
... ... @@ -68,19 +68,24 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
// Return a shallow copy
Ort::Value View(Ort::Value *v);
float ComputeSum(const Ort::Value *v, int32_t n = -1);
float ComputeMean(const Ort::Value *v, int32_t n = -1);
// Print a 1-D tensor to stderr
template <typename T = float>
void Print1D(Ort::Value *v);
void Print1D(const Ort::Value *v);
// Print a 2-D tensor to stderr
template <typename T = float>
void Print2D(Ort::Value *v);
void Print2D(const Ort::Value *v);
// Print a 3-D tensor to stderr
void Print3D(Ort::Value *v);
void Print3D(const Ort::Value *v);
// Print a 4-D tensor to stderr
void Print4D(Ort::Value *v);
void Print4D(const Ort::Value *v);
void PrintShape(const Ort::Value *v);
template <typename T = float>
void Fill(Ort::Value *tensor, T value) {
... ...
... ... @@ -15,6 +15,7 @@ set(srcs
offline-paraformer-model-config.cc
offline-punctuation.cc
offline-recognizer.cc
offline-sense-voice-model-config.cc
offline-stream.cc
offline-tdnn-model-config.cc
offline-transducer-model-config.cc
... ...
... ... @@ -10,6 +10,7 @@
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
... ... @@ -26,6 +27,7 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineTdnnModelConfig(m);
PybindOfflineZipformerCtcModelConfig(m);
PybindOfflineWenetCtcModelConfig(m);
PybindOfflineSenseVoiceModelConfig(m);
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
... ... @@ -36,7 +38,8 @@ void PybindOfflineModelConfig(py::module *m) {
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &, const std::string &,
const OfflineWenetCtcModelConfig &,
const OfflineSenseVoiceModelConfig &, const std::string &,
const std::string &, int32_t, bool, const std::string &,
const std::string &, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
... ... @@ -46,6 +49,7 @@ void PybindOfflineModelConfig(py::module *m) {
py::arg("tdnn") = OfflineTdnnModelConfig(),
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
py::arg("telespeech_ctc") = "", py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "",
... ... @@ -57,6 +61,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("tdnn", &PyClass::tdnn)
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("sense_voice", &PyClass::sense_voice)
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
... ...
... ... @@ -14,6 +14,7 @@ namespace sherpa_onnx {
void PybindOfflineParaformerModelConfig(py::module *m) {
using PyClass = OfflineParaformerModelConfig;
py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
.def(py::init<>())
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
... ...
// sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
namespace sherpa_onnx {
void PybindOfflineSenseVoiceModelConfig(py::module *m) {
using PyClass = OfflineSenseVoiceModelConfig;
py::class_<PyClass>(*m, "OfflineSenseVoiceModelConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &, bool>(),
py::arg("model"), py::arg("language"), py::arg("use_itn"))
.def_readwrite("model", &PyClass::model)
.def_readwrite("language", &PyClass::language)
.def_readwrite("use_itn", &PyClass::use_itn)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-sense-voice-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineSenseVoiceModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
... ...
... ... @@ -10,6 +10,7 @@ from _sherpa_onnx import (
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
OfflineSenseVoiceModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
... ... @@ -174,6 +175,88 @@ class OfflineRecognizer(object):
return self
@classmethod
def from_sense_voice(
cls,
model: str,
tokens: str,
num_threads: int = 1,
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
language: str = "",
use_itn: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
):
"""
Please refer to
`<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_
to download pre-trained models.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
language:
If not empty, then valid values are: auto, zh, en, ja, ko, yue
use_itn:
True to enable inverse text normalization; False to disable it.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
sense_voice=OfflineSenseVoiceModelConfig(
model=model,
language=language,
use_itn=use_itn,
),
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_paraformer(
cls,
paraformer: str,
... ...
... ... @@ -355,6 +355,18 @@ func sherpaOnnxOfflineTdnnModelConfig(
)
}
func sherpaOnnxOfflineSenseVoiceModelConfig(
model: String = "",
language: String = "",
useInverseTextNormalization: Bool = false
) -> SherpaOnnxOfflineSenseVoiceModelConfig {
return SherpaOnnxOfflineSenseVoiceModelConfig(
model: toCPointer(model),
language: toCPointer(language),
use_itn: useInverseTextNormalization ? 1 : 0
)
}
func sherpaOnnxOfflineLMConfig(
model: String = "",
scale: Float = 1.0
... ... @@ -378,7 +390,8 @@ func sherpaOnnxOfflineModelConfig(
modelType: String = "",
modelingUnit: String = "cjkchar",
bpeVocab: String = "",
teleSpeechCtc: String = ""
teleSpeechCtc: String = "",
senseVoice: SherpaOnnxOfflineSenseVoiceModelConfig = sherpaOnnxOfflineSenseVoiceModelConfig()
) -> SherpaOnnxOfflineModelConfig {
return SherpaOnnxOfflineModelConfig(
transducer: transducer,
... ... @@ -393,7 +406,8 @@ func sherpaOnnxOfflineModelConfig(
model_type: toCPointer(modelType),
modeling_unit: toCPointer(modelingUnit),
bpe_vocab: toCPointer(bpeVocab),
telespeech_ctc: toCPointer(teleSpeechCtc)
telespeech_ctc: toCPointer(teleSpeechCtc),
sense_voice: senseVoice
)
}
... ...
... ... @@ -17,6 +17,7 @@ func run() {
var modelConfig: SherpaOnnxOfflineModelConfig
var modelType = "whisper"
// modelType = "paraformer"
// modelType = "sense_voice"
if modelType == "whisper" {
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
... ... @@ -47,6 +48,19 @@ func run() {
debug: 0,
modelType: "paraformer"
)
} else if modelType == "sense_voice" {
let model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"
let tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"
let senseVoiceConfig = sherpaOnnxOfflineSenseVoiceModelConfig(
model: model,
useInverseTextNormalization: true
)
modelConfig = sherpaOnnxOfflineModelConfig(
tokens: tokens,
debug: 0,
senseVoice: senseVoiceConfig
)
} else {
print("Please specify a supported modelType \(modelType)")
return
... ... @@ -63,7 +77,10 @@ func run() {
recognizer = SherpaOnnxOfflineRecognizer(config: &config)
let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
var filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
if modelType == "sense_voice" {
filePath = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav"
}
let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
let audioFile = try! AVAudioFile(forReading: fileURL as URL)
... ...