Fangjun Kuang
Committed by GitHub

Refactor the code (#15)

* code refactoring

* Remove reference files

* Update README and CI

* small fixes

* fix style issues

* add style check for CI

* fix style issues

* remove kaldi-native-io
  1 +---
  2 +BasedOnStyle: Google
  3 +---
  4 +Language: Cpp
  5 +Cpp11BracedListStyle: true
  6 +Standard: Cpp11
  7 +DerivePointerAlignment: false
  8 +PointerAlignment: Right
  9 +---
  1 +# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  2 +#
  3 +# See ../../LICENSE for clarification regarding multiple authors
  4 +#
  5 +# Licensed under the Apache License, Version 2.0 (the "License");
  6 +# you may not use this file except in compliance with the License.
  7 +# You may obtain a copy of the License at
  8 +#
  9 +# http://www.apache.org/licenses/LICENSE-2.0
  10 +#
  11 +# Unless required by applicable law or agreed to in writing, software
  12 +# distributed under the License is distributed on an "AS IS" BASIS,
  13 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14 +# See the License for the specific language governing permissions and
  15 +# limitations under the License.
  16 +
  17 +name: style_check
  18 +
  19 +on:
  20 + push:
  21 + branches:
  22 + - master
  23 + paths:
  24 + - '.github/workflows/style_check.yaml'
  25 + - 'sherpa-onnx/**'
  26 + pull_request:
  27 + branches:
  28 + - master
  29 + paths:
  30 + - '.github/workflows/style_check.yaml'
  31 + - 'sherpa-onnx/**'
  32 +
  33 +concurrency:
  34 + group: style_check-${{ github.ref }}
  35 + cancel-in-progress: true
  36 +
  37 +jobs:
  38 + style_check:
  39 + runs-on: ubuntu-latest
  40 + strategy:
  41 + matrix:
  42 + python-version: [3.8]
  43 + fail-fast: false
  44 +
  45 + steps:
  46 + - uses: actions/checkout@v2
  47 + with:
  48 + fetch-depth: 0
  49 +
  50 + - name: Setup Python ${{ matrix.python-version }}
  51 + uses: actions/setup-python@v1
  52 + with:
  53 + python-version: ${{ matrix.python-version }}
  54 +
  55 + - name: Check style with cpplint
  56 + shell: bash
  57 + working-directory: ${{github.workspace}}
  58 + run: ./scripts/check_style_cpplint.sh
@@ -59,31 +59,28 @@ jobs: @@ -59,31 +59,28 @@ jobs:
59 - name: Run tests for ubuntu/macos (English) 59 - name: Run tests for ubuntu/macos (English)
60 run: | 60 run: |
61 time ./build/bin/sherpa-onnx \ 61 time ./build/bin/sherpa-onnx \
  62 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
62 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ 63 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
63 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ 64 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
64 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ 65 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
65 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ 66 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
66 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ 67 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
67 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
68 - greedy \  
69 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav 68 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
70 69
71 time ./build/bin/sherpa-onnx \ 70 time ./build/bin/sherpa-onnx \
  71 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
72 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ 72 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
73 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ 73 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
74 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ 74 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
75 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ 75 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
76 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ 76 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
77 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
78 - greedy \  
79 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav 77 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
80 78
81 time ./build/bin/sherpa-onnx \ 79 time ./build/bin/sherpa-onnx \
  80 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
82 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ 81 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
83 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ 82 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
84 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ 83 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
85 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ 84 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
86 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ 85 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
87 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
88 - greedy \  
89 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav 86 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
@@ -38,7 +38,6 @@ set(CMAKE_CXX_EXTENSIONS OFF) @@ -38,7 +38,6 @@ set(CMAKE_CXX_EXTENSIONS OFF)
38 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) 38 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
39 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) 39 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
40 40
41 -include(kaldi_native_io)  
42 include(kaldi-native-fbank) 41 include(kaldi-native-fbank)
43 include(onnxruntime) 42 include(onnxruntime)
44 43
@@ -14,6 +14,9 @@ the following links: @@ -14,6 +14,9 @@ the following links:
14 **NOTE**: We provide only non-streaming models at present. 14 **NOTE**: We provide only non-streaming models at present.
15 15
16 16
  17 +**HINT**: The script for exporting the English model can be found at
  18 +<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py>
  19 +
17 # Usage 20 # Usage
18 21
19 ```bash 22 ```bash
@@ -34,13 +37,14 @@ cd .. @@ -34,13 +37,14 @@ cd ..
34 git lfs install 37 git lfs install
35 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 38 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
36 39
  40 +./build/bin/sherpa-onnx --help
  41 +
37 ./build/bin/sherpa-onnx \ 42 ./build/bin/sherpa-onnx \
  43 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
38 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ 44 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
39 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ 45 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
40 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ 46 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
41 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ 47 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
42 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ 48 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
43 - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \  
44 - greedy \  
45 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav 49 ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
46 ``` 50 ```
1 -function(download_kaldi_native_io)  
2 - if(CMAKE_VERSION VERSION_LESS 3.11)  
3 - # FetchContent is available since 3.11,  
4 - # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules  
5 - # so that it can be used in lower CMake versions.  
6 - message(STATUS "Use FetchContent provided by sherpa-onnx")  
7 - list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)  
8 - endif()  
9 -  
10 - include(FetchContent)  
11 -  
12 - set(kaldi_native_io_URL "https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz")  
13 - set(kaldi_native_io_HASH "SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6")  
14 -  
15 - set(KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL "" FORCE)  
16 - set(KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL "" FORCE)  
17 -  
18 - FetchContent_Declare(kaldi_native_io  
19 - URL ${kaldi_native_io_URL}  
20 - URL_HASH ${kaldi_native_io_HASH}  
21 - )  
22 -  
23 - FetchContent_GetProperties(kaldi_native_io)  
24 - if(NOT kaldi_native_io_POPULATED)  
25 - message(STATUS "Downloading kaldi_native_io ${kaldi_native_io_URL}")  
26 - FetchContent_Populate(kaldi_native_io)  
27 - endif()  
28 - message(STATUS "kaldi_native_io is downloaded to ${kaldi_native_io_SOURCE_DIR}")  
29 - message(STATUS "kaldi_native_io's binary dir is ${kaldi_native_io_BINARY_DIR}")  
30 -  
31 - add_subdirectory(${kaldi_native_io_SOURCE_DIR} ${kaldi_native_io_BINARY_DIR} EXCLUDE_FROM_ALL)  
32 -  
33 - target_include_directories(kaldi_native_io_core  
34 - PUBLIC  
35 - ${kaldi_native_io_SOURCE_DIR}/  
36 - )  
37 -endfunction()  
38 -  
39 -download_kaldi_native_io()  
@@ -10,7 +10,7 @@ function(download_onnxruntime) @@ -10,7 +10,7 @@ function(download_onnxruntime)
10 include(FetchContent) 10 include(FetchContent)
11 11
12 if(UNIX AND NOT APPLE) 12 if(UNIX AND NOT APPLE)
13 - set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz") 13 + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz")
14 14
15 # If you don't have access to the internet, you can first download onnxruntime to some directory, and the use 15 # If you don't have access to the internet, you can first download onnxruntime to some directory, and the use
16 # set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz") 16 # set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz")
  1 +#!/bin/bash
  2 +#
  3 +# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
  4 +#
  5 +# Licensed under the Apache License, Version 2.0 (the "License");
  6 +# you may not use this file except in compliance with the License.
  7 +# You may obtain a copy of the License at
  8 +#
  9 +# http://www.apache.org/licenses/LICENSE-2.0
  10 +#
  11 +# Unless required by applicable law or agreed to in writing, software
  12 +# distributed under the License is distributed on an "AS IS" BASIS,
  13 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14 +# See the License for the specific language governing permissions and
  15 +# limitations under the License.
  16 +#
  17 +# Usage:
  18 +#
  19 +# (1) To check files of the last commit
  20 +# ./scripts/check_style_cpplint.sh
  21 +#
  22 +# (2) To check changed files not committed yet
  23 +# ./scripts/check_style_cpplint.sh 1
  24 +#
  25 +# (3) To check all files in the project
  26 +# ./scripts/check_style_cpplint.sh 2
  27 +
  28 +
  29 +cpplint_version="1.5.4"
  30 +cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd)
  31 +sherpa_onnx_dir=$(cd $cur_dir/.. && pwd)
  32 +
  33 +build_dir=$sherpa_onnx_dir/build
  34 +mkdir -p $build_dir
  35 +
  36 +cpplint_src=$build_dir/cpplint-${cpplint_version}/cpplint.py
  37 +
  38 +if [ ! -d "$build_dir/cpplint-${cpplint_version}" ]; then
  39 + pushd $build_dir
  40 + if command -v wget &> /dev/null; then
  41 + wget https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz
  42 + elif command -v curl &> /dev/null; then
  43 + curl -O -SL https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz
  44 + else
  45 + echo "Please install wget or curl to download cpplint"
  46 + exit 1
  47 + fi
  48 + tar xf ${cpplint_version}.tar.gz
  49 + rm ${cpplint_version}.tar.gz
  50 +
  51 + # cpplint will report the following error for: __host__ __device__ (
  52 + #
  53 + # Extra space before ( in function call [whitespace/parens] [4]
  54 + #
  55 + # the following patch disables the above error
  56 + sed -i "3490i\ not Search(r'__host__ __device__\\\s+\\\(', fncall) and" $cpplint_src
  57 + popd
  58 +fi
  59 +
  60 +source $sherpa_onnx_dir/scripts/utils.sh
  61 +
  62 +# return true if the given file is a c++ source file
  63 +# return false otherwise
  64 +function is_source_code_file() {
  65 + case "$1" in
  66 + *.cc|*.h|*.cu)
  67 + echo true;;
  68 + *)
  69 + echo false;;
  70 + esac
  71 +}
  72 +
  73 +function check_style() {
  74 + python3 $cpplint_src $1 || abort $1
  75 +}
  76 +
  77 +function check_last_commit() {
  78 + files=$(git diff HEAD^1 --name-only --diff-filter=ACDMRUXB)
  79 + echo $files
  80 +}
  81 +
  82 +function check_current_dir() {
  83 + files=$(git status -s -uno --porcelain | awk '{
  84 + if (NF == 4) {
  85 + # a file has been renamed
  86 + print $NF
  87 + } else {
  88 + print $2
  89 + }}')
  90 +
  91 + echo $files
  92 +}
  93 +
  94 +function do_check() {
  95 + case "$1" in
  96 + 1)
  97 + echo "Check changed files"
  98 + files=$(check_current_dir)
  99 + ;;
  100 + 2)
  101 + echo "Check all files"
  102 + files=$(find $sherpa_onnx_dir/sherpa-onnx -name "*.h" -o -name "*.cc")
  103 + ;;
  104 + *)
  105 + echo "Check last commit"
  106 + files=$(check_last_commit)
  107 + ;;
  108 + esac
  109 +
  110 + for f in $files; do
  111 + need_check=$(is_source_code_file $f)
  112 + if $need_check; then
  113 + [[ -f $f ]] && check_style $f
  114 + fi
  115 + done
  116 +}
  117 +
  118 +function main() {
  119 + do_check $1
  120 +
  121 + ok "Great! Style check passed!"
  122 +}
  123 +
  124 +cd $sherpa_onnx_dir
  125 +
  126 +main $1
  1 +#!/bin/bash
  2 +
  3 +default='\033[0m'
  4 +bold='\033[1m'
  5 +red='\033[31m'
  6 +green='\033[32m'
  7 +
  8 +function ok() {
  9 + printf "${bold}${green}[OK]${default} $1\n"
  10 +}
  11 +
  12 +function error() {
  13 + printf "${bold}${red}[FAILED]${default} $1\n"
  14 +}
  15 +
  16 +function abort() {
  17 + printf "${bold}${red}[FAILED]${default} $1\n"
  18 + exit 1
  19 +}
