Fangjun Kuang
Committed by GitHub

Add Python APIs for WeNet CTC models (#428)

... ... @@ -8,6 +8,51 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
wenet_models=(
sherpa-onnx-zh-wenet-aishell
sherpa-onnx-zh-wenet-aishell2
sherpa-onnx-zh-wenet-wenetspeech
sherpa-onnx-zh-wenet-multi-cn
sherpa-onnx-en-wenet-librispeech
sherpa-onnx-en-wenet-gigaspeech
)
mkdir -p /tmp/icefall-models
dir=/tmp/icefall-models
for name in ${wenet_models[@]}; do
repo_url=https://huggingface.co/csukuangfj/$name
log "Start testing ${repo_url}"
repo=$dir/$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
pushd $dir
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
cd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd
python3 ./python-api-examples/offline-decode-files.py \
--tokens=$repo/tokens.txt \
--wenet-ctc=$repo/model.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
python3 ./python-api-examples/online-decode-files.py \
--tokens=$repo/tokens.txt \
--wenet-ctc=$repo/model-streaming.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose
rm -rf $repo
done
log "Offline TTS test"
# test waves are saved in ./tts
mkdir ./tts
... ...
... ... @@ -85,10 +85,19 @@ jobs:
arch=${{ matrix.arch }}
cd mfc-examples/$arch/Release
cp StreamingSpeechRecognition.exe sherpa-onnx-streaming-${SHERPA_ONNX_VERSION}.exe
cp NonStreamingSpeechRecognition.exe sherpa-onnx-non-streaming-${SHERPA_ONNX_VERSION}.exe
ls -lh
cp -v StreamingSpeechRecognition.exe sherpa-onnx-streaming-${SHERPA_ONNX_VERSION}.exe
cp -v NonStreamingSpeechRecognition.exe sherpa-onnx-non-streaming-${SHERPA_ONNX_VERSION}.exe
cp -v NonStreamingTextToSpeech.exe ../sherpa-onnx-non-streaming-tts-${SHERPA_ONNX_VERSION}.exe
ls -lh
- name: Upload artifact tts
uses: actions/upload-artifact@v3
with:
name: non-streaming-tts-${{ matrix.arch }}
path: ./mfc-examples/${{ matrix.arch }}/Release/NonStreamingTextToSpeech.exe
- name: Upload artifact
uses: actions/upload-artifact@v3
with:
... ... @@ -116,3 +125,11 @@ jobs:
file_glob: true
overwrite: true
file: ./mfc-examples/${{ matrix.arch }}/Release/sherpa-onnx-non-streaming-*.exe
- name: Release pre-compiled binaries and libs for Windows ${{ matrix.arch }}
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/')
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
overwrite: true
file: ./mfc-examples/${{ matrix.arch }}/sherpa-onnx-non-streaming-*.exe
... ...
... ... @@ -10,6 +10,7 @@ on:
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
- 'python-api-examples/**'
pull_request:
branches:
- master
... ... @@ -19,6 +20,7 @@ on:
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
- 'python-api-examples/**'
workflow_dispatch:
concurrency:
... ...
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.8.9")
set(SHERPA_ONNX_VERSION "1.8.10")
# Disable warning about
#
... ...
... ... @@ -58,6 +58,15 @@ wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
--num-threads=2 \
/path/to/test.mp4
(4) For WeNet CTC models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
--num-threads=2 \
/path/to/test.mp4
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download non-streaming pre-trained models
... ... @@ -122,6 +131,13 @@ def get_args():
)
parser.add_argument(
"--wenet-ctc",
default="",
type=str,
help="Path to the CTC model.onnx from WeNet",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
... ... @@ -215,6 +231,7 @@ def assert_file_exists(filename: str):
def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
... ... @@ -234,6 +251,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
debug=args.debug,
)
elif args.paraformer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
... ... @@ -248,6 +266,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert_file_exists(args.wenet_ctc)
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=args.wenet_ctc,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
... ...
... ... @@ -58,7 +58,19 @@ python3 ./python-api-examples/non_streaming_server.py \
--nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
--tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt
(4) Use a Whisper model
(4) Use a non-streaming CTC model from WeNet
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
cd sherpa-onnx-zh-wenet-wenetspeech
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
(5) Use a Whisper model
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
... ... @@ -210,6 +222,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
)
def add_wenet_ctc_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--wenet-ctc",
default="",
type=str,
help="Path to the model.onnx from WeNet CTC",
)
def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--tdnn-model",
... ... @@ -261,6 +282,7 @@ def add_model_args(parser: argparse.ArgumentParser):
add_transducer_model_args(parser)
add_paraformer_model_args(parser)
add_nemo_ctc_model_args(parser)
add_wenet_ctc_model_args(parser)
add_tdnn_ctc_model_args(parser)
add_whisper_model_args(parser)
... ... @@ -804,6 +826,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
... ... @@ -827,6 +850,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
)
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
... ... @@ -842,6 +866,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method=args.decoding_method,
)
elif args.nemo_ctc:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
... ... @@ -856,6 +881,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
)
elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.wenet_ctc)
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=args.wenet_ctc,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
)
elif args.whisper_encoder:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)
... ...
... ... @@ -59,7 +59,16 @@ python3 ./python-api-examples/offline-decode-files.py \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
(5) For tdnn models of the yesno recipe from icefall
(5) For CTC models from WeNet
python3 ./python-api-examples/offline-decode-files.py \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
(6) For tdnn models of the yesno recipe from icefall
python3 ./python-api-examples/offline-decode-files.py \
--sample-rate=8000 \
... ... @@ -155,6 +164,13 @@ def get_args():
)
parser.add_argument(
"--wenet-ctc",
default="",
type=str,
help="Path to the model.onnx from WeNet CTC",
)
parser.add_argument(
"--tdnn-model",
default="",
type=str,
... ... @@ -254,6 +270,7 @@ def assert_file_exists(filename: str):
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
... ... @@ -287,6 +304,7 @@ def main():
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
... ... @@ -310,6 +328,7 @@ def main():
)
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
... ... @@ -326,6 +345,7 @@ def main():
debug=args.debug,
)
elif args.nemo_ctc:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
... ... @@ -341,6 +361,22 @@ def main():
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.wenet_ctc)
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=args.wenet_ctc,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.whisper_encoder:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)
... ...
... ... @@ -37,8 +37,25 @@ git lfs pull --include "*.onnx"
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
(3) Streaming Conformer CTC from WeNet
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
cd sherpa-onnx-zh-wenet-wenetspeech
git lfs pull --include "*.onnx"
./python-api-examples/online-decode-files.py \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
and
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
to install sherpa-onnx and to download streaming pre-trained models.
"""
import argparse
... ... @@ -93,6 +110,26 @@ def get_args():
)
parser.add_argument(
"--wenet-ctc",
type=str,
help="Path to the wenet ctc model model",
)
parser.add_argument(
"--wenet-ctc-chunk-size",
type=int,
default=16,
help="The --chunk-size parameter for streaming WeNet models",
)
parser.add_argument(
"--wenet-ctc-num-left-chunks",
type=int,
default=4,
help="The --num-left-chunks parameter for streaming WeNet models",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
... ... @@ -249,6 +286,18 @@ def main():
feature_dim=80,
decoding_method="greedy_search",
)
elif args.wenet_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
tokens=args.tokens,
model=args.wenet_ctc,
chunk_size=args.wenet_ctc_chunk_size,
num_left_chunks=args.wenet_ctc_num_left_chunks,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
decoding_method="greedy_search",
)
else:
raise ValueError("Please provide a model")
... ...
... ... @@ -40,10 +40,17 @@ python3 ./python-api-examples/streaming_server.py \
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
to download pre-trained models.
The model in the above help messages is from
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
To use a WeNet streaming Conformer CTC model, please use
python3 ./python-api-examples/streaming_server.py \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx
"""
import argparse
... ... @@ -131,6 +138,12 @@ def add_model_args(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--wenet-ctc",
type=str,
help="Path to the model.onnx from WeNet",
)
parser.add_argument(
"--paraformer-encoder",
type=str,
help="Path to the paraformer encoder model",
... ... @@ -212,7 +225,6 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
)
def add_modified_beam_search_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-active-paths",
... ... @@ -393,6 +405,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
elif args.wenet_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
tokens=args.tokens,
model=args.wenet_ctc,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
enable_endpoint_detection=args.use_endpoint != 0,
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
else:
raise ValueError("Please provide a model")
... ... @@ -727,6 +753,8 @@ def check_args(args):
assert Path(
args.paraformer_decoder
).is_file(), f"{args.paraformer_decoder} does not exist"
elif args.wenet_ctc:
assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist"
else:
raise ValueError("Please provide a model")
... ...
... ... @@ -9,15 +9,16 @@ from _sherpa_onnx import (
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
OfflineTdnnModelConfig,
OfflineWhisperModelConfig,
OfflineZipformerCtcModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
OfflineRecognizerConfig,
OfflineStream,
OfflineTdnnModelConfig,
OfflineTransducerModelConfig,
OfflineWenetCtcModelConfig,
OfflineWhisperModelConfig,
OfflineZipformerCtcModelConfig,
)
... ... @@ -389,6 +390,70 @@ class OfflineRecognizer(object):
self.config = recognizer_config
return self
@classmethod
def from_wenet_ctc(
cls,
model: str,
tokens: str,
num_threads: int = 1,
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
model:
Path to ``model.onnx``.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
wenet_ctc=OfflineWenetCtcModelConfig(model=model),
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="wenet_ctc",
)
feat_config = OfflineFeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
def create_stream(self, hotwords: Optional[str] = None):
if hotwords is None:
return self.recognizer.create_stream()
... ...
... ... @@ -12,6 +12,7 @@ from _sherpa_onnx import (
OnlineRecognizerConfig,
OnlineStream,
OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig,
)
... ... @@ -271,6 +272,112 @@ class OnlineRecognizer(object):
self.config = recognizer_config
return self
@classmethod
def from_wenet_ctc(
cls,
tokens: str,
model: str,
chunk_size: int = 16,
num_left_chunks: int = 4,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
chunk_size:
The --chunk-size parameter from WeNet.
num_left_chunks:
The --num-left-chunks parameter from WeNet.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(model)
assert num_threads > 0, num_threads
wenet_ctc_config = OnlineWenetCtcModelConfig(
model=model,
chunk_size=chunk_size,
num_left_chunks=num_left_chunks,
)
model_config = OnlineModelConfig(
wenet_ctc=wenet_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
model_type="wenet_ctc",
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
def create_stream(self, hotwords: Optional[str] = None):
if hotwords is None:
return self.recognizer.create_stream()
... ...
... ... @@ -267,6 +267,53 @@ class TestOfflineRecognizer(unittest.TestCase):
print(s1.result.text)
print(s2.result.text)
def test_wenet_ctc(self):
models = [
"sherpa-onnx-zh-wenet-aishell",
"sherpa-onnx-zh-wenet-aishell2",
"sherpa-onnx-zh-wenet-wenetspeech",
"sherpa-onnx-zh-wenet-multi-cn",
"sherpa-onnx-en-wenet-librispeech",
"sherpa-onnx-en-wenet-gigaspeech",
]
for m in models:
for use_int8 in [True, False]:
name = "model.int8.onnx" if use_int8 else "model.onnx"
model = f"{d}/{m}/{name}"
tokens = f"{d}/{m}/tokens.txt"
wave0 = f"{d}/{m}/test_wavs/0.wav"
wave1 = f"{d}/{m}/test_wavs/1.wav"
wave2 = f"{d}/{m}/test_wavs/8k.wav"
if not Path(model).is_file():
print("skipping test_wenet_ctc()")
return
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
s0 = recognizer.create_stream()
samples0, sample_rate0 = read_wave(wave0)
s0.accept_waveform(sample_rate0, samples0)
s1 = recognizer.create_stream()
samples1, sample_rate1 = read_wave(wave1)
s1.accept_waveform(sample_rate1, samples1)
s2 = recognizer.create_stream()
samples2, sample_rate2 = read_wave(wave2)
s2.accept_waveform(sample_rate2, samples2)
recognizer.decode_streams([s0, s1, s2])
print(s0.result.text)
print(s1.result.text)
print(s2.result.text)
if __name__ == "__main__":
unittest.main()
... ...
... ... @@ -143,6 +143,64 @@ class TestOnlineRecognizer(unittest.TestCase):
print(f"{wave_filename}\n{result}")
print("-" * 10)
def test_wenet_ctc(self):
models = [
"sherpa-onnx-zh-wenet-aishell",
"sherpa-onnx-zh-wenet-aishell2",
"sherpa-onnx-zh-wenet-wenetspeech",
"sherpa-onnx-zh-wenet-multi-cn",
"sherpa-onnx-en-wenet-librispeech",
"sherpa-onnx-en-wenet-gigaspeech",
]
for m in models:
for use_int8 in [True, False]:
name = (
"model-streaming.int8.onnx" if use_int8 else "model-streaming.onnx"
)
model = f"{d}/{m}/{name}"
tokens = f"{d}/{m}/tokens.txt"
wave0 = f"{d}/{m}/test_wavs/0.wav"
wave1 = f"{d}/{m}/test_wavs/1.wav"
wave2 = f"{d}/{m}/test_wavs/8k.wav"
if not Path(model).is_file():
print("skipping test_wenet_ctc()")
return
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
model=model,
tokens=tokens,
num_threads=1,
provider="cpu",
)
streams = []
waves = [wave0, wave1, wave2]
for wave in waves:
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
while True:
ready_list = []
for s in streams:
if recognizer.is_ready(s):
ready_list.append(s)
if len(ready_list) == 0:
break
recognizer.decode_streams(ready_list)
results = [recognizer.get_result(s) for s in streams]
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
if __name__ == "__main__":
unittest.main()
... ...