Fangjun Kuang
Committed by GitHub

Support streaming paraformer (#263)

正在显示 38 个修改的文件 包含 1458 行增加82 行删除
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run streaming Paraformer"
log "------------------------------------------------------------"
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd
time $EXE \
--tokens=$repo/tokens.txt \
--paraformer-encoder=$repo/encoder.onnx \
--paraformer-decoder=$repo/decoder.onnx \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/8k.wav
time $EXE \
--tokens=$repo/tokens.txt \
--paraformer-encoder=$repo/encoder.int8.onnx \
--paraformer-decoder=$repo/decoder.int8.onnx \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/8k.wav
rm -rf $repo
... ...
... ... @@ -9,6 +9,7 @@ on:
paths:
- '.github/workflows/linux-gpu.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -22,6 +23,7 @@ on:
paths:
- '.github/workflows/linux-gpu.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -85,6 +87,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
- name: Test online paraformer
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-paraformer.sh
- name: Test offline Whisper
shell: bash
run: |
... ...
... ... @@ -9,6 +9,7 @@ on:
paths:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -22,6 +23,7 @@ on:
paths:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -84,6 +86,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
- name: Test online paraformer
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-paraformer.sh
- name: Test offline Whisper
shell: bash
run: |
... ...
... ... @@ -7,6 +7,7 @@ on:
paths:
- '.github/workflows/macos.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -18,6 +19,7 @@ on:
paths:
- '.github/workflows/macos.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -82,6 +84,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test online paraformer
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-paraformer.sh
- name: Test offline Whisper
shell: bash
run: |
... ...
... ... @@ -58,7 +58,6 @@ jobs:
sherpa-onnx-microphone-offline --help
sherpa-onnx-offline-websocket-server --help
sherpa-onnx-offline-websocket-client --help
sherpa-onnx-online-websocket-server --help
sherpa-onnx-online-websocket-client --help
... ...
... ... @@ -84,14 +84,14 @@ jobs:
if: matrix.model_type == 'paraformer'
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
cd sherpa-onnx-paraformer-zh-2023-03-28
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
cd sherpa-onnx-paraformer-bilingual-zh-en
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
--tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt &
--paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \
--tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt &
echo "sleep 10 seconds to wait the server start"
sleep 10
... ... @@ -101,16 +101,16 @@ jobs:
shell: bash
run: |
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav
- name: Start server for nemo_ctc models
if: matrix.model_type == 'nemo_ctc'
... ...
... ... @@ -24,7 +24,7 @@ jobs:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
model_type: ["transducer"]
model_type: ["transducer", "paraformer"]
steps:
- uses: actions/checkout@v2
... ... @@ -71,3 +71,36 @@ jobs:
run: |
python3 ./python-api-examples/online-websocket-client-decode-file.py \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav
- name: Start server for paraformer models
if: matrix.model_type == 'paraformer'
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
cd sherpa-onnx-streaming-paraformer-bilingual-zh-en
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/streaming_server.py \
--tokens ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
--paraformer-encoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
--paraformer-decoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx &
echo "sleep 10 seconds to wait the server start"
sleep 10
- name: Start client for paraformer models
if: matrix.model_type == 'paraformer'
shell: bash
run: |
python3 ./python-api-examples/online-websocket-client-decode-file.py \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav
python3 ./python-api-examples/online-websocket-client-decode-file.py \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav
python3 ./python-api-examples/online-websocket-client-decode-file.py \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav
python3 ./python-api-examples/online-websocket-client-decode-file.py \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav
... ...
... ... @@ -9,6 +9,7 @@ on:
paths:
- '.github/workflows/windows-x64-cuda.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -20,6 +21,7 @@ on:
paths:
- '.github/workflows/windows-x64-cuda.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -74,6 +76,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test online paraformer for windows x64
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx.exe
.github/scripts/test-online-paraformer.sh
- name: Test offline Whisper for windows x64
shell: bash
run: |
... ...
... ... @@ -9,6 +9,7 @@ on:
paths:
- '.github/workflows/windows-x64.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -20,6 +21,7 @@ on:
paths:
- '.github/workflows/windows-x64.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -75,6 +77,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test online paraformer for windows x64
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx.exe
.github/scripts/test-online-paraformer.sh
- name: Test offline Whisper for windows x64
shell: bash
run: |
... ...
... ... @@ -7,6 +7,7 @@ on:
paths:
- '.github/workflows/windows-x86.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -18,6 +19,7 @@ on:
paths:
- '.github/workflows/windows-x86.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
... ... @@ -73,6 +75,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test online paraformer for windows x86
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx.exe
.github/scripts/test-online-paraformer.sh
- name: Test offline Whisper for windows x86
shell: bash
run: |
... ...
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.7.3")
set(SHERPA_ONNX_VERSION "1.7.4")
# Disable warning about
#
... ...
... ... @@ -37,14 +37,14 @@ python3 ./python-api-examples/non_streaming_server.py \
(2) Use a non-streaming paraformer
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
cd sherpa-onnx-paraformer-zh-2023-03-28
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
cd sherpa-onnx-paraformer-bilingual-zh-en/
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
--tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt
--paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \
--tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt
(3) Use a non-streaming CTC model from NeMo
... ...
... ... @@ -5,16 +5,41 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a streaming model.
Usage:
./online-decode-files.py \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/16kHz.wav \
/path/to/8kHz.wav
(1) Streaming transducer
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
cd sherpa-onnx-streaming-zipformer-en-2023-06-26
git lfs pull --include "*.onnx"
./python-api-examples/online-decode-files.py \
--tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \
--encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
--decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
--joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
(2) Streaming paraformer
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
cd sherpa-onnx-streaming-paraformer-bilingual-zh-en
git lfs pull --include "*.onnx"
./python-api-examples/online-decode-files.py \
--tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
--paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
--paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models
used in this file.
to install sherpa-onnx and to download streaming pre-trained models.
"""
import argparse
import time
... ... @@ -41,19 +66,31 @@ def get_args():
parser.add_argument(
"--encoder",
type=str,
help="Path to the encoder model",
help="Path to the transducer encoder model",
)
parser.add_argument(
"--decoder",
type=str,
help="Path to the decoder model",
help="Path to the transducer decoder model",
)
parser.add_argument(
"--joiner",
type=str,
help="Path to the joiner model",
help="Path to the transducer joiner model",
)
parser.add_argument(
"--paraformer-encoder",
type=str,
help="Path to the paraformer encoder model",
)
parser.add_argument(
"--paraformer-decoder",
type=str,
help="Path to the paraformer decoder model",
)
parser.add_argument(
... ... @@ -200,10 +237,15 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
def main():
args = get_args()
assert_file_exists(args.tokens)
if args.encoder:
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert_file_exists(args.tokens)
assert not args.paraformer_encoder, args.paraformer_encoder
assert not args.paraformer_decoder, args.paraformer_decoder
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
... ... @@ -218,6 +260,19 @@ def main():
max_active_paths=args.max_active_paths,
context_score=args.context_score,
)
elif args.paraformer_encoder:
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
tokens=args.tokens,
encoder=args.paraformer_encoder,
decoder=args.paraformer_decoder,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
decoding_method="greedy_search",
)
else:
raise ValueError("Please provide a model")
print("Started!")
start_time = time.time()
... ... @@ -243,7 +298,7 @@ def main():
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
... ...
... ... @@ -16,9 +16,9 @@ Example:
(1) Without a certificate
python3 ./python-api-examples/streaming_server.py \
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
(2) With a certificate
... ... @@ -32,9 +32,9 @@ python3 ./python-api-examples/streaming_server.py \
(b) Start the server
python3 ./python-api-examples/streaming_server.py \
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
--certificate ./python-api-examples/web/cert.pem
... ... @@ -113,24 +113,33 @@ def setup_logger(
def add_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-model",
"--encoder",
type=str,
required=True,
help="Path to the encoder model",
help="Path to the transducer encoder model",
)
parser.add_argument(
"--decoder-model",
"--decoder",
type=str,
required=True,
help="Path to the decoder model.",
help="Path to the transducer decoder model.",
)
parser.add_argument(
"--joiner-model",
"--joiner",
type=str,
required=True,
help="Path to the joiner model.",
help="Path to the transducer joiner model.",
)
parser.add_argument(
"--paraformer-encoder",
type=str,
help="Path to the paraformer encoder model",
)
parser.add_argument(
"--paraformer-decoder",
type=str,
help="Path to the transducer decoder model.",
)
parser.add_argument(
... ... @@ -323,11 +332,12 @@ def get_args():
def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
if args.encoder:
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
encoder=args.encoder_model,
decoder=args.decoder_model,
joiner=args.joiner_model,
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
... ... @@ -339,6 +349,23 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
elif args.paraformer_encoder:
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
tokens=args.tokens,
encoder=args.paraformer_encoder,
decoder=args.paraformer_decoder,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
enable_endpoint_detection=args.use_endpoint != 0,
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
else:
raise ValueError("Please provide a model")
return recognizer
... ... @@ -654,11 +681,25 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a>
def check_args(args):
assert Path(args.encoder_model).is_file(), f"{args.encoder_model} does not exist"
if args.encoder:
assert Path(args.encoder).is_file(), f"{args.encoder} does not exist"
assert Path(args.decoder_model).is_file(), f"{args.decoder_model} does not exist"
assert Path(args.decoder).is_file(), f"{args.decoder} does not exist"
assert Path(args.joiner_model).is_file(), f"{args.joiner_model} does not exist"
assert Path(args.joiner).is_file(), f"{args.joiner} does not exist"
assert args.paraformer_encoder is None, args.paraformer_encoder
assert args.paraformer_decoder is None, args.paraformer_decoder
elif args.paraformer_encoder:
assert Path(
args.paraformer_encoder
).is_file(), f"{args.paraformer_encoder} does not exist"
assert Path(
args.paraformer_decoder
).is_file(), f"{args.paraformer_decoder} does not exist"
else:
raise ValueError("Please provide a model")
if not Path(args.tokens).is_file():
raise ValueError(f"{args.tokens} does not exist")
... ...
... ... @@ -46,6 +46,8 @@ set(sources
online-lm.cc
online-lstm-transducer-model.cc
online-model-config.cc
online-paraformer-model-config.cc
online-paraformer-model.cc
online-recognizer-impl.cc
online-recognizer.cc
online-rnn-lm.cc
... ...
... ... @@ -39,7 +39,7 @@ std::string FeatureExtractorConfig::ToString() const {
class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) {
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
... ... @@ -50,6 +50,19 @@ class FeatureExtractor::Impl {
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n);
} else {
std::vector<float> buf(n);
for (int32_t i = 0; i != n; ++i) {
buf[i] = waveform[i] * 32768;
}
AcceptWaveformImpl(sampling_rate, buf.data(), n);
}
}
void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
int32_t n) {
std::lock_guard<std::mutex> lock(mutex_);
if (resampler_) {
... ... @@ -146,6 +159,7 @@ class FeatureExtractor::Impl {
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_;
FeatureExtractorConfig config_;
mutable std::mutex mutex_;
std::unique_ptr<LinearResample> resampler_;
int32_t last_frame_index_ = 0;
... ...
... ... @@ -21,6 +21,13 @@ struct FeatureExtractorConfig {
// Feature dimension
int32_t feature_dim = 80;
// Set internally by some models, e.g., paraformer sets it to false.
// This parameter is not exposed to users from the commandline
// If true, the feature extractor expects inputs to be normalized to
// the range [-1, 1].
// If false, we will multiply the inputs by 32768
bool normalize_samples = true;
std::string ToString() const;
void Register(ParseOptions *po);
... ...
... ... @@ -12,6 +12,7 @@ namespace sherpa_onnx {
void OnlineModelConfig::Register(ParseOptions *po) {
transducer.Register(po);
paraformer.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
... ... @@ -41,6 +42,10 @@ bool OnlineModelConfig::Validate() const {
return false;
}
if (!paraformer.encoder.empty()) {
return paraformer.Validate();
}
return transducer.Validate();
}
... ... @@ -49,6 +54,7 @@ std::string OnlineModelConfig::ToString() const {
os << "OnlineModelConfig(";
os << "transducer=" << transducer.ToString() << ", ";
os << "paraformer=" << paraformer.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
... ...
... ... @@ -6,12 +6,14 @@
#include <string>
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
namespace sherpa_onnx {
struct OnlineModelConfig {
OnlineTransducerModelConfig transducer;
OnlineParaformerModelConfig paraformer;
std::string tokens;
int32_t num_threads = 1;
bool debug = false;
... ... @@ -28,9 +30,11 @@ struct OnlineModelConfig {
OnlineModelConfig() = default;
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
const OnlineParaformerModelConfig &paraformer,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
paraformer(paraformer),
tokens(tokens),
num_threads(num_threads),
debug(debug),
... ...
// sherpa-onnx/csrc/online-paraformer-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
struct OnlineParaformerDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
int32_t last_non_blank_frame_index = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
... ...
// sherpa-onnx/csrc/online-paraformer-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OnlineParaformerModelConfig::Register(ParseOptions *po) {
po->Register("paraformer-encoder", &encoder,
"Path to encoder.onnx of paraformer.");
po->Register("paraformer-decoder", &decoder,
"Path to decoder.onnx of paraformer.");
}
bool OnlineParaformerModelConfig::Validate() const {
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("Paraformer encoder %s does not exist", encoder.c_str());
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("Paraformer decoder %s does not exist", decoder.c_str());
return false;
}
return true;
}
std::string OnlineParaformerModelConfig::ToString() const {
std::ostringstream os;
os << "OnlineParaformerModelConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-paraformer-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineParaformerModelConfig {
std::string encoder;
std::string decoder;
OnlineParaformerModelConfig() = default;
OnlineParaformerModelConfig(const std::string &encoder,
const std::string &decoder)
: encoder(encoder), decoder(decoder) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/online-paraformer-model.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-paraformer-model.h"
#include <algorithm>
#include <cmath>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OnlineParaformerModel::Impl {
public:
explicit Impl(const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.paraformer.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.paraformer.decoder);
InitDecoder(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.paraformer.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.paraformer.decoder);
InitDecoder(buf.data(), buf.size());
}
}
#endif
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> inputs = {std::move(features),
std::move(features_length)};
return encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
}
std::vector<Ort::Value> ForwardDecoder(Ort::Value encoder_out,
Ort::Value encoder_out_length,
Ort::Value acoustic_embedding,
Ort::Value acoustic_embedding_length,
std::vector<Ort::Value> states) {
std::vector<Ort::Value> decoder_inputs;
decoder_inputs.reserve(4 + states.size());
decoder_inputs.push_back(std::move(encoder_out));
decoder_inputs.push_back(std::move(encoder_out_length));
decoder_inputs.push_back(std::move(acoustic_embedding));
decoder_inputs.push_back(std::move(acoustic_embedding_length));
for (auto &v : states) {
decoder_inputs.push_back(std::move(v));
}
return decoder_sess_->Run({}, decoder_input_names_ptr_.data(),
decoder_inputs.data(), decoder_inputs.size(),
decoder_output_names_ptr_.data(),
decoder_output_names_ptr_.size());
}
int32_t VocabSize() const { return vocab_size_; }
int32_t LfrWindowSize() const { return lfr_window_size_; }
int32_t LfrWindowShift() const { return lfr_window_shift_; }
int32_t EncoderOutputSize() const { return encoder_output_size_; }
int32_t DecoderKernelSize() const { return decoder_kernel_size_; }
int32_t DecoderNumBlocks() const { return decoder_num_blocks_; }
const std::vector<float> &NegativeMean() const { return neg_mean_; }
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size");
SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift");
SHERPA_ONNX_READ_META_DATA(encoder_output_size_, "encoder_output_size");
SHERPA_ONNX_READ_META_DATA(decoder_num_blocks_, "decoder_num_blocks");
SHERPA_ONNX_READ_META_DATA(decoder_kernel_size_, "decoder_kernel_size");
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean");
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev");
float scale = std::sqrt(encoder_output_size_);
for (auto &f : inv_stddev_) {
f *= scale;
}
}
void InitDecoder(void *model_data, size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
private:
OnlineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<float> neg_mean_;
std::vector<float> inv_stddev_;
int32_t vocab_size_ = 0; // initialized in Init
int32_t lfr_window_size_ = 0;
int32_t lfr_window_shift_ = 0;
int32_t encoder_output_size_ = 0;
int32_t decoder_num_blocks_ = 0;
int32_t decoder_kernel_size_ = 0;
};
OnlineParaformerModel::OnlineParaformerModel(const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OnlineParaformerModel::OnlineParaformerModel(AAssetManager *mgr,
const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OnlineParaformerModel::~OnlineParaformerModel() = default;
std::vector<Ort::Value> OnlineParaformerModel::ForwardEncoder(
Ort::Value features, Ort::Value features_length) const {
return impl_->ForwardEncoder(std::move(features), std::move(features_length));
}
std::vector<Ort::Value> OnlineParaformerModel::ForwardDecoder(
Ort::Value encoder_out, Ort::Value encoder_out_length,
Ort::Value acoustic_embedding, Ort::Value acoustic_embedding_length,
std::vector<Ort::Value> states) const {
return impl_->ForwardDecoder(
std::move(encoder_out), std::move(encoder_out_length),
std::move(acoustic_embedding), std::move(acoustic_embedding_length),
std::move(states));
}
int32_t OnlineParaformerModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OnlineParaformerModel::LfrWindowSize() const {
return impl_->LfrWindowSize();
}
int32_t OnlineParaformerModel::LfrWindowShift() const {
return impl_->LfrWindowShift();
}
int32_t OnlineParaformerModel::EncoderOutputSize() const {
return impl_->EncoderOutputSize();
}
int32_t OnlineParaformerModel::DecoderKernelSize() const {
return impl_->DecoderKernelSize();
}
int32_t OnlineParaformerModel::DecoderNumBlocks() const {
return impl_->DecoderNumBlocks();
}
const std::vector<float> &OnlineParaformerModel::NegativeMean() const {
return impl_->NegativeMean();
}
const std::vector<float> &OnlineParaformerModel::InverseStdDev() const {
return impl_->InverseStdDev();
}
OrtAllocator *OnlineParaformerModel::Allocator() const {
return impl_->Allocator();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-paraformer-model.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-model-config.h"
namespace sherpa_onnx {
class OnlineParaformerModel {
public:
explicit OnlineParaformerModel(const OnlineModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineParaformerModel(AAssetManager *mgr, const OnlineModelConfig &config);
#endif
~OnlineParaformerModel();
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
Ort::Value features_length) const;
std::vector<Ort::Value> ForwardDecoder(Ort::Value encoder_out,
Ort::Value encoder_out_length,
Ort::Value acoustic_embedding,
Ort::Value acoustic_embedding_length,
std::vector<Ort::Value> states) const;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const;
/** It is lfr_m in config.yaml
*/
int32_t LfrWindowSize() const;
/** It is lfr_n in config.yaml
*/
int32_t LfrWindowShift() const;
int32_t EncoderOutputSize() const;
int32_t DecoderKernelSize() const;
int32_t DecoderNumBlocks() const;
/** Return negative mean for CMVN
*/
const std::vector<float> &NegativeMean() const;
/** Return inverse stddev for CMVN
*/
const std::vector<float> &InverseStdDev() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
... ...
... ... @@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
namespace sherpa_onnx {
... ... @@ -14,6 +15,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
}
if (!config.model_config.paraformer.encoder.empty()) {
return std::make_unique<OnlineRecognizerParaformerImpl>(config);
}
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
}
... ... @@ -25,6 +30,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
}
if (!config.model_config.paraformer.encoder.empty()) {
return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config);
}
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
}
... ...
... ... @@ -26,8 +26,6 @@ class OnlineRecognizerImpl {
virtual ~OnlineRecognizerImpl() = default;
virtual void InitOnlineStream(OnlineStream *stream) const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream(
... ...
// sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-lm.h"
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
#include "sherpa-onnx/csrc/online-paraformer-model.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace sherpa_onnx {
static OnlineRecognizerResult Convert(const OnlineParaformerDecoderResult &src,
const SymbolTable &sym_table) {
OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
std::string text;
// When the current token ends with "@@" we set mergeable to true
bool mergeable = false;
for (int32_t i = 0; i != src.tokens.size(); ++i) {
auto sym = sym_table[src.tokens[i]];
r.tokens.push_back(sym);
if ((sym.back() != '@') || (sym.size() > 2 && sym[sym.size() - 2] != '@')) {
// sym does not end with "@@"
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
if (p[0] < 0x80) {
// an ascii
if (mergeable) {
mergeable = false;
text.append(sym);
} else {
text.append(" ");
text.append(sym);
}
} else {
// not an ascii
mergeable = false;
if (i > 0) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(
sym_table[src.tokens[i - 1]].c_str());
if (p[0] < 0x80) {
// put a space between ascii and non-ascii
text.append(" ");
}
}
text.append(sym);
}
} else {
// this sym ends with @@
sym = std::string(sym.data(), sym.size() - 2);
if (mergeable) {
text.append(sym);
} else {
text.append(" ");
text.append(sym);
mergeable = true;
}
}
}
r.text = std::move(text);
return r;
}
// y[i] += x[i] * scale
static void ScaleAddInPlace(const float *x, int32_t n, float scale, float *y) {
for (int32_t i = 0; i != n; ++i) {
y[i] += x[i] * scale;
}
}
// y[i] = x[i] * scale
static void Scale(const float *x, int32_t n, float scale, float *y) {
for (int32_t i = 0; i != n; ++i) {
y[i] = x[i] * scale;
}
}
class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
public:
explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config)
: config_(config),
model_(config.model_config),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (config.decoding_method != "greedy_search") {
SHERPA_ONNX_LOGE(
"Unsupported decoding method: %s. Support only greedy_search at "
"present",
config.decoding_method.c_str());
exit(-1);
}
// Paraformer models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
#if __ANDROID_API__ >= 9
explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr,
const OnlineRecognizerConfig &config)
: config_(config),
model_(mgr, config.model_config),
sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (config.decoding_method == "greedy_search") {
// add greedy search decoder
// SHERPA_ONNX_LOGE("to be implemented");
// exit(-1);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
exit(-1);
}
// Paraformer models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
#endif
OnlineRecognizerParaformerImpl(const OnlineRecognizerParaformerImpl &) =
delete;
OnlineRecognizerParaformerImpl operator=(
const OnlineRecognizerParaformerImpl &) = delete;
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
OnlineParaformerDecoderResult r;
stream->SetParaformerResult(r);
return stream;
}
bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() + chunk_size_ < s->NumFramesReady();
}
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
// TODO(fangjun): Support batch size > 1
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
}
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
auto decoder_result = s->GetParaformerResult();
return Convert(decoder_result, sym_);
}
bool IsEndpoint(OnlineStream *s) const override {
if (!config_.enable_endpoint) {
return false;
}
const auto &result = s->GetParaformerResult();
int32_t num_processed_frames = s->GetNumProcessedFrames();
// frame shift is 10 milliseconds
float frame_shift_in_seconds = 0.01;
int32_t trailing_silence_frames =
num_processed_frames - result.last_non_blank_frame_index;
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
frame_shift_in_seconds);
}
void Reset(OnlineStream *s) const override {
OnlineParaformerDecoderResult r;
s->SetParaformerResult(r);
// the internal model caches are not reset
// Note: We only update counters. The underlying audio samples
// are not discarded.
s->Reset();
}
private:
void DecodeStream(OnlineStream *s) const {
const auto num_processed_frames = s->GetNumProcessedFrames();
std::vector<float> frames = s->GetFrames(num_processed_frames, chunk_size_);
s->GetNumProcessedFrames() += chunk_size_ - 1;
frames = ApplyLFR(frames);
ApplyCMVN(&frames);
PositionalEncoding(&frames, num_processed_frames / model_.LfrWindowShift());
int32_t feat_dim = model_.NegativeMean().size();
// We have scaled inv_stddev by sqrt(encoder_output_size)
// so the following line can be commented out
// frames *= encoder_output_size ** 0.5
// add overlap chunk
std::vector<float> &feat_cache = s->GetParaformerFeatCache();
if (feat_cache.empty()) {
int32_t n = (left_chunk_size_ + right_chunk_size_) * feat_dim;
feat_cache.resize(n, 0);
}
frames.insert(frames.begin(), feat_cache.begin(), feat_cache.end());
std::copy(frames.end() - feat_cache.size(), frames.end(),
feat_cache.begin());
int32_t num_frames = frames.size() / feat_dim;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, frames.data(), frames.size(),
x_shape.data(), x_shape.size());
int64_t x_len_shape = 1;
int32_t x_len_val = num_frames;
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &x_len_val, 1, &x_len_shape, 1);
auto encoder_out_vec =
model_.ForwardEncoder(std::move(x), std::move(x_length));
// CIF search
auto &encoder_out = encoder_out_vec[0];
auto &encoder_out_len = encoder_out_vec[1];
auto &alpha = encoder_out_vec[2];
float *p_alpha = alpha.GetTensorMutableData<float>();
std::vector<int64_t> alpha_shape =
alpha.GetTensorTypeAndShapeInfo().GetShape();
std::fill(p_alpha, p_alpha + left_chunk_size_, 0);
std::fill(p_alpha + alpha_shape[1] - right_chunk_size_,
p_alpha + alpha_shape[1], 0);
const float *p_encoder_out = encoder_out.GetTensorData<float>();
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
std::vector<float> &initial_hidden = s->GetParaformerEncoderOutCache();
if (initial_hidden.empty()) {
initial_hidden.resize(encoder_out_shape[2]);
}
std::vector<float> &alpha_cache = s->GetParaformerAlphaCache();
if (alpha_cache.empty()) {
alpha_cache.resize(1);
}
std::vector<float> acoustic_embedding;
acoustic_embedding.reserve(encoder_out_shape[1] * encoder_out_shape[2]);
float threshold = 1.0;
float integrate = alpha_cache[0];
for (int32_t i = 0; i != encoder_out_shape[1]; ++i) {
float this_alpha = p_alpha[i];
if (integrate + this_alpha < threshold) {
integrate += this_alpha;
ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2],
encoder_out_shape[2], this_alpha,
initial_hidden.data());
continue;
}
// fire
ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2],
encoder_out_shape[2], threshold - integrate,
initial_hidden.data());
acoustic_embedding.insert(acoustic_embedding.end(),
initial_hidden.begin(), initial_hidden.end());
integrate += this_alpha - threshold;
Scale(p_encoder_out + i * encoder_out_shape[2], encoder_out_shape[2],
integrate, initial_hidden.data());
}
alpha_cache[0] = integrate;
if (acoustic_embedding.empty()) {
return;
}
auto &states = s->GetStates();
if (states.empty()) {
states.reserve(model_.DecoderNumBlocks());
std::array<int64_t, 3> shape{1, model_.EncoderOutputSize(),
model_.DecoderKernelSize() - 1};
int32_t num_bytes = sizeof(float) * shape[0] * shape[1] * shape[2];
for (int32_t i = 0; i != model_.DecoderNumBlocks(); ++i) {
Ort::Value this_state = Ort::Value::CreateTensor<float>(
model_.Allocator(), shape.data(), shape.size());
memset(this_state.GetTensorMutableData<float>(), 0, num_bytes);
states.push_back(std::move(this_state));
}
}
int32_t num_tokens = acoustic_embedding.size() / initial_hidden.size();
std::array<int64_t, 3> acoustic_embedding_shape{
1, num_tokens, static_cast<int32_t>(initial_hidden.size())};
Ort::Value acoustic_embedding_tensor = Ort::Value::CreateTensor(
memory_info, acoustic_embedding.data(), acoustic_embedding.size(),
acoustic_embedding_shape.data(), acoustic_embedding_shape.size());
std::array<int64_t, 1> acoustic_embedding_length_shape{1};
Ort::Value acoustic_embedding_length_tensor = Ort::Value::CreateTensor(
memory_info, &num_tokens, 1, acoustic_embedding_length_shape.data(),
acoustic_embedding_length_shape.size());
auto decoder_out_vec = model_.ForwardDecoder(
std::move(encoder_out), std::move(encoder_out_len),
std::move(acoustic_embedding_tensor),
std::move(acoustic_embedding_length_tensor), std::move(states));
states.reserve(model_.DecoderNumBlocks());
for (int32_t i = 2; i != decoder_out_vec.size(); ++i) {
// TODO(fangjun): When we change chunk_size_, we need to
// slice decoder_out_vec[i] accordingly.
states.push_back(std::move(decoder_out_vec[i]));
}
const auto &sample_ids = decoder_out_vec[1];
const int64_t *p_sample_ids = sample_ids.GetTensorData<int64_t>();
bool non_blank_detected = false;
auto &result = s->GetParaformerResult();
for (int32_t i = 0; i != num_tokens; ++i) {
int32_t t = p_sample_ids[i];
if (t == 0) {
continue;
}
non_blank_detected = true;
result.tokens.push_back(t);
}
if (non_blank_detected) {
result.last_non_blank_frame_index = num_processed_frames;
}
}
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
int32_t lfr_window_size = model_.LfrWindowSize();
int32_t lfr_window_shift = model_.LfrWindowShift();
int32_t in_feat_dim = config_.feat_config.feature_dim;
int32_t in_num_frames = in.size() / in_feat_dim;
int32_t out_num_frames =
(in_num_frames - lfr_window_size) / lfr_window_shift + 1;
int32_t out_feat_dim = in_feat_dim * lfr_window_size;
std::vector<float> out(out_num_frames * out_feat_dim);
const float *p_in = in.data();
float *p_out = out.data();
for (int32_t i = 0; i != out_num_frames; ++i) {
std::copy(p_in, p_in + out_feat_dim, p_out);
p_out += out_feat_dim;
p_in += lfr_window_shift * in_feat_dim;
}
return out;
}
void ApplyCMVN(std::vector<float> *v) const {
const std::vector<float> &neg_mean = model_.NegativeMean();
const std::vector<float> &inv_stddev = model_.InverseStdDev();
int32_t dim = neg_mean.size();
int32_t num_frames = v->size() / dim;
float *p = v->data();
for (int32_t i = 0; i != num_frames; ++i) {
for (int32_t k = 0; k != dim; ++k) {
p[k] = (p[k] + neg_mean[k]) * inv_stddev[k];
}
p += dim;
}
}
void PositionalEncoding(std::vector<float> *v, int32_t t_offset) const {
int32_t lfr_window_size = model_.LfrWindowSize();
int32_t in_feat_dim = config_.feat_config.feature_dim;
int32_t feat_dim = in_feat_dim * lfr_window_size;
int32_t T = v->size() / feat_dim;
// log(10000)/(7*80/2-1) == 0.03301197265941284
// 7 is lfr_window_size
// 80 is in_feat_dim
// 7*80 is feat_dim
constexpr float kScale = -0.03301197265941284;
for (int32_t t = 0; t != T; ++t) {
float *p = v->data() + t * feat_dim;
int32_t offset = t + 1 + t_offset;
for (int32_t d = 0; d < feat_dim / 2; ++d) {
float inv_timescale = offset * std::exp(d * kScale);
float sin_d = std::sin(inv_timescale);
float cos_d = std::cos(inv_timescale);
p[d] += sin_d;
p[d + feat_dim / 2] += cos_d;
}
}
}
private:
OnlineRecognizerConfig config_;
OnlineParaformerModel model_;
SymbolTable sym_;
Endpoint endpoint_;
// 0.61 seconds
int32_t chunk_size_ = 61;
// (61 - 7) / 6 + 1 = 10
int32_t left_chunk_size_ = 5;
int32_t right_chunk_size_ = 5;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
... ...
... ... @@ -94,21 +94,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
#endif
void InitOnlineStream(OnlineStream *stream) const override {
auto r = decoder_->GetEmptyResult();
if (config_.decoding_method == "modified_beam_search" &&
nullptr != stream->GetContextGraph()) {
// r.hyps has only one element.
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
it->second.context_state = stream->GetContextGraph()->Root();
}
}
stream->SetResult(r);
stream->SetStates(model_->GetEncoderInitStates());
}
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
InitOnlineStream(stream.get());
... ... @@ -211,7 +196,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
bool IsEndpoint(OnlineStream *s) const override {
if (!config_.enable_endpoint) return false;
if (!config_.enable_endpoint) {
return false;
}
int32_t num_processed_frames = s->GetNumProcessedFrames();
// frame shift is 10 milliseconds
... ... @@ -245,6 +233,22 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
private:
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
if (config_.decoding_method == "modified_beam_search" &&
nullptr != stream->GetContextGraph()) {
// r.hyps has only one element.
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
it->second.context_state = stream->GetContextGraph()->Root();
}
}
stream->SetResult(r);
stream->SetStates(model_->GetEncoderInitStates());
}
private:
OnlineRecognizerConfig config_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineLM> lm_;
... ...
... ... @@ -47,6 +47,14 @@ class OnlineStream::Impl {
OnlineTransducerDecoderResult &GetResult() { return result_; }
void SetParaformerResult(const OnlineParaformerDecoderResult &r) {
paraformer_result_ = r;
}
OnlineParaformerDecoderResult &GetParaformerResult() {
return paraformer_result_;
}
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
void SetStates(std::vector<Ort::Value> states) {
... ... @@ -57,6 +65,18 @@ class OnlineStream::Impl {
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
std::vector<float> &GetParaformerFeatCache() {
return paraformer_feat_cache_;
}
std::vector<float> &GetParaformerEncoderOutCache() {
return paraformer_encoder_out_cache_;
}
std::vector<float> &GetParaformerAlphaCache() {
return paraformer_alpha_cache_;
}
private:
FeatureExtractor feat_extractor_;
/// For contextual-biasing
... ... @@ -65,6 +85,10 @@ class OnlineStream::Impl {
int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_;
std::vector<float> paraformer_feat_cache_;
std::vector<float> paraformer_encoder_out_cache_;
std::vector<float> paraformer_alpha_cache_;
OnlineParaformerDecoderResult paraformer_result_;
};
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
... ... @@ -107,6 +131,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult();
}
void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) {
impl_->SetParaformerResult(r);
}
OnlineParaformerDecoderResult &OnlineStream::GetParaformerResult() {
return impl_->GetParaformerResult();
}
void OnlineStream::SetStates(std::vector<Ort::Value> states) {
impl_->SetStates(std::move(states));
}
... ... @@ -119,4 +151,16 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
std::vector<float> &OnlineStream::GetParaformerFeatCache() {
return impl_->GetParaformerFeatCache();
}
std::vector<float> &OnlineStream::GetParaformerEncoderOutCache() {
return impl_->GetParaformerEncoderOutCache();
}
std::vector<float> &OnlineStream::GetParaformerAlphaCache() {
return impl_->GetParaformerAlphaCache();
}
} // namespace sherpa_onnx
... ...
... ... @@ -11,6 +11,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
namespace sherpa_onnx {
... ... @@ -70,6 +71,9 @@ class OnlineStream {
void SetResult(const OnlineTransducerDecoderResult &r);
OnlineTransducerDecoderResult &GetResult();
void SetParaformerResult(const OnlineParaformerDecoderResult &r);
OnlineParaformerDecoderResult &GetParaformerResult();
void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates();
... ... @@ -80,6 +84,11 @@ class OnlineStream {
*/
const ContextGraphPtr &GetContextGraph() const;
// for streaming parformer
std::vector<float> &GetParaformerFeatCache();
std::vector<float> &GetParaformerEncoderOutCache();
std::vector<float> &GetParaformerAlphaCache();
private:
class Impl;
std::unique_ptr<Impl> impl_;
... ...
... ... @@ -12,8 +12,8 @@
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
typedef struct {
... ... @@ -92,14 +92,14 @@ for a list of pre-trained models to download.
auto s = recognizer.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate));
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
// Note: We can call AcceptWaveform() multiple times.
s->AcceptWaveform(
sampling_rate, tail_paddings.data(), tail_paddings.size());
s->AcceptWaveform(sampling_rate, tail_paddings.data(),
tail_paddings.size());
// Call InputFinished() to indicate that no audio samples are available
s->InputFinished();
ss.push_back({ std::move(s), duration, 0 });
ss.push_back({std::move(s), duration, 0});
}
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
... ... @@ -113,7 +113,8 @@ for a list of pre-trained models to download.
const auto end = std::chrono::steady_clock::now();
const float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() / 1000.;
.count() /
1000.;
s.elapsed_seconds = elapsed_seconds;
}
}
... ...
... ... @@ -15,6 +15,7 @@ pybind11_add_module(_sherpa_onnx
offline-whisper-model-config.cc
online-lm-config.cc
online-model-config.cc
online-paraformer-model-config.cc
online-recognizer.cc
online-stream.cc
online-transducer-model-config.cc
... ...
// sherpa-onnx/python/csrc/online-model-config.cc
//
// Copyright (c) 2023 by manyeyes
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-model-config.h"
... ... @@ -9,21 +9,26 @@
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
namespace sherpa_onnx {
void PybindOnlineModelConfig(py::module *m) {
PybindOnlineTransducerModelConfig(m);
PybindOnlineParaformerModelConfig(m);
using PyClass = OnlineModelConfig;
py::class_<PyClass>(*m, "OnlineModelConfig")
.def(py::init<const OnlineTransducerModelConfig &, std::string &, int32_t,
.def(py::init<const OnlineTransducerModelConfig &,
const OnlineParaformerModelConfig &, std::string &, int32_t,
bool, const std::string &, const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
... ...
// sherpa-onnx/python/csrc/online-model-config.h
//
// Copyright (c) 2023 by manyeyes
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
... ...
// sherpa-onnx/python/csrc/online-paraformer-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
namespace sherpa_onnx {
void PybindOnlineParaformerModelConfig(py::module *m) {
using PyClass = OnlineParaformerModelConfig;
py::class_<PyClass>(*m, "OnlineParaformerModelConfig")
.def(py::init<const std::string &, const std::string &>(),
py::arg("encoder"), py::arg("decoder"))
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("decoder", &PyClass::decoder)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/online-paraformer-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOnlineParaformerModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
... ...
... ... @@ -33,7 +33,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths"), py::arg("context_score"))
py::arg("max_active_paths") = 4, py::arg("context_score") = 0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
... ...
... ... @@ -6,6 +6,7 @@ from _sherpa_onnx import (
EndpointConfig,
FeatureExtractorConfig,
OnlineModelConfig,
OnlineParaformerModelConfig,
OnlineRecognizer as _Recognizer,
OnlineRecognizerConfig,
OnlineStream,
... ... @@ -32,7 +33,7 @@ class OnlineRecognizer(object):
encoder: str,
decoder: str,
joiner: str,
num_threads: int = 4,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
... ... @@ -144,6 +145,109 @@ class OnlineRecognizer(object):
self.config = recognizer_config
return self
@classmethod
def from_paraformer(
cls,
tokens: str,
encoder: str,
decoder: str,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
rule1_min_trailing_silence: float = 2.4,
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
encoder:
Path to ``encoder.onnx``.
decoder:
Path to ``decoder.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
enable_endpoint_detection:
True to enable endpoint detection. False to disable endpoint
detection.
rule1_min_trailing_silence:
Used only when enable_endpoint_detection is True. If the duration
of trailing silence in seconds is larger than this value, we assume
an endpoint is detected.
rule2_min_trailing_silence:
Used only when enable_endpoint_detection is True. If we have decoded
something that is nonsilence and if the duration of trailing silence
in seconds is larger than this value, we assume an endpoint is
detected.
rule3_min_utterance_length:
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
The only valid value is greedy_search.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
_assert_file_exists(encoder)
_assert_file_exists(decoder)
assert num_threads > 0, num_threads
paraformer_config = OnlineParaformerModelConfig(
encoder=encoder,
decoder=decoder,
)
model_config = OnlineModelConfig(
paraformer=paraformer_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
model_type="paraformer",
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
endpoint_config = EndpointConfig(
rule1_min_trailing_silence=rule1_min_trailing_silence,
rule2_min_trailing_silence=rule2_min_trailing_silence,
rule3_min_utterance_length=rule3_min_utterance_length,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()
... ...