1 include_directories(${CMAKE_SOURCE_DIR}) 1 include_directories(${CMAKE_SOURCE_DIR})
2 -add_executable(sherpa-onnx main.cpp) 2 +
  3 +add_executable(sherpa-onnx
  4 + decode.cc
  5 + rnnt-model.cc
  6 + sherpa-onnx.cc
  7 + symbol-table.cc
  8 + wave-reader.cc
  9 +)
3 10
4 target_link_libraries(sherpa-onnx 11 target_link_libraries(sherpa-onnx
5 onnxruntime 12 onnxruntime
6 kaldi-native-fbank-core 13 kaldi-native-fbank-core
7 - kaldi_native_io_core  
8 ) 14 )
  15 +
  16 +add_executable(sherpa-show-onnx-info show-onnx-info.cc)
  17 +target_link_libraries(sherpa-show-onnx-info onnxruntime)
  1 +/**
  2 + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3 + *
  4 + * See LICENSE for clarification regarding multiple authors
  5 + *
  6 + * Licensed under the Apache License, Version 2.0 (the "License");
  7 + * you may not use this file except in compliance with the License.
  8 + * You may obtain a copy of the License at
  9 + *
  10 + * http://www.apache.org/licenses/LICENSE-2.0
  11 + *
  12 + * Unless required by applicable law or agreed to in writing, software
  13 + * distributed under the License is distributed on an "AS IS" BASIS,
  14 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 + * See the License for the specific language governing permissions and
  16 + * limitations under the License.
  17 + */
  18 +
  19 +#include "sherpa-onnx/csrc/decode.h"
  20 +
  21 +#include <assert.h>
  22 +
  23 +#include <algorithm>
  24 +#include <vector>
  25 +
  26 +namespace sherpa_onnx {
  27 +
  28 +std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
  29 + const Ort::Value &encoder_out) {
  30 + std::vector<int64_t> encoder_out_shape =
  31 + encoder_out.GetTensorTypeAndShapeInfo().GetShape();
  32 + assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented");
  33 + Ort::Value projected_encoder_out =
  34 + model.RunJoinerEncoderProj(encoder_out.GetTensorData<float>(),
  35 + encoder_out_shape[1], encoder_out_shape[2]);
  36 +
  37 + const float *p_projected_encoder_out =
  38 + projected_encoder_out.GetTensorData<float>();
  39 +
  40 + int32_t context_size = 2; // hard-code it to 2
  41 + int32_t blank_id = 0; // hard-code it to 0
  42 + std::vector<int32_t> hyp(context_size, blank_id);
  43 + std::array<int64_t, 2> decoder_input{blank_id, blank_id};
  44 +
  45 + Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size);
  46 +
  47 + std::vector<int64_t> decoder_out_shape =
  48 + decoder_out.GetTensorTypeAndShapeInfo().GetShape();
  49 +
  50 + Ort::Value projected_decoder_out = model.RunJoinerDecoderProj(
  51 + decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
  52 +
  53 + int32_t joiner_dim =
  54 + projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1];
  55 +
  56 + int32_t T = encoder_out_shape[1];
  57 + for (int32_t t = 0; t != T; ++t) {
  58 + Ort::Value logit = model.RunJoiner(
  59 + p_projected_encoder_out + t * joiner_dim,
  60 + projected_decoder_out.GetTensorData<float>(), joiner_dim);
  61 +
  62 + int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1];
  63 +
  64 + const float *p_logit = logit.GetTensorData<float>();
  65 +
  66 + auto y = static_cast<int32_t>(std::distance(
  67 + static_cast<const float *>(p_logit),
  68 + std::max_element(static_cast<const float *>(p_logit),
  69 + static_cast<const float *>(p_logit) + vocab_size)));
  70 +
  71 + if (y != blank_id) {
  72 + decoder_input[0] = hyp.back();
  73 + decoder_input[1] = y;
  74 + hyp.push_back(y);
  75 + decoder_out = model.RunDecoder(decoder_input.data(), context_size);
  76 + projected_decoder_out = model.RunJoinerDecoderProj(
  77 + decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
  78 + }
  79 + }
  80 +
  81 + return {hyp.begin() + context_size, hyp.end()};
  82 +}
  83 +
  84 +} // namespace sherpa_onnx
  1 +/**
  2 + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3 + *
  4 + * See LICENSE for clarification regarding multiple authors
  5 + *
  6 + * Licensed under the Apache License, Version 2.0 (the "License");
  7 + * you may not use this file except in compliance with the License.
  8 + * You may obtain a copy of the License at
  9 + *
  10 + * http://www.apache.org/licenses/LICENSE-2.0
  11 + *
  12 + * Unless required by applicable law or agreed to in writing, software
  13 + * distributed under the License is distributed on an "AS IS" BASIS,
  14 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 + * See the License for the specific language governing permissions and
  16 + * limitations under the License.
  17 + */
  18 +
  19 +#ifndef SHERPA_ONNX_CSRC_DECODE_H_
  20 +#define SHERPA_ONNX_CSRC_DECODE_H_
  21 +
  22 +#include <vector>
  23 +
  24 +#include "sherpa-onnx/csrc/rnnt-model.h"
  25 +
  26 +namespace sherpa_onnx {
  27 +
  28 +/** Greedy search for non-streaming ASR.
  29 + *
  30 + * @TODO(fangjun) Support batch size > 1
  31 + *
  32 + * @param model The RnntModel
  33 + * @param encoder_out Its shape is (1, num_frames, encoder_out_dim).
  34 + */
  35 +std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
  36 + const Ort::Value &encoder_out);
  37 +
  38 +} // namespace sherpa_onnx
  39 +
  40 +#endif // SHERPA_ONNX_CSRC_DECODE_H_
