Fangjun Kuang
Committed by GitHub

Support spoken language identification with whisper (#694)

正在显示 36 个修改的文件 包含 1173 行增加200 行删除
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
names=(
tiny
base
small
medium
)
# all_language_codes=bo,ml,tt,fa,sl,bg,sn,sr,tl,km,ln,mr,hr,eu,ro,ba,bs,pl,as,nn,sk,ko,oc,ar,uz,pa,tg,mk,kk,hi,ha,uk,is,de,el,ja,yo,be,so,tk,id,sa,ru,yi,en,am,cs,ne,la,sv,su,pt,mi,ca,sd,hy,haw,fi,et,kn,da,lt,it,nl,he,mg,ur,tr,af,br,bn,ta,no,my,si,mt,th,gl,sw,mn,jw,ms,ps,fo,ka,hu,zh,ht,az,fr,lo,sq,gu,cy,lv,es,lb,te,vi
log "Download test waves"
waves=(
ar-arabic.wav
bg-bulgarian.wav
cs-czech.wav
da-danish.wav
de-german.wav
el-greek.wav
en-english.wav
es-spanish.wav
fa-persian.wav
fi-finnish.wav
fr-french.wav
hi-hindi.wav
hr-croatian.wav
id-indonesian.wav
it-italian.wav
ja-japanese.wav
ko-korean.wav
nl-dutch.wav
no-norwegian.wav
po-polish.wav
pt-portuguese.wav
ro-romanian.wav
ru-russian.wav
sk-slovak.wav
sv-swedish.wav
ta-tamil.wav
tl-tagalog.wav
tr-turkish.wav
uk-ukrainian.wav
zh-chinese.wav
)
for wav in ${waves[@]}; do
echo "Downloading $wav"
curl -SL -O https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/$wav
ls -lh *.wav
done
for name in ${names[@]}; do
log "------------------------------------------------------------"
log "Run $name"
log "------------------------------------------------------------"
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-whisper-$name
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
# git lfs pull --include "*.ort"
ls -lh *.onnx
popd
for wav in ${waves[@]}; do
log "test fp32 onnx"
time $EXE \
--whisper-encoder=$repo/${name}-encoder.onnx \
--whisper-decoder=$repo/${name}-decoder.onnx \
$wav
log "test int8 onnx"
time $EXE \
--whisper-encoder=$repo/${name}-encoder.int8.onnx \
--whisper-decoder=$repo/${name}-decoder.int8.onnx \
$wav
done
rm -rf $repo
done
... ...
... ... @@ -82,7 +82,6 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
shell: bash
with:
max_attempts: 20
timeout_seconds: 200
... ...
... ... @@ -21,27 +21,12 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest]
python-version: ["cp37", "cp38", "cp39", "cp310", "cp311", "cp312"]
python-version: ["cp38", "cp39", "cp310", "cp311", "cp312"]
steps:
- uses: actions/checkout@v4
# see https://cibuildwheel.readthedocs.io/en/stable/changelog/
# for a list of versions
- name: Build wheels
if: matrix.python-version == 'cp37'
uses: pypa/cibuildwheel@v2.11.4
env:
CIBW_BUILD: "${{ matrix.python-version}}-* "
CIBW_ENVIRONMENT: SHERPA_ONNX_CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES='arm64'"
CIBW_ARCHS: "arm64"
CIBW_BUILD_VERBOSITY: 3
# Don't repair macOS wheels
CIBW_REPAIR_WHEEL_COMMAND_MACOS: ""
- name: Build wheels
if: matrix.python-version != 'cp37'
uses: pypa/cibuildwheel@v2.15.0
env:
CIBW_BUILD: "${{ matrix.python-version}}-* "
... ...
... ... @@ -92,6 +92,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
- name: Test spoken language identification
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-language-identification
.github/scripts/test-spoken-language-identification.sh
- name: Test online CTC
shell: bash
run: |
... ... @@ -116,6 +124,7 @@ jobs:
.github/scripts/test-online-paraformer.sh
- name: Test offline Whisper
shell: bash
run: |
... ...
... ... @@ -123,6 +123,15 @@ jobs:
name: release-${{ matrix.build_type }}-${{ matrix.shared_lib }}
path: build/bin/*
- name: Test spoken language identification
if: matrix.build_type != 'Debug'
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-language-identification
.github/scripts/test-spoken-language-identification.sh
- name: Test transducer kws
shell: bash
run: |
... ... @@ -140,6 +149,7 @@ jobs:
.github/scripts/test-online-ctc.sh
- name: Test offline Whisper
if: matrix.build_type != 'Debug'
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
... ...
... ... @@ -102,6 +102,15 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test spoken language identification
if: matrix.build_type != 'Debug'
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-language-identification
.github/scripts/test-spoken-language-identification.sh
- name: Test transducer kws
shell: bash
run: |
... ... @@ -135,6 +144,7 @@ jobs:
.github/scripts/test-online-paraformer.sh
- name: Test offline Whisper
if: matrix.build_type != 'Debug'
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
... ...
... ... @@ -68,6 +68,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test spoken language identification
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-language-identification.exe
.github/scripts/test-spoken-language-identification.sh
- name: Test online CTC
shell: bash
run: |
... ...
... ... @@ -68,6 +68,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test spoken language identification
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-language-identification.exe
.github/scripts/test-spoken-language-identification.sh
- name: Test online CTC
shell: bash
run: |
... ...
... ... @@ -69,6 +69,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
# - name: Test spoken language identification
# shell: bash
# run: |
# export PATH=$PWD/build/bin/Release:$PATH
# export EXE=sherpa-onnx-offline-language-identification.exe
#
# .github/scripts/test-spoken-language-identification.sh
- name: Test online CTC
shell: bash
run: |
... ...
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.9.13")
set(SHERPA_ONNX_VERSION "1.9.14")
# Disable warning about
#
... ...
... ... @@ -43,6 +43,50 @@ def enable_alsa():
return build_alsa and is_linux() and (is_arm64() or is_x86())
def get_binaries():
binaries = [
"sherpa-onnx",
"sherpa-onnx-keyword-spotter",
"sherpa-onnx-microphone",
"sherpa-onnx-microphone-offline",
"sherpa-onnx-microphone-offline-speaker-identification",
"sherpa-onnx-offline",
"sherpa-onnx-offline-language-identification",
"sherpa-onnx-offline-tts",
"sherpa-onnx-offline-tts-play",
"sherpa-onnx-offline-websocket-server",
"sherpa-onnx-online-websocket-client",
"sherpa-onnx-online-websocket-server",
"sherpa-onnx-vad-microphone",
"sherpa-onnx-vad-microphone-offline-asr",
]
if enable_alsa():
binaries += [
"sherpa-onnx-alsa",
"sherpa-onnx-alsa-offline",
"sherpa-onnx-alsa-offline-speaker-identification",
"sherpa-onnx-offline-tts-play-alsa",
]
if is_windows():
binaries += [
"espeak-ng.dll",
"kaldi-decoder-core.dll",
"kaldi-native-fbank-core.dll",
"onnxruntime.dll",
"piper_phonemize.dll",
"sherpa-onnx-c-api.dll",
"sherpa-onnx-core.dll",
"sherpa-onnx-fst.lib",
"sherpa-onnx-kaldifst-core.lib",
"sherpa-onnx-portaudio.dll",
"ucd.dll",
]
return binaries
try:
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
... ... @@ -150,38 +194,7 @@ class BuildExtension(build_ext):
suffix = ".exe" if is_windows() else ""
# Remember to also change setup.py
binaries = ["sherpa-onnx"]
binaries += ["sherpa-onnx-keyword-spotter"]
binaries += ["sherpa-onnx-offline"]
binaries += ["sherpa-onnx-microphone"]
binaries += ["sherpa-onnx-microphone-offline"]
binaries += ["sherpa-onnx-microphone-offline-speaker-identification"]
binaries += ["sherpa-onnx-online-websocket-server"]
binaries += ["sherpa-onnx-offline-websocket-server"]
binaries += ["sherpa-onnx-online-websocket-client"]
binaries += ["sherpa-onnx-vad-microphone"]
binaries += ["sherpa-onnx-vad-microphone-offline-asr"]
binaries += ["sherpa-onnx-offline-tts"]
binaries += ["sherpa-onnx-offline-tts-play"]
if enable_alsa():
binaries += ["sherpa-onnx-alsa"]
binaries += ["sherpa-onnx-alsa-offline"]
binaries += ["sherpa-onnx-offline-tts-play-alsa"]
binaries += ["sherpa-onnx-alsa-offline-speaker-identification"]
if is_windows():
binaries += ["kaldi-native-fbank-core.dll"]
binaries += ["sherpa-onnx-c-api.dll"]
binaries += ["sherpa-onnx-core.dll"]
binaries += ["sherpa-onnx-portaudio.dll"]
binaries += ["onnxruntime.dll"]
binaries += ["piper_phonemize.dll"]
binaries += ["espeak-ng.dll"]
binaries += ["ucd.dll"]
binaries += ["kaldi-decoder-core.dll"]
binaries += ["sherpa-onnx-fst.lib"]
binaries += ["sherpa-onnx-kaldifst-core.lib"]
binaries = get_binaries()
for f in binaries:
suffix = "" if (".dll" in f or ".lib" in f) else suffix
... ...
#!/usr/bin/env python3
"""
This script shows how to use Python APIs for spoken languge identification.
It detects the language spoken in the given wave file.
Usage:
1. Download a whisper multilingual model. We use a tiny model below.
Please refer to https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
to download more models.
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
rm sherpa-onnx-whisper-tiny.tar.bz2
We only use the int8.onnx models below.
2. Download a test wave.
You can find many wave files for different languages at
https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
wget https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/de-german.wav
python3 ./python-api-examples/spoken-language-identification.py
--whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
--whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
--num-threads=1 \
./de-german.wav
"""
import argparse
import logging
import time
import wave
from pathlib import Path
from typing import Tuple
import numpy as np
import sherpa_onnx
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--whisper-encoder",
required=True,
type=str,
help="Path to a multilingual whisper encoder model",
)
parser.add_argument(
"--whisper-decoder",
required=True,
type=str,
help="Path to a multilingual whisper decoder model",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--debug",
type=bool,
default=False,
help="True to show debug messages",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to identify. It must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
return parser.parse_args()
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
def main():
args = get_args()
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
assert args.num_threads > 0, args.num_threads
config = sherpa_onnx.SpokenLanguageIdentificationConfig(
whisper=sherpa_onnx.SpokenLanguageIdentificationWhisperConfig(
encoder=args.whisper_encoder,
decoder=args.whisper_decoder,
),
num_threads=args.num_threads,
debug=args.debug,
provider=args.provider,
)
slid = sherpa_onnx.SpokenLanguageIdentification(config)
samples, sample_rate = read_wave(args.sound_file)
start_time = time.time()
stream = slid.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
lang = slid.compute(stream)
end_time = time.time()
elapsed_seconds = end_time - start_time
audio_duration = len(samples) / sample_rate
real_time_factor = elapsed_seconds / audio_duration
logging.info(f"File: {args.sound_file}")
logging.info(f"Detected language: {lang}")
logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
logging.info(
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
... ...
#!/usr/bin/env python3
import os
import re
import sys
from pathlib import Path
import setuptools
... ... @@ -11,7 +9,7 @@ from cmake.cmake_extension import (
BuildExtension,
bdist_wheel,
cmake_extension,
enable_alsa,
get_binaries,
is_windows,
)
... ... @@ -42,39 +40,7 @@ def get_binaries_to_install():
bin_dir.mkdir(parents=True, exist_ok=True)
suffix = ".exe" if is_windows() else ""
# Remember to also change cmake/cmake_extension.py
binaries = ["sherpa-onnx"]
binaries += ["sherpa-onnx-keyword-spotter"]
binaries += ["sherpa-onnx-offline"]
binaries += ["sherpa-onnx-microphone"]
binaries += ["sherpa-onnx-microphone-offline"]
binaries += ["sherpa-onnx-microphone-offline-speaker-identification"]
binaries += ["sherpa-onnx-online-websocket-server"]
binaries += ["sherpa-onnx-offline-websocket-server"]
binaries += ["sherpa-onnx-online-websocket-client"]
binaries += ["sherpa-onnx-vad-microphone"]
binaries += ["sherpa-onnx-vad-microphone-offline-asr"]
binaries += ["sherpa-onnx-offline-tts"]
binaries += ["sherpa-onnx-offline-tts-play"]
if enable_alsa():
binaries += ["sherpa-onnx-alsa"]
binaries += ["sherpa-onnx-alsa-offline"]
binaries += ["sherpa-onnx-offline-tts-play-alsa"]
binaries += ["sherpa-onnx-alsa-offline-speaker-identification"]
if is_windows():
binaries += ["kaldi-native-fbank-core.dll"]
binaries += ["sherpa-onnx-c-api.dll"]
binaries += ["sherpa-onnx-core.dll"]
binaries += ["sherpa-onnx-portaudio.dll"]
binaries += ["onnxruntime.dll"]
binaries += ["piper_phonemize.dll"]
binaries += ["espeak-ng.dll"]
binaries += ["ucd.dll"]
binaries += ["kaldi-decoder-core.dll"]
binaries += ["sherpa-onnx-fst.lib"]
binaries += ["sherpa-onnx-kaldifst-core.lib"]
binaries = get_binaries()
exe = []
for f in binaries:
... ...
... ... @@ -86,6 +86,8 @@ set(sources
silero-vad-model-config.cc
silero-vad-model.cc
slice.cc
spoken-language-identification-impl.cc
spoken-language-identification.cc
stack.cc
symbol-table.cc
text-utils.cc
... ... @@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
set(main_exes
sherpa-onnx
... ... @@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline
sherpa-onnx-offline-parallel
sherpa-onnx-offline-tts
sherpa-onnx-offline-language-identification
)
foreach(exe IN LISTS main_exes)
... ...
... ... @@ -23,7 +23,7 @@ enum class ModelType {
kTdnn,
kZipformerCtc,
kWenetCtc,
kUnkown,
kUnknown,
};
} // namespace
... ... @@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"run.sh\n"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnkown;
return ModelType::kUnknown;
}
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
... ... @@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kWenetCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
return ModelType::kUnknown;
}
}
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
const OfflineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
ModelType model_type = ModelType::kUnknown;
std::string filename;
if (!config.nemo_ctc.model.empty()) {
... ... @@ -113,7 +113,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(config);
break;
case ModelType::kUnkown:
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
}
... ... @@ -125,7 +125,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
AAssetManager *mgr, const OfflineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
ModelType model_type = ModelType::kUnknown;
std::string filename;
if (!config.nemo_ctc.model.empty()) {
... ... @@ -160,7 +160,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
}
... ...
... ... @@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
num_frames = max_num_frames - 50;
}
NormalizeFeatures(f.data(), num_frames, feat_dim);
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
// note that 1000 is an experience-value.
// You can replace 1000 by other values, say, 100.
... ... @@ -163,38 +163,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
private:
static void NormalizeFeatures(float *features, int32_t num_frames,
int32_t feat_dim) {
// log_spec = torch.clamp(features, min=1e-10).log10()
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
// mel = (log_spec + 4.0) / 4.0
int32_t n = num_frames * feat_dim;
float max_v = -1e20;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max<float>(f, 1e-10);
f = std::log10(f);
max_v = std::max(f, max_v);
features[i] = f;
}
max_v -= 8;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max(f, max_v);
f = (f + 4) / 4;
features[i] = f;
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineWhisperModel> model_;
... ...
... ... @@ -12,56 +12,6 @@
namespace sherpa_onnx {
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
int64_t token_val = model_->SOT();
std::array<int64_t, 2> token_shape{1, 1};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
auto self_kv_cache = model_->GetInitialSelfKVCache();
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto decoder_out = model_->ForwardDecoder(
std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
std::move(offset));
cross_k = std::move(std::get<3>(decoder_out));
cross_v = std::move(std::get<4>(decoder_out));
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
int32_t vocab_size = model_->VocabSize();
const auto &all_language_ids = model_->GetAllLanguageIDs();
int32_t lang_id = all_language_ids[0];
float this_logit = p_logits[lang_id];
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
int32_t id = all_language_ids[i];
float p = p_logits[id];
if (p > this_logit) {
this_logit = p;
lang_id = id;
}
}
#if 1
SHERPA_ONNX_LOGE("Detected language: %s",
model_->GetID2Lang().at(lang_id).c_str());
#endif
return lang_id;
}
std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
... ... @@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens[1] = lang_id;
} else {
int32_t lang_id = DetectLanguage(cross_k, cross_v);
int32_t lang_id = model_->DetectLanguage(cross_k, cross_v);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens[1] = lang_id;
... ...
... ... @@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
Ort::Value cross_v) override;
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v) const; // NOLINT
private:
OfflineWhisperModelConfig config_;
OfflineWhisperModel *model_; // not owned
... ...
... ... @@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
po->Register(
"whisper-tail-paddings", &tail_paddings,
"Suggest value: 50 for English models. 300 for multilingual models. "
"Suggested value: 50 for English models. 300 for multilingual models. "
"Since we have removed the 30-second constraint, we need to add some "
"tail padding frames "
"so that whisper can detect the eot token. Leave it to -1 to use 50 for "
"English models and 300 for multilingual models.");
"so that whisper can detect the eot token. Leave it to -1 to use 1000.");
}
bool OfflineWhisperModelConfig::Validate() const {
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
return false;
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
return false;
}
if (decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
return false;
... ...
... ... @@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.whisper.decoder);
InitDecoder(buf.data(), buf.size());
}
}
explicit Impl(const SpokenLanguageIdentificationConfig &config)
: lid_config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
... ... @@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(mgr, config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
... ... @@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl {
std::move(decoder_input[4]), std::move(decoder_input[5])};
}
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v) { // NOLINT
int64_t token_val = SOT();
std::array<int64_t, 2> token_shape{1, 1};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
auto self_kv_cache = GetInitialSelfKVCache();
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto decoder_out =
ForwardDecoder(std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k),
std::move(cross_v), std::move(offset));
cross_k = std::move(std::get<3>(decoder_out));
cross_v = std::move(std::get<4>(decoder_out));
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
int32_t vocab_size = VocabSize();
const auto &all_language_ids = GetAllLanguageIDs();
int32_t lang_id = all_language_ids[0];
float this_logit = p_logits[lang_id];
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
int32_t id = all_language_ids[i];
float p = p_logits[id];
if (p > this_logit) {
this_logit = p;
lang_id = id;
}
}
if (debug_) {
SHERPA_ONNX_LOGE("Detected language: %s",
GetID2Lang().at(lang_id).c_str());
}
return lang_id;
}
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
... ... @@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl {
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
if (debug_) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
... ... @@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl {
private:
OfflineModelConfig config_;
SpokenLanguageIdentificationConfig lid_config_;
bool debug_ = false;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
... ... @@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl {
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineWhisperModel::OfflineWhisperModel(
const SpokenLanguageIdentificationConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
const OfflineModelConfig &config)
... ... @@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
std::move(n_layer_cross_v), std::move(offset));
}
int32_t OfflineWhisperModel::DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v) { // NOLINT
return impl_->DetectLanguage(cross_k, cross_v);
}
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
const {
return impl_->GetInitialSelfKVCache();
... ... @@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const {
return impl_->IsMultiLingual();
}
void OfflineWhisperModel::NormalizeFeatures(float *features, int32_t num_frames,
int32_t feat_dim) {
// log_spec = torch.clamp(features, min=1e-10).log10()
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
// mel = (log_spec + 4.0) / 4.0
int32_t n = num_frames * feat_dim;
float max_v = -1e20;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max<float>(f, 1e-10);
f = std::log10(f);
max_v = std::max(f, max_v);
features[i] = f;
}
max_v -= 8;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max(f, max_v);
f = (f + 4) / 4;
features[i] = f;
}
}
} // namespace sherpa_onnx
... ...
... ... @@ -18,6 +18,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
namespace sherpa_onnx {
... ... @@ -25,6 +26,9 @@ class OfflineWhisperModel {
public:
explicit OfflineWhisperModel(const OfflineModelConfig &config);
explicit OfflineWhisperModel(
const SpokenLanguageIdentificationConfig &config);
#if __ANDROID_API__ >= 9
OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
... ... @@ -72,7 +76,8 @@ class OfflineWhisperModel {
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset) const;
int32_t DetectLanguage() const;
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v); // NOLINT
/** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 4-D tensor of shape
... ... @@ -98,6 +103,9 @@ class OfflineWhisperModel {
int32_t Translate() const;
bool IsMultiLingual() const;
static void NormalizeFeatures(float *features, int32_t num_frames,
int32_t feat_dim);
private:
class Impl;
std::unique_ptr<Impl> impl_;
... ...
... ... @@ -28,7 +28,7 @@ enum class ModelType {
kLstm,
kZipformer,
kZipformer2,
kUnkown,
kUnknown,
};
} // namespace
... ... @@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"No model_type in the metadata!\n"
"Please make sure you are using the latest export-onnx.py from icefall "
"to export your transducer models");
return ModelType::kUnkown;
return ModelType::kUnknown;
}
if (model_type.get() == std::string("conformer")) {
... ... @@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kZipformer2;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
return ModelType::kUnknown;
}
}
... ... @@ -93,7 +93,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
model_type.c_str());
}
}
ModelType model_type = ModelType::kUnkown;
ModelType model_type = ModelType::kUnknown;
{
auto buffer = ReadFile(config.transducer.encoder);
... ... @@ -110,7 +110,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
return std::make_unique<OnlineZipformerTransducerModel>(config);
case ModelType::kZipformer2:
return std::make_unique<OnlineZipformer2TransducerModel>(config);
case ModelType::kUnkown:
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
return nullptr;
}
... ... @@ -185,7 +185,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
case ModelType::kZipformer2:
return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
case ModelType::kUnkown:
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
return nullptr;
}
... ...
... ... @@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions(
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(
const SpokenLanguageIdentificationConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
} // namespace sherpa_onnx
... ...
... ... @@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/vad-model-config.h"
namespace sherpa_onnx {
... ... @@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config);
Ort::SessionOptions GetSessionOptions(
const SpokenLanguageIdentificationConfig &config);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SESSION_H_
... ...
// sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
//
// Copyright (c) 2022-2024 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/wave-reader.h"
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Spoken language identification with sherpa-onnx.
Usage:
(1) Use a whisper multilingual model
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
rm sherpa-onnx-whisper-tiny.tar.bz2
We only use the int8.onnx models below.
./bin/sherpa-onnx-offline-spoken-language-identification \
--whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
--whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
--num-threads=1 \
/path/to/foo.wav
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
You can find test waves for different languages at
https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
Note that only whisper multilingual models are supported. For instance,
"tiny" is supported but "tiny.en" is not.
for a list of pre-trained models to download.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::SpokenLanguageIdentificationConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() != 1) {
fprintf(stderr, "Error: Please provide 1 wave file.\n\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
fprintf(stderr, "Creating spoken language identifier ...\n");
sherpa_onnx::SpokenLanguageIdentification slid(config);
fprintf(stderr, "Started\n");
const std::string wav_filename = po.GetArg(1);
int32_t sampling_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
return -1;
}
float duration = samples.size() / static_cast<float>(sampling_rate);
const auto begin = std::chrono::steady_clock::now();
auto s = slid.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
auto language = slid.Compute(s.get());
const auto end = std::chrono::steady_clock::now();
fprintf(stderr, "Done!\n\n");
fprintf(stderr, "%s\nDetected language: %s\n", wav_filename.c_str(),
language.c_str());
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
fprintf(stderr, "num threads: %d\n", config.num_threads);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
return 0;
}
... ...
... ... @@ -16,7 +16,7 @@ enum class ModelType {
kWeSpeaker,
k3dSpeaker,
kNeMo,
kUnkown,
kUnknown,
};
} // namespace
... ... @@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
"add_meta_data.py"
"to add metadata to models from WeSpeaker\n");
return ModelType::kUnkown;
return ModelType::kUnknown;
}
if (model_type.get() == std::string("wespeaker")) {
... ... @@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kNeMo;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
return ModelType::kUnknown;
}
}
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
SpeakerEmbeddingExtractorImpl::Create(
const SpeakerEmbeddingExtractorConfig &config) {
ModelType model_type = ModelType::kUnkown;
ModelType model_type = ModelType::kUnknown;
{
auto buffer = ReadFile(config.model);
... ... @@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create(
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
case ModelType::kNeMo:
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type for speaker embedding extractor!");
return nullptr;
}
... ... @@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create(
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
SpeakerEmbeddingExtractorImpl::Create(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) {
ModelType model_type = ModelType::kUnkown;
ModelType model_type = ModelType::kUnknown;
{
auto buffer = ReadFile(mgr, config.model);
... ... @@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create(
config);
case ModelType::kNeMo:
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config);
case ModelType::kUnkown:
case ModelType::kUnknown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
return nullptr;
... ...
// sherpa-onnx/csrc/spoken-language-identification-impl.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
#include <memory>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"
namespace sherpa_onnx {
namespace {
enum class ModelType {
kWhisper,
kUnknown,
};
}
static ModelType GetModelType(char *model_data, size_t model_data_length,
bool debug) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
Ort::SessionOptions sess_opts;
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
sess_opts);
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
if (debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you have added metadata to the model.\n\n"
"For instance, you can use\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/"
"export-onnx.py "
"to add metadata to models from whisper\n");
return ModelType::kUnknown;
}
auto model_type_str = std::string(model_type.get());
if (model_type_str.find("whisper") == 0) {
return ModelType::kWhisper;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnknown;
}
}
std::unique_ptr<SpokenLanguageIdentificationImpl>
SpokenLanguageIdentificationImpl::Create(
const SpokenLanguageIdentificationConfig &config) {
ModelType model_type = ModelType::kUnknown;
{
if (config.whisper.encoder.empty()) {
SHERPA_ONNX_LOGE("Only whisper models are supported at present");
exit(-1);
}
auto buffer = ReadFile(config.whisper.encoder);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kWhisper:
return std::make_unique<SpokenLanguageIdentificationWhisperImpl>(config);
case ModelType::kUnknown:
SHERPA_ONNX_LOGE(
"Unknown model type for spoken language identification!");
return nullptr;
}
// unreachable code
return nullptr;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/spoken-language-identification-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/spoken-language-identification.h"
namespace sherpa_onnx {
class SpokenLanguageIdentificationImpl {
public:
virtual ~SpokenLanguageIdentificationImpl() = default;
static std::unique_ptr<SpokenLanguageIdentificationImpl> Create(
const SpokenLanguageIdentificationConfig &config);
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
virtual std::string Compute(OfflineStream *s) const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
... ...
// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class SpokenLanguageIdentificationWhisperImpl
: public SpokenLanguageIdentificationImpl {
public:
explicit SpokenLanguageIdentificationWhisperImpl(
const SpokenLanguageIdentificationConfig &config)
: config_(config), model_(std::make_unique<OfflineWhisperModel>(config)) {
Check();
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(WhisperTag{});
}
std::string Compute(OfflineStream *s) const override {
int32_t max_num_frames = 3000;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = s->FeatureDim();
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
// we use 50 here so that there will be some zero tail paddings
if (num_frames >= max_num_frames - 50) {
SHERPA_ONNX_LOGE(
"Only waves less than 30 seconds are supported. We process only the "
"first 30 seconds and discard the remaining data");
num_frames = max_num_frames - 50;
}
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
// note that 1000 is an experience-value.
// You can replace 1000 by other values, say, 100.
//
// Since we have removed the 30 seconds constraint, we need
// tail_padding_frames so that whisper is able to detect the eot token.
int32_t tail_padding_frames = 1000;
if (config_.whisper.tail_paddings > 0) {
tail_padding_frames = config_.whisper.tail_paddings;
}
int32_t actual_frames =
std::min(num_frames + tail_padding_frames, max_num_frames);
std::array<int64_t, 3> shape{1, actual_frames, feat_dim};
Ort::Value mel = Ort::Value::CreateTensor<float>(
model_->Allocator(), shape.data(), shape.size());
float *p_mel = mel.GetTensorMutableData<float>();
std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel);
std::fill_n(p_mel + num_frames * feat_dim,
(actual_frames - num_frames) * feat_dim, 0);
mel = Transpose12(model_->Allocator(), &mel);
try {
auto cross_kv = model_->ForwardEncoder(std::move(mel));
int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second);
const auto &id2lang = model_->GetID2Lang();
if (id2lang.count(lang_id)) {
return id2lang.at(lang_id);
} else {
SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.",
lang_id);
return "";
}
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE(
"\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
"input frames: %d, Current tail "
"paddings: %d. If you see a lot of such exceptions, please consider "
"using a larger --whisper-tail-paddings",
ex.what(), num_frames, tail_padding_frames);
return "";
}
}
private:
void Check() const {
if (!model_->IsMultiLingual()) {
SHERPA_ONNX_LOGE(
"Only whisper multilingual models can be used for spoken language "
"identification. Given: %s,%s",
config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str());
exit(-1);
}
}
private:
SpokenLanguageIdentificationConfig config_;
std::unique_ptr<OfflineWhisperModel> model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
... ...
// sherpa-onnx/csrc/spoken-language-identification.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include <string>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
namespace sherpa_onnx {
void SpokenLanguageIdentificationWhisperConfig::Register(ParseOptions *po) {
po->Register(
"whisper-encoder", &encoder,
"Path to then encoder of a whisper multilingual model. Support only "
"tiny, base, small, medium, large.");
po->Register(
"whisper-decoder", &decoder,
"Path to the decoder of a whisper multilingual model. Support only "
"tiny, base, small, medium, large.");
po->Register(
"whisper-tail-paddings", &tail_paddings,
"Suggested value: 300 for multilingual models. "
"Since we have removed the 30-second constraint, we need to add some "
"tail padding frames "
"so that whisper can detect the eot token. Leave it to -1 to use 1000");
}
bool SpokenLanguageIdentificationWhisperConfig::Validate() const {
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
return false;
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
return false;
}
if (decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
return false;
}
return true;
}
std::string SpokenLanguageIdentificationWhisperConfig::ToString() const {
std::ostringstream os;
os << "SpokenLanguageIdentificationWhisperConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\", ";
os << "tail_paddings=" << tail_paddings << ")";
return os.str();
}
void SpokenLanguageIdentificationConfig::Register(ParseOptions *po) {
whisper.Register(po);
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool SpokenLanguageIdentificationConfig::Validate() const {
if (!whisper.Validate()) {
return false;
}
return true;
}
std::string SpokenLanguageIdentificationConfig::ToString() const {
std::ostringstream os;
os << "SpokenLanguageIdentificationConfig(";
os << "whisper=\"" << whisper.ToString() << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
return os.str();
}
SpokenLanguageIdentification::SpokenLanguageIdentification(
const SpokenLanguageIdentificationConfig &config)
: impl_(SpokenLanguageIdentificationImpl::Create(config)) {}
SpokenLanguageIdentification::~SpokenLanguageIdentification() = default;
std::unique_ptr<OfflineStream> SpokenLanguageIdentification::CreateStream()
const {
return impl_->CreateStream();
}
std::string SpokenLanguageIdentification::Compute(OfflineStream *s) const {
return impl_->Compute(s);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/spoken-language-identification.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/offline-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct SpokenLanguageIdentificationWhisperConfig {
// Requires a multi-lingual whisper model.
// That is, it supports only tiny, base, small, medium, large.
// Note: It does NOT support tiny.en, base.en, small.en, medium.en
std::string encoder;
std::string decoder;
// Number of tail padding frames.
//
// Since we remove the 30-second constraint, we need to add some paddings
// at the end.
//
// Recommended values:
// - 50 for English models
// - 300 for multilingual models
int32_t tail_paddings = -1;
SpokenLanguageIdentificationWhisperConfig() = default;
SpokenLanguageIdentificationWhisperConfig(const std::string &encoder,
const std::string &decoder,
int32_t tail_paddings)
: encoder(encoder), decoder(decoder), tail_paddings(tail_paddings) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
struct SpokenLanguageIdentificationConfig {
SpokenLanguageIdentificationWhisperConfig whisper;
int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";
SpokenLanguageIdentificationConfig() = default;
SpokenLanguageIdentificationConfig(
const SpokenLanguageIdentificationWhisperConfig &whisper,
int32_t num_threads, bool debug, const std::string &provider)
: whisper(whisper),
num_threads(num_threads),
debug(debug),
provider(provider) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
class SpokenLanguageIdentificationImpl;
class SpokenLanguageIdentification {
public:
explicit SpokenLanguageIdentification(
const SpokenLanguageIdentificationConfig &config);
~SpokenLanguageIdentification();
// Create a stream to accept audio samples and compute features
std::unique_ptr<OfflineStream> CreateStream() const;
// Return a string containing the language, e.g., en, zh, de,
// etc.
// Note: en is for English, zh is for Chinese, de is for German, etc.
std::string Compute(OfflineStream *s) const;
private:
std::unique_ptr<SpokenLanguageIdentificationImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
... ...
... ... @@ -33,6 +33,7 @@ set(srcs
silero-vad-model-config.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
spoken-language-identification.cc
vad-model-config.cc
vad-model.cc
voice-activity-detector.cc
... ...
... ... @@ -22,6 +22,7 @@
#include "sherpa-onnx/python/csrc/online-stream.h"
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
#include "sherpa-onnx/python/csrc/vad-model-config.h"
#include "sherpa-onnx/python/csrc/vad-model.h"
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
... ... @@ -55,6 +56,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOfflineTts(&m);
PybindSpeakerEmbeddingExtractor(&m);
PybindSpeakerEmbeddingManager(&m);
PybindSpokenLanguageIdentification(&m);
PybindAlsa(&m);
}
... ...
// sherpa-onnx/python/csrc/spoken-language-identification.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
#include <string>
#include "sherpa-onnx/csrc/spoken-language-identification.h"
namespace sherpa_onnx {
static void PybindSpokenLanguageIdentificationWhisperConfig(py::module *m) {
using PyClass = SpokenLanguageIdentificationWhisperConfig;
py::class_<PyClass>(*m, "SpokenLanguageIdentificationWhisperConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &, int32_t>(),
py::arg("encoder"), py::arg("decoder"),
py::arg("tail_paddings") = -1)
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("decoder", &PyClass::decoder)
.def_readwrite("tail_paddings", &PyClass::tail_paddings)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
static void PybindSpokenLanguageIdentificationConfig(py::module *m) {
PybindSpokenLanguageIdentificationWhisperConfig(m);
using PyClass = SpokenLanguageIdentificationConfig;
py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
.def(py::init<>())
.def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
bool, const std::string>(),
py::arg("whisper"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("whisper", &PyClass::whisper)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
void PybindSpokenLanguageIdentification(py::module *m) {
PybindSpokenLanguageIdentificationConfig(m);
using PyClass = SpokenLanguageIdentification;
py::class_<PyClass>(*m, "SpokenLanguageIdentification")
.def(py::init<const SpokenLanguageIdentificationConfig &>(),
py::arg("config"), py::call_guard<py::gil_scoped_release>())
.def("create_stream", &PyClass::CreateStream,
py::call_guard<py::gil_scoped_release>())
.def("compute", &PyClass::Compute,
py::call_guard<py::gil_scoped_release>());
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/spoken-language-identification.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
#define SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindSpokenLanguageIdentification(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
... ...
... ... @@ -13,6 +13,9 @@ from _sherpa_onnx import (
SpeakerEmbeddingExtractorConfig,
SpeakerEmbeddingManager,
SpeechSegment,
SpokenLanguageIdentification,
SpokenLanguageIdentificationConfig,
SpokenLanguageIdentificationWhisperConfig,
VadModel,
VadModelConfig,
VoiceActivityDetector,
... ...