Fangjun Kuang
Committed by GitHub

Support streaming zipformer CTC (#496)

* Support streaming zipformer CTC

* test online zipformer2 CTC

* Update doc of sherpa-onnx.cc

* Add Python APIs for streaming zipformer2 ctc

* Add Python API examples for streaming zipformer2 ctc

* Swift API for streaming zipformer2 CTC

* NodeJS API for streaming zipformer2 CTC

* Kotlin API for streaming zipformer2 CTC

* Golang API for streaming zipformer2 CTC

* C# API for streaming zipformer2 CTC

* Release v1.9.6
正在显示 70 个修改的文件 包含 1480 行增加174 行删除
... ... @@ -51,6 +51,13 @@ rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
node ./test-online-transducer.js
rm -rf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
node ./test-online-zipformer2-ctc.js
rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
# offline tts
curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
... ...
... ... @@ -14,6 +14,37 @@ echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run streaming Zipformer2 CTC "
log "------------------------------------------------------------"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
repo=$(basename -s .tar.bz2 $url)
curl -SL -O $url
tar xvf $repo.tar.bz2
rm $repo.tar.bz2
log "test fp32"
time $EXE \
--debug=1 \
--zipformer2-ctc-model=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--tokens=$repo/tokens.txt \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
log "test int8"
time $EXE \
--debug=1 \
--zipformer2-ctc-model=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
--tokens=$repo/tokens.txt \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
log "------------------------------------------------------------"
log "Run streaming Conformer CTC from WeNet"
log "------------------------------------------------------------"
wenet_models=(
... ...
... ... @@ -8,6 +8,27 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
mkdir -p /tmp/icefall-models
dir=/tmp/icefall-models
pushd $dir
wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
popd
repo=$dir/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
python3 ./python-api-examples/online-decode-files.py \
--tokens=$repo/tokens.txt \
--zipformer2-ctc=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
rm -rf $dir/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
wenet_models=(
sherpa-onnx-zh-wenet-aishell
sherpa-onnx-zh-wenet-aishell2
... ... @@ -17,8 +38,6 @@ sherpa-onnx-en-wenet-librispeech
sherpa-onnx-en-wenet-gigaspeech
)
mkdir -p /tmp/icefall-models
dir=/tmp/icefall-models
for name in ${wenet_models[@]}; do
repo_url=https://huggingface.co/csukuangfj/$name
... ...
... ... @@ -22,6 +22,9 @@ cat /Users/fangjun/Desktop/Obama.srt
ls -lh
./run-decode-file.sh
rm decode-file
sed -i.bak '20d' ./decode-file.swift
./run-decode-file.sh
./run-decode-file-non-streaming.sh
... ...
... ... @@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -24,7 +24,7 @@ jobs:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -107,23 +107,23 @@ jobs:
name: release-static
path: build/bin/*
- name: Test offline Whisper
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
readelf -d build/bin/sherpa-onnx-offline
export EXE=sherpa-onnx
.github/scripts/test-offline-whisper.sh
.github/scripts/test-online-ctc.sh
- name: Test online CTC
- name: Test offline Whisper
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
export EXE=sherpa-onnx-offline
.github/scripts/test-online-ctc.sh
readelf -d build/bin/sherpa-onnx-offline
.github/scripts/test-offline-whisper.sh
- name: Test offline CTC
shell: bash
... ...
... ... @@ -25,7 +25,7 @@ jobs:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -55,7 +55,7 @@ jobs:
key: ${{ matrix.os }}-python-${{ matrix.python-version }}
- name: Setup Python
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -49,7 +49,7 @@ jobs:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -29,7 +29,7 @@ jobs:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -61,7 +61,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
os: [ubuntu-latest, macos-latest] #, windows-latest]
python-version: ["3.8"]
steps:
... ... @@ -70,7 +70,7 @@ jobs:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ... @@ -143,6 +143,7 @@ jobs:
cd dotnet-examples/
cd online-decode-files
./run-zipformer2-ctc.sh
./run-transducer.sh
./run-paraformer.sh
... ...
... ... @@ -53,7 +53,7 @@ jobs:
mkdir build
cd build
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_SHARED_LIBS=ON -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF ..
make -j
make -j1
cp -v _deps/onnxruntime-src/lib/libonnxruntime*dylib ./lib/
cd ../scripts/go/_internal/
... ... @@ -153,6 +153,14 @@ jobs:
git lfs install
echo "Test zipformer2 CTC"
wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
./run-zipformer2-ctc.sh
rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
echo "Test transducer"
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
./run-transducer.sh
... ...
... ... @@ -34,7 +34,7 @@ jobs:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -52,7 +52,7 @@ jobs:
ls -lh install/lib
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -40,7 +40,7 @@ jobs:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -38,7 +38,7 @@ jobs:
key: ${{ matrix.os }}-python-${{ matrix.python-version }}
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ...
... ... @@ -25,7 +25,7 @@ jobs:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
model_type: ["transducer", "paraformer"]
model_type: ["transducer", "paraformer", "zipformer2-ctc"]
steps:
- uses: actions/checkout@v4
... ... @@ -38,7 +38,7 @@ jobs:
key: ${{ matrix.os }}-python-${{ matrix.python-version }}
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
... ... @@ -57,6 +57,26 @@ jobs:
python3 -m pip install --no-deps --verbose .
python3 -m pip install websockets
- name: Start server for zipformer2 CTC models
if: matrix.model_type == 'zipformer2-ctc'
shell: bash
run: |
curl -O -L https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
python3 ./python-api-examples/streaming_server.py \
--zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt &
echo "sleep 10 seconds to wait the server start"
sleep 10
- name: Start client for zipformer2 CTC models
if: matrix.model_type == 'zipformer2-ctc'
shell: bash
run: |
python3 ./python-api-examples/online-websocket-client-decode-file.py \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav
- name: Start server for transducer models
if: matrix.model_type == 'transducer'
... ...
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.9.4")
set(SHERPA_ONNX_VERSION "1.9.6")
# Disable warning about
#
... ...
... ... @@ -26,9 +26,14 @@ data class OnlineParaformerModelConfig(
var decoder: String = "",
)
data class OnlineZipformer2CtcModelConfig(
var model: String = "",
)
data class OnlineModelConfig(
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
var zipformer2Ctc: OnlineZipformer2CtcModelConfig = OnlineZipformer2CtcModelConfig(),
var tokens: String,
var numThreads: Int = 1,
var debug: Boolean = false,
... ...
... ... @@ -38,6 +38,9 @@ class OnlineDecodeFiles
[Option("paraformer-decoder", Required = false, HelpText = "Path to paraformer decoder.onnx")]
public string ParaformerDecoder { get; set; }
[Option("zipformer2-ctc", Required = false, HelpText = "Path to zipformer2 CTC onnx model")]
public string Zipformer2Ctc { get; set; }
[Option("num-threads", Required = false, Default = 1, HelpText = "Number of threads for computation")]
public int NumThreads { get; set; }
... ... @@ -107,7 +110,19 @@ dotnet run \
--files ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav
(2) Streaming Paraformer models
(2) Streaming Zipformer2 Ctc models
dotnet run -c Release \
--tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
--zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--files ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000113.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000219.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000351.wav
(3) Streaming Paraformer models
dotnet run \
--tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
--paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
... ... @@ -121,6 +136,7 @@ dotnet run \
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html
to download pre-trained streaming models.
";
... ... @@ -150,6 +166,8 @@ to download pre-trained streaming models.
config.ModelConfig.Paraformer.Encoder = options.ParaformerEncoder;
config.ModelConfig.Paraformer.Decoder = options.ParaformerDecoder;
config.ModelConfig.Zipformer2Ctc.Model = options.Zipformer2Ctc;
config.ModelConfig.Tokens = options.Tokens;
config.ModelConfig.Provider = options.Provider;
config.ModelConfig.NumThreads = options.NumThreads;
... ...
#!/usr/bin/env bash
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese
# to download the model files
if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi
dotnet run -c Release \
--tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
--zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--files ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000113.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000219.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000351.wav
... ...
... ... @@ -22,6 +22,7 @@ func main() {
flag.StringVar(&config.ModelConfig.Transducer.Joiner, "joiner", "", "Path to the transducer joiner model")
flag.StringVar(&config.ModelConfig.Paraformer.Encoder, "paraformer-encoder", "", "Path to the paraformer encoder model")
flag.StringVar(&config.ModelConfig.Paraformer.Decoder, "paraformer-decoder", "", "Path to the paraformer decoder model")
flag.StringVar(&config.ModelConfig.Zipformer2Ctc.Model, "zipformer2-ctc", "", "Path to the zipformer2 CTC model")
flag.StringVar(&config.ModelConfig.Tokens, "tokens", "", "Path to the tokens file")
flag.IntVar(&config.ModelConfig.NumThreads, "num-threads", 1, "Number of threads for computing")
flag.IntVar(&config.ModelConfig.Debug, "debug", 0, "Whether to show debug message")
... ...
#!/usr/bin/env bash
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese
# to download the model
# before you run this script.
#
# You can switch to a different online model if you need
./streaming-decode-files \
--zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav
... ...
... ... @@ -8,7 +8,8 @@ fun callback(samples: FloatArray): Unit {
fun main() {
testTts()
testAsr()
testAsr("transducer")
testAsr("zipformer2-ctc")
}
fun testTts() {
... ... @@ -30,16 +31,20 @@ fun testTts() {
audio.save(filename="test-en.wav")
}
fun testAsr() {
fun testAsr(type: String) {
var featConfig = FeatureConfig(
sampleRate = 16000,
featureDim = 80,
)
var waveFilename: String
var modelConfig: OnlineModelConfig = when (type) {
"transducer" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav"
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models
var modelConfig = OnlineModelConfig(
OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
... ... @@ -49,6 +54,20 @@ fun testAsr() {
numThreads = 1,
debug = false,
)
}
"zipformer2-ctc" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"
OnlineModelConfig(
zipformer2Ctc = OnlineZipformer2CtcModelConfig(
model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt",
numThreads = 1,
debug = false,
)
}
else -> throw IllegalArgumentException(type)
}
var endpointConfig = EndpointConfig()
... ... @@ -69,7 +88,7 @@ fun testAsr() {
)
var objArray = WaveReader.readWaveFromFile(
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
filename = waveFilename,
)
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
... ...
... ... @@ -34,6 +34,12 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
fi
if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi
if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
tar xf vits-piper-en_US-amy-low.tar.bz2
... ...
... ... @@ -85,7 +85,7 @@ npm install wav naudiodon2
how to decode a file with a NeMo CTC model. In the code we use
[stt_en_conformer_ctc_small](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/english.html#stt-en-conformer-ctc-small).
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-conformer-small.tar.bz2
... ... @@ -99,7 +99,7 @@ node ./test-offline-nemo-ctc.js
how to decode a file with a non-streaming Paraformer model. In the code we use
[sherpa-onnx-paraformer-zh-2023-03-28](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese).
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
... ... @@ -113,7 +113,7 @@ node ./test-offline-paraformer.js
how to decode a file with a non-streaming transducer model. In the code we use
[sherpa-onnx-zipformer-en-2023-06-26](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-zipformer-en-2023-06-26-english).
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-en-2023-06-26.tar.bz2
... ... @@ -126,7 +126,7 @@ node ./test-offline-transducer.js
how to decode a file with a Whisper model. In the code we use
[sherpa-onnx-whisper-tiny.en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html).
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
... ... @@ -140,7 +140,7 @@ demonstrates how to do real-time speech recognition from microphone
with a streaming Paraformer model. In the code we use
[sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english).
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
... ... @@ -153,7 +153,7 @@ node ./test-online-paraformer-microphone.js
how to decode a file using a streaming Paraformer model. In the code we use
[sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english).
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
... ... @@ -167,7 +167,7 @@ demonstrates how to do real-time speech recognition with microphone using a stre
we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english).
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
... ... @@ -180,7 +180,7 @@ node ./test-online-transducer-microphone.js
how to decode a file using a streaming transducer model. In the code
we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english).
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
... ... @@ -188,13 +188,26 @@ tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
node ./test-online-transducer.js
```
## ./test-online-zipformer2-ctc.js
[./test-online-zipformer2-ctc.js](./test-online-zipformer2-ctc.js) demonstrates
how to decode a file using a streaming zipformer2 CTC model. In the code
we use [sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese).
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
node ./test-online-zipformer2-ctc.js
```
## ./test-vad-microphone-offline-paraformer.js
[./test-vad-microphone-offline-paraformer.js](./test-vad-microphone-offline-paraformer.js)
demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad)
with non-streaming Paraformer for speech recognition from microphone.
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
... ... @@ -209,7 +222,7 @@ node ./test-vad-microphone-offline-paraformer.js
demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad)
with a non-streaming transducer model for speech recognition from microphone.
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
... ... @@ -224,7 +237,7 @@ node ./test-vad-microphone-offline-transducer.js
demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad)
with whisper for speech recognition from microphone.
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
... ... @@ -238,7 +251,7 @@ node ./test-vad-microphone-offline-whisper.js
[./test-vad-microphone.js](./test-vad-microphone.js)
demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad).
You can use the following command run it:
You can use the following command to run it:
```bash
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
... ...
// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
//
const fs = require('fs');
const {Readable} = require('stream');
const wav = require('wav');
const sherpa_onnx = require('sherpa-onnx');
function createRecognizer() {
const featConfig = new sherpa_onnx.FeatureConfig();
featConfig.sampleRate = 16000;
featConfig.featureDim = 80;
// test online recognizer
const zipformer2Ctc = new sherpa_onnx.OnlineZipformer2CtcModelConfig();
zipformer2Ctc.model =
'./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx';
const tokens =
'./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt';
const modelConfig = new sherpa_onnx.OnlineModelConfig();
modelConfig.zipformer2Ctc = zipformer2Ctc;
modelConfig.tokens = tokens;
const recognizerConfig = new sherpa_onnx.OnlineRecognizerConfig();
recognizerConfig.featConfig = featConfig;
recognizerConfig.modelConfig = modelConfig;
recognizerConfig.decodingMethod = 'greedy_search';
recognizer = new sherpa_onnx.OnlineRecognizer(recognizerConfig);
return recognizer;
}
recognizer = createRecognizer();
stream = recognizer.createStream();
const waveFilename =
'./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav';
const reader = new wav.Reader();
const readable = new Readable().wrap(reader);
function decode(samples) {
stream.acceptWaveform(recognizer.config.featConfig.sampleRate, samples);
while (recognizer.isReady(stream)) {
recognizer.decode(stream);
}
const r = recognizer.getResult(stream);
console.log(r.text);
}
reader.on('format', ({audioFormat, bitDepth, channels, sampleRate}) => {
if (sampleRate != recognizer.config.featConfig.sampleRate) {
throw new Error(`Only support sampleRate ${
recognizer.config.featConfig.sampleRate}. Given ${sampleRate}`);
}
if (audioFormat != 1) {
throw new Error(`Only support PCM format. Given ${audioFormat}`);
}
if (channels != 1) {
throw new Error(`Only a single channel. Given ${channel}`);
}
if (bitDepth != 16) {
throw new Error(`Only support 16-bit samples. Given ${bitDepth}`);
}
});
fs.createReadStream(waveFilename, {'highWaterMark': 4096})
.pipe(reader)
.on('finish', function(err) {
// tail padding
const floatSamples =
new Float32Array(recognizer.config.featConfig.sampleRate * 0.5);
decode(floatSamples);
stream.free();
recognizer.free();
});
readable.on('readable', function() {
let chunk;
while ((chunk = readable.read()) != null) {
const int16Samples = new Int16Array(
chunk.buffer, chunk.byteOffset,
chunk.length / Int16Array.BYTES_PER_ELEMENT);
const floatSamples = new Float32Array(int16Samples.length);
for (let i = 0; i < floatSamples.length; i++) {
floatSamples[i] = int16Samples[i] / 32768.0;
}
decode(floatSamples);
}
});
... ...
... ... @@ -37,7 +37,20 @@ git lfs pull --include "*.onnx"
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
(3) Streaming Conformer CTC from WeNet
(3) Streaming Zipformer2 CTC
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
ls -lh sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
./python-api-examples/online-decode-files.py \
--zipformer2-ctc=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav
(4) Streaming Conformer CTC from WeNet
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
cd sherpa-onnx-zh-wenet-wenetspeech
... ... @@ -51,12 +64,9 @@ git lfs pull --include "*.onnx"
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
and
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
to install sherpa-onnx and to download streaming pre-trained models.
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
to download streaming pre-trained models.
"""
import argparse
import time
... ... @@ -98,6 +108,12 @@ def get_args():
)
parser.add_argument(
"--zipformer2-ctc",
type=str,
help="Path to the zipformer2 ctc model",
)
parser.add_argument(
"--paraformer-encoder",
type=str,
help="Path to the paraformer encoder model",
... ... @@ -112,7 +128,7 @@ def get_args():
parser.add_argument(
"--wenet-ctc",
type=str,
help="Path to the wenet ctc model model",
help="Path to the wenet ctc model",
)
parser.add_argument(
... ... @@ -275,6 +291,16 @@ def main():
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
)
elif args.zipformer2_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
tokens=args.tokens,
model=args.zipformer2_ctc,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
decoding_method="greedy_search",
)
elif args.paraformer_encoder:
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
tokens=args.tokens,
... ...
... ... @@ -25,6 +25,7 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc
import argparse
import asyncio
import json
import logging
import wave
... ... @@ -112,7 +113,7 @@ async def receive_results(socket: websockets.WebSocketServerProtocol):
async for message in socket:
if message != "Done!":
last_message = message
logging.info(message)
logging.info(json.loads(message))
else:
break
return last_message
... ... @@ -151,7 +152,7 @@ async def run(
await websocket.send("Done")
decoding_results = await receive_task
logging.info(f"\nFinal result is:\n{decoding_results}")
logging.info(f"\nFinal result is:\n{json.loads(decoding_results)}")
async def main():
... ...
... ... @@ -138,6 +138,12 @@ def add_model_args(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--zipformer2-ctc",
type=str,
help="Path to the model file from zipformer2 ctc",
)
parser.add_argument(
"--wenet-ctc",
type=str,
help="Path to the model.onnx from WeNet",
... ... @@ -405,6 +411,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
elif args.zipformer2_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
tokens=args.tokens,
model=args.zipformer2_ctc,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
enable_endpoint_detection=args.use_endpoint != 0,
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
elif args.wenet_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
tokens=args.tokens,
... ... @@ -748,6 +768,8 @@ def check_args(args):
assert args.paraformer_encoder is None, args.paraformer_encoder
assert args.paraformer_decoder is None, args.paraformer_decoder
assert args.zipformer2_ctc is None, args.zipformer2_ctc
assert args.wenet_ctc is None, args.wenet_ctc
elif args.paraformer_encoder:
assert Path(
args.paraformer_encoder
... ... @@ -756,6 +778,10 @@ def check_args(args):
assert Path(
args.paraformer_decoder
).is_file(), f"{args.paraformer_decoder} does not exist"
elif args.zipformer2_ctc:
assert Path(
args.zipformer2_ctc
).is_file(), f"{args.zipformer2_ctc} does not exist"
elif args.wenet_ctc:
assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist"
else:
... ...
... ... @@ -51,12 +51,25 @@ namespace SherpaOnnx
}
[StructLayout(LayoutKind.Sequential)]
public struct OnlineZipformer2CtcModelConfig
{
public OnlineZipformer2CtcModelConfig()
{
Model = "";
}
[MarshalAs(UnmanagedType.LPStr)]
public string Model;
}
[StructLayout(LayoutKind.Sequential)]
public struct OnlineModelConfig
{
public OnlineModelConfig()
{
Transducer = new OnlineTransducerModelConfig();
Paraformer = new OnlineParaformerModelConfig();
Zipformer2Ctc = new OnlineZipformer2CtcModelConfig();
Tokens = "";
NumThreads = 1;
Provider = "cpu";
... ... @@ -66,6 +79,7 @@ namespace SherpaOnnx
public OnlineTransducerModelConfig Transducer;
public OnlineParaformerModelConfig Paraformer;
public OnlineZipformer2CtcModelConfig Zipformer2Ctc;
[MarshalAs(UnmanagedType.LPStr)]
public string Tokens;
... ...
../../../../go-api-examples/streaming-decode-files/run-zipformer2-ctc.sh
\ No newline at end of file
... ...
... ... @@ -65,6 +65,13 @@ type OnlineParaformerModelConfig struct {
Decoder string // Path to the decoder model.
}
// Please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html
// to download pre-trained models
type OnlineZipformer2CtcModelConfig struct {
Model string // Path to the onnx model
}
// Configuration for online/streaming models
//
// Please refer to
... ... @@ -74,6 +81,7 @@ type OnlineParaformerModelConfig struct {
type OnlineModelConfig struct {
Transducer OnlineTransducerModelConfig
Paraformer OnlineParaformerModelConfig
Zipformer2Ctc OnlineZipformer2CtcModelConfig
Tokens string // Path to tokens.txt
NumThreads int // Number of threads to use for neural network computation
Provider string // Optional. Valid values are: cpu, cuda, coreml
... ... @@ -157,6 +165,9 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer {
c.model_config.paraformer.decoder = C.CString(config.ModelConfig.Paraformer.Decoder)
defer C.free(unsafe.Pointer(c.model_config.paraformer.decoder))
c.model_config.zipformer2_ctc.model = C.CString(config.ModelConfig.Zipformer2Ctc.Model)
defer C.free(unsafe.Pointer(c.model_config.zipformer2_ctc.model))
c.model_config.tokens = C.CString(config.ModelConfig.Tokens)
defer C.free(unsafe.Pointer(c.model_config.tokens))
... ...
... ... @@ -41,9 +41,14 @@ const SherpaOnnxOnlineParaformerModelConfig = StructType({
"decoder" : cstring,
});
const SherpaOnnxOnlineZipformer2CtcModelConfig = StructType({
"model" : cstring,
});
const SherpaOnnxOnlineModelConfig = StructType({
"transducer" : SherpaOnnxOnlineTransducerModelConfig,
"paraformer" : SherpaOnnxOnlineParaformerModelConfig,
"zipformer2Ctc" : SherpaOnnxOnlineZipformer2CtcModelConfig,
"tokens" : cstring,
"numThreads" : int32_t,
"provider" : cstring,
... ... @@ -663,6 +668,7 @@ const OnlineModelConfig = SherpaOnnxOnlineModelConfig;
const FeatureConfig = SherpaOnnxFeatureConfig;
const OnlineRecognizerConfig = SherpaOnnxOnlineRecognizerConfig;
const OnlineParaformerModelConfig = SherpaOnnxOnlineParaformerModelConfig;
const OnlineZipformer2CtcModelConfig = SherpaOnnxOnlineZipformer2CtcModelConfig;
// offline asr
const OfflineTransducerModelConfig = SherpaOnnxOfflineTransducerModelConfig;
... ... @@ -692,6 +698,7 @@ module.exports = {
OnlineRecognizer,
OnlineStream,
OnlineParaformerModelConfig,
OnlineZipformer2CtcModelConfig,
// offline asr
OfflineRecognizer,
... ...
... ... @@ -54,6 +54,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
recognizer_config.model_config.paraformer.decoder =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder, "");
recognizer_config.model_config.zipformer2_ctc.model =
SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, "");
recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads =
... ...
... ... @@ -66,9 +66,17 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineParaformerModelConfig {
const char *decoder;
} SherpaOnnxOnlineParaformerModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxModelConfig {
// Please visit
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#
// to download pre-trained streaming zipformer2 ctc models
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineZipformer2CtcModelConfig {
const char *model;
} SherpaOnnxOnlineZipformer2CtcModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
SherpaOnnxOnlineTransducerModelConfig transducer;
SherpaOnnxOnlineParaformerModelConfig paraformer;
SherpaOnnxOnlineZipformer2CtcModelConfig zipformer2_ctc;
const char *tokens;
int32_t num_threads;
const char *provider;
... ...
... ... @@ -70,6 +70,8 @@ set(sources
online-wenet-ctc-model-config.cc
online-wenet-ctc-model.cc
online-zipformer-transducer-model.cc
online-zipformer2-ctc-model-config.cc
online-zipformer2-ctc-model.cc
online-zipformer2-transducer-model.cc
onnx-utils.cc
packed-sequence.cc
... ...
... ... @@ -12,6 +12,9 @@
namespace sherpa_onnx {
struct OnlineCtcDecoderResult {
/// Number of frames after subsampling we have decoded so far
int32_t frame_offset = 0;
/// The decoded token IDs
std::vector<int64_t> tokens;
... ...
... ... @@ -49,12 +49,17 @@ void OnlineCtcGreedySearchDecoder::Decode(
if (y != blank_id_ && y != prev_id) {
r.tokens.push_back(y);
r.timestamps.push_back(t);
r.timestamps.push_back(t + r.frame_offset);
}
prev_id = y;
} // for (int32_t t = 0; t != num_frames; ++t) {
} // for (int32_t b = 0; b != batch_size; ++b)
// Update frame_offset
for (auto &r : *results) {
r.frame_offset += num_frames;
}
}
} // namespace sherpa_onnx
... ...
... ... @@ -11,127 +11,35 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace {
enum class ModelType {
kZipformerCtc,
kWenetCtc,
kUnkown,
};
} // namespace
namespace sherpa_onnx {
static ModelType GetModelType(char *model_data, size_t model_data_length,
bool debug) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
Ort::SessionOptions sess_opts;
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
sess_opts);
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
if (debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"If you are using models from WeNet, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
"run.sh\n"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnkown;
}
if (model_type.get() == std::string("zipformer2")) {
return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) {
return ModelType::kWenetCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
}
}
std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
const OnlineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
std::string filename;
if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
return std::make_unique<OnlineWenetCtcModel>(config);
} else if (!config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineZipformer2CtcModel>(config);
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
}
{
auto buffer = ReadFile(filename);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kZipformerCtc:
return nullptr;
// return std::make_unique<OnlineZipformerCtcModel>(config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OnlineWenetCtcModel>(config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in online CTC!");
return nullptr;
}
return nullptr;
}
#if __ANDROID_API__ >= 9
std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
AAssetManager *mgr, const OnlineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
std::string filename;
if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
return std::make_unique<OnlineWenetCtcModel>(mgr, config);
} else if (!config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineZipformer2CtcModel>(mgr, config);
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
}
{
auto buffer = ReadFile(mgr, filename);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kZipformerCtc:
return nullptr;
// return std::make_unique<OnlineZipformerCtcModel>(mgr, config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OnlineWenetCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in online CTC!");
return nullptr;
}
return nullptr;
}
#endif
... ...
... ... @@ -33,6 +33,26 @@ class OnlineCtcModel {
// Return a list of tensors containing the initial states
virtual std::vector<Ort::Value> GetInitStates() const = 0;
/** Stack a list of individual states into a batch.
*
* It is the inverse operation of `UnStackStates`.
*
* @param states states[i] contains the state for the i-th utterance.
* @return Return a single value representing the batched state.
*/
virtual std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const = 0;
/** Unstack a batch state into a list of individual states.
*
* It is the inverse operation of `StackStates`.
*
* @param states A batched state.
* @return ans[i] contains the state for the i-th utterance.
*/
virtual std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const = 0;
/**
*
* @param x A 3-D tensor of shape (N, T, C). N has to be 1.
... ... @@ -60,6 +80,9 @@ class OnlineCtcModel {
// ChunkLength() frames, we advance by ChunkShift() frames
// before we process the next chunk.
virtual int32_t ChunkShift() const = 0;
// Return true if the model supports batch size > 1
virtual bool SupportBatchProcessing() const { return true; }
};
} // namespace sherpa_onnx
... ...
... ... @@ -14,6 +14,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
transducer.Register(po);
paraformer.Register(po);
wenet_ctc.Register(po);
zipformer2_ctc.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
... ... @@ -26,9 +27,10 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
po->Register("model-type", &model_type,
po->Register(
"model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2."
"Valid values are: conformer, lstm, zipformer, zipformer2, wenet_ctc"
"All other values lead to loading the model twice.");
}
... ... @@ -51,6 +53,10 @@ bool OnlineModelConfig::Validate() const {
return wenet_ctc.Validate();
}
if (!zipformer2_ctc.model.empty()) {
return zipformer2_ctc.Validate();
}
return transducer.Validate();
}
... ... @@ -61,6 +67,7 @@ std::string OnlineModelConfig::ToString() const {
os << "transducer=" << transducer.ToString() << ", ";
os << "paraformer=" << paraformer.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
... ...
... ... @@ -9,6 +9,7 @@
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
namespace sherpa_onnx {
... ... @@ -16,6 +17,7 @@ struct OnlineModelConfig {
OnlineTransducerModelConfig transducer;
OnlineParaformerModelConfig paraformer;
OnlineWenetCtcModelConfig wenet_ctc;
OnlineZipformer2CtcModelConfig zipformer2_ctc;
std::string tokens;
int32_t num_threads = 1;
bool debug = false;
... ... @@ -25,7 +27,8 @@ struct OnlineModelConfig {
// - conformer, conformer transducer from icefall
// - lstm, lstm transducer from icefall
// - zipformer, zipformer transducer from icefall
// - zipformer2, zipformer2 transducer from icefall
// - zipformer2, zipformer2 transducer or CTC from icefall
// - wenet_ctc, wenet CTC model
//
// All other values are invalid and lead to loading the model twice.
std::string model_type;
... ... @@ -34,11 +37,13 @@ struct OnlineModelConfig {
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
const OnlineParaformerModelConfig &paraformer,
const OnlineWenetCtcModelConfig &wenet_ctc,
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
paraformer(paraformer),
wenet_ctc(wenet_ctc),
zipformer2_ctc(zipformer2_ctc),
tokens(tokens),
num_threads(num_threads),
debug(debug),
... ...
... ... @@ -96,9 +96,68 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
}
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
if (n == 1 || !model_->SupportBatchProcessing()) {
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
return;
}
// batch processing
int32_t chunk_length = model_->ChunkLength();
int32_t chunk_shift = model_->ChunkShift();
int32_t feat_dim = ss[0]->FeatureDim();
std::vector<OnlineCtcDecoderResult> results(n);
std::vector<float> features_vec(n * chunk_length * feat_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n);
for (int32_t i = 0; i != n; ++i) {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_length);
// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;
std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_length * feat_dim);
results[i] = std::move(ss[i]->GetCtcResult());
states_vec[i] = std::move(ss[i]->GetStates());
all_processed_frames[i] = num_processed_frames;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{n, chunk_length, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
features_vec.size(), x_shape.data(),
x_shape.size());
auto states = model_->StackStates(std::move(states_vec));
int32_t num_states = states.size();
auto out = model_->Forward(std::move(x), std::move(states));
std::vector<Ort::Value> out_states;
out_states.reserve(num_states);
for (int32_t k = 1; k != num_states + 1; ++k) {
out_states.push_back(std::move(out[k]));
}
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(std::move(out_states));
decoder_->Decode(std::move(out[0]), &results);
for (int32_t k = 0; k != n; ++k) {
ss[k]->SetCtcResult(results[k]);
ss[k]->SetStates(std::move(next_states[k]));
}
}
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
... ...
... ... @@ -20,7 +20,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
return std::make_unique<OnlineRecognizerParaformerImpl>(config);
}
if (!config.model_config.wenet_ctc.model.empty()) {
if (!config.model_config.wenet_ctc.model.empty() ||
!config.model_config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcImpl>(config);
}
... ... @@ -39,7 +40,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config);
}
if (!config.model_config.wenet_ctc.model.empty()) {
if (!config.model_config.wenet_ctc.model.empty() ||
!config.model_config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config);
}
... ...
// sherpa-onnx/csrc/online-paraformer-model.cc
// sherpa-onnx/csrc/online-wenet-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
... ... @@ -239,4 +239,21 @@ std::vector<Ort::Value> OnlineWenetCtcModel::GetInitStates() const {
return impl_->GetInitStates();
}
std::vector<Ort::Value> OnlineWenetCtcModel::StackStates(
std::vector<std::vector<Ort::Value>> states) const {
if (states.size() != 1) {
SHERPA_ONNX_LOGE("wenet CTC model supports only batch_size==1. Given: %d",
static_cast<int32_t>(states.size()));
}
return std::move(states[0]);
}
std::vector<std::vector<Ort::Value>> OnlineWenetCtcModel::UnStackStates(
std::vector<Ort::Value> states) const {
std::vector<std::vector<Ort::Value>> ans(1);
ans[0] = std::move(states);
return ans;
}
} // namespace sherpa_onnx
... ...
... ... @@ -35,6 +35,12 @@ class OnlineWenetCtcModel : public OnlineCtcModel {
// - offset
std::vector<Ort::Value> GetInitStates() const override;
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const override;
/**
*
* @param x A 3-D tensor of shape (N, T, C). N has to be 1.
... ... @@ -63,6 +69,8 @@ class OnlineWenetCtcModel : public OnlineCtcModel {
// before we process the next chunk.
int32_t ChunkShift() const override;
bool SupportBatchProcessing() const override { return false; }
private:
class Impl;
std::unique_ptr<Impl> impl_;
... ...
// sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OnlineZipformer2CtcModelConfig::Register(ParseOptions *po) {
po->Register("zipformer2-ctc-model", &model,
"Path to CTC model.onnx. See also "
"https://github.com/k2-fsa/icefall/pull/1413");
}
bool OnlineZipformer2CtcModelConfig::Validate() const {
if (model.empty()) {
SHERPA_ONNX_LOGE("--zipformer2-ctc-model is empty!");
return false;
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--zipformer2-ctc-model %s does not exist", model.c_str());
return false;
}
return true;
}
std::string OnlineZipformer2CtcModelConfig::ToString() const {
std::ostringstream os;
os << "OnlineZipformer2CtcModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineZipformer2CtcModelConfig {
std::string model;
OnlineZipformer2CtcModelConfig() = default;
explicit OnlineZipformer2CtcModelConfig(const std::string &model)
: model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/online-zipformer2-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
#include <assert.h>
#include <math.h>
#include <algorithm>
#include <cmath>
#include <numeric>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/unbind.h"
namespace sherpa_onnx {
class OnlineZipformer2CtcModel::Impl {
public:
explicit Impl(const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.zipformer2_ctc.model);
Init(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.zipformer2_ctc.model);
Init(buf.data(), buf.size());
}
}
#endif
std::vector<Ort::Value> Forward(Ort::Value features,
std::vector<Ort::Value> states) {
std::vector<Ort::Value> inputs;
inputs.reserve(1 + states.size());
inputs.push_back(std::move(features));
for (auto &v : states) {
inputs.push_back(std::move(v));
}
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
}
int32_t VocabSize() const { return vocab_size_; }
int32_t ChunkLength() const { return T_; }
int32_t ChunkShift() const { return decode_chunk_len_; }
OrtAllocator *Allocator() const { return allocator_; }
// Return a vector containing 3 tensors
// - attn_cache
// - conv_cache
// - offset
std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans;
ans.reserve(initial_states_.size());
for (auto &s : initial_states_) {
ans.push_back(View(&s));
}
return ans;
}
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const {
int32_t batch_size = static_cast<int32_t>(states.size());
int32_t num_encoders = static_cast<int32_t>(num_encoder_layers_.size());
std::vector<const Ort::Value *> buf(batch_size);
std::vector<Ort::Value> ans;
int32_t num_states = static_cast<int32_t>(states[0].size());
ans.reserve(num_states);
for (int32_t i = 0; i != (num_states - 2) / 6; ++i) {
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i];
}
auto v = Cat(allocator_, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 1];
}
auto v = Cat(allocator_, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 2];
}
auto v = Cat(allocator_, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 3];
}
auto v = Cat(allocator_, buf, 1);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 4];
}
auto v = Cat(allocator_, buf, 0);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][6 * i + 5];
}
auto v = Cat(allocator_, buf, 0);
ans.push_back(std::move(v));
}
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_states - 2];
}
auto v = Cat(allocator_, buf, 0);
ans.push_back(std::move(v));
}
{
for (int32_t n = 0; n != batch_size; ++n) {
buf[n] = &states[n][num_states - 1];
}
auto v = Cat<int64_t>(allocator_, buf, 0);
ans.push_back(std::move(v));
}
return ans;
}
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const {
int32_t m = std::accumulate(num_encoder_layers_.begin(),
num_encoder_layers_.end(), 0);
assert(states.size() == m * 6 + 2);
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
int32_t num_encoders = num_encoder_layers_.size();
std::vector<std::vector<Ort::Value>> ans;
ans.resize(batch_size);
for (int32_t i = 0; i != m; ++i) {
{
auto v = Unbind(allocator_, &states[i * 6], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 1], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 2], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 3], 1);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 4], 0);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind(allocator_, &states[i * 6 + 5], 0);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
}
{
auto v = Unbind(allocator_, &states[m * 6], 0);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
{
auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0);
assert(v.size() == batch_size);
for (int32_t n = 0; n != batch_size; ++n) {
ans[n].push_back(std::move(v[n]));
}
}
return ans;
}
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---zipformer2_ctc---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims");
SHERPA_ONNX_READ_META_DATA_VEC(query_head_dims_, "query_head_dims");
SHERPA_ONNX_READ_META_DATA_VEC(value_head_dims_, "value_head_dims");
SHERPA_ONNX_READ_META_DATA_VEC(num_heads_, "num_heads");
SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers");
SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels");
SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len");
SHERPA_ONNX_READ_META_DATA(T_, "T");
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
{
auto shape =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
vocab_size_ = shape[2];
}
if (config_.debug) {
auto print = [](const std::vector<int32_t> &v, const char *name) {
fprintf(stderr, "%s: ", name);
for (auto i : v) {
fprintf(stderr, "%d ", i);
}
fprintf(stderr, "\n");
};
print(encoder_dims_, "encoder_dims");
print(query_head_dims_, "query_head_dims");
print(value_head_dims_, "value_head_dims");
print(num_heads_, "num_heads");
print(num_encoder_layers_, "num_encoder_layers");
print(cnn_module_kernels_, "cnn_module_kernels");
print(left_context_len_, "left_context_len");
SHERPA_ONNX_LOGE("T: %d", T_);
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
SHERPA_ONNX_LOGE("vocab_size_: %d", vocab_size_);
}
InitStates();
}
void InitStates() {
int32_t n = static_cast<int32_t>(encoder_dims_.size());
int32_t m = std::accumulate(num_encoder_layers_.begin(),
num_encoder_layers_.end(), 0);
initial_states_.reserve(m * 6 + 2);
for (int32_t i = 0; i != n; ++i) {
int32_t num_layers = num_encoder_layers_[i];
int32_t key_dim = query_head_dims_[i] * num_heads_[i];
int32_t value_dim = value_head_dims_[i] * num_heads_[i];
int32_t nonlin_attn_head_dim = 3 * encoder_dims_[i] / 4;
for (int32_t j = 0; j != num_layers; ++j) {
{
std::array<int64_t, 3> s{left_context_len_[i], 1, key_dim};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 4> s{1, 1, left_context_len_[i],
nonlin_attn_head_dim};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 3> s{left_context_len_[i], 1, value_dim};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 3> s{left_context_len_[i], 1, value_dim};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 3> s{1, encoder_dims_[i],
cnn_module_kernels_[i] / 2};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 3> s{1, encoder_dims_[i],
cnn_module_kernels_[i] / 2};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
}
}
{
std::array<int64_t, 4> s{1, 128, 3, 19};
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
initial_states_.push_back(std::move(v));
}
{
std::array<int64_t, 1> s{1};
auto v =
Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size());
Fill<int64_t>(&v, 0);
initial_states_.push_back(std::move(v));
}
}
private:
OnlineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
std::vector<Ort::Value> initial_states_;
std::vector<int32_t> encoder_dims_;
std::vector<int32_t> query_head_dims_;
std::vector<int32_t> value_head_dims_;
std::vector<int32_t> num_heads_;
std::vector<int32_t> num_encoder_layers_;
std::vector<int32_t> cnn_module_kernels_;
std::vector<int32_t> left_context_len_;
int32_t T_ = 0;
int32_t decode_chunk_len_ = 0;
int32_t vocab_size_ = 0;
};
OnlineZipformer2CtcModel::OnlineZipformer2CtcModel(
const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OnlineZipformer2CtcModel::OnlineZipformer2CtcModel(
AAssetManager *mgr, const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OnlineZipformer2CtcModel::~OnlineZipformer2CtcModel() = default;
std::vector<Ort::Value> OnlineZipformer2CtcModel::Forward(
Ort::Value x, std::vector<Ort::Value> states) const {
return impl_->Forward(std::move(x), std::move(states));
}
int32_t OnlineZipformer2CtcModel::VocabSize() const {
return impl_->VocabSize();
}
int32_t OnlineZipformer2CtcModel::ChunkLength() const {
return impl_->ChunkLength();
}
int32_t OnlineZipformer2CtcModel::ChunkShift() const {
return impl_->ChunkShift();
}
OrtAllocator *OnlineZipformer2CtcModel::Allocator() const {
return impl_->Allocator();
}
std::vector<Ort::Value> OnlineZipformer2CtcModel::GetInitStates() const {
return impl_->GetInitStates();
}
std::vector<Ort::Value> OnlineZipformer2CtcModel::StackStates(
std::vector<std::vector<Ort::Value>> states) const {
return impl_->StackStates(std::move(states));
}
std::vector<std::vector<Ort::Value>> OnlineZipformer2CtcModel::UnStackStates(
std::vector<Ort::Value> states) const {
return impl_->UnStackStates(std::move(states));
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-zipformer2-ctc-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-ctc-model.h"
#include "sherpa-onnx/csrc/online-model-config.h"
namespace sherpa_onnx {
class OnlineZipformer2CtcModel : public OnlineCtcModel {
public:
explicit OnlineZipformer2CtcModel(const OnlineModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineZipformer2CtcModel(AAssetManager *mgr, const OnlineModelConfig &config);
#endif
~OnlineZipformer2CtcModel() override;
// A list of tensors.
// See also
// https://github.com/k2-fsa/icefall/pull/1413
// and
// https://github.com/k2-fsa/icefall/pull/1415
std::vector<Ort::Value> GetInitStates() const override;
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const override;
/**
*
* @param x A 3-D tensor of shape (N, T, C). N has to be 1.
* @param states It is from GetInitStates() or returned from this method.
*
* @return Return a list of tensors
* - ans[0] contains log_probs, of shape (N, T, C)
* - ans[1:] contains next_states
*/
std::vector<Ort::Value> Forward(
Ort::Value x, std::vector<Ort::Value> states) const override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
// The model accepts this number of frames before subsampling as input
int32_t ChunkLength() const override;
// Similar to frame_shift in feature extractor, after processing
// ChunkLength() frames, we advance by ChunkShift() frames
// before we process the next chunk.
int32_t ChunkShift() const override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_
... ...
... ... @@ -26,6 +26,8 @@ int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Usage:
(1) Streaming transducer
./bin/sherpa-onnx \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
... ... @@ -36,6 +38,30 @@ Usage:
--decoding-method=greedy_search \
/path/to/foo.wav [bar.wav foobar.wav ...]
(2) Streaming zipformer2 CTC
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
./bin/sherpa-onnx \
--debug=1 \
--zipformer2-ctc-model=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
--tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav
(3) Streaming paraformer
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
./bin/sherpa-onnx \
--tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
--paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.onnx \
--paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.onnx \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav
Note: It supports decoding multiple files in batches
Default value for num_threads is 2.
... ...
... ... @@ -8,9 +8,6 @@
#include <fstream>
#include <sstream>
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#if __ANDROID_API__ >= 9
#include <strstream>
... ... @@ -18,6 +15,9 @@
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
SymbolTable::SymbolTable(const std::string &filename) {
... ...
... ... @@ -262,22 +262,34 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config);
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_config_cls, "encoder",
"Ljava/lang/String;");
fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(paraformer_config_config_cls, "decoder",
"Ljava/lang/String;");
fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.decoder = p;
env->ReleaseStringUTFChars(s, p);
// streaming zipformer2 CTC
fid =
env->GetFieldID(model_config_cls, "zipformer2Ctc",
"Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;");
jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid);
jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config);
fid =
env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.zipformer2_ctc.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
... ...
... ... @@ -27,6 +27,7 @@ pybind11_add_module(_sherpa_onnx
online-stream.cc
online-transducer-model-config.cc
online-wenet-ctc-model-config.cc
online-zipformer2-ctc-model-config.cc
sherpa-onnx.cc
silero-vad-model-config.cc
vad-model-config.cc
... ...
... ... @@ -58,6 +58,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -12,6 +12,7 @@
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
namespace sherpa_onnx {
... ... @@ -19,26 +20,31 @@ void PybindOnlineModelConfig(py::module *m) {
PybindOnlineTransducerModelConfig(m);
PybindOnlineParaformerModelConfig(m);
PybindOnlineWenetCtcModelConfig(m);
PybindOnlineZipformer2CtcModelConfig(m);
using PyClass = OnlineModelConfig;
py::class_<PyClass>(*m, "OnlineModelConfig")
.def(py::init<const OnlineTransducerModelConfig &,
const OnlineParaformerModelConfig &,
const OnlineWenetCtcModelConfig &, const std::string &,
const OnlineWenetCtcModelConfig &,
const OnlineZipformer2CtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
... ...
// sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
namespace sherpa_onnx {
void PybindOnlineZipformer2CtcModelConfig(py::module *m) {
using PyClass = OnlineZipformer2CtcModelConfig;
py::class_<PyClass>(*m, "OnlineZipformer2CtcModelConfig")
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOnlineZipformer2CtcModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_
... ...
... ... @@ -8,11 +8,14 @@ from _sherpa_onnx import (
OnlineLMConfig,
OnlineModelConfig,
OnlineParaformerModelConfig,
OnlineRecognizer as _Recognizer,
)
from _sherpa_onnx import OnlineRecognizer as _Recognizer
from _sherpa_onnx import (
OnlineRecognizerConfig,
OnlineStream,
OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig,
OnlineZipformer2CtcModelConfig,
)
... ... @@ -273,6 +276,101 @@ class OnlineRecognizer(object):
return self
@classmethod
def from_zipformer2_ctc(
cls,
tokens: str,
model: str,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(model)
assert num_threads > 0, num_threads
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
model_config = OnlineModelConfig(
zipformer2_ctc=zipformer2_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
@classmethod
def from_wenet_ctc(
cls,
tokens: str,
... ... @@ -352,7 +450,6 @@ class OnlineRecognizer(object):
tokens=tokens,
num_threads=num_threads,
provider=provider,
model_type="wenet_ctc",
)
feat_config = FeatureExtractorConfig(
... ...
... ... @@ -143,6 +143,57 @@ class TestOnlineRecognizer(unittest.TestCase):
print(f"{wave_filename}\n{result}")
print("-" * 10)
def test_zipformer2_ctc(self):
m = "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13"
for use_int8 in [True, False]:
name = (
"ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx"
if use_int8
else "ctc-epoch-20-avg-1-chunk-16-left-128.onnx"
)
model = f"{d}/{m}/{name}"
tokens = f"{d}/{m}/tokens.txt"
wave0 = f"{d}/{m}/test_wavs/DEV_T0000000000.wav"
wave1 = f"{d}/{m}/test_wavs/DEV_T0000000001.wav"
wave2 = f"{d}/{m}/test_wavs/DEV_T0000000002.wav"
if not Path(model).is_file():
print("skipping test_zipformer2_ctc()")
return
print(f"testing {model}")
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
streams = []
waves = [wave0, wave1, wave2]
for wave in waves:
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
while True:
ready_list = []
for s in streams:
if recognizer.is_ready(s):
ready_list.append(s)
if len(ready_list) == 0:
break
recognizer.decode_streams(ready_list)
results = [recognizer.get_result(s) for s in streams]
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
def test_wenet_ctc(self):
models = [
"sherpa-onnx-zh-wenet-aishell",
... ...
... ... @@ -5,3 +5,4 @@ tts
vits-vctk
sherpa-onnx-paraformer-zh-2023-09-14
!*.sh
*.bak
... ...
... ... @@ -60,6 +60,14 @@ func sherpaOnnxOnlineParaformerModelConfig(
)
}
func sherpaOnnxOnlineZipformer2CtcModelConfig(
model: String = ""
) -> SherpaOnnxOnlineZipformer2CtcModelConfig {
return SherpaOnnxOnlineZipformer2CtcModelConfig(
model: toCPointer(model)
)
}
/// Return an instance of SherpaOnnxOnlineModelConfig.
///
/// Please refer to
... ... @@ -75,6 +83,8 @@ func sherpaOnnxOnlineModelConfig(
tokens: String,
transducer: SherpaOnnxOnlineTransducerModelConfig = sherpaOnnxOnlineTransducerModelConfig(),
paraformer: SherpaOnnxOnlineParaformerModelConfig = sherpaOnnxOnlineParaformerModelConfig(),
zipformer2Ctc: SherpaOnnxOnlineZipformer2CtcModelConfig =
sherpaOnnxOnlineZipformer2CtcModelConfig(),
numThreads: Int = 1,
provider: String = "cpu",
debug: Int = 0,
... ... @@ -83,6 +93,7 @@ func sherpaOnnxOnlineModelConfig(
return SherpaOnnxOnlineModelConfig(
transducer: transducer,
paraformer: paraformer,
zipformer2_ctc: zipformer2Ctc,
tokens: toCPointer(tokens),
num_threads: Int32(numThreads),
provider: toCPointer(provider),
... ...
... ... @@ -13,6 +13,14 @@ extension AVAudioPCMBuffer {
}
func run() {
var modelConfig: SherpaOnnxOnlineModelConfig
var modelType = "zipformer2-ctc"
var filePath: String
modelType = "transducer"
if modelType == "transducer" {
filePath = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav"
let encoder =
"./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx"
let decoder =
... ... @@ -27,10 +35,25 @@ func run() {
joiner: joiner
)
let modelConfig = sherpaOnnxOnlineModelConfig(
modelConfig = sherpaOnnxOnlineModelConfig(
tokens: tokens,
transducer: transducerConfig
)
} else {
filePath =
"./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"
let model =
"./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx"
let tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt"
let zipfomer2CtcModelConfig = sherpaOnnxOnlineZipformer2CtcModelConfig(
model: model
)
modelConfig = sherpaOnnxOnlineModelConfig(
tokens: tokens,
zipformer2Ctc: zipfomer2CtcModelConfig
)
}
let featConfig = sherpaOnnxFeatureConfig(
sampleRate: 16000,
... ... @@ -43,7 +66,6 @@ func run() {
let recognizer = SherpaOnnxRecognizer(config: &config)
let filePath = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav"
let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
let audioFile = try! AVAudioFile(forReading: fileURL as URL)
... ...
... ... @@ -20,6 +20,12 @@ if [ ! -d ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 ]; then
rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
fi
if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi
if [ ! -e ./decode-file ]; then
# Note: We use -lc++ to link against libc++ instead of libstdc++
swiftc \
... ...
... ... @@ -22,7 +22,7 @@ if [ ! -d ./sherpa-onnx-whisper-tiny.en ]; then
fi
if [ ! -f ./silero_vad.onnx ]; then
echo "downloading silero_vad"
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
fi
if [ ! -e ./generate-subtitles ]; then
... ...