1 -#include <iostream>  
2 -  
3 -#include "kaldi_native_io/csrc/kaldi-io.h"  
4 -#include "kaldi_native_io/csrc/wave-reader.h"  
5 -#include "kaldi-native-fbank/csrc/online-feature.h"  
6 -  
7 -  
8 -kaldiio::Matrix<float> readWav(std::string filename, bool log = false){  
9 - if (log)  
10 - std::cout << "reading " << filename << std::endl;  
11 -  
12 - bool binary = true;  
13 - kaldiio::Input ki(filename, &binary);  
14 - kaldiio::WaveHolder wh;  
15 -  
16 - if (!wh.Read(ki.Stream())) {  
17 - std::cerr << "Failed to read " << filename;  
18 - exit(EXIT_FAILURE);  
19 - }  
20 -  
21 - auto &wave_data = wh.Value();  
22 - auto &d = wave_data.Data();  
23 -  
24 - if (log)  
25 - std::cout << "wav shape: " << "(" << d.NumRows() << "," << d.NumCols() << ")" << std::endl;  
26 -  
27 - return d;  
28 -}  
29 -  
30 -  
31 -std::vector<float> ComputeFeatures(knf::OnlineFbank &fbank, knf::FbankOptions opts, kaldiio::Matrix<float> samples, bool log = false){  
32 - int numSamples = samples.NumCols();  
33 -  
34 - for (int i = 0; i < numSamples; i++)  
35 - {  
36 - float currentSample = samples.Row(0).Data()[i] / 32768;  
37 - fbank.AcceptWaveform(opts.frame_opts.samp_freq, &currentSample, 1);  
38 - }  
39 -  
40 - std::vector<float> features;  
41 - int32_t num_frames = fbank.NumFramesReady();  
42 - for (int32_t i = 0; i != num_frames; ++i) {  
43 - const float *frame = fbank.GetFrame(i);  
44 - for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) {  
45 - features.push_back(frame[k]);  
46 - }  
47 - }  
48 - if (log){  
49 - std::cout << "done feature extraction" << std::endl;  
50 - std::cout << "extracted fbank shape " << "(" << num_frames << "," << opts.mel_opts.num_bins << ")" << std::endl;  
51 -  
52 - for (int i=0; i< 20; i++)  
53 - std::cout << features.at(i) << std::endl;  
54 - }  
55 -  
56 - return features;  
57 -}  
1 -#include <algorithm>  
2 -#include <fstream>  
3 -#include <iostream>  
4 -#include <math.h>  
5 -#include <time.h>  
6 -#include <vector>  
7 -  
8 -#include "sherpa-onnx/csrc/fbank_features.h"  
9 -#include "sherpa-onnx/csrc/rnnt_beam_search.h"  
10 -  
11 -#include "kaldi-native-fbank/csrc/online-feature.h"  
12 -  
13 -int main(int argc, char *argv[]) {  
14 - char *encoder_path = argv[1];  
15 - char *decoder_path = argv[2];  
16 - char *joiner_path = argv[3];  
17 - char *joiner_encoder_proj_path = argv[4];  
18 - char *joiner_decoder_proj_path = argv[5];  
19 - char *token_path = argv[6];  
20 - std::string search_method = argv[7];  
21 - char *filename = argv[8];  
22 -  
23 - // General parameters  
24 - int numberOfThreads = 16;  
25 -  
26 - // Initialize fbanks  
27 - knf::FbankOptions opts;  
28 - opts.frame_opts.dither = 0;  
29 - opts.frame_opts.samp_freq = 16000;  
30 - opts.frame_opts.frame_shift_ms = 10.0f;  
31 - opts.frame_opts.frame_length_ms = 25.0f;  
32 - opts.mel_opts.num_bins = 80;  
33 - opts.frame_opts.window_type = "povey";  
34 - opts.frame_opts.snip_edges = false;  
35 - knf::OnlineFbank fbank(opts);  
36 -  
37 - // set session opts  
38 - // https://onnxruntime.ai/docs/performance/tune-performance.html  
39 - session_options.SetIntraOpNumThreads(numberOfThreads);  
40 - session_options.SetInterOpNumThreads(numberOfThreads);  
41 - session_options.SetGraphOptimizationLevel(  
42 - GraphOptimizationLevel::ORT_ENABLE_EXTENDED);  
43 - session_options.SetLogSeverityLevel(4);  
44 - session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);  
45 -  
46 - api.CreateTensorRTProviderOptions(&tensorrt_options);  
47 - std::unique_ptr<OrtTensorRTProviderOptionsV2,  
48 - decltype(api.ReleaseTensorRTProviderOptions)>  
49 - rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);  
50 - api.SessionOptionsAppendExecutionProvider_TensorRT_V2(  
51 - static_cast<OrtSessionOptions *>(session_options), rel_trt_options.get());  
52 -  
53 - // Define model  
54 - auto model =  
55 - get_model(encoder_path, decoder_path, joiner_path,  
56 - joiner_encoder_proj_path, joiner_decoder_proj_path, token_path);  
57 -  
58 - std::vector<std::string> filename_list{filename};  
59 -  
60 - for (auto filename : filename_list) {  
61 - std::cout << filename << std::endl;  
62 - auto samples = readWav(filename, true);  
63 - int numSamples = samples.NumCols();  
64 -  
65 - auto features = ComputeFeatures(fbank, opts, samples);  
66 -  
67 - auto tic = std::chrono::high_resolution_clock::now();  
68 -  
69 - // # === Encoder Out === #  
70 - int num_frames = features.size() / opts.mel_opts.num_bins;  
71 - auto encoder_out =  
72 - model.encoder_forward(features, std::vector<int64_t>{num_frames},  
73 - std::vector<int64_t>{1, num_frames, 80},  
74 - std::vector<int64_t>{1}, memory_info);  
75 -  
76 - // # === Search === #  
77 - std::vector<std::vector<int32_t>> hyps;  
78 - if (search_method == "greedy")  
79 - hyps = GreedySearch(&model, &encoder_out);  
80 - else {  
81 - std::cout << "wrong search method!" << std::endl;  
82 - exit(0);  
83 - }  
84 - auto results = hyps2result(model.tokens_map, hyps);  
85 -  
86 - // # === Print Elapsed Time === #  
87 - auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(  
88 - std::chrono::high_resolution_clock::now() - tic);  
89 - std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds"  
90 - << std::endl;  
91 - std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000)  
92 - << std::endl;  
93 -  
94 - print_hyps(hyps);  
95 - std::cout << results[0] << std::endl;  
96 - }  
97 -  
98 - return 0;  
99 -}  
1 -#include <map>  
2 -#include <vector>  
3 -#include <iostream>  
4 -#include <algorithm>  
5 -#include <sys/stat.h>  
6 -  
7 -#include "utils_onnx.h"  
8 -  
9 -  
10 -struct Model  
11 -{  
12 - public:  
13 - const char* encoder_path;  
14 - const char* decoder_path;  
15 - const char* joiner_path;  
16 - const char* joiner_encoder_proj_path;  
17 - const char* joiner_decoder_proj_path;  
18 - const char* tokens_path;  
19 -  
20 - Ort::Session encoder = load_model(encoder_path);  
21 - Ort::Session decoder = load_model(decoder_path);  
22 - Ort::Session joiner = load_model(joiner_path);  
23 - Ort::Session joiner_encoder_proj = load_model(joiner_encoder_proj_path);  
24 - Ort::Session joiner_decoder_proj = load_model(joiner_decoder_proj_path);  
25 - std::map<int, std::string> tokens_map = get_token_map(tokens_path);  
26 -  
27 - int32_t blank_id;  
28 - int32_t unk_id;  
29 - int32_t context_size;  
30 -  
31 - std::vector<Ort::Value> encoder_forward(std::vector<float> in_vector,  
32 - std::vector<int64_t> in_vector_length,  
33 - std::vector<int64_t> feature_dims,  
34 - std::vector<int64_t> feature_length_dims,  
35 - Ort::MemoryInfo &memory_info){  
36 - std::vector<Ort::Value> encoder_inputTensors;  
37 - encoder_inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), feature_dims.data(), feature_dims.size()));  
38 - encoder_inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector_length.data(), in_vector_length.size(), feature_length_dims.data(), feature_length_dims.size()));  
39 -  
40 - std::vector<const char*> encoder_inputNames = {encoder.GetInputName(0, allocator), encoder.GetInputName(1, allocator)};  
41 - std::vector<const char*> encoder_outputNames = {encoder.GetOutputName(0, allocator)};  
42 -  
43 - auto out = encoder.Run(Ort::RunOptions{nullptr},  
44 - encoder_inputNames.data(),  
45 - encoder_inputTensors.data(),  
46 - encoder_inputTensors.size(),  
47 - encoder_outputNames.data(),  
48 - encoder_outputNames.size());  
49 - return out;  
50 - }  
51 -  
52 - std::vector<Ort::Value> decoder_forward(std::vector<int64_t> in_vector,  
53 - std::vector<int64_t> dims,  
54 - Ort::MemoryInfo &memory_info){  
55 - std::vector<Ort::Value> inputTensors;  
56 - inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));  
57 -  
58 - std::vector<const char*> inputNames {decoder.GetInputName(0, allocator)};  
59 - std::vector<const char*> outputNames {decoder.GetOutputName(0, allocator)};  
60 -  
61 - auto out = decoder.Run(Ort::RunOptions{nullptr},  
62 - inputNames.data(),  
63 - inputTensors.data(),  
64 - inputTensors.size(),  
65 - outputNames.data(),  
66 - outputNames.size());  
67 -  
68 - return out;  
69 - }  
70 -  
71 - std::vector<Ort::Value> joiner_forward(std::vector<float> projected_encoder_out,  
72 - std::vector<float> decoder_out,  
73 - std::vector<int64_t> projected_encoder_out_dims,  
74 - std::vector<int64_t> decoder_out_dims,  
75 - Ort::MemoryInfo &memory_info){  
76 - std::vector<Ort::Value> inputTensors;  
77 - inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, projected_encoder_out.data(), projected_encoder_out.size(), projected_encoder_out_dims.data(), projected_encoder_out_dims.size()));  
78 - inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, decoder_out.data(), decoder_out.size(), decoder_out_dims.data(), decoder_out_dims.size()));  
79 - std::vector<const char*> inputNames = {joiner.GetInputName(0, allocator), joiner.GetInputName(1, allocator)};  
80 - std::vector<const char*> outputNames = {joiner.GetOutputName(0, allocator)};  
81 -  
82 - auto out = joiner.Run(Ort::RunOptions{nullptr},  
83 - inputNames.data(),  
84 - inputTensors.data(),  
85 - inputTensors.size(),  
86 - outputNames.data(),  
87 - outputNames.size());  
88 -  
89 - return out;  
90 - }  
91 -  
92 - std::vector<Ort::Value> joiner_encoder_proj_forward(std::vector<float> in_vector,  
93 - std::vector<int64_t> dims,  
94 - Ort::MemoryInfo &memory_info){  
95 - std::vector<Ort::Value> inputTensors;  
96 - inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));  
97 -  
98 - std::vector<const char*> inputNames {joiner_encoder_proj.GetInputName(0, allocator)};  
99 - std::vector<const char*> outputNames {joiner_encoder_proj.GetOutputName(0, allocator)};  
100 -  
101 - auto out = joiner_encoder_proj.Run(Ort::RunOptions{nullptr},  
102 - inputNames.data(),  
103 - inputTensors.data(),  
104 - inputTensors.size(),  
105 - outputNames.data(),  
106 - outputNames.size());  
107 -  
108 - return out;  
109 - }  
110 -  
111 - std::vector<Ort::Value> joiner_decoder_proj_forward(std::vector<float> in_vector,  
112 - std::vector<int64_t> dims,  
113 - Ort::MemoryInfo &memory_info){  
114 - std::vector<Ort::Value> inputTensors;  
115 - inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));  
116 -  
117 - std::vector<const char*> inputNames {joiner_decoder_proj.GetInputName(0, allocator)};  
118 - std::vector<const char*> outputNames {joiner_decoder_proj.GetOutputName(0, allocator)};  
119 -  
120 - auto out = joiner_decoder_proj.Run(Ort::RunOptions{nullptr},  
121 - inputNames.data(),  
122 - inputTensors.data(),  
123 - inputTensors.size(),  
124 - outputNames.data(),  
125 - outputNames.size());  
126 -  
127 - return out;  
128 - }  
129 -  
130 - Ort::Session load_model(const char* path){  
131 - struct stat buffer;  
132 - if (stat(path, &buffer) != 0){  
133 - std::cout << "File does not exist!: " << path << std::endl;  
134 - exit(0);  
135 - }  
136 - std::cout << "loading " << path << std::endl;  
137 - Ort::Session onnx_model(env, path, session_options);  
138 - return onnx_model;  
139 - }  
140 -  
141 - void extract_constant_lm_parameters(){  
142 - /*  
143 - all_in_one contains these params. We should trace all_in_one and find 'constants_lm' nodes to extract these params  
144 - For now, these params are set staticaly.  
145 - in: Ort::Session &all_in_one  
146 - out: {blank_id, unk_id, context_size}  
147 - should return std::vector<int32_t>  
148 - */  
149 - blank_id = 0;  
150 - unk_id = 0;  
151 - context_size = 2;  
152 - }  
153 -  
154 - std::map<int, std::string> get_token_map(const char* token_path){  
155 - std::ifstream inFile;  
156 - inFile.open(token_path);  
157 - if (inFile.fail())  
158 - std::cerr << "Could not find token file" << std::endl;  
159 -  
160 - std::map<int, std::string> token_map;  
161 -  
162 - std::string line;  
163 - while (std::getline(inFile, line))  
164 - {  
165 - int id;  
166 - std::string token;  
167 -  
168 - std::istringstream iss(line);  
169 - iss >> token;  
170 - iss >> id;  
171 -  
172 - token_map[id] = token;  
173 - }  
174 -  
175 - return token_map;  
176 - }  
177 -  
178 -};  
179 -  
180 -  
181 -Model get_model(std::string exp_path, char* tokens_path){  
182 - Model model{  
183 - (exp_path + "/encoder_simp.onnx").c_str(),  
184 - (exp_path + "/decoder_simp.onnx").c_str(),  
185 - (exp_path + "/joiner_simp.onnx").c_str(),  
186 - (exp_path + "/joiner_encoder_proj_simp.onnx").c_str(),  
187 - (exp_path + "/joiner_decoder_proj_simp.onnx").c_str(),  
188 - tokens_path,  
189 - };  
190 - model.extract_constant_lm_parameters();  
191 -  
192 - return model;  
193 -}  
194 -  
195 -Model get_model(char* encoder_path,  
196 - char* decoder_path,  
197 - char* joiner_path,  
198 - char* joiner_encoder_proj_path,  
199 - char* joiner_decoder_proj_path,  
200 - char* tokens_path){  
201 - Model model{  
202 - encoder_path,  
203 - decoder_path,  
204 - joiner_path,  
205 - joiner_encoder_proj_path,  
206 - joiner_decoder_proj_path,  
207 - tokens_path,  
208 - };  
209 - model.extract_constant_lm_parameters();  
210 -  
211 - return model;  
212 -}  
213 -  
214 -  
215 -void doWarmup(Model *model, int numWarmup = 5){  
216 - std::cout << "Warmup is started" << std::endl;  
217 -  
218 - std::vector<float> encoder_warmup_sample (500 * 80, 1.0);  
219 - for (int i=0; i<numWarmup; i++)  
220 - auto encoder_out = model->encoder_forward(encoder_warmup_sample,  
221 - std::vector<int64_t> {500},  
222 - std::vector<int64_t> {1, 500, 80},  
223 - std::vector<int64_t> {1},  
224 - memory_info);  
225 -  
226 - std::vector<int64_t> decoder_warmup_sample {1, 1};  
227 - for (int i=0; i<numWarmup; i++)  
228 - auto decoder_out = model->decoder_forward(decoder_warmup_sample,  
229 - std::vector<int64_t> {1, 2},  
230 - memory_info);  
231 -  
232 - std::vector<float> joiner_warmup_sample1 (512, 1.0);  
233 - std::vector<float> joiner_warmup_sample2 (512, 1.0);  
234 - for (int i=0; i<numWarmup; i++)  
235 - auto logits = model->joiner_forward(joiner_warmup_sample1,  
236 - joiner_warmup_sample2,  
237 - std::vector<int64_t> {1, 1, 1, 512},  
238 - std::vector<int64_t> {1, 1, 1, 512},  
239 - memory_info);  
240 -  
241 - std::vector<float> joiner_encoder_proj_warmup_sample (100 * 512, 1.0);  
242 - for (int i=0; i<numWarmup; i++)  
243 - auto projected_encoder_out = model->joiner_encoder_proj_forward(joiner_encoder_proj_warmup_sample,  
244 - std::vector<int64_t> {100, 512},  
245 - memory_info);  
246 -  
247 - std::vector<float> joiner_decoder_proj_warmup_sample (512, 1.0);  
248 - for (int i=0; i<numWarmup; i++)  
249 - auto projected_decoder_out = model->joiner_decoder_proj_forward(joiner_decoder_proj_warmup_sample,  
250 - std::vector<int64_t> {1, 512},  
251 - memory_info);  
252 - std::cout << "Warmup is done" << std::endl;  
253 -}  
  1 +/**
  2 + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3 + *
  4 + * See LICENSE for clarification regarding multiple authors
  5 + *
  6 + * Licensed under the Apache License, Version 2.0 (the "License");
  7 + * you may not use this file except in compliance with the License.
  8 + * You may obtain a copy of the License at
  9 + *
  10 + * http://www.apache.org/licenses/LICENSE-2.0
  11 + *
  12 + * Unless required by applicable law or agreed to in writing, software
  13 + * distributed under the License is distributed on an "AS IS" BASIS,
  14 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 + * See the License for the specific language governing permissions and
  16 + * limitations under the License.
  17 + */
  18 +#include "sherpa-onnx/csrc/rnnt-model.h"
  19 +
  20 +#include <array>
  21 +#include <utility>
  22 +#include <vector>
  23 +
  24 +namespace sherpa_onnx {
  25 +
  26 +/**
  27 + * Get the input names of a model.
  28 + *
  29 + * @param sess An onnxruntime session.
  30 + * @param input_names. On return, it contains the input names of the model.
  31 + * @param input_names_ptr. On return, input_names_ptr[i] contains
  32 + * input_names[i].c_str()
  33 + */
  34 +static void GetInputNames(Ort::Session *sess,
  35 + std::vector<std::string> *input_names,
  36 + std::vector<const char *> *input_names_ptr) {
  37 + Ort::AllocatorWithDefaultOptions allocator;
  38 + size_t node_count = sess->GetInputCount();
  39 + input_names->resize(node_count);
  40 + input_names_ptr->resize(node_count);
  41 + for (size_t i = 0; i != node_count; ++i) {
  42 + auto tmp = sess->GetInputNameAllocated(i, allocator);
  43 + (*input_names)[i] = tmp.get();
  44 + (*input_names_ptr)[i] = (*input_names)[i].c_str();
  45 + }
  46 +}
  47 +
  48 +/**
  49 + * Get the output names of a model.
  50 + *
  51 + * @param sess An onnxruntime session.
  52 + * @param output_names. On return, it contains the output names of the model.
  53 + * @param output_names_ptr. On return, output_names_ptr[i] contains
  54 + * output_names[i].c_str()
  55 + */
  56 +static void GetOutputNames(Ort::Session *sess,
  57 + std::vector<std::string> *output_names,
  58 + std::vector<const char *> *output_names_ptr) {
  59 + Ort::AllocatorWithDefaultOptions allocator;
  60 + size_t node_count = sess->GetOutputCount();
  61 + output_names->resize(node_count);
  62 + output_names_ptr->resize(node_count);
  63 + for (size_t i = 0; i != node_count; ++i) {
  64 + auto tmp = sess->GetOutputNameAllocated(i, allocator);
  65 + (*output_names)[i] = tmp.get();
  66 + (*output_names_ptr)[i] = (*output_names)[i].c_str();
  67 + }
  68 +}
  69 +
  70 +RnntModel::RnntModel(const std::string &encoder_filename,
  71 + const std::string &decoder_filename,
  72 + const std::string &joiner_filename,
  73 + const std::string &joiner_encoder_proj_filename,
  74 + const std::string &joiner_decoder_proj_filename,
  75 + int32_t num_threads)
  76 + : env_(ORT_LOGGING_LEVEL_WARNING) {
  77 + sess_opts_.SetIntraOpNumThreads(num_threads);
  78 + sess_opts_.SetInterOpNumThreads(num_threads);
  79 +
  80 + InitEncoder(encoder_filename);
  81 + InitDecoder(decoder_filename);
  82 + InitJoiner(joiner_filename);
  83 + InitJoinerEncoderProj(joiner_encoder_proj_filename);
  84 + InitJoinerDecoderProj(joiner_decoder_proj_filename);
  85 +}
  86 +
  87 +void RnntModel::InitEncoder(const std::string &filename) {
  88 + encoder_sess_ =
  89 + std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
  90 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  91 + &encoder_input_names_ptr_);
  92 +
  93 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  94 + &encoder_output_names_ptr_);
  95 +}
  96 +
  97 +void RnntModel::InitDecoder(const std::string &filename) {
  98 + decoder_sess_ =
  99 + std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
  100 +
  101 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  102 + &decoder_input_names_ptr_);
  103 +
  104 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  105 + &decoder_output_names_ptr_);
  106 +}
  107 +
  108 +void RnntModel::InitJoiner(const std::string &filename) {
  109 + joiner_sess_ =
  110 + std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
  111 +
  112 + GetInputNames(joiner_sess_.get(), &joiner_input_names_,
  113 + &joiner_input_names_ptr_);
  114 +
  115 + GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
  116 + &joiner_output_names_ptr_);
  117 +}
  118 +
  119 +void RnntModel::InitJoinerEncoderProj(const std::string &filename) {
  120 + joiner_encoder_proj_sess_ =
  121 + std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
  122 +
  123 + GetInputNames(joiner_encoder_proj_sess_.get(),
  124 + &joiner_encoder_proj_input_names_,
  125 + &joiner_encoder_proj_input_names_ptr_);
  126 +
  127 + GetOutputNames(joiner_encoder_proj_sess_.get(),
  128 + &joiner_encoder_proj_output_names_,
  129 + &joiner_encoder_proj_output_names_ptr_);
  130 +}
  131 +
  132 +void RnntModel::InitJoinerDecoderProj(const std::string &filename) {
  133 + joiner_decoder_proj_sess_ =
  134 + std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
  135 +
  136 + GetInputNames(joiner_decoder_proj_sess_.get(),
  137 + &joiner_decoder_proj_input_names_,
  138 + &joiner_decoder_proj_input_names_ptr_);
  139 +
  140 + GetOutputNames(joiner_decoder_proj_sess_.get(),
  141 + &joiner_decoder_proj_output_names_,
  142 + &joiner_decoder_proj_output_names_ptr_);
  143 +}
  144 +
  145 +Ort::Value RnntModel::RunEncoder(const float *features, int32_t T,
  146 + int32_t feature_dim) {
  147 + auto memory_info =
  148 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  149 + std::array<int64_t, 3> x_shape{1, T, feature_dim};
  150 + Ort::Value x =
  151 + Ort::Value::CreateTensor(memory_info, const_cast<float *>(features),
  152 + T * feature_dim, x_shape.data(), x_shape.size());
  153 +
  154 + std::array<int64_t, 1> x_lens_shape{1};
  155 + int64_t x_lens_tmp = T;
  156 +
  157 + Ort::Value x_lens = Ort::Value::CreateTensor(
  158 + memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size());
  159 +
  160 + std::array<Ort::Value, 2> encoder_inputs{std::move(x), std::move(x_lens)};
  161 +
  162 + // Note: We discard encoder_out_lens since we only implement
  163 + // batch==1.
  164 + auto encoder_out = encoder_sess_->Run(
  165 + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
  166 + encoder_inputs.size(), encoder_output_names_ptr_.data(),
  167 + encoder_output_names_ptr_.size());
  168 + return std::move(encoder_out[0]);
  169 +}
  170 +Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T,
  171 + int32_t encoder_out_dim) {
  172 + auto memory_info =
  173 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  174 +
  175 + std::array<int64_t, 2> in_shape{T, encoder_out_dim};
  176 + Ort::Value in = Ort::Value::CreateTensor(
  177 + memory_info, const_cast<float *>(encoder_out), T * encoder_out_dim,
  178 + in_shape.data(), in_shape.size());
  179 +
  180 + auto encoder_proj_out = joiner_encoder_proj_sess_->Run(
  181 + {}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1,
  182 + joiner_encoder_proj_output_names_ptr_.data(),
  183 + joiner_encoder_proj_output_names_ptr_.size());
  184 + return std::move(encoder_proj_out[0]);
  185 +}
  186 +
  187 +Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input,
  188 + int32_t context_size) {
  189 + auto memory_info =
  190 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  191 +
  192 + int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
  193 + std::array<int64_t, 2> shape{batch_size, context_size};
  194 + Ort::Value in = Ort::Value::CreateTensor(
  195 + memory_info, const_cast<int64_t *>(decoder_input),
  196 + batch_size * context_size, shape.data(), shape.size());
  197 +
  198 + auto decoder_out = decoder_sess_->Run(
  199 + {}, decoder_input_names_ptr_.data(), &in, 1,
  200 + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
  201 + return std::move(decoder_out[0]);
  202 +}
  203 +
  204 +Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out,
  205 + int32_t decoder_out_dim) {
  206 + auto memory_info =
  207 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  208 +
  209 + int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
  210 + std::array<int64_t, 2> shape{batch_size, decoder_out_dim};
  211 + Ort::Value in = Ort::Value::CreateTensor(
  212 + memory_info, const_cast<float *>(decoder_out),
  213 + batch_size * decoder_out_dim, shape.data(), shape.size());
  214 +
  215 + auto decoder_proj_out = joiner_decoder_proj_sess_->Run(
  216 + {}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1,
  217 + joiner_decoder_proj_output_names_ptr_.data(),
  218 + joiner_decoder_proj_output_names_ptr_.size());
  219 + return std::move(decoder_proj_out[0]);
  220 +}
  221 +
  222 +Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out,
  223 + const float *projected_decoder_out,
  224 + int32_t joiner_dim) {
  225 + auto memory_info =
  226 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  227 + int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
  228 + std::array<int64_t, 2> shape{batch_size, joiner_dim};
  229 +
  230 + Ort::Value enc = Ort::Value::CreateTensor(
  231 + memory_info, const_cast<float *>(projected_encoder_out),
  232 + batch_size * joiner_dim, shape.data(), shape.size());
  233 +
  234 + Ort::Value dec = Ort::Value::CreateTensor(
  235 + memory_info, const_cast<float *>(projected_decoder_out),
  236 + batch_size * joiner_dim, shape.data(), shape.size());
  237 +
  238 + std::array<Ort::Value, 2> inputs{std::move(enc), std::move(dec)};
  239 +
  240 + auto logit = joiner_sess_->Run(
  241 + {}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(),
  242 + joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size());
  243 +
  244 + return std::move(logit[0]);
  245 +}
  246 +
  247 +} // namespace sherpa_onnx
  1 +/**
  2 + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3 + *
  4 + * See LICENSE for clarification regarding multiple authors
  5 + *
  6 + * Licensed under the Apache License, Version 2.0 (the "License");
  7 + * you may not use this file except in compliance with the License.
  8 + * You may obtain a copy of the License at
  9 + *
  10 + * http://www.apache.org/licenses/LICENSE-2.0
  11 + *
  12 + * Unless required by applicable law or agreed to in writing, software
  13 + * distributed under the License is distributed on an "AS IS" BASIS,
  14 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 + * See the License for the specific language governing permissions and
  16 + * limitations under the License.
  17 + */
  18 +
  19 +#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_
  20 +#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_
  21 +
  22 +#include <memory>
  23 +#include <string>
  24 +#include <vector>
  25 +
  26 +#include "onnxruntime_cxx_api.h" // NOLINT
  27 +
  28 +namespace sherpa_onnx {
  29 +
  30 +class RnntModel {
  31 + public:
  32 + /**
  33 + * @param encoder_filename Path to the encoder model
  34 + * @param decoder_filename Path to the decoder model
  35 + * @param joiner_filename Path to the joiner model
  36 + * @param joiner_encoder_proj_filename Path to the joiner encoder_proj model
  37 + * @param joiner_decoder_proj_filename Path to the joiner decoder_proj model
  38 + * @param num_threads Number of threads to use to run the models
  39 + */
  40 + RnntModel(const std::string &encoder_filename,
  41 + const std::string &decoder_filename,
  42 + const std::string &joiner_filename,
  43 + const std::string &joiner_encoder_proj_filename,
  44 + const std::string &joiner_decoder_proj_filename,
  45 + int32_t num_threads);
  46 +
  47 + /** Run the encoder model.
  48 + *
  49 + * @TODO(fangjun): Support batch_size > 1
  50 + *
  51 + * @param features A tensor of shape (batch_size, T, feature_dim)
  52 + * @param T Number of feature frames
  53 + * @param feature_dim Dimension of the feature.
  54 + *
  55 + * @return Return a tensor of shape (batch_size, T', encoder_out_dim)
  56 + */
  57 + Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim);
  58 +
  59 + /** Run the joiner encoder_proj model.
  60 + *
  61 + * @param encoder_out A tensor of shape (T, encoder_out_dim)
  62 + * @param T Number of frames in encoder_out.
  63 + * @param encoder_out_dim Dimension of encoder_out.
  64 + *
  65 + * @return Return a tensor of shape (T, joiner_dim)
  66 + *
  67 + */
  68 + Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T,
  69 + int32_t encoder_out_dim);
  70 +
  71 + /** Run the decoder model.
  72 + *
  73 + * @TODO(fangjun): Support batch_size > 1
  74 + *
  75 + * @param decoder_input A tensor of shape (batch_size, context_size).
  76 + * @return Return a tensor of shape (batch_size, 1, decoder_out_dim)
  77 + */
  78 + Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size);
  79 +
  80 + /** Run joiner decoder_proj model.
  81 + *
  82 + * @TODO(fangjun): Support batch_size > 1
  83 + *
  84 + * @param decoder_out A tensor of shape (batch_size, decoder_out_dim)
  85 + * @param decoder_out_dim Output dimension of the decoder_out.
  86 + *
  87 + * @return Return a tensor of shape (batch_size, joiner_dim);
  88 + */
  89 + Ort::Value RunJoinerDecoderProj(const float *decoder_out,
  90 + int32_t decoder_out_dim);
  91 +
  92 + /** Run the joiner model.
  93 + *
  94 + * @TODO(fangjun): Support batch_size > 1
  95 + *
  96 + * @param projected_encoder_out A tensor of shape (batch_size, joiner_dim).
  97 + * @param projected_decoder_out A tensor of shape (batch_size, joiner_dim).
  98 + *
  99 + * @return Return a tensor of shape (batch_size, vocab_size)
  100 + */
  101 + Ort::Value RunJoiner(const float *projected_encoder_out,
  102 + const float *projected_decoder_out, int32_t joiner_dim);
  103 +
  104 + private:
  105 + void InitEncoder(const std::string &encoder_filename);
  106 + void InitDecoder(const std::string &decoder_filename);
  107 + void InitJoiner(const std::string &joiner_filename);
  108 + void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename);
  109 + void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename);
  110 +
  111 + private:
  112 + Ort::Env env_;
  113 + Ort::SessionOptions sess_opts_;
  114 + std::unique_ptr<Ort::Session> encoder_sess_;
  115 + std::unique_ptr<Ort::Session> decoder_sess_;
  116 + std::unique_ptr<Ort::Session> joiner_sess_;
  117 + std::unique_ptr<Ort::Session> joiner_encoder_proj_sess_;
  118 + std::unique_ptr<Ort::Session> joiner_decoder_proj_sess_;
  119 +
  120 + std::vector<std::string> encoder_input_names_;
  121 + std::vector<const char *> encoder_input_names_ptr_;
  122 + std::vector<std::string> encoder_output_names_;
  123 + std::vector<const char *> encoder_output_names_ptr_;
  124 +
  125 + std::vector<std::string> decoder_input_names_;
  126 + std::vector<const char *> decoder_input_names_ptr_;
  127 + std::vector<std::string> decoder_output_names_;
  128 + std::vector<const char *> decoder_output_names_ptr_;
  129 +
  130 + std::vector<std::string> joiner_input_names_;
  131 + std::vector<const char *> joiner_input_names_ptr_;
  132 + std::vector<std::string> joiner_output_names_;
  133 + std::vector<const char *> joiner_output_names_ptr_;
  134 +
  135 + std::vector<std::string> joiner_encoder_proj_input_names_;
  136 + std::vector<const char *> joiner_encoder_proj_input_names_ptr_;
  137 + std::vector<std::string> joiner_encoder_proj_output_names_;
  138 + std::vector<const char *> joiner_encoder_proj_output_names_ptr_;
  139 +
  140 + std::vector<std::string> joiner_decoder_proj_input_names_;
  141 + std::vector<const char *> joiner_decoder_proj_input_names_ptr_;
  142 + std::vector<std::string> joiner_decoder_proj_output_names_;
  143 + std::vector<const char *> joiner_decoder_proj_output_names_ptr_;
  144 +};
  145 +
  146 +} // namespace sherpa_onnx
  147 +
  148 +#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_
