Fangjun Kuang
Committed by GitHub

Add online LSTM transducer model (#25)

#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run LSTM transducer (English)"
log "------------------------------------------------------------"
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
popd
waves=(
$repo/test_wavs/1089-134686-0001.wav
$repo/test_wavs/1221-135766-0001.wav
$repo/test_wavs/1221-135766-0002.wav
)
for wave in ${waves[@]}; do
time $EXE \
$repo/tokens.txt \
$repo/encoder-epoch-99-avg-1.onnx \
$repo/decoder-epoch-99-avg-1.onnx \
$repo/joiner-epoch-99-avg-1.onnx \
$wave \
4
done
rm -rf $repo
... ...
name: linux
on:
push:
branches:
- master
paths:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-online-transducer.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
pull_request:
branches:
- master
paths:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-online-transducer.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
concurrency:
group: linux-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
linux:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Configure CMake
shell: bash
run: |
mkdir build
cd build
cmake -D CMAKE_BUILD_TYPE=Release ..
- name: Build sherpa-onnx for ubuntu
shell: bash
run: |
cd build
make -j2
ls -lh lib
ls -lh bin
- name: Display dependencies of sherpa-onnx for linux
shell: bash
run: |
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
- name: Test online transducer
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-transducer.sh
... ...
name: macos
on:
push:
branches:
- master
paths:
- '.github/workflows/macos.yaml'
- '.github/scripts/test-online-transducer.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
pull_request:
branches:
- master
paths:
- '.github/workflows/macos.yaml'
- '.github/scripts/test-online-transducer.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
concurrency:
group: macos-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
macos:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Configure CMake
shell: bash
run: |
mkdir build
cd build
cmake -D CMAKE_BUILD_TYPE=Release ..
- name: Build sherpa for macos
shell: bash
run: |
cd build
make -j2
ls -lh lib
ls -lh bin
- name: Display dependencies of sherpa-onnx for macos
shell: bash
run: |
file bin/sherpa-onnx
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test online transducer
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-transducer.sh
... ...
name: test-linux-macos-windows
on:
push:
branches:
- master
paths:
- '.github/workflows/test-linux-macos-windows.yaml'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
pull_request:
branches:
- master
paths:
- '.github/workflows/test-linux-macos-windows.yaml'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
concurrency:
group: test-linux-macos-windows-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
test-linux-macos-windows:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
# see https://github.com/microsoft/setup-msbuild
- name: Add msbuild to PATH
if: startsWith(matrix.os, 'windows')
uses: microsoft/setup-msbuild@v1.0.2
- name: Download pretrained model and test-data (English)
shell: bash
run: |
git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
ls -lh exp/onnx/*.onnx
git lfs pull --include "exp/onnx/*.onnx"
ls -lh exp/onnx/*.onnx
- name: Download pretrained model and test-data (Chinese)
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2
cd icefall_asr_wenetspeech_pruned_transducer_stateless2
ls -lh exp/*.onnx
git lfs pull --include "exp/*.onnx"
ls -lh exp/*.onnx
- name: Configure CMake
shell: bash
run: |
mkdir build
cd build
cmake -D CMAKE_BUILD_TYPE=Release ..
- name: Build sherpa-onnx for ubuntu/macos
if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
shell: bash
run: |
cd build
make VERBOSE=1 -j3
- name: Build sherpa-onnx for Windows
if: startsWith(matrix.os, 'windows')
shell: bash
run: |
cmake --build ./build --config Release
- name: Run tests for ubuntu/macos (English)
if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
shell: bash
run: |
time ./build/bin/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
time ./build/bin/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
time ./build/bin/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
- name: Run tests for Windows (English)
if: startsWith(matrix.os, 'windows')
shell: bash
run: |
./build/bin/Release/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
./build/bin/Release/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
./build/bin/Release/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
- name: Run tests for ubuntu/macos (Chinese)
if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
shell: bash
run: |
time ./build/bin/sherpa-onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav
time ./build/bin/sherpa-onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000001.wav
time ./build/bin/sherpa-onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000002.wav
- name: Run tests for windows (Chinese)
if: startsWith(matrix.os, 'windows')
shell: bash
run: |
./build/bin/Release/sherpa-onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav
./build/bin/Release/sherpa-onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000001.wav
./build/bin/Release/sherpa-onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000002.wav
name: windows-x64
on:
push:
branches:
- master
paths:
- '.github/workflows/windows-x64.yaml'
- '.github/scripts/test-online-transducer.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
pull_request:
branches:
- master
paths:
- '.github/workflows/windows-x64.yaml'
- '.github/scripts/test-online-transducer.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
concurrency:
group: windows-x64-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
windows_x64:
runs-on: ${{ matrix.os }}
name: ${{ matrix.vs-version }}
strategy:
fail-fast: false
matrix:
include:
- vs-version: vs2015
toolset-version: v140
os: windows-2019
- vs-version: vs2017
toolset-version: v141
os: windows-2019
- vs-version: vs2019
toolset-version: v142
os: windows-2022
- vs-version: vs2022
toolset-version: v143
os: windows-2022
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Configure CMake
shell: bash
run: |
mkdir build
cd build
cmake -T ${{ matrix.toolset-version}},host=x64 -A x64 -D CMAKE_BUILD_TYPE=Release ..
- name: Build sherpa-onnx for windows
shell: bash
run: |
cd build
cmake --build . --config Release -- -m:2
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test sherpa-onnx for Windows x64
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx.exe
.github/scripts/test-online-transducer.sh
... ...
... ... @@ -4,3 +4,4 @@ build
onnxruntime-*
icefall-*
run.sh
sherpa-onnx-*
... ...
... ... @@ -2,89 +2,7 @@
Documentation: <https://k2-fsa.github.io/sherpa/onnx/index.html>
Try it in colab:
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tmQbdlYeTl_klmtaGiUb7a7ZPz-AkBSH?usp=sharing)
See <https://github.com/k2-fsa/sherpa>
This repo uses [onnxruntime](https://github.com/microsoft/onnxruntime) and
does not depend on libtorch.
We provide exported models in onnx format and they can be downloaded using
the following links:
- English: <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13>
- Chinese: <https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2>
**NOTE**: We provide only non-streaming models at present.
**HINT**: The script for exporting the English model can be found at
<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py>
**HINT**: The script for exporting the Chinese model can be found at
<https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py>
## Build for Linux/macOS
```bash
git clone https://github.com/k2-fsa/sherpa-onnx
cd sherpa-onnx
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j6
cd ..
```
## Build for Windows
```bash
git clone https://github.com/k2-fsa/sherpa-onnx
cd sherpa-onnx
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
cmake --build . --config Release
cd ..
```
## Download the pretrained model (English)
```bash
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
git lfs pull --include "exp/onnx/*.onnx"
cd ..
./build/bin/sherpa-onnx --help
./build/bin/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
```
## Download the pretrained model (Chinese)
```bash
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2
cd icefall_asr_wenetspeech_pruned_transducer_stateless2
git lfs pull --include "exp/*.onnx"
cd ..
./build/bin/sherpa-onnx --help
./build/bin/sherpa-onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav
```
... ...
... ... @@ -2,7 +2,11 @@ include_directories(${CMAKE_SOURCE_DIR})
add_executable(sherpa-onnx
decode.cc
rnnt-model.cc
features.cc
online-lstm-transducer-model.cc
online-transducer-model-config.cc
online-transducer-model.cc
onnx-utils.cc
sherpa-onnx.cc
symbol-table.cc
wave-reader.cc
... ... @@ -13,5 +17,5 @@ target_link_libraries(sherpa-onnx
kaldi-native-fbank-core
)
# add_executable(sherpa-show-onnx-info show-onnx-info.cc)
# target_link_libraries(sherpa-show-onnx-info onnxruntime)
add_executable(sherpa-onnx-show-info show-onnx-info.cc)
target_link_libraries(sherpa-onnx-show-info onnxruntime)
... ...
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// sherpa/csrc/decode.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/decode.h"
#include <assert.h>
#include <algorithm>
#include <utility>
#include <vector>
namespace sherpa_onnx {
std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
const Ort::Value &encoder_out) {
static Ort::Value Clone(Ort::Value *v) {
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
std::vector<int64_t> shape = type_and_shape.GetShape();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),
type_and_shape.GetElementCount(),
shape.data(), shape.size());
}
static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented");
Ort::Value projected_encoder_out =
model.RunJoinerEncoderProj(encoder_out.GetTensorData<float>(),
encoder_out_shape[1], encoder_out_shape[2]);
encoder_out->GetTensorTypeAndShapeInfo().GetShape();
assert(encoder_out_shape[0] == 1);
const float *p_projected_encoder_out =
projected_encoder_out.GetTensorData<float>();
int32_t encoder_out_dim = encoder_out_shape[2];
int32_t context_size = 2; // hard-code it to 2
int32_t blank_id = 0; // hard-code it to 0
std::vector<int32_t> hyp(context_size, blank_id);
std::array<int64_t, 2> decoder_input{blank_id, blank_id};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size);
std::array<int64_t, 2> shape{1, encoder_out_dim};
std::vector<int64_t> decoder_out_shape =
decoder_out.GetTensorTypeAndShapeInfo().GetShape();
return Ort::Value::CreateTensor(
memory_info,
encoder_out->GetTensorMutableData<float>() + t * encoder_out_dim,
encoder_out_dim, shape.data(), shape.size());
}
Ort::Value projected_decoder_out = model.RunJoinerDecoderProj(
decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,
std::vector<int64_t> *hyp) {
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
int32_t joiner_dim =
projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1];
if (encoder_out_shape[0] > 1) {
fprintf(stderr, "Only batch_size=1 is implemented. Given: %d\n",
static_cast<int32_t>(encoder_out_shape[0]));
}
int32_t T = encoder_out_shape[1];
for (int32_t t = 0; t != T; ++t) {
Ort::Value logit = model.RunJoiner(
p_projected_encoder_out + t * joiner_dim,
projected_decoder_out.GetTensorData<float>(), joiner_dim);
int32_t num_frames = encoder_out_shape[1];
int32_t vocab_size = model->VocabSize();
int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1];
Ort::Value decoder_input = model->BuildDecoderInput(*hyp);
Ort::Value decoder_out = model->RunDecoder(std::move(decoder_input));
for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
Ort::Value logit =
model->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
const float *p_logit = logit.GetTensorData<float>();
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
if (y != blank_id) {
decoder_input[0] = hyp.back();
decoder_input[1] = y;
hyp.push_back(y);
decoder_out = model.RunDecoder(decoder_input.data(), context_size);
projected_decoder_out = model.RunJoinerDecoderProj(
decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
if (y != 0) {
hyp->push_back(y);
decoder_input = model->BuildDecoderInput(*hyp);
decoder_out = model->RunDecoder(std::move(decoder_input));
}
}
return {hyp.begin() + context_size, hyp.end()};
}
} // namespace sherpa_onnx
... ...
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// sherpa/csrc/decode.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_DECODE_H_
#define SHERPA_ONNX_CSRC_DECODE_H_
#include <vector>
#include "sherpa-onnx/csrc/rnnt-model.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace sherpa_onnx {
... ... @@ -32,8 +18,8 @@ namespace sherpa_onnx {
* @param model The RnntModel
* @param encoder_out Its shape is (1, num_frames, encoder_out_dim).
*/
std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
const Ort::Value &encoder_out);
void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,
std::vector<int64_t> *hyp);
} // namespace sherpa_onnx
... ...
// sherpa/csrc/features.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/features.h"
#include <algorithm>
#include <memory>
#include <vector>
namespace sherpa_onnx {
FeatureExtractor::FeatureExtractor() {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = 16000;
// cache 100 seconds of feature frames, which is more than enough
// for real needs
opts_.frame_opts.max_feature_vectors = 100 * 100;
opts_.mel_opts.num_bins = 80; // feature dim
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts)
: opts_(opts) {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void FeatureExtractor::AcceptWaveform(float sampling_rate,
const float *waveform, int32_t n) {
std::lock_guard<std::mutex> lock(mutex_);
fbank_->AcceptWaveform(sampling_rate, waveform, n);
}
void FeatureExtractor::InputFinished() {
std::lock_guard<std::mutex> lock(mutex_);
fbank_->InputFinished();
}
int32_t FeatureExtractor::NumFramesReady() const {
std::lock_guard<std::mutex> lock(mutex_);
return fbank_->NumFramesReady();
}
bool FeatureExtractor::IsLastFrame(int32_t frame) const {
std::lock_guard<std::mutex> lock(mutex_);
return fbank_->IsLastFrame(frame);
}
std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
int32_t n) const {
if (frame_index + n > NumFramesReady()) {
fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
exit(-1);
}
std::lock_guard<std::mutex> lock(mutex_);
int32_t feature_dim = fbank_->Dim();
std::vector<float> features(feature_dim * n);
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f = fbank_->GetFrame(i + frame_index);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
return features;
}
void FeatureExtractor::Reset() {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
} // namespace sherpa_onnx
... ...
// sherpa/csrc/features.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_FEATURES_H_
#define SHERPA_ONNX_CSRC_FEATURES_H_
#include <memory>
#include <mutex> // NOLINT
#include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
namespace sherpa_onnx {
class FeatureExtractor {
public:
FeatureExtractor();
explicit FeatureExtractor(const knf::FbankOptions &fbank_opts);
/**
@param sampling_rate The sampling_rate of the input waveform. Should match
the one expected by the feature extractor.
@param waveform Pointer to a 1-D array of size n
@param n Number of entries in waveform
*/
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
// InputFinished() tells the class you won't be providing any
// more waveform. This will help flush out the last frame or two
// of features, in the case where snip-edges == false; it also
// affects the return value of IsLastFrame().
void InputFinished();
int32_t NumFramesReady() const;
// Note: IsLastFrame() will only ever return true if you have called
// InputFinished() (and this frame is the last frame).
bool IsLastFrame(int32_t frame) const;
/** Get n frames starting from the given frame index.
*
* @param frame_index The starting frame index
* @param n Number of frames to get.
* @return Return a 2-D tensor of shape (n, feature_dim).
* which is flattened into a 1-D vector (flattened in in row major)
*/
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
void Reset();
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_;
mutable std::mutex mutex_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_FEATURES_H_
... ...
// sherpa/csrc/online-lstm-transducer-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/onnx-utils.h"
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
exit(-1); \
} \
dst = atoi(value.get()); \
if (dst <= 0) { \
fprintf(stderr, "Invalud value %d for %s\n", dst, src_key); \
exit(-1); \
} \
} while (0)
namespace sherpa_onnx {
OnlineLstmTransducerModel::OnlineLstmTransducerModel(
const OnlineTransducerModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_{},
allocator_{} {
sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads);
InitEncoder(config.encoder_filename);
InitDecoder(config.decoder_filename);
InitJoiner(config.joiner_filename);
}
void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers");
SHERPA_ONNX_READ_META_DATA(T_, "T");
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "rnn_hidden_size");
SHERPA_ONNX_READ_META_DATA(d_model_, "d_model");
}
void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---decoder---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
}
void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
joiner_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_);
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
&joiner_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---joiner---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
}
}
Ort::Value OnlineLstmTransducerModel::StackStates(
const std::vector<Ort::Value> &states) const {
fprintf(stderr, "implement me: %s:%d!\n", __func__,
static_cast<int>(__LINE__));
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int64_t a;
std::array<int64_t, 3> x_shape{1, 1, 1};
Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0);
return x;
}
std::vector<Ort::Value> OnlineLstmTransducerModel::UnStackStates(
Ort::Value states) const {
fprintf(stderr, "implement me: %s:%d!\n", __func__,
static_cast<int>(__LINE__));
return {};
}
std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
// Please see
// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py#L185
// for details
constexpr int32_t kBatchSize = 1;
std::array<int64_t, 3> h_shape{num_encoder_layers_, kBatchSize, d_model_};
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
h_shape.size());
std::fill(h.GetTensorMutableData<float>(),
h.GetTensorMutableData<float>() +
num_encoder_layers_ * kBatchSize * d_model_,
0);
std::array<int64_t, 3> c_shape{num_encoder_layers_, kBatchSize,
rnn_hidden_size_};
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
c_shape.size());
std::fill(c.GetTensorMutableData<float>(),
c.GetTensorMutableData<float>() +
num_encoder_layers_ * kBatchSize * rnn_hidden_size_,
0);
std::vector<Ort::Value> states;
states.reserve(2);
states.push_back(std::move(h));
states.push_back(std::move(c));
return states;
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> &states) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<Ort::Value, 3> encoder_inputs = {
std::move(features), std::move(states[0]), std::move(states[1])};
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
std::vector<Ort::Value> next_states;
next_states.reserve(2);
next_states.push_back(std::move(encoder_out[1]));
next_states.push_back(std::move(encoder_out[2]));
return {std::move(encoder_out[0]), std::move(next_states)};
}
Ort::Value OnlineLstmTransducerModel::BuildDecoderInput(
const std::vector<int64_t> &hyp) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> shape{1, context_size_};
return Ort::Value::CreateTensor(
memory_info,
const_cast<int64_t *>(hyp.data() + hyp.size() - context_size_),
context_size_, shape.data(), shape.size());
}
Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), &decoder_input, 1,
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
return std::move(decoder_out[0]);
}
Ort::Value OnlineLstmTransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
joiner_output_names_ptr_.size());
return std::move(logit[0]);
}
} // namespace sherpa_onnx
... ...
// sherpa/csrc/online-lstm-transducer-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace sherpa_onnx {
class OnlineLstmTransducerModel : public OnlineTransducerModel {
public:
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
Ort::Value StackStates(const std::vector<Ort::Value> &states) const override;
std::vector<Ort::Value> UnStackStates(Ort::Value states) const override;
std::vector<Ort::Value> GetEncoderInitStates() override;
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> &states) override;
Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) override;
Ort::Value RunDecoder(Ort::Value decoder_input) override;
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
int32_t ContextSize() const override { return context_size_; }
int32_t ChunkSize() const override { return T_; }
int32_t ChunkShift() const override { return decode_chunk_len_; }
int32_t VocabSize() const override { return vocab_size_; }
private:
void InitEncoder(const std::string &encoder_filename);
void InitDecoder(const std::string &decoder_filename);
void InitJoiner(const std::string &joiner_filename);
private:
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::unique_ptr<Ort::Session> joiner_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<std::string> joiner_input_names_;
std::vector<const char *> joiner_input_names_ptr_;
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
OnlineTransducerModelConfig config_;
int32_t num_encoder_layers_ = 0;
int32_t T_ = 0;
int32_t decode_chunk_len_ = 0;
int32_t rnn_hidden_size_ = 0;
int32_t d_model_ = 0;
int32_t context_size_ = 0;
int32_t vocab_size_ = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
... ...
// sherpa/csrc/online-transducer-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include <sstream>
namespace sherpa_onnx {
std::string OnlineTransducerModelConfig::ToString() const {
std::ostringstream os;
os << "OnlineTransducerModelConfig(";
os << "encoder_filename=\"" << encoder_filename << "\", ";
os << "decoder_filename=\"" << decoder_filename << "\", ";
os << "joiner_filename=\"" << joiner_filename << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa/csrc/online-transducer-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
#include <string>
namespace sherpa_onnx {
struct OnlineTransducerModelConfig {
std::string encoder_filename;
std::string decoder_filename;
std::string joiner_filename;
int32_t num_threads;
bool debug = false;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
... ...
// sherpa/csrc/online-transducer-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include <memory>
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
enum class ModelType {
kLstm,
kUnkown,
};
static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
Ort::SessionOptions sess_opts;
auto sess = std::make_unique<Ort::Session>(
env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts);
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
if (config.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
fprintf(stderr, "No model_type in the metadata!\n");
return ModelType::kUnkown;
}
if (model_type.get() == std::string("lstm")) {
return ModelType::kLstm;
} else {
fprintf(stderr, "Unsupported model_type: %s\n", model_type.get());
return ModelType::kUnkown;
}
}
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
const OnlineTransducerModelConfig &config) {
auto model_type = GetModelType(config);
switch (model_type) {
case ModelType::kLstm:
return std::make_unique<OnlineLstmTransducerModel>(config);
case ModelType::kUnkown:
return nullptr;
}
// unreachable code
return nullptr;
}
} // namespace sherpa_onnx
... ...
// sherpa/csrc/online-transducer-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
namespace sherpa_onnx {
class OnlineTransducerModel {
public:
virtual ~OnlineTransducerModel() = default;
static std::unique_ptr<OnlineTransducerModel> Create(
const OnlineTransducerModelConfig &config);
/** Stack a list of individual states into a batch.
*
* It is the inverse operation of `UnStackStates`.
*
* @param states states[i] contains the state for the i-th utterance.
* @return Return a single value representing the batched state.
*/
virtual Ort::Value StackStates(
const std::vector<Ort::Value> &states) const = 0;
/** Unstack a batch state into a list of individual states.
*
* It is the inverse operation of `StackStates`.
*
* @param states A batched state.
* @return ans[i] contains the state for the i-th utterance.
*/
virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0;
/** Get the initial encoder states.
*
* @return Return the initial encoder state.
*/
virtual std::vector<Ort::Value> GetEncoderInitStates() = 0;
/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param states Encoder state of the previous chunk. It is changed in-place.
*
* @return Return a tuple containing:
* - encoder_out, a tensor of shape (N, T', encoder_out_dim)
* - next_states Encoder state for the next chunk.
*/
virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features,
std::vector<Ort::Value> &states) = 0; // NOLINT
virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0;
/** Run the decoder network.
*
* Caution: We assume there are no recurrent connections in the decoder and
* the decoder is stateless. See
* https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py
* for an example
*
* @param decoder_input It is usually of shape (N, context_size)
* @return Return a tensor of shape (N, decoder_dim).
*/
virtual Ort::Value RunDecoder(Ort::Value decoder_input) = 0;
/** Run the joint network.
*
* @param encoder_out Output of the encoder network. A tensor of shape
* (N, joiner_dim).
* @param decoder_out Output of the decoder network. A tensor of shape
* (N, joiner_dim).
* @return Return a tensor of shape (N, vocab_size). In icefall, the last
* last layer of the joint network is `nn.Linear`,
* not `nn.LogSoftmax`.
*/
virtual Ort::Value RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) = 0;
/** If we are using a stateless decoder and if it contains a
* Conv1D, this function returns the kernel size of the convolution layer.
*/
virtual int32_t ContextSize() const = 0;
/** We send this number of feature frames to the encoder at a time. */
virtual int32_t ChunkSize() const = 0;
/** Number of input frames to discard after each call to RunEncoder.
*
* For instance, if we have 30 frames, chunk_size=8, chunk_shift=6.
*
* In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8.
* Then we discard frame 0~5 since chunk_shift is 6.
* In the second call of RunEncoder, we use frames 6~13; and then we discard
* frames 6~11.
* In the third call of RunEncoder, we use frames 12~19; and then we discard
* frames 12~16.
*
* Note: ChunkSize() - ChunkShift() == right context size
*/
virtual int32_t ChunkShift() const = 0;
virtual int32_t VocabSize() const = 0;
virtual int32_t SubsamplingFactor() const { return 4; }
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
... ...
// sherpa/csrc/onnx-utils.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/onnx-utils.h"
#include <string>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
std::vector<const char *> *input_names_ptr) {
Ort::AllocatorWithDefaultOptions allocator;
size_t node_count = sess->GetInputCount();
input_names->resize(node_count);
input_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetInputNameAllocated(i, allocator);
(*input_names)[i] = tmp.get();
(*input_names_ptr)[i] = (*input_names)[i].c_str();
}
}
void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
std::vector<const char *> *output_names_ptr) {
Ort::AllocatorWithDefaultOptions allocator;
size_t node_count = sess->GetOutputCount();
output_names->resize(node_count);
output_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetOutputNameAllocated(i, allocator);
(*output_names)[i] = tmp.get();
(*output_names_ptr)[i] = (*output_names)[i].c_str();
}
}
void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
Ort::AllocatorWithDefaultOptions allocator;
std::vector<Ort::AllocatedStringPtr> v =
meta_data.GetCustomMetadataMapKeysAllocated(allocator);
for (const auto &key : v) {
auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator);
os << key.get() << "=" << p.get() << "\n";
}
}
} // namespace sherpa_onnx
... ...
// sherpa/csrc/onnx-utils.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_
#ifdef _MSC_VER
// For ToWide() below
#include <codecvt>
#include <locale>
#endif
#include <ostream>
#include <string>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
#ifdef _MSC_VER
// See
// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t
static std::wstring ToWide(const std::string &s) {
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.from_bytes(s);
}
#define SHERPA_MAYBE_WIDE(s) ToWide(s)
#else
#define SHERPA_MAYBE_WIDE(s) s
#endif
/**
* Get the input names of a model.
*
* @param sess An onnxruntime session.
* @param input_names. On return, it contains the input names of the model.
* @param input_names_ptr. On return, input_names_ptr[i] contains
* input_names[i].c_str()
*/
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
std::vector<const char *> *input_names_ptr);
/**
* Get the output names of a model.
*
* @param sess An onnxruntime session.
* @param output_names. On return, it contains the output names of the model.
* @param output_names_ptr. On return, output_names_ptr[i] contains
* output_names[i].c_str()
*/
void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
std::vector<const char *> *output_names_ptr);
void PrintModelMetadata(std::ostream &os,
const Ort::ModelMetadata &meta_data); // NOLINT
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
... ...
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sherpa-onnx/csrc/rnnt-model.h"
#include <array>
#include <utility>
#include <vector>
#ifdef _MSC_VER
// For ToWide() below
#include <codecvt>
#include <locale>
#endif
namespace sherpa_onnx {
#ifdef _MSC_VER
// See
// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t
static std::wstring ToWide(const std::string &s) {
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.from_bytes(s);
}
#define SHERPA_MAYBE_WIDE(s) ToWide(s)
#else
#define SHERPA_MAYBE_WIDE(s) s
#endif
/**
* Get the input names of a model.
*
* @param sess An onnxruntime session.
* @param input_names. On return, it contains the input names of the model.
* @param input_names_ptr. On return, input_names_ptr[i] contains
* input_names[i].c_str()
*/
static void GetInputNames(Ort::Session *sess,
std::vector<std::string> *input_names,
std::vector<const char *> *input_names_ptr) {
Ort::AllocatorWithDefaultOptions allocator;
size_t node_count = sess->GetInputCount();
input_names->resize(node_count);
input_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetInputNameAllocated(i, allocator);
(*input_names)[i] = tmp.get();
(*input_names_ptr)[i] = (*input_names)[i].c_str();
}
}
/**
* Get the output names of a model.
*
* @param sess An onnxruntime session.
* @param output_names. On return, it contains the output names of the model.
* @param output_names_ptr. On return, output_names_ptr[i] contains
* output_names[i].c_str()
*/
static void GetOutputNames(Ort::Session *sess,
std::vector<std::string> *output_names,
std::vector<const char *> *output_names_ptr) {
Ort::AllocatorWithDefaultOptions allocator;
size_t node_count = sess->GetOutputCount();
output_names->resize(node_count);
output_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetOutputNameAllocated(i, allocator);
(*output_names)[i] = tmp.get();
(*output_names_ptr)[i] = (*output_names)[i].c_str();
}
}
RnntModel::RnntModel(const std::string &encoder_filename,
const std::string &decoder_filename,
const std::string &joiner_filename,
const std::string &joiner_encoder_proj_filename,
const std::string &joiner_decoder_proj_filename,
int32_t num_threads)
: env_(ORT_LOGGING_LEVEL_WARNING) {
sess_opts_.SetIntraOpNumThreads(num_threads);
sess_opts_.SetInterOpNumThreads(num_threads);
InitEncoder(encoder_filename);
InitDecoder(decoder_filename);
InitJoiner(joiner_filename);
InitJoinerEncoderProj(joiner_encoder_proj_filename);
InitJoinerDecoderProj(joiner_decoder_proj_filename);
}
void RnntModel::InitEncoder(const std::string &filename) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
}
void RnntModel::InitDecoder(const std::string &filename) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
void RnntModel::InitJoiner(const std::string &filename) {
joiner_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_);
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
&joiner_output_names_ptr_);
}
void RnntModel::InitJoinerEncoderProj(const std::string &filename) {
joiner_encoder_proj_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(joiner_encoder_proj_sess_.get(),
&joiner_encoder_proj_input_names_,
&joiner_encoder_proj_input_names_ptr_);
GetOutputNames(joiner_encoder_proj_sess_.get(),
&joiner_encoder_proj_output_names_,
&joiner_encoder_proj_output_names_ptr_);
}
void RnntModel::InitJoinerDecoderProj(const std::string &filename) {
joiner_decoder_proj_sess_ = std::make_unique<Ort::Session>(
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
GetInputNames(joiner_decoder_proj_sess_.get(),
&joiner_decoder_proj_input_names_,
&joiner_decoder_proj_input_names_ptr_);
GetOutputNames(joiner_decoder_proj_sess_.get(),
&joiner_decoder_proj_output_names_,
&joiner_decoder_proj_output_names_ptr_);
}
Ort::Value RnntModel::RunEncoder(const float *features, int32_t T,
int32_t feature_dim) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{1, T, feature_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, const_cast<float *>(features),
T * feature_dim, x_shape.data(), x_shape.size());
std::array<int64_t, 1> x_lens_shape{1};
int64_t x_lens_tmp = T;
Ort::Value x_lens = Ort::Value::CreateTensor(
memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size());
std::array<Ort::Value, 2> encoder_inputs{std::move(x), std::move(x_lens)};
// Note: We discard encoder_out_lens since we only implement
// batch==1.
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
return std::move(encoder_out[0]);
}
Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T,
int32_t encoder_out_dim) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> in_shape{T, encoder_out_dim};
Ort::Value in = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(encoder_out), T * encoder_out_dim,
in_shape.data(), in_shape.size());
auto encoder_proj_out = joiner_encoder_proj_sess_->Run(
{}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1,
joiner_encoder_proj_output_names_ptr_.data(),
joiner_encoder_proj_output_names_ptr_.size());
return std::move(encoder_proj_out[0]);
}
Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input,
int32_t context_size) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
std::array<int64_t, 2> shape{batch_size, context_size};
Ort::Value in = Ort::Value::CreateTensor(
memory_info, const_cast<int64_t *>(decoder_input),
batch_size * context_size, shape.data(), shape.size());
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), &in, 1,
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
return std::move(decoder_out[0]);
}
Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out,
int32_t decoder_out_dim) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
std::array<int64_t, 2> shape{batch_size, decoder_out_dim};
Ort::Value in = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(decoder_out),
batch_size * decoder_out_dim, shape.data(), shape.size());
auto decoder_proj_out = joiner_decoder_proj_sess_->Run(
{}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1,
joiner_decoder_proj_output_names_ptr_.data(),
joiner_decoder_proj_output_names_ptr_.size());
return std::move(decoder_proj_out[0]);
}
Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out,
const float *projected_decoder_out,
int32_t joiner_dim) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
std::array<int64_t, 2> shape{batch_size, joiner_dim};
Ort::Value enc = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(projected_encoder_out),
batch_size * joiner_dim, shape.data(), shape.size());
Ort::Value dec = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(projected_decoder_out),
batch_size * joiner_dim, shape.data(), shape.size());
std::array<Ort::Value, 2> inputs{std::move(enc), std::move(dec)};
auto logit = joiner_sess_->Run(
{}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(),
joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size());
return std::move(logit[0]);
}
} // namespace sherpa_onnx
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_
#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_
#include <memory>
#include <string>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
class RnntModel {
public:
/**
* @param encoder_filename Path to the encoder model
* @param decoder_filename Path to the decoder model
* @param joiner_filename Path to the joiner model
* @param joiner_encoder_proj_filename Path to the joiner encoder_proj model
* @param joiner_decoder_proj_filename Path to the joiner decoder_proj model
* @param num_threads Number of threads to use to run the models
*/
RnntModel(const std::string &encoder_filename,
const std::string &decoder_filename,
const std::string &joiner_filename,
const std::string &joiner_encoder_proj_filename,
const std::string &joiner_decoder_proj_filename,
int32_t num_threads);
/** Run the encoder model.
*
* @TODO(fangjun): Support batch_size > 1
*
* @param features A tensor of shape (batch_size, T, feature_dim)
* @param T Number of feature frames
* @param feature_dim Dimension of the feature.
*
* @return Return a tensor of shape (batch_size, T', encoder_out_dim)
*/
Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim);
/** Run the joiner encoder_proj model.
*
* @param encoder_out A tensor of shape (T, encoder_out_dim)
* @param T Number of frames in encoder_out.
* @param encoder_out_dim Dimension of encoder_out.
*
* @return Return a tensor of shape (T, joiner_dim)
*
*/
Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T,
int32_t encoder_out_dim);
/** Run the decoder model.
*
* @TODO(fangjun): Support batch_size > 1
*
* @param decoder_input A tensor of shape (batch_size, context_size).
* @return Return a tensor of shape (batch_size, 1, decoder_out_dim)
*/
Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size);
/** Run joiner decoder_proj model.
*
* @TODO(fangjun): Support batch_size > 1
*
* @param decoder_out A tensor of shape (batch_size, decoder_out_dim)
* @param decoder_out_dim Output dimension of the decoder_out.
*
* @return Return a tensor of shape (batch_size, joiner_dim);
*/
Ort::Value RunJoinerDecoderProj(const float *decoder_out,
int32_t decoder_out_dim);
/** Run the joiner model.
*
* @TODO(fangjun): Support batch_size > 1
*
* @param projected_encoder_out A tensor of shape (batch_size, joiner_dim).
* @param projected_decoder_out A tensor of shape (batch_size, joiner_dim).
*
* @return Return a tensor of shape (batch_size, vocab_size)
*/
Ort::Value RunJoiner(const float *projected_encoder_out,
const float *projected_decoder_out, int32_t joiner_dim);
private:
void InitEncoder(const std::string &encoder_filename);
void InitDecoder(const std::string &decoder_filename);
void InitJoiner(const std::string &joiner_filename);
void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename);
void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename);
private:
Ort::Env env_;
Ort::SessionOptions sess_opts_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::unique_ptr<Ort::Session> joiner_sess_;
std::unique_ptr<Ort::Session> joiner_encoder_proj_sess_;
std::unique_ptr<Ort::Session> joiner_decoder_proj_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<std::string> joiner_input_names_;
std::vector<const char *> joiner_input_names_ptr_;
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
std::vector<std::string> joiner_encoder_proj_input_names_;
std::vector<const char *> joiner_encoder_proj_input_names_ptr_;
std::vector<std::string> joiner_encoder_proj_output_names_;
std::vector<const char *> joiner_encoder_proj_output_names_ptr_;
std::vector<std::string> joiner_decoder_proj_input_names_;
std::vector<const char *> joiner_decoder_proj_input_names_ptr_;
std::vector<std::string> joiner_decoder_proj_output_names_;
std::vector<const char *> joiner_decoder_proj_output_names_ptr_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// sherpa-onnx/csrc/sherpa-onnx.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include <chrono> // NOLINT
#include <iostream>
#include <string>
#include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "sherpa-onnx/csrc/decode.h"
#include "sherpa-onnx/csrc/rnnt-model.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
/** Compute fbank features of the input wave filename.
*
* @param wav_filename. Path to a mono wave file.
* @param expected_sampling_rate Expected sampling rate of the input wave file.
* @param num_frames On return, it contains the number of feature frames.
* @return Return the computed feature of shape (num_frames, feature_dim)
* stored in row-major.
*/
static std::vector<float> ComputeFeatures(const std::string &wav_filename,
float expected_sampling_rate,
int32_t *num_frames) {
std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate);
float duration = samples.size() / expected_sampling_rate;
std::cout << "wav filename: " << wav_filename << "\n";
std::cout << "wav duration (s): " << duration << "\n";
knf::FbankOptions opts;
opts.frame_opts.dither = 0;
opts.frame_opts.snip_edges = false;
opts.frame_opts.samp_freq = expected_sampling_rate;
int32_t feature_dim = 80;
opts.mel_opts.num_bins = feature_dim;
knf::OnlineFbank fbank(opts);
fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
fbank.InputFinished();
*num_frames = fbank.NumFramesReady();
std::vector<float> features(*num_frames * feature_dim);
float *p = features.data();
for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) {
const float *f = fbank.GetFrame(i);
std::copy(f, f + feature_dim, p);
}
return features;
}
int main(int32_t argc, char *argv[]) {
if (argc < 8 || argc > 9) {
if (argc < 6 || argc > 7) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx \
... ... @@ -80,12 +24,11 @@ Usage:
/path/to/encoder.onnx \
/path/to/decoder.onnx \
/path/to/joiner.onnx \
/path/to/joiner_encoder_proj.onnx \
/path/to/joiner_decoder_proj.onnx \
/path/to/foo.wav [num_threads]
You can download pre-trained models from the following repository:
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)usage";
std::cerr << usage << "\n";
... ... @@ -93,37 +36,102 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
}
std::string tokens = argv[1];
std::string encoder = argv[2];
std::string decoder = argv[3];
std::string joiner = argv[4];
std::string joiner_encoder_proj = argv[5];
std::string joiner_decoder_proj = argv[6];
std::string wav_filename = argv[7];
int32_t num_threads = 4;
if (argc == 9) {
num_threads = atoi(argv[8]);
sherpa_onnx::OnlineTransducerModelConfig config;
config.debug = true;
config.encoder_filename = argv[2];
config.decoder_filename = argv[3];
config.joiner_filename = argv[4];
std::string wav_filename = argv[5];
config.num_threads = 2;
if (argc == 7) {
config.num_threads = atoi(argv[6]);
}
std::cout << config.ToString().c_str() << "\n";
auto model = sherpa_onnx::OnlineTransducerModel::Create(config);
sherpa_onnx::SymbolTable sym(tokens);
int32_t num_frames;
auto features = ComputeFeatures(wav_filename, 16000, &num_frames);
int32_t feature_dim = features.size() / num_frames;
Ort::AllocatorWithDefaultOptions allocator;
int32_t chunk_size = model->ChunkSize();
int32_t chunk_shift = model->ChunkShift();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<Ort::Value> states = model->GetEncoderInitStates();
sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj,
joiner_decoder_proj, num_threads);
Ort::Value encoder_out =
model.RunEncoder(features.data(), num_frames, feature_dim);
std::vector<int64_t> hyp(model->ContextSize(), 0);
auto hyp = sherpa_onnx::GreedySearch(model, encoder_out);
int32_t expected_sampling_rate = 16000;
bool is_ok = false;
std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
if (!is_ok) {
std::cerr << "Failed to read " << wav_filename << "\n";
return -1;
}
const float duration = samples.size() / expected_sampling_rate;
std::cout << "wav filename: " << wav_filename << "\n";
std::cout << "wav duration (s): " << duration << "\n";
auto begin = std::chrono::steady_clock::now();
std::cout << "Started!\n";
sherpa_onnx::FeatureExtractor feat_extractor;
feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(),
samples.size());
std::vector<float> tail_paddings(
static_cast<int>(0.2 * expected_sampling_rate));
feat_extractor.AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
tail_paddings.size());
feat_extractor.InputFinished();
int32_t num_frames = feat_extractor.NumFramesReady();
int32_t feature_dim = feat_extractor.FeatureDim();
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
for (int32_t start = 0; start + chunk_size < num_frames;
start += chunk_shift) {
std::vector<float> features = feat_extractor.GetFrames(start, chunk_size);
Ort::Value x =
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
x_shape.data(), x_shape.size());
auto pair = model->RunEncoder(std::move(x), states);
states = std::move(pair.second);
sherpa_onnx::GreedySearch(model.get(), std::move(pair.first), &hyp);
}
std::string text;
for (auto i : hyp) {
text += sym[i];
for (size_t i = model->ContextSize(); i != hyp.size(); ++i) {
text += sym[hyp[i]];
}
std::cout << "Done!\n";
std::cout << "Recognition result for " << wav_filename << "\n"
<< text << "\n";
auto end = std::chrono::steady_clock::now();
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
std::cout << "num threads: " << config.num_threads << "\n";
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
return 0;
}
... ...
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// sherpa-onnx/csrc/show-onnx-info.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include <iostream>
#include <sstream>
... ...
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// sherpa-onnx/csrc/symbol-table.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/symbol-table.h"
... ...
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// sherpa-onnx/csrc/symbol-table.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
... ...
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// sherpa/csrc/wave-reader.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/wave-reader.h"
... ... @@ -31,19 +17,44 @@ namespace {
// Note: We assume little endian here
// TODO(fangjun): Support big endian
struct WaveHeader {
void Validate() const {
bool Validate() const {
// F F I R
assert(chunk_id == 0x46464952);
assert(chunk_size == 36 + subchunk2_size);
if (chunk_id != 0x46464952) {
return false;
}
// E V A W
assert(format == 0x45564157);
assert(subchunk1_id == 0x20746d66);
assert(subchunk1_size == 16); // 16 for PCM
assert(audio_format == 1); // 1 for PCM
assert(num_channels == 1); // we support only single channel for now
assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);
assert(block_align == num_channels * bits_per_sample / 8);
assert(bits_per_sample == 16); // we support only 16 bits per sample
if (format != 0x45564157) {
return false;
}
if (subchunk1_id != 0x20746d66) {
return false;
}
if (subchunk1_size != 16) { // 16 for PCM
return false;
}
if (audio_format != 1) { // 1 for PCM
return false;
}
if (num_channels != 1) { // we support only single channel for now
return false;
}
if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
return false;
}
if (block_align != (num_channels * bits_per_sample / 8)) {
return false;
}
if (bits_per_sample != 16) { // we support only 16 bits per sample
return false;
}
return true;
}
// See
... ... @@ -52,7 +63,7 @@ struct WaveHeader {
// https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
void SeekToDataChunk(std::istream &is) {
// a t a d
while (subchunk2_id != 0x61746164) {
while (is && subchunk2_id != 0x61746164) {
// const char *p = reinterpret_cast<const char *>(&subchunk2_id);
// printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
// p[1], p[2], p[3], subchunk2_size);
... ... @@ -80,44 +91,61 @@ static_assert(sizeof(WaveHeader) == 44, "");
// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) {
std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
bool *is_ok) {
WaveHeader header;
is.read(reinterpret_cast<char *>(&header), sizeof(header));
assert(static_cast<bool>(is));
header.Validate();
if (!is) {
*is_ok = false;
return {};
}
if (!header.Validate()) {
*is_ok = false;
return {};
}
header.SeekToDataChunk(is);
if (!is) {
*is_ok = false;
return {};
}
*sample_rate = header.sample_rate;
if (expected_sample_rate != header.sample_rate) {
*is_ok = false;
return {};
}
// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains two bytes, so it is divided by 2 here
std::vector<int16_t> samples(header.subchunk2_size / 2);
is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
assert(static_cast<bool>(is));
if (!is) {
*is_ok = false;
return {};
}
std::vector<float> ans(samples.size());
for (int32_t i = 0; i != ans.size(); ++i) {
ans[i] = samples[i] / 32768.;
}
*is_ok = true;
return ans;
}
} // namespace
std::vector<float> ReadWave(const std::string &filename,
float expected_sample_rate) {
float expected_sample_rate, bool *is_ok) {
std::ifstream is(filename, std::ifstream::binary);
float sample_rate;
auto samples = ReadWaveImpl(is, &sample_rate);
if (expected_sample_rate != sample_rate) {
std::cerr << "Expected sample rate: " << expected_sample_rate
<< ". Given: " << sample_rate << ".\n";
exit(-1);
}
return ReadWave(is, expected_sample_rate, is_ok);
}
std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
bool *is_ok) {
auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok);
return samples;
}
... ...
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// sherpa/csrc/wave-reader.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_
#define SHERPA_ONNX_CSRC_WAVE_READER_H_
... ... @@ -30,11 +16,15 @@ namespace sherpa_onnx {
@param filename Path to a wave file. It MUST be single channel, PCM encoded.
@param expected_sample_rate Expected sample rate of the wave file. If the
sample rate don't match, it throws an exception.
@param is_ok On return it is true if the reading succeeded; false otherwise.
@return Return wave samples normalized to the range [-1, 1).
*/
std::vector<float> ReadWave(const std::string &filename,
float expected_sample_rate);
float expected_sample_rate, bool *is_ok);
std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
bool *is_ok);
} // namespace sherpa_onnx
... ...