Fangjun Kuang
Committed by GitHub

Add C++ runtime for Tele-AI/TeleSpeech-ASR (#970)

正在显示 52 个修改的文件 包含 1050 行增加143 行删除
... ... @@ -2,7 +2,16 @@
cd dotnet-examples/
cd vad-non-streaming-asr-paraformer
cd ./offline-decode-files
./run-telespeech-ctc.sh
./run-nemo-ctc.sh
./run-paraformer.sh
./run-zipformer.sh
./run-hotwords.sh
./run-whisper.sh
./run-tdnn-yesno.sh
cd ../vad-non-streaming-asr-paraformer
./run.sh
cd ../offline-punctuation
... ... @@ -22,14 +31,6 @@ cd ../online-decode-files
./run-transducer.sh
./run-paraformer.sh
cd ../offline-decode-files
./run-nemo-ctc.sh
./run-paraformer.sh
./run-zipformer.sh
./run-hotwords.sh
./run-whisper.sh
./run-tdnn-yesno.sh
cd ../offline-tts
./run-aishell3.sh
./run-piper.sh
... ...
... ... @@ -15,6 +15,39 @@ echo "PATH: $PATH"
which $EXE
log "test offline TeleSpeech CTC"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
name=$(basename $url)
repo=$(basename -s .tar.bz2 $name)
curl -SL -O $url
tar xvf $name
rm $name
ls -lh $repo
test_wavs=(
3-sichuan.wav
4-tianjin.wav
5-henan.wav
)
for w in ${test_wavs[@]}; do
time $EXE \
--tokens=$repo/tokens.txt \
--telespeech-ctc=$repo/model.int8.onnx \
--debug=1 \
$repo/test_wavs/$w
done
time $EXE \
--tokens=$repo/tokens.txt \
--telespeech-ctc=$repo/model.int8.onnx \
--debug=1 \
$repo/test_wavs/3-sichuan.wav \
$repo/test_wavs/4-tianjin.wav \
$repo/test_wavs/5-henan.wav
rm -rf $repo
log "-----------------------------------------------------------------"
log "Run Nemo fast conformer hybrid transducer ctc models (CTC branch)"
log "-----------------------------------------------------------------"
... ...
... ... @@ -10,6 +10,18 @@ log() {
export GIT_CLONE_PROTECTION_ACTIVE=false
log "test offline TeleSpeech CTC"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
name=$(basename $url)
repo=$(basename -s .tar.bz2 $name)
curl -SL -O $url
tar xvf $name
rm $name
ls -lh $repo
python3 ./python-api-examples/offline-telespeech-ctc-decode-files.py
rm -rf $repo
log "test online NeMo CTC"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2
... ...
... ... @@ -82,7 +82,7 @@ jobs:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python3 -m pip install --upgrade pip
python3 -m pip install wheel twine setuptools
python3 -m pip install --break-system-packages --upgrade pip
python3 -m pip install --break-system-packages wheel twine setuptools
twine upload ./wheelhouse/*.whl
... ...
name: build-wheels-macos-universal2
on:
push:
branches:
- wheel
tags:
- '*'
workflow_dispatch:
env:
SHERPA_ONNX_IS_IN_GITHUB_ACTIONS: 1
concurrency:
group: build-wheels-macos-universal2-${{ github.ref }}
cancel-in-progress: true
jobs:
build_wheels_macos_universal2:
name: ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest]
python-version: ["cp38", "cp39", "cp310", "cp311", "cp312"]
steps:
- uses: actions/checkout@v4
- name: Build wheels
uses: pypa/cibuildwheel@v2.15.0
env:
CIBW_BUILD: "${{ matrix.python-version}}-* "
CIBW_ENVIRONMENT: SHERPA_ONNX_CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES='arm64;x86_64'"
CIBW_ARCHS: "universal2"
CIBW_BUILD_VERBOSITY: 3
# Don't repair macOS wheels
CIBW_REPAIR_WHEEL_COMMAND_MACOS: ""
- name: Display wheels
shell: bash
run: |
ls -lh ./wheelhouse/
- uses: actions/upload-artifact@v4
with:
name: wheel-${{ matrix.python-version }}
path: ./wheelhouse/*.whl
- name: Publish to huggingface
if: matrix.python-version == 'cp38'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://huggingface.co/csukuangfj/sherpa-onnx-wheels huggingface
cd huggingface
git fetch
git pull
git merge -m "merge remote" --ff origin main
cp -v ../wheelhouse/*.whl .
git status
git add .
git commit -m "add more wheels"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-wheels main
- name: Publish wheels to PyPI
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python3 -m pip install --break-system-packages --upgrade pip
python3 -m pip install --break-system-packages wheel twine setuptools
twine upload ./wheelhouse/*.whl
... ...
... ... @@ -99,7 +99,7 @@ jobs:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python3 -m pip install --upgrade pip
python3 -m pip install wheel twine setuptools
python3 -m pip install --break-system-packages --upgrade pip
python3 -m pip install --break-system-packages wheel twine setuptools
twine upload ./wheelhouse/*.whl
... ...
... ... @@ -48,3 +48,49 @@ jobs:
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: asr-models
- name: Publish float32 model to huggingface
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
src=scripts/tele-speech/sherpa-onnx-telespeech-ctc-zh-2024-06-04
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
export GIT_CLONE_PROTECTION_ACTIVE=false
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-telespeech-ctc-zh-2024-06-04 hf
cp -a $src/* hf/
cd hf
git lfs track "*.pdf"
git lfs track "*.onnx"
git add .
git commit -m 'add model files' || true
git status
ls -lh
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-telespeech-ctc-zh-2024-06-04 main || true
rm -rf hf
- name: Publish int8 model to huggingface
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
src=scripts/tele-speech/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
export GIT_CLONE_PROTECTION_ACTIVE=false
rm -rf hf
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04 hf
cp -a $src/* hf/
cd hf
git lfs track "*.pdf"
git lfs track "*.onnx"
git add .
git commit -m 'add model files' || true
git status
ls -lh
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04 main || true
... ...
... ... @@ -130,34 +130,34 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*
- name: Test online transducer
- name: Test offline CTC
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
export EXE=sherpa-onnx-offline
.github/scripts/test-online-transducer.sh
.github/scripts/test-offline-ctc.sh
du -h -d1 .
- name: Test online transducer (C API)
- name: Test online transducer
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=decode-file-c-api
export EXE=sherpa-onnx
.github/scripts/test-online-transducer.sh
du -h -d1 .
- name: Test offline CTC
- name: Test online transducer (C API)
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
export EXE=decode-file-c-api
.github/scripts/test-offline-ctc.sh
.github/scripts/test-online-transducer.sh
du -h -d1 .
- name: Test spoken language identification (C++ API)
... ...
... ... @@ -107,6 +107,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test offline CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
- name: Test offline transducer
shell: bash
run: |
... ... @@ -192,13 +200,7 @@ jobs:
.github/scripts/test-offline-whisper.sh
- name: Test offline CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
- name: Test online transducer
shell: bash
... ...
... ... @@ -39,7 +39,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [macos-13]
os: [macos-latest, macos-14]
steps:
- uses: actions/checkout@v4
... ...
... ... @@ -30,14 +30,12 @@ concurrency:
jobs:
test-go:
name: ${{ matrix.os }} ${{matrix.arch }}
name: ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- os: macos-latest
arch: amd64
os: [macos-latest, macos-14]
steps:
- uses: actions/checkout@v4
... ... @@ -47,7 +45,7 @@ jobs:
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ${{ matrix.os }}-${{ matrix.arch }}
key: ${{ matrix.os }}-go
- uses: actions/setup-go@v5
with:
... ... @@ -109,8 +107,6 @@ jobs:
go build
ls -lh
git lfs install
echo "Test vits-ljs"
./run-vits-ljs.sh
rm -rf vits-ljs
... ... @@ -144,7 +140,13 @@ jobs:
go build
ls -lh
git lfs install
echo "Test telespeech ctc"
./run-telespeech-ctc.sh
rm -rf sherpa-onnx-telespeech-ctc-*
echo "Test transducer"
./run-transducer.sh
rm -rf sherpa-onnx-zipformer-en-2023-06-26
echo "Test transducer"
./run-transducer.sh
... ...
... ... @@ -57,7 +57,7 @@ jobs:
mkdir build
cd build
cmake -DCMAKE_VERBOSE_MAKEFILE=ON -D SHERPA_ONNX_ENABLE_TESTS=ON -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} -D BUILD_SHARED_LIBS=${{ matrix.shared_lib }} -DCMAKE_INSTALL_PREFIX=./install ..
cmake -DSHERPA_ONNX_ENABLE_EPSEAK_NG_EXE=ON -DBUILD_ESPEAK_NG_EXE=ON -DCMAKE_VERBOSE_MAKEFILE=ON -D SHERPA_ONNX_ENABLE_TESTS=ON -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} -D BUILD_SHARED_LIBS=${{ matrix.shared_lib }} -DCMAKE_INSTALL_PREFIX=./install ..
- name: Build
shell: bash
... ...
... ... @@ -106,3 +106,4 @@ node_modules
package-lock.json
sherpa-onnx-nemo-*
sherpa-onnx-vits-*
sherpa-onnx-telespeech-ctc-*
... ...
... ... @@ -6,7 +6,7 @@ set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14" CACHE STRING "Minimum OS X deployment ve
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.9.26")
set(SHERPA_ONNX_VERSION "1.9.27")
# Disable warning about
#
... ...
... ... @@ -14,7 +14,9 @@ function(download_espeak_ng_for_piper)
set(USE_SPEECHPLAYER OFF CACHE BOOL "" FORCE)
set(EXTRA_cmn ON CACHE BOOL "" FORCE)
set(EXTRA_ru ON CACHE BOOL "" FORCE)
set(BUILD_ESPEAK_NG_EXE OFF CACHE BOOL "" FORCE)
if (NOT SHERPA_ONNX_ENABLE_EPSEAK_NG_EXE)
set(BUILD_ESPEAK_NG_EXE OFF CACHE BOOL "" FORCE)
endif()
# If you don't have access to the Internet,
# please pre-download kaldi-decoder
... ...
function(download_kaldi_native_fbank)
include(FetchContent)
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz")
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=335fe1daf1b9bfb2a7b6bf03b64c4c4686c39077c57fb8058c02611981676638")
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
... ... @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.19.1.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.19.1.tar.gz
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.19.1.tar.gz
/tmp/kaldi-native-fbank-1.19.1.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.19.1.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.19.3.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.19.3.tar.gz
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.19.3.tar.gz
/tmp/kaldi-native-fbank-1.19.3.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.19.3.tar.gz
)
foreach(f IN LISTS possible_file_locations)
... ...
... ... @@ -34,6 +34,9 @@ class OfflineDecodeFiles
[Option(Required = false, Default = "",HelpText = "Path to transducer joiner.onnx. Used only for transducer models")]
public string Joiner { get; set; }
[Option("model-type", Required = false, Default = "", HelpText = "model type")]
public string ModelType { get; set; }
[Option("whisper-encoder", Required = false, Default = "", HelpText = "Path to whisper encoder.onnx. Used only for whisper models")]
public string WhisperEncoder { get; set; }
... ... @@ -56,6 +59,9 @@ class OfflineDecodeFiles
[Option("nemo-ctc", Required = false, HelpText = "Path to model.onnx. Used only for NeMo CTC models")]
public string NeMoCtc { get; set; }
[Option("telespeech-ctc", Required = false, HelpText = "Path to model.onnx. Used only for TeleSpeech CTC models")]
public string TeleSpeechCtc { get; set; }
[Option("num-threads", Required = false, Default = 1, HelpText = "Number of threads for computation")]
public int NumThreads { get; set; }
... ... @@ -201,6 +207,10 @@ to download pre-trained Tdnn models.
{
config.ModelConfig.NeMoCtc.Model = options.NeMoCtc;
}
else if (!String.IsNullOrEmpty(options.TeleSpeechCtc))
{
config.ModelConfig.TeleSpeechCtc = options.TeleSpeechCtc;
}
else if (!String.IsNullOrEmpty(options.WhisperEncoder))
{
config.ModelConfig.Whisper.Encoder = options.WhisperEncoder;
... ... @@ -218,6 +228,7 @@ to download pre-trained Tdnn models.
return;
}
config.ModelConfig.ModelType = options.ModelType;
config.DecodingMethod = options.DecodingMethod;
config.MaxActivePaths = options.MaxActivePaths;
config.HotwordsFile = options.HotwordsFile;
... ...
#!/usr/bin/env bash
set -ex
if [ ! -d sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04 ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
tar xvf sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
rm sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
fi
dotnet run \
--telespeech-ctc=./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/model.int8.onnx \
--tokens=./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/tokens.txt \
--model-type=telespeech-ctc \
--files ./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/test_wavs/3-sichuan.wav
... ...
... ... @@ -40,6 +40,9 @@ func main() {
flag.IntVar(&config.ModelConfig.Debug, "debug", 0, "Whether to show debug message")
flag.StringVar(&config.ModelConfig.ModelType, "model-type", "", "Optional. Used for loading the model in a faster way")
flag.StringVar(&config.ModelConfig.Provider, "provider", "cpu", "Provider to use")
flag.StringVar(&config.ModelConfig.ModelingUnit, "modeling-unit", "cjkchar", "cjkchar, bpe, cjkchar+bpe, or leave it to empty")
flag.StringVar(&config.ModelConfig.BpeVocab, "bpe-vocab", "", "")
flag.StringVar(&config.ModelConfig.TeleSpeechCtc, "telespeech-ctc", "", "Used for TeleSpeechCtc model")
flag.StringVar(&config.LmConfig.Model, "lm-model", "", "Optional. Path to the LM model")
flag.Float32Var(&config.LmConfig.Scale, "lm-scale", 1.0, "Optional. Scale for the LM model")
... ...
#!/usr/bin/env bash
set -ex
if [ ! -d sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04 ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
tar xvf sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
rm sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
fi
go mod tidy
go build
./non-streaming-decode-files \
--telespeech-ctc ./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/model.int8.onnx \
--tokens ./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/tokens.txt \
--model-type telespeech-ctc \
--debug 0 \
./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/test_wavs/3-sichuan.wav
... ...
... ... @@ -4,7 +4,7 @@
// to decode files.
import com.k2fsa.sherpa.onnx.*;
public class NonStreamingDecodeFileTransducer {
public class NonStreamingDecodeFileParaformer {
public static void main(String[] args) {
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese-english
... ...
// Copyright 2024 Xiaomi Corporation
// This file shows how to use an offline TeleSpeech CTC model
// to decode files.
import com.k2fsa.sherpa.onnx.*;
public class NonStreamingDecodeFileTeleSpeechCtc {
public static void main(String[] args) {
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese-english
// to download model files
String model = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/model.int8.onnx";
String tokens = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/tokens.txt";
String waveFilename = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/test_wavs/3-sichuan.wav";
WaveReader reader = new WaveReader(waveFilename);
OfflineModelConfig modelConfig =
OfflineModelConfig.builder()
.setTeleSpeech(model)
.setTokens(tokens)
.setNumThreads(1)
.setDebug(true)
.setModelType("telespeech_ctc")
.build();
OfflineRecognizerConfig config =
OfflineRecognizerConfig.builder()
.setOfflineModelConfig(modelConfig)
.setDecodingMethod("greedy_search")
.build();
OfflineRecognizer recognizer = new OfflineRecognizer(config);
OfflineStream stream = recognizer.createStream();
stream.acceptWaveform(reader.getSamples(), reader.getSampleRate());
recognizer.decode(stream);
String text = recognizer.getResult(stream).getText();
System.out.printf("filename:%s\nresult:%s\n", waveFilename, text);
stream.release();
recognizer.release();
}
}
... ...
#!/usr/bin/env bash
set -ex
if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
mkdir -p ../build
pushd ../build
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..
make -j4
ls -lh lib
popd
fi
if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
pushd ../sherpa-onnx/java-api
make
popd
fi
if [ ! -f ./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/tokens.txt ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
tar xvf sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
rm sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
fi
java \
-Djava.library.path=$PWD/../build/lib \
-cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
./NonStreamingDecodeFileTeleSpeechCtc.java
... ...
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming CTC model from
https://github.com/Tele-AI/TeleSpeech-ASR
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
model = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/model.int8.onnx"
tokens = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/tokens.txt"
test_wav = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/test_wavs/3-sichuan.wav"
# test_wav = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/test_wavs/4-tianjin.wav"
# test_wav = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/test_wavs/5-henan.wav"
if not Path(model).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_telespeech_ctc(
model=model,
tokens=tokens,
debug=True,
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print(wave_filename)
print(stream.result)
if __name__ == "__main__":
main()
... ...
... ... @@ -166,6 +166,22 @@ def get_models():
popd
""",
),
Model(
model_name="sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04",
idx=11,
lang="zh",
short_name="telespeech",
cmd="""
pushd $model_name
rm -rfv test_wavs
rm test.py
ls -lh
popd
""",
),
]
return models
... ...
... ... @@ -25,6 +25,7 @@ namespace SherpaOnnx
ModelType = "";
ModelingUnit = "cjkchar";
BpeVocab = "";
TeleSpeechCtc = "";
}
public OfflineTransducerModelConfig Transducer;
public OfflineParaformerModelConfig Paraformer;
... ... @@ -50,5 +51,8 @@ namespace SherpaOnnx
[MarshalAs(UnmanagedType.LPStr)]
public string BpeVocab;
[MarshalAs(UnmanagedType.LPStr)]
public string TeleSpeechCtc;
}
}
... ...
... ... @@ -30,7 +30,7 @@ mkdir -p linux macos windows-x64 windows-x86
linux_wheel_filename=sherpa_onnx-${SHERPA_ONNX_VERSION}-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
linux_wheel=$src_dir/$linux_wheel_filename
macos_wheel_filename=sherpa_onnx-${SHERPA_ONNX_VERSION}-cp38-cp38-macosx_11_0_x86_64.whl
macos_wheel_filename=sherpa_onnx-${SHERPA_ONNX_VERSION}-cp38-cp38-macosx_11_0_universal2.whl
macos_wheel=$src_dir/$macos_wheel_filename
windows_x64_wheel_filename=sherpa_onnx-${SHERPA_ONNX_VERSION}-cp38-cp38-win_amd64.whl
... ... @@ -61,7 +61,7 @@ if [ ! -f $src_dir/linux/libsherpa-onnx-core.so ]; then
fi
if [ ! -f $src_dir/macos/libsherpa-onnx-core.dylib ]; then
echo "---macOS x86_64---"
echo "--- macOS x86_64/arm64 universal2---"
cd macos
mkdir -p wheel
cd wheel
... ...
../../../../go-api-examples/non-streaming-decode-files/run-telespeech-ctc.sh
\ No newline at end of file
... ...
... ... @@ -381,8 +381,9 @@ type OfflineModelConfig struct {
// Optional. Specify it for faster model initialization.
ModelType string
ModelingUnit string // Optional. cjkchar, bpe, cjkchar+bpe
BpeVocab string // Optional.
ModelingUnit string // Optional. cjkchar, bpe, cjkchar+bpe
BpeVocab string // Optional.
TeleSpeechCtc string // Optional.
}
// Configuration for the offline/non-streaming recognizer.
... ... @@ -477,6 +478,9 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer {
c.model_config.bpe_vocab = C.CString(config.ModelConfig.BpeVocab)
defer C.free(unsafe.Pointer(c.model_config.bpe_vocab))
c.model_config.telespeech_ctc = C.CString(config.ModelConfig.TeleSpeechCtc)
defer C.free(unsafe.Pointer(c.model_config.telespeech_ctc))
c.lm_config.model = C.CString(config.LmConfig.Model)
defer C.free(unsafe.Pointer(c.lm_config.model))
... ...
... ... @@ -128,6 +128,7 @@ static SherpaOnnxOfflineModelConfig GetOfflineModelConfig(Napi::Object obj) {
SHERPA_ONNX_ASSIGN_ATTR_STR(model_type, modelType);
SHERPA_ONNX_ASSIGN_ATTR_STR(modeling_unit, modelingUnit);
SHERPA_ONNX_ASSIGN_ATTR_STR(bpe_vocab, bpeVocab);
SHERPA_ONNX_ASSIGN_ATTR_STR(telespeech_ctc, teleSpeechCtc);
return c;
}
... ... @@ -242,6 +243,10 @@ CreateOfflineRecognizerWrapper(const Napi::CallbackInfo &info) {
delete[] c.model_config.bpe_vocab;
}
if (c.model_config.telespeech_ctc) {
delete[] c.model_config.telespeech_ctc;
}
if (c.lm_config.model) {
delete[] c.lm_config.model;
}
... ...
... ... @@ -366,6 +366,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
recognizer_config.model_config.bpe_vocab =
SHERPA_ONNX_OR(config->model_config.bpe_vocab, "");
recognizer_config.model_config.telespeech_ctc =
SHERPA_ONNX_OR(config->model_config.telespeech_ctc, "");
recognizer_config.lm_config.model =
SHERPA_ONNX_OR(config->lm_config.model, "");
recognizer_config.lm_config.scale =
... ...
... ... @@ -395,6 +395,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
// - cjkchar+bpe
const char *modeling_unit;
const char *bpe_vocab;
const char *telespeech_ctc;
} SherpaOnnxOfflineModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {
... ...
... ... @@ -39,6 +39,7 @@ set(sources
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
offline-telespeech-ctc-model.cc
offline-transducer-greedy-search-decoder.cc
offline-transducer-greedy-search-nemo-decoder.cc
offline-transducer-model-config.cc
... ...
... ... @@ -56,22 +56,11 @@ std::string FeatureExtractorConfig::ToString() const {
class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;
opts_.mel_opts.num_bins = config.feature_dim;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
if (config_.is_mfcc) {
InitMfcc();
} else {
InitFbank();
}
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
... ... @@ -101,35 +90,48 @@ class FeatureExtractor::Impl {
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
} else {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
}
return;
}
if (sampling_rate != opts_.frame_opts.samp_freq) {
if (sampling_rate != config_.sampling_rate) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
sampling_rate, static_cast<int32_t>(config_.sampling_rate));
float min_freq =
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
float min_freq = std::min<int32_t>(sampling_rate, config_.sampling_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
resampler_ = std::make_unique<LinearResample>(
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
sampling_rate, config_.sampling_rate, lowpass_cutoff,
lowpass_filter_width);
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
if (fbank_) {
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
} else {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
}
return;
}
fbank_->AcceptWaveform(sampling_rate, waveform, n);
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
} else {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
}
}
void InputFinished() const {
... ... @@ -179,11 +181,56 @@ class FeatureExtractor::Impl {
return features;
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
int32_t FeatureDim() const {
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
}
private:
void InitFbank() {
opts_.frame_opts.dither = config_.dither;
opts_.frame_opts.snip_edges = config_.snip_edges;
opts_.frame_opts.samp_freq = config_.sampling_rate;
opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
opts_.frame_opts.window_type = config_.window_type;
opts_.mel_opts.num_bins = config_.feature_dim;
opts_.mel_opts.high_freq = config_.high_freq;
opts_.mel_opts.low_freq = config_.low_freq;
opts_.mel_opts.is_librosa = config_.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void InitMfcc() {
mfcc_opts_.frame_opts.dither = config_.dither;
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
mfcc_opts_.frame_opts.samp_freq = config_.sampling_rate;
mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
mfcc_opts_.frame_opts.window_type = config_.window_type;
mfcc_opts_.mel_opts.num_bins = config_.feature_dim;
mfcc_opts_.mel_opts.high_freq = config_.high_freq;
mfcc_opts_.mel_opts.low_freq = config_.low_freq;
mfcc_opts_.mel_opts.is_librosa = config_.is_librosa;
mfcc_opts_.num_ceps = config_.num_ceps;
mfcc_opts_.use_energy = config_.use_energy;
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
}
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineMfcc> mfcc_;
knf::FbankOptions opts_;
knf::MfccOptions mfcc_opts_;
FeatureExtractorConfig config_;
mutable std::mutex mutex_;
std::unique_ptr<LinearResample> resampler_;
... ...
... ... @@ -18,7 +18,10 @@ struct FeatureExtractorConfig {
// the sampling rate of the input waveform, we will do resampling inside.
int32_t sampling_rate = 16000;
// Feature dimension
// num_mel_bins
//
// Note: for mfcc, this value is also for num_mel_bins.
// The actual feature dimension is actuall num_ceps
int32_t feature_dim = 80;
// minimal frequency for Mel-filterbank, in Hz
... ... @@ -69,6 +72,12 @@ struct FeatureExtractorConfig {
// for details
std::string nemo_normalize_type;
// for MFCC
int32_t num_ceps = 13;
bool use_energy = true;
bool is_mfcc = false;
std::string ToString() const;
void Register(ParseOptions *po);
... ...
... ... @@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
#include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
... ... @@ -24,6 +25,7 @@ enum class ModelType {
kTdnn,
kZipformerCtc,
kWenetCtc,
kTeleSpeechCtc,
kUnknown,
};
... ... @@ -63,6 +65,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"If you are using models from WeNet, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
"run.sh\n"
"If you are using models from TeleSpeech, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/tele-speech/"
"add-metadata.py"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnknown;
... ... @@ -78,6 +83,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) {
return ModelType::kWenetCtc;
} else if (model_type.get() == std::string("telespeech_ctc")) {
return ModelType::kTeleSpeechCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnknown;
... ... @@ -97,6 +104,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else if (!config.telespeech_ctc.empty()) {
filename = config.telespeech_ctc;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
... ... @@ -124,6 +133,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(config);
break;
case ModelType::kTeleSpeechCtc:
return std::make_unique<OfflineTeleSpeechCtcModel>(config);
break;
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
... ... @@ -147,6 +159,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else if (!config.telespeech_ctc.empty()) {
filename = config.telespeech_ctc;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
... ... @@ -175,6 +189,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
break;
case ModelType::kTeleSpeechCtc:
return std::make_unique<OfflineTeleSpeechCtcModel>(mgr, config);
break;
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
... ...
... ... @@ -19,6 +19,9 @@ void OfflineModelConfig::Register(ParseOptions *po) {
zipformer_ctc.Register(po);
wenet_ctc.Register(po);
po->Register("telespeech-ctc", &telespeech_ctc,
"Path to model.onnx for telespeech ctc");
po->Register("tokens", &tokens, "Path to tokens.txt");
po->Register("num-threads", &num_threads,
... ... @@ -33,7 +36,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
"tdnn, zipformer2_ctc"
"tdnn, zipformer2_ctc, telespeech_ctc."
"All other values lead to loading the model twice.");
po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
... ... @@ -55,14 +58,14 @@ bool OfflineModelConfig::Validate() const {
}
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str());
return false;
}
if (!modeling_unit.empty() &&
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
SHERPA_ONNX_LOGE("bpe_vocab: '%s' does not exist", bpe_vocab.c_str());
return false;
}
}
... ... @@ -91,6 +94,14 @@ bool OfflineModelConfig::Validate() const {
return wenet_ctc.Validate();
}
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
telespeech_ctc.c_str());
return false;
} else {
return true;
}
return transducer.Validate();
}
... ... @@ -105,6 +116,7 @@ std::string OfflineModelConfig::ToString() const {
os << "tdnn=" << tdnn.ToString() << ", ";
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
... ...
... ... @@ -24,6 +24,7 @@ struct OfflineModelConfig {
OfflineTdnnModelConfig tdnn;
OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc;
std::string telespeech_ctc;
std::string tokens;
int32_t num_threads = 2;
... ... @@ -52,6 +53,7 @@ struct OfflineModelConfig {
const OfflineTdnnModelConfig &tdnn,
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const std::string &telespeech_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type,
const std::string &modeling_unit,
... ... @@ -63,6 +65,7 @@ struct OfflineModelConfig {
tdnn(tdnn),
zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc),
telespeech_ctc(telespeech_ctc),
tokens(tokens),
num_threads(num_threads),
debug(debug),
... ...
... ... @@ -88,6 +88,17 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
#endif
void Init() {
if (!config_.model_config.telespeech_ctc.empty()) {
config_.feat_config.snip_edges = true;
config_.feat_config.num_ceps = 40;
config_.feat_config.feature_dim = 40;
config_.feat_config.low_freq = 40;
config_.feat_config.high_freq = -200;
config_.feat_config.use_energy = false;
config_.feat_config.normalize_samples = false;
config_.feat_config.is_mfcc = true;
}
if (!config_.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
... ...
... ... @@ -29,7 +29,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
... ... @@ -53,6 +54,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.paraformer.model;
} else if (!config.model_config.nemo_ctc.model.empty()) {
model_filename = config.model_config.nemo_ctc.model;
} else if (!config.model_config.telespeech_ctc.empty()) {
model_filename = config.model_config.telespeech_ctc;
} else if (!config.model_config.tdnn.model.empty()) {
model_filename = config.model_config.tdnn.model;
} else if (!config.model_config.zipformer_ctc.model.empty()) {
... ... @@ -111,6 +114,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n "
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
"\n"
"(7) CTC models from TeleSpeech"
"\n "
"https://github.com/Tele-AI/TeleSpeech-ASR"
"\n"
"\n");
exit(-1);
}
... ... @@ -133,7 +140,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
if (model_type == "EncDecCTCModelBPE" ||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}
... ... @@ -151,7 +159,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Whisper models\n"
" - Tdnn models\n"
" - Zipformer CTC models\n"
" - WeNet CTC models\n",
" - WeNet CTC models\n"
" - TeleSpeech CTC models\n",
model_type.c_str());
exit(-1);
... ... @@ -169,7 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
... ... @@ -199,6 +209,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.zipformer_ctc.model;
} else if (!config.model_config.wenet_ctc.model.empty()) {
model_filename = config.model_config.wenet_ctc.model;
} else if (!config.model_config.telespeech_ctc.empty()) {
model_filename = config.model_config.telespeech_ctc;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
... ... @@ -251,6 +263,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n "
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
"\n"
"(7) CTC models from TeleSpeech"
"\n "
"https://github.com/Tele-AI/TeleSpeech-ASR"
"\n"
"\n");
exit(-1);
}
... ... @@ -273,7 +289,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
if (model_type == "EncDecCTCModelBPE" ||
model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" ||
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc" ||
model_type == "telespeech_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
... ... @@ -291,7 +308,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Whisper models\n"
" - Tdnn models\n"
" - Zipformer CTC models\n"
" - WeNet CTC models\n",
" - WeNet CTC models\n"
" - TeleSpeech CTC models\n",
model_type.c_str());
exit(-1);
... ...
... ... @@ -57,22 +57,44 @@ class OfflineStream::Impl {
explicit Impl(const FeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: config_(config), context_graph_(context_graph) {
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;
if (config.is_mfcc) {
mfcc_opts_.frame_opts.dither = config_.dither;
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
mfcc_opts_.frame_opts.samp_freq = config_.sampling_rate;
mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms;
mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset;
mfcc_opts_.frame_opts.window_type = config_.window_type;
opts_.mel_opts.num_bins = config.feature_dim;
mfcc_opts_.mel_opts.num_bins = config_.feature_dim;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
mfcc_opts_.mel_opts.high_freq = config_.high_freq;
mfcc_opts_.mel_opts.low_freq = config_.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
mfcc_opts_.mel_opts.is_librosa = config_.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
mfcc_opts_.num_ceps = config_.num_ceps;
mfcc_opts_.use_energy = config_.use_energy;
mfcc_ = std::make_unique<knf::OnlineMfcc>(mfcc_opts_);
} else {
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;
opts_.mel_opts.num_bins = config.feature_dim;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
}
explicit Impl(WhisperTag /*tag*/) {
... ... @@ -81,6 +103,7 @@ class OfflineStream::Impl {
opts_.mel_opts.num_bins = 80; // not used
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
config_.sampling_rate = opts_.frame_opts.samp_freq;
}
explicit Impl(CEDTag /*tag*/) {
... ... @@ -98,6 +121,8 @@ class OfflineStream::Impl {
opts_.mel_opts.num_bins = 64;
opts_.mel_opts.high_freq = 8000;
config_.sampling_rate = opts_.frame_opts.samp_freq;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
... ... @@ -115,52 +140,60 @@ class OfflineStream::Impl {
void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
int32_t n) {
if (sampling_rate != opts_.frame_opts.samp_freq) {
if (sampling_rate != config_.sampling_rate) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
sampling_rate, static_cast<int32_t>(config_.sampling_rate));
float min_freq =
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
float min_freq = std::min<int32_t>(sampling_rate, config_.sampling_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
auto resampler = std::make_unique<LinearResample>(
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
sampling_rate, config_.sampling_rate, lowpass_cutoff,
lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
if (fbank_) {
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
fbank_->InputFinished();
} else if (mfcc_) {
mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
mfcc_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
samples.data(), samples.size());
whisper_fbank_->AcceptWaveform(config_.sampling_rate, samples.data(),
samples.size());
whisper_fbank_->InputFinished();
}
return;
} // if (sampling_rate != opts_.frame_opts.samp_freq)
} // if (sampling_rate != config_.sampling_rate)
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
} else if (mfcc_) {
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
mfcc_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
whisper_fbank_->InputFinished();
}
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
int32_t FeatureDim() const {
return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins;
}
std::vector<float> GetFrames() const {
int32_t n =
fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
int32_t n = fbank_ ? fbank_->NumFramesReady()
: mfcc_ ? mfcc_->NumFramesReady()
: whisper_fbank_->NumFramesReady();
assert(n > 0 && "Please first call AcceptWaveform()");
int32_t feature_dim = FeatureDim();
... ... @@ -170,8 +203,9 @@ class OfflineStream::Impl {
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f =
fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
const float *f = fbank_ ? fbank_->GetFrame(i)
: mfcc_ ? mfcc_->GetFrame(i)
: whisper_fbank_->GetFrame(i);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
... ... @@ -222,8 +256,10 @@ class OfflineStream::Impl {
private:
FeatureExtractorConfig config_;
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineMfcc> mfcc_;
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
knf::FbankOptions opts_;
knf::MfccOptions mfcc_opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
};
... ...
// sherpa-onnx/csrc/offline-telespeech-ctc-model.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class OfflineTeleSpeechCtcModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.telespeech_ctc);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.telespeech_ctc);
Init(buf.data(), buf.size());
}
#endif
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value /*features_length*/) {
std::vector<int64_t> shape =
features.GetTensorTypeAndShapeInfo().GetShape();
if (static_cast<int32_t>(shape[0]) != 1) {
SHERPA_ONNX_LOGE("This model supports only batch size 1. Given %d",
static_cast<int32_t>(shape[0]));
}
auto out = sess_->Run({}, input_names_ptr_.data(), &features, 1,
output_names_ptr_.data(), output_names_ptr_.size());
std::vector<int64_t> logits_shape = {1};
Ort::Value logits_length = Ort::Value::CreateTensor<int64_t>(
allocator_, logits_shape.data(), logits_shape.size());
int64_t *dst = logits_length.GetTensorMutableData<int64_t>();
dst[0] = out[0].GetTensorTypeAndShapeInfo().GetShape()[0];
// (T, B, C) -> (B, T, C)
Ort::Value logits = Transpose01(allocator_, &out[0]);
std::vector<Ort::Value> ans;
ans.reserve(2);
ans.push_back(std::move(logits));
ans.push_back(std::move(logits_length));
return ans;
}
int32_t VocabSize() const { return vocab_size_; }
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
{
auto shape =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
vocab_size_ = shape[2];
}
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t vocab_size_ = 0;
int32_t subsampling_factor_ = 4;
};
OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel(
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel(
AAssetManager *mgr, const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineTeleSpeechCtcModel::~OfflineTeleSpeechCtcModel() = default;
std::vector<Ort::Value> OfflineTeleSpeechCtcModel::Forward(
Ort::Value features, Ort::Value features_length) {
return impl_->Forward(std::move(features), std::move(features_length));
}
int32_t OfflineTeleSpeechCtcModel::VocabSize() const {
return impl_->VocabSize();
}
int32_t OfflineTeleSpeechCtcModel::SubsamplingFactor() const {
return impl_->SubsamplingFactor();
}
OrtAllocator *OfflineTeleSpeechCtcModel::Allocator() const {
return impl_->Allocator();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-telespeech-ctc-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
/** This class implements the CTC model from
* https://github.com/Tele-AI/TeleSpeech-ASR.
*
* See
* https://github.com/lovemefan/telespeech-asr-python/blob/main/telespeechasr/onnx/onnx_infer.py
* and
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/tele-speech/test.py
*/
class OfflineTeleSpeechCtcModel : public OfflineCtcModel {
public:
explicit OfflineTeleSpeechCtcModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineTeleSpeechCtcModel(AAssetManager *mgr,
const OfflineModelConfig &config);
#endif
~OfflineTeleSpeechCtcModel() override;
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C).
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a vector containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** SubsamplingFactor of the model
*/
int32_t SubsamplingFactor() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
// TeleSpeech CTC models do not support batch size > 1
bool SupportBatchProcessing() const override { return false; }
std::string FeatureNormalizationMethod() const override {
return "per_feature";
}
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_
... ...
... ... @@ -66,7 +66,7 @@ bool OnlineModelConfig::Validate() const {
if (!modeling_unit.empty() &&
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
SHERPA_ONNX_LOGE("bpe_vocab: '%s' does not exist", bpe_vocab.c_str());
return false;
}
}
... ...
... ... @@ -7,6 +7,7 @@ public class OfflineModelConfig {
private final OfflineParaformerModelConfig paraformer;
private final OfflineWhisperModelConfig whisper;
private final OfflineNemoEncDecCtcModelConfig nemo;
private final String teleSpeech;
private final String tokens;
private final int numThreads;
private final boolean debug;
... ... @@ -21,6 +22,7 @@ public class OfflineModelConfig {
this.paraformer = builder.paraformer;
this.whisper = builder.whisper;
this.nemo = builder.nemo;
this.teleSpeech = builder.teleSpeech;
this.tokens = builder.tokens;
this.numThreads = builder.numThreads;
this.debug = builder.debug;
... ... @@ -74,11 +76,16 @@ public class OfflineModelConfig {
return bpeVocab;
}
public String getTeleSpeech() {
return teleSpeech;
}
public static class Builder {
private OfflineParaformerModelConfig paraformer = OfflineParaformerModelConfig.builder().build();
private OfflineTransducerModelConfig transducer = OfflineTransducerModelConfig.builder().build();
private OfflineWhisperModelConfig whisper = OfflineWhisperModelConfig.builder().build();
private OfflineNemoEncDecCtcModelConfig nemo = OfflineNemoEncDecCtcModelConfig.builder().build();
private String teleSpeech = "";
private String tokens = "";
private int numThreads = 1;
private boolean debug = true;
... ... @@ -106,6 +113,12 @@ public class OfflineModelConfig {
return this;
}
public Builder setTeleSpeech(String teleSpeech) {
this.teleSpeech = teleSpeech;
return this;
}
public Builder setWhisper(OfflineWhisperModelConfig whisper) {
this.whisper = whisper;
return this;
... ...
... ... @@ -172,6 +172,12 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
ans.model_config.nemo_ctc.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "teleSpeech", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.telespeech_ctc = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
... ...
... ... @@ -35,6 +35,7 @@ data class OfflineModelConfig(
var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),
var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),
var nemo: OfflineNemoEncDecCtcModelConfig = OfflineNemoEncDecCtcModelConfig(),
var teleSpeech: String = "",
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
... ... @@ -272,6 +273,15 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
tokens = "$modelDir/tokens.txt",
)
}
11 -> {
val modelDir = "sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04"
return OfflineModelConfig(
teleSpeech = "$modelDir/model.int8.onnx",
tokens = "$modelDir/tokens.txt",
modelType = "tele_speech",
)
}
}
return null
}
... ...
... ... @@ -29,25 +29,27 @@ void PybindOfflineModelConfig(py::module *m) {
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
.def(py::init<const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &,
const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &,
const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("whisper") = OfflineWhisperModelConfig(),
py::arg("tdnn") = OfflineTdnnModelConfig(),
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "",
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
.def(
py::init<
const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &, const std::string &,
const std::string &, int32_t, bool, const std::string &,
const std::string &, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("whisper") = OfflineWhisperModelConfig(),
py::arg("tdnn") = OfflineTdnnModelConfig(),
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("telespeech_ctc") = "", py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "",
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
... ... @@ -55,6 +57,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("tdnn", &PyClass::tdnn)
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
... ...
... ... @@ -212,6 +212,71 @@ class OfflineRecognizer(object):
return self
@classmethod
def from_telespeech_ctc(
cls,
model: str,
tokens: str,
num_threads: int = 1,
sample_rate: int = 16000,
feature_dim: int = 40,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
):
"""
Please refer to
`<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_
to download pre-trained models.
Args:
model:
Path to ``model.onnx``.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model. It is
ignored and is hard-coded in C++ to 40.
feature_dim:
Dimension of the feature used to train the model. It is ignored
and is hard-coded in C++ to 40.
decoding_method:
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
telespeech_ctc=model,
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="nemo_ctc",
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_nemo_ctc(
cls,
model: str,
... ...
... ... @@ -102,7 +102,7 @@ func sherpaOnnxOnlineModelConfig(
debug: Int32(debug),
model_type: toCPointer(modelType),
modeling_unit: toCPointer(modelingUnit),
bpeVocab: toCPointer(bpeVocab)
bpe_vocab: toCPointer(bpeVocab)
)
}
... ... @@ -360,7 +360,8 @@ func sherpaOnnxOfflineModelConfig(
debug: Int = 0,
modelType: String = "",
modelingUnit: String = "cjkchar",
bpeVocab: String = ""
bpeVocab: String = "",
teleSpeechCtc: String = ""
) -> SherpaOnnxOfflineModelConfig {
return SherpaOnnxOfflineModelConfig(
transducer: transducer,
... ... @@ -374,7 +375,8 @@ func sherpaOnnxOfflineModelConfig(
provider: toCPointer(provider),
model_type: toCPointer(modelType),
modeling_unit: toCPointer(modelingUnit),
bpeVocab: toCPointer(bpeVocab)
bpe_vocab: toCPointer(bpeVocab),
telespeech_ctc: toCPointer(teleSpeechCtc)
)
}
... ...
... ... @@ -529,7 +529,7 @@ function initSherpaOnnxOfflineModelConfig(config, Module) {
const tdnn = initSherpaOnnxOfflineTdnnModelConfig(config.tdnn, Module);
const len = transducer.len + paraformer.len + nemoCtc.len + whisper.len +
tdnn.len + 7 * 4;
tdnn.len + 8 * 4;
const ptr = Module._malloc(len);
let offset = 0;
... ... @@ -553,9 +553,11 @@ function initSherpaOnnxOfflineModelConfig(config, Module) {
const modelTypeLen = Module.lengthBytesUTF8(config.modelType) + 1;
const modelingUnitLen = Module.lengthBytesUTF8(config.modelingUnit || '') + 1;
const bpeVocabLen = Module.lengthBytesUTF8(config.bpeVocab || '') + 1;
const teleSpeechCtcLen =
Module.lengthBytesUTF8(config.teleSpeechCtc || '') + 1;
const bufferLen =
tokensLen + providerLen + modelTypeLen + modelingUnitLen + bpeVocabLen;
const bufferLen = tokensLen + providerLen + modelTypeLen + modelingUnitLen +
bpeVocabLen + teleSpeechCtcLen;
const buffer = Module._malloc(bufferLen);
offset = 0;
... ... @@ -575,6 +577,10 @@ function initSherpaOnnxOfflineModelConfig(config, Module) {
Module.stringToUTF8(config.bpeVocab || '', buffer + offset, bpeVocabLen);
offset += bpeVocabLen;
Module.stringToUTF8(
config.teleSpeechCtc || '', buffer + offset, teleSpeechCtcLen);
offset += teleSpeechCtcLen;
offset =
transducer.len + paraformer.len + nemoCtc.len + whisper.len + tdnn.len;
Module.setValue(ptr + offset, buffer, 'i8*'); // tokens
... ... @@ -604,6 +610,13 @@ function initSherpaOnnxOfflineModelConfig(config, Module) {
'i8*'); // bpeVocab
offset += 4;
Module.setValue(
ptr + offset,
buffer + tokensLen + providerLen + modelTypeLen + modelingUnitLen +
bpeVocabLen,
'i8*'); // teleSpeechCtc
offset += 4;
return {
buffer: buffer, ptr: ptr, len: len, transducer: transducer,
paraformer: paraformer, nemoCtc: nemoCtc, whisper: whisper, tdnn: tdnn
... ...
... ... @@ -23,7 +23,7 @@ static_assert(sizeof(SherpaOnnxOfflineModelConfig) ==
sizeof(SherpaOnnxOfflineParaformerModelConfig) +
sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) +
sizeof(SherpaOnnxOfflineWhisperModelConfig) +
sizeof(SherpaOnnxOfflineTdnnModelConfig) + 7 * 4,
sizeof(SherpaOnnxOfflineTdnnModelConfig) + 8 * 4,
"");
static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, "");
static_assert(sizeof(SherpaOnnxOfflineRecognizerConfig) ==
... ... @@ -92,6 +92,7 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) {
fprintf(stdout, "model type: %s\n", model_config->model_type);
fprintf(stdout, "modeling unit: %s\n", model_config->modeling_unit);
fprintf(stdout, "bpe vocab: %s\n", model_config->bpe_vocab);
fprintf(stdout, "telespeech_ctc: %s\n", model_config->telespeech_ctc);
fprintf(stdout, "----------feat config----------\n");
fprintf(stdout, "sample rate: %d\n", feat->sample_rate);
... ...