1 -#include <vector>  
2 -#include <iostream>  
3 -#include <algorithm>  
4 -#include <time.h>  
5 -  
6 -#include "models.h"  
7 -#include "utils.h"  
8 -  
9 -  
10 -std::vector<float> getEncoderCol(Ort::Value &tensor, int start, int length){  
11 - float* floatarr = tensor.GetTensorMutableData<float>();  
12 - std::vector<float> vector {floatarr + start, floatarr + length};  
13 - return vector;  
14 -}  
15 -  
16 -  
17 -/**  
18 - * Assume batch size = 1  
19 - */  
20 -std::vector<int64_t> BuildDecoderInput(const std::vector<std::vector<int32_t>> &hyps,  
21 - std::vector<int64_t> &decoder_input) {  
22 -  
23 - int32_t context_size = decoder_input.size();  
24 - int32_t hyps_length = hyps[0].size();  
25 - for (int i=0; i < context_size; i++)  
26 - decoder_input[i] = hyps[0][hyps_length-context_size+i];  
27 -  
28 - return decoder_input;  
29 -}  
30 -  
31 -  
32 -std::vector<std::vector<int32_t>> GreedySearch(  
33 - Model *model, // NOLINT  
34 - std::vector<Ort::Value> *encoder_out){  
35 - Ort::Value &encoder_out_tensor = encoder_out->at(0);  
36 - int encoder_out_dim1 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];  
37 - int encoder_out_dim2 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2];  
38 - auto encoder_out_vector = ortVal2Vector(encoder_out_tensor, encoder_out_dim1 * encoder_out_dim2);  
39 -  
40 - // # === Greedy Search === #  
41 - int32_t batch_size = 1;  
42 - std::vector<int32_t> blanks(model->context_size, model->blank_id);  
43 - std::vector<std::vector<int32_t>> hyps(batch_size, blanks);  
44 - std::vector<int64_t> decoder_input(model->context_size, model->blank_id);  
45 -  
46 - auto decoder_out = model->decoder_forward(decoder_input,  
47 - std::vector<int64_t> {batch_size, model->context_size},  
48 - memory_info);  
49 -  
50 - Ort::Value &decoder_out_tensor = decoder_out[0];  
51 - int decoder_out_dim = decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2];  
52 - auto decoder_out_vector = ortVal2Vector(decoder_out_tensor, decoder_out_dim);  
53 -  
54 - decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector,  
55 - std::vector<int64_t> {1, decoder_out_dim},  
56 - memory_info);  
57 - Ort::Value &projected_decoder_out_tensor = decoder_out[0];  
58 - auto projected_decoder_out_dim = projected_decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];  
59 - auto projected_decoder_out_vector = ortVal2Vector(projected_decoder_out_tensor, projected_decoder_out_dim);  
60 -  
61 - auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector,  
62 - std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2},  
63 - memory_info);  
64 - Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0];  
65 - int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0];  
66 - int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];  
67 - auto projected_encoder_out_vector = ortVal2Vector(projected_encoder_out_tensor, projected_encoder_out_dim1 * projected_encoder_out_dim2);  
68 -  
69 - int32_t offset = 0;  
70 - for (int i=0; i< projected_encoder_out_dim1; i++){  
71 - int32_t cur_batch_size = 1;  
72 - int32_t start = offset;  
73 - int32_t end = start + cur_batch_size;  
74 - offset = end;  
75 -  
76 - auto cur_encoder_out = getEncoderCol(projected_encoder_out_tensor, start * projected_encoder_out_dim2, end * projected_encoder_out_dim2);  
77 -  
78 - auto logits = model->joiner_forward(cur_encoder_out,  
79 - projected_decoder_out_vector,  
80 - std::vector<int64_t> {1, projected_encoder_out_dim2},  
81 - std::vector<int64_t> {1, projected_decoder_out_dim},  
82 - memory_info);  
83 -  
84 - Ort::Value &logits_tensor = logits[0];  
85 - int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];  
86 - auto logits_vector = ortVal2Vector(logits_tensor, logits_dim);  
87 -  
88 - int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end())));  
89 - bool emitted = false;  
90 -  
91 - for (int32_t k = 0; k != cur_batch_size; ++k) {  
92 - auto index = max_indices;  
93 - if (index != model->blank_id && index != model->unk_id) {  
94 - emitted = true;  
95 - hyps[k].push_back(index);  
96 - }  
97 - }  
98 -  
99 - if (emitted) {  
100 - decoder_input = BuildDecoderInput(hyps, decoder_input);  
101 -  
102 - decoder_out = model->decoder_forward(decoder_input,  
103 - std::vector<int64_t> {batch_size, model->context_size},  
104 - memory_info);  
105 -  
106 - decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[2];  
107 - decoder_out_vector = ortVal2Vector(decoder_out[0], decoder_out_dim);  
108 -  
109 - decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector,  
110 - std::vector<int64_t> {1, decoder_out_dim},  
111 - memory_info);  
112 -  
113 - projected_decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[1];  
114 - projected_decoder_out_vector = ortVal2Vector(decoder_out[0], projected_decoder_out_dim);  
115 - }  
116 - }  
117 -  
118 - return hyps;  
119 -}  
120 -  
  1 +/**
  2 + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3 + *
  4 + * See LICENSE for clarification regarding multiple authors
  5 + *
  6 + * Licensed under the Apache License, Version 2.0 (the "License");
  7 + * you may not use this file except in compliance with the License.
  8 + * You may obtain a copy of the License at
  9 + *
  10 + * http://www.apache.org/licenses/LICENSE-2.0
  11 + *
  12 + * Unless required by applicable law or agreed to in writing, software
  13 + * distributed under the License is distributed on an "AS IS" BASIS,
  14 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 + * See the License for the specific language governing permissions and
  16 + * limitations under the License.
  17 + */
  18 +
  19 +#include <iostream>
  20 +#include <string>
  21 +#include <vector>
  22 +
  23 +#include "kaldi-native-fbank/csrc/online-feature.h"
  24 +#include "sherpa-onnx/csrc/decode.h"
  25 +#include "sherpa-onnx/csrc/rnnt-model.h"
  26 +#include "sherpa-onnx/csrc/symbol-table.h"
  27 +#include "sherpa-onnx/csrc/wave-reader.h"
  28 +
  29 +/** Compute fbank features of the input wave filename.
  30 + *
  31 + * @param wav_filename. Path to a mono wave file.
  32 + * @param expected_sampling_rate Expected sampling rate of the input wave file.
  33 + * @param num_frames On return, it contains the number of feature frames.
  34 + * @return Return the computed feature of shape (num_frames, feature_dim)
  35 + * stored in row-major.
  36 + */
  37 +static std::vector<float> ComputeFeatures(const std::string &wav_filename,
  38 + float expected_sampling_rate,
  39 + int32_t *num_frames) {
  40 + std::vector<float> samples =
  41 + sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate);
  42 +
  43 + float duration = samples.size() / expected_sampling_rate;
  44 +
  45 + std::cout << "wav filename: " << wav_filename << "\n";
  46 + std::cout << "wav duration (s): " << duration << "\n";
  47 +
  48 + knf::FbankOptions opts;
  49 + opts.frame_opts.dither = 0;
  50 + opts.frame_opts.snip_edges = false;
  51 + opts.frame_opts.samp_freq = expected_sampling_rate;
  52 +
  53 + int32_t feature_dim = 80;
  54 +
  55 + opts.mel_opts.num_bins = feature_dim;
  56 +
  57 + knf::OnlineFbank fbank(opts);
  58 + fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
  59 + fbank.InputFinished();
  60 +
  61 + *num_frames = fbank.NumFramesReady();
  62 +
  63 + std::vector<float> features(*num_frames * feature_dim);
  64 + float *p = features.data();
  65 +
  66 + for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) {
  67 + const float *f = fbank.GetFrame(i);
  68 + std::copy(f, f + feature_dim, p);
  69 + }
  70 +
  71 + return features;
  72 +}
  73 +
  74 +int main(int32_t argc, char *argv[]) {
  75 + if (argc < 8 || argc > 9) {
  76 + const char *usage = R"usage(
  77 +Usage:
  78 + ./bin/sherpa-onnx \
  79 + /path/to/tokens.txt \
  80 + /path/to/encoder.onnx \
  81 + /path/to/decoder.onnx \
  82 + /path/to/joiner.onnx \
  83 + /path/to/joiner_encoder_proj.ncnn.param \
  84 + /path/to/joiner_decoder_proj.ncnn.param \
  85 + /path/to/foo.wav [num_threads]
  86 +
  87 +You can download pre-trained models from the following repository:
  88 +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
  89 +)usage";
  90 + std::cerr << usage << "\n";
  91 +
  92 + return 0;
  93 + }
  94 +
  95 + std::string tokens = argv[1];
  96 + std::string encoder = argv[2];
  97 + std::string decoder = argv[3];
  98 + std::string joiner = argv[4];
  99 + std::string joiner_encoder_proj = argv[5];
  100 + std::string joiner_decoder_proj = argv[6];
  101 + std::string wav_filename = argv[7];
  102 + int32_t num_threads = 4;
  103 + if (argc == 9) {
  104 + num_threads = atoi(argv[8]);
  105 + }
  106 +
  107 + sherpa_onnx::SymbolTable sym(tokens);
  108 +
  109 + int32_t num_frames;
  110 + auto features = ComputeFeatures(wav_filename, 16000, &num_frames);
  111 + int32_t feature_dim = features.size() / num_frames;
  112 +
  113 + sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj,
  114 + joiner_decoder_proj, num_threads);
  115 + Ort::Value encoder_out =
  116 + model.RunEncoder(features.data(), num_frames, feature_dim);
  117 +
  118 + auto hyp = sherpa_onnx::GreedySearch(model, encoder_out);
  119 +
  120 + std::string text;
  121 + for (auto i : hyp) {
  122 + text += sym[i];
  123 + }
  124 +
  125 + std::cout << "Recognition result for " << wav_filename << "\n"
  126 + << text << "\n";
  127 +
  128 + return 0;
  129 +}
@@ -15,34 +15,21 @@ @@ -15,34 +15,21 @@
15 * See the License for the specific language governing permissions and 15 * See the License for the specific language governing permissions and
16 * limitations under the License. 16 * limitations under the License.
17 */ 17 */
18 -  
19 #include <iostream> 18 #include <iostream>
  19 +#include <sstream>
20 20
21 -#include "kaldi-native-fbank/csrc/online-feature.h" 21 +#include "onnxruntime_cxx_api.h" // NOLINT
22 22
23 int main() { 23 int main() {
24 - knf::FbankOptions opts;  
25 - opts.frame_opts.dither = 0;  
26 - opts.mel_opts.num_bins = 10;  
27 -  
28 - knf::OnlineFbank fbank(opts);  
29 - for (int32_t i = 0; i < 1600; ++i) {  
30 - float s = (i * i - i / 2) / 32767.;  
31 - fbank.AcceptWaveform(16000, &s, 1);  
32 - }  
33 - 24 + std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n";
  25 + std::vector<std::string> providers = Ort::GetAvailableProviders();
34 std::ostringstream os; 26 std::ostringstream os;
35 -  
36 - int32_t n = fbank.NumFramesReady();  
37 - for (int32_t i = 0; i != n; ++i) {  
38 - const float *frame = fbank.GetFrame(i);  
39 - for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) {  
40 - os << frame[k] << ", ";  
41 - }  
42 - os << "\n"; 27 + os << "Available providers: ";
  28 + std::string sep = "";
  29 + for (const auto &p : providers) {
  30 + os << sep << p;
  31 + sep = ", ";
43 } 32 }
44 -  
45 std::cout << os.str() << "\n"; 33 std::cout << os.str() << "\n";
46 -  
47 return 0; 34 return 0;
48 } 35 }
  1 +/**
  2 + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3 + *
  4 + * See LICENSE for clarification regarding multiple authors
  5 + *
  6 + * Licensed under the Apache License, Version 2.0 (the "License");
  7 + * you may not use this file except in compliance with the License.
  8 + * You may obtain a copy of the License at
  9 + *
  10 + * http://www.apache.org/licenses/LICENSE-2.0
  11 + *
  12 + * Unless required by applicable law or agreed to in writing, software
  13 + * distributed under the License is distributed on an "AS IS" BASIS,
  14 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 + * See the License for the specific language governing permissions and
  16 + * limitations under the License.
  17 + */
  18 +
  19 +#include "sherpa-onnx/csrc/symbol-table.h"
  20 +
  21 +#include <cassert>
  22 +#include <fstream>
  23 +#include <sstream>
  24 +
  25 +namespace sherpa_onnx {
  26 +
  27 +SymbolTable::SymbolTable(const std::string &filename) {
  28 + std::ifstream is(filename);
  29 + std::string sym;
  30 + int32_t id;
  31 + while (is >> sym >> id) {
  32 + if (sym.size() >= 3) {
  33 + // For BPE-based models, we replace ▁ with a space
  34 + // Unicode 9601, hex 0x2581, utf8 0xe29681
  35 + const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
  36 + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
  37 + sym = sym.replace(0, 3, " ");
  38 + }
  39 + }
  40 +
  41 + assert(!sym.empty());
  42 + assert(sym2id_.count(sym) == 0);
  43 + assert(id2sym_.count(id) == 0);
  44 +
  45 + sym2id_.insert({sym, id});
  46 + id2sym_.insert({id, sym});
  47 + }
  48 + assert(is.eof());
  49 +}
  50 +
  51 +std::string SymbolTable::ToString() const {
  52 + std::ostringstream os;
  53 + char sep = ' ';
  54 + for (const auto &p : sym2id_) {
  55 + os << p.first << sep << p.second << "\n";
  56 + }
  57 + return os.str();
  58 +}
  59 +
  60 +const std::string &SymbolTable::operator[](int32_t id) const {
  61 + return id2sym_.at(id);
  62 +}
  63 +
  64 +int32_t SymbolTable::operator[](const std::string &sym) const {
  65 + return sym2id_.at(sym);
  66 +}
  67 +
  68 +bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; }
  69 +
  70 +bool SymbolTable::contains(const std::string &sym) const {
  71 + return sym2id_.count(sym) != 0;
  72 +}
  73 +
  74 +std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
  75 + return os << symbol_table.ToString();
  76 +}
  77 +
  78 +} // namespace sherpa_onnx
  1 +/**
  2 + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3 + *
  4 + * See LICENSE for clarification regarding multiple authors
  5 + *
  6 + * Licensed under the Apache License, Version 2.0 (the "License");
  7 + * you may not use this file except in compliance with the License.
  8 + * You may obtain a copy of the License at
  9 + *
  10 + * http://www.apache.org/licenses/LICENSE-2.0
  11 + *
  12 + * Unless required by applicable law or agreed to in writing, software
  13 + * distributed under the License is distributed on an "AS IS" BASIS,
  14 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 + * See the License for the specific language governing permissions and
  16 + * limitations under the License.
  17 + */
  18 +
  19 +#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
  20 +#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
  21 +
  22 +#include <string>
  23 +#include <unordered_map>
  24 +
  25 +namespace sherpa_onnx {
  26 +
  27 +/// It manages mapping between symbols and integer IDs.
  28 +class SymbolTable {
  29 + public:
  30 + SymbolTable() = default;
  31 + /// Construct a symbol table from a file.
  32 + /// Each line in the file contains two fields:
  33 + ///
  34 + /// sym ID
  35 + ///
  36 + /// Fields are separated by space(s).
  37 + explicit SymbolTable(const std::string &filename);
  38 +
  39 + /// Return a string representation of this symbol table
  40 + std::string ToString() const;
  41 +
  42 + /// Return the symbol corresponding to the given ID.
  43 + const std::string &operator[](int32_t id) const;
  44 + /// Return the ID corresponding to the given symbol.
  45 + int32_t operator[](const std::string &sym) const;
  46 +
  47 + /// Return true if there is a symbol with the given ID.
  48 + bool contains(int32_t id) const;
  49 +
  50 + /// Return true if there is a given symbol in the symbol table.
  51 + bool contains(const std::string &sym) const;
  52 +
  53 + private:
  54 + std::unordered_map<std::string, int32_t> sym2id_;
  55 + std::unordered_map<int32_t, std::string> id2sym_;
  56 +};
  57 +
  58 +std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table);
  59 +
  60 +} // namespace sherpa_onnx
  61 +
  62 +#endif // SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
1 -#include <iostream>  
2 -#include <fstream>  
3 -  
4 -  
5 -void vector2file(std::vector<float> vector, std::string saveFileName){  
6 - std::ofstream f(saveFileName);  
7 - for(std::vector<float>::const_iterator i = vector.begin(); i != vector.end(); ++i) {  
8 - f << *i << '\n';  
9 - }  
10 -}  
11 -  
12 -  
13 -std::vector<std::string> hyps2result(std::map<int, std::string> token_map, std::vector<std::vector<int32_t>> hyps, int context_size = 2){  
14 - std::vector<std::string> results;  
15 -  
16 - for (int k=0; k < hyps.size(); k++){  
17 - std::string result = token_map[hyps[k][context_size]];  
18 -  
19 - for (int i=context_size+1; i < hyps[k].size(); i++){  
20 - std::string token = token_map[hyps[k][i]];  
21 -  
22 - // TODO: recognising '_' is not working  
23 - if (token.at(0) == '_')  
24 - result += " " + token;  
25 - else  
26 - result += token;  
27 - }  
28 - results.push_back(result);  
29 - }  
30 - return results;  
31 -}  
32 -  
33 -  
34 -void print_hyps(std::vector<std::vector<int32_t>> hyps, int context_size = 2){  
35 - std::cout << "Hyps:" << std::endl;  
36 - for (int i=context_size; i<hyps[0].size(); i++)  
37 - std::cout << hyps[0][i] << "-";  
38 - std::cout << "|" << std::endl;  
39 -}  
1 -#include <iostream>  
2 -#include "onnxruntime_cxx_api.h"  
3 -  
4 -Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");  
5 -const auto& api = Ort::GetApi();  
6 -OrtTensorRTProviderOptionsV2* tensorrt_options;  
7 -Ort::SessionOptions session_options;  
8 -Ort::AllocatorWithDefaultOptions allocator;  
9 -auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);  
10 -  
11 -  
12 -std::vector<float> ortVal2Vector(Ort::Value &tensor, int tensor_length){  
13 - /**  
14 - * convert ort tensor to vector  
15 - */  
16 - float* floatarr = tensor.GetTensorMutableData<float>();  
17 - std::vector<float> vector {floatarr, floatarr + tensor_length};  
18 - return vector;  
19 -}  
20 -  
21 -  
22 -void print_onnx_forward_output(std::vector<Ort::Value> &output_tensors, int num){  
23 - float* floatarr = output_tensors.front().GetTensorMutableData<float>();  
24 - for (int i = 0; i < num; i++)  
25 - printf("[%d] = %f\n", i, floatarr[i]);  
26 -}  
27 -  
28 -  
29 -void print_shape_of_ort_val(std::vector<Ort::Value> &tensor){  
30 - auto out_shape = tensor.front().GetTensorTypeAndShapeInfo().GetShape();  
31 - auto out_size = out_shape.size();  
32 - std::cout << "(";  
33 - for (int i=0; i<out_size; i++){  
34 - std::cout << out_shape[i];  
35 - if (i < out_size-1)  
36 - std::cout << ",";  
37 - }  
38 - std::cout << ")" << std::endl;  
39 -}  
40 -  
41 -  
42 -void print_model_info(Ort::Session &session, std::string title){  
43 - std::cout << "=== Printing '" << title << "' model ===" << std::endl;  
44 - Ort::AllocatorWithDefaultOptions allocator;  
45 -  
46 - // print number of model input nodes  
47 - size_t num_input_nodes = session.GetInputCount();  
48 - std::vector<const char*> input_node_names(num_input_nodes);  
49 - std::vector<int64_t> input_node_dims;  
50 -  
51 - printf("Number of inputs = %zu\n", num_input_nodes);  
52 -  
53 - char* output_name = session.GetOutputName(0, allocator);  
54 - printf("output name: %s\n", output_name);  
55 -  
56 - // iterate over all input nodes  
57 - for (int i = 0; i < num_input_nodes; i++) {  
58 - // print input node names  
59 - char* input_name = session.GetInputName(i, allocator);  
60 - printf("Input %d : name=%s\n", i, input_name);  
61 - input_node_names[i] = input_name;  
62 -  
63 - // print input node types  
64 - Ort::TypeInfo type_info = session.GetInputTypeInfo(i);  
65 - auto tensor_info = type_info.GetTensorTypeAndShapeInfo();  
66 -  
67 - ONNXTensorElementDataType type = tensor_info.GetElementType();  
68 - printf("Input %d : type=%d\n", i, type);  
69 -  
70 - // print input shapes/dims  
71 - input_node_dims = tensor_info.GetShape();  
72 - printf("Input %d : num_dims=%zu\n", i, input_node_dims.size());  
73 - for (size_t j = 0; j < input_node_dims.size(); j++)  
74 - printf("Input %d : dim %zu=%jd\n", i, j, input_node_dims[j]);  
75 - }  
76 - std::cout << "=======================================" << std::endl;  
77 -}  
  1 +/**
  2 + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3 + *
  4 + * See LICENSE for clarification regarding multiple authors
  5 + *
  6 + * Licensed under the Apache License, Version 2.0 (the "License");
  7 + * you may not use this file except in compliance with the License.
  8 + * You may obtain a copy of the License at
  9 + *
  10 + * http://www.apache.org/licenses/LICENSE-2.0
  11 + *
  12 + * Unless required by applicable law or agreed to in writing, software
  13 + * distributed under the License is distributed on an "AS IS" BASIS,
  14 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 + * See the License for the specific language governing permissions and
  16 + * limitations under the License.
  17 + */
  18 +
  19 +#include "sherpa-onnx/csrc/wave-reader.h"
  20 +
  21 +#include <cassert>
  22 +#include <fstream>
  23 +#include <iostream>
  24 +#include <utility>
  25 +#include <vector>
  26 +
  27 +namespace sherpa_onnx {
  28 +
  29 +namespace {
  30 +// see http://soundfile.sapp.org/doc/WaveFormat/
  31 +//
  32 +// Note: We assume little endian here
  33 +// TODO(fangjun): Support big endian
  34 +struct WaveHeader {
  35 + void Validate() const {
  36 + // F F I R
  37 + assert(chunk_id == 0x46464952);
  38 + assert(chunk_size == 36 + subchunk2_size);
  39 + // E V A W
  40 + assert(format == 0x45564157);
  41 + assert(subchunk1_id == 0x20746d66);
  42 + assert(subchunk1_size == 16); // 16 for PCM
  43 + assert(audio_format == 1); // 1 for PCM
  44 + assert(num_channels == 1); // we support only single channel for now
  45 + assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);
  46 + assert(block_align == num_channels * bits_per_sample / 8);
  47 + assert(bits_per_sample == 16); // we support only 16 bits per sample
  48 + }
  49 +
  50 + int32_t chunk_id;
  51 + int32_t chunk_size;
  52 + int32_t format;
  53 + int32_t subchunk1_id;
  54 + int32_t subchunk1_size;
  55 + int16_t audio_format;
  56 + int16_t num_channels;
  57 + int32_t sample_rate;
  58 + int32_t byte_rate;
  59 + int16_t block_align;
  60 + int16_t bits_per_sample;
  61 + int32_t subchunk2_id;
  62 + int32_t subchunk2_size;
  63 +};
  64 +static_assert(sizeof(WaveHeader) == 44, "");
  65 +
  66 +// Read a wave file of mono-channel.
  67 +// Return its samples normalized to the range [-1, 1).
  68 +std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) {
  69 + WaveHeader header;
  70 + is.read(reinterpret_cast<char *>(&header), sizeof(header));
  71 + assert(static_cast<bool>(is));
  72 +
  73 + header.Validate();
  74 +
  75 + *sample_rate = header.sample_rate;
  76 +
  77 + // header.subchunk2_size contains the number of bytes in the data.
  78 + // As we assume each sample contains two bytes, so it is divided by 2 here
  79 + std::vector<int16_t> samples(header.subchunk2_size / 2);
  80 +
  81 + is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
  82 +
  83 + assert(static_cast<bool>(is));
  84 +
  85 + std::vector<float> ans(samples.size());
  86 + for (int32_t i = 0; i != ans.size(); ++i) {
  87 + ans[i] = samples[i] / 32768.;
  88 + }
  89 +
  90 + return ans;
  91 +}
  92 +
  93 +} // namespace
  94 +
  95 +std::vector<float> ReadWave(const std::string &filename,
  96 + float expected_sample_rate) {
  97 + std::ifstream is(filename, std::ifstream::binary);
  98 + float sample_rate;
  99 + auto samples = ReadWaveImpl(is, &sample_rate);
  100 + if (expected_sample_rate != sample_rate) {
  101 + std::cerr << "Expected sample rate: " << expected_sample_rate
  102 + << ". Given: " << sample_rate << ".\n";
  103 + exit(-1);
  104 + }
  105 + return samples;
  106 +}
  107 +
  108 +} // namespace sherpa_onnx
  1 +/**
  2 + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3 + *
  4 + * See LICENSE for clarification regarding multiple authors
  5 + *
  6 + * Licensed under the Apache License, Version 2.0 (the "License");
  7 + * you may not use this file except in compliance with the License.
  8 + * You may obtain a copy of the License at
  9 + *
  10 + * http://www.apache.org/licenses/LICENSE-2.0
  11 + *
  12 + * Unless required by applicable law or agreed to in writing, software
  13 + * distributed under the License is distributed on an "AS IS" BASIS,
  14 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 + * See the License for the specific language governing permissions and
  16 + * limitations under the License.
  17 + */
  18 +
  19 +#ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_
  20 +#define SHERPA_ONNX_CSRC_WAVE_READER_H_
  21 +
  22 +#include <istream>
  23 +#include <string>
  24 +#include <vector>
  25 +
  26 +namespace sherpa_onnx {
  27 +
  28 +/** Read a wave file with expected sample rate.
  29 +
  30 + @param filename Path to a wave file. It MUST be single channel, PCM encoded.
  31 + @param expected_sample_rate Expected sample rate of the wave file. If the
  32 + sample rate don't match, it throws an exception.
  33 +
  34 + @return Return wave samples normalized to the range [-1, 1).
  35 + */
  36 +std::vector<float> ReadWave(const std::string &filename,
  37 + float expected_sample_rate);
  38 +
  39 +} // namespace sherpa_onnx
  40 +
  41 +#endif // SHERPA_ONNX_CSRC_WAVE_READER_H_