Fangjun Kuang
Committed by GitHub

Refactor the code (#15)

* code refactoring

* Remove reference files

* Update README and CI

* small fixes

* fix style issues

* add style check for CI

* fix style issues

* remove kaldi-native-io
---
BasedOnStyle: Google
---
Language: Cpp
Cpp11BracedListStyle: true
Standard: Cpp11
DerivePointerAlignment: false
PointerAlignment: Right
---
... ...
# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: style_check
on:
push:
branches:
- master
paths:
- '.github/workflows/style_check.yaml'
- 'sherpa-onnx/**'
pull_request:
branches:
- master
paths:
- '.github/workflows/style_check.yaml'
- 'sherpa-onnx/**'
concurrency:
group: style_check-${{ github.ref }}
cancel-in-progress: true
jobs:
style_check:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Check style with cpplint
shell: bash
working-directory: ${{github.workspace}}
run: ./scripts/check_style_cpplint.sh
... ...
... ... @@ -59,31 +59,28 @@ jobs:
- name: Run tests for ubuntu/macos (English)
run: |
time ./build/bin/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
greedy \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
time ./build/bin/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
greedy \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
time ./build/bin/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
greedy \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
... ...
... ... @@ -38,7 +38,6 @@ set(CMAKE_CXX_EXTENSIONS OFF)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include(kaldi_native_io)
include(kaldi-native-fbank)
include(onnxruntime)
... ...
... ... @@ -14,6 +14,9 @@ the following links:
**NOTE**: We provide only non-streaming models at present.
**HINT**: The script for exporting the English model can be found at
<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py>
# Usage
```bash
... ... @@ -34,13 +37,14 @@ cd ..
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
./build/bin/sherpa-onnx --help
./build/bin/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
greedy \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
```
... ...
function(download_kaldi_native_io)
if(CMAKE_VERSION VERSION_LESS 3.11)
# FetchContent is available since 3.11,
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
# so that it can be used in lower CMake versions.
message(STATUS "Use FetchContent provided by sherpa-onnx")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
include(FetchContent)
set(kaldi_native_io_URL "https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz")
set(kaldi_native_io_HASH "SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6")
set(KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
FetchContent_Declare(kaldi_native_io
URL ${kaldi_native_io_URL}
URL_HASH ${kaldi_native_io_HASH}
)
FetchContent_GetProperties(kaldi_native_io)
if(NOT kaldi_native_io_POPULATED)
message(STATUS "Downloading kaldi_native_io ${kaldi_native_io_URL}")
FetchContent_Populate(kaldi_native_io)
endif()
message(STATUS "kaldi_native_io is downloaded to ${kaldi_native_io_SOURCE_DIR}")
message(STATUS "kaldi_native_io's binary dir is ${kaldi_native_io_BINARY_DIR}")
add_subdirectory(${kaldi_native_io_SOURCE_DIR} ${kaldi_native_io_BINARY_DIR} EXCLUDE_FROM_ALL)
target_include_directories(kaldi_native_io_core
PUBLIC
${kaldi_native_io_SOURCE_DIR}/
)
endfunction()
download_kaldi_native_io()
... ... @@ -10,7 +10,7 @@ function(download_onnxruntime)
include(FetchContent)
if(UNIX AND NOT APPLE)
set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz")
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz")
# If you don't have access to the internet, you can first download onnxruntime to some directory, and the use
# set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz")
... ...
#!/bin/bash
#
# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Usage:
#
# (1) To check files of the last commit
# ./scripts/check_style_cpplint.sh
#
# (2) To check changed files not committed yet
# ./scripts/check_style_cpplint.sh 1
#
# (3) To check all files in the project
# ./scripts/check_style_cpplint.sh 2
cpplint_version="1.5.4"
cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd)
sherpa_onnx_dir=$(cd $cur_dir/.. && pwd)
build_dir=$sherpa_onnx_dir/build
mkdir -p $build_dir
cpplint_src=$build_dir/cpplint-${cpplint_version}/cpplint.py
if [ ! -d "$build_dir/cpplint-${cpplint_version}" ]; then
pushd $build_dir
if command -v wget &> /dev/null; then
wget https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz
elif command -v curl &> /dev/null; then
curl -O -SL https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz
else
echo "Please install wget or curl to download cpplint"
exit 1
fi
tar xf ${cpplint_version}.tar.gz
rm ${cpplint_version}.tar.gz
# cpplint will report the following error for: __host__ __device__ (
#
# Extra space before ( in function call [whitespace/parens] [4]
#
# the following patch disables the above error
sed -i "3490i\ not Search(r'__host__ __device__\\\s+\\\(', fncall) and" $cpplint_src
popd
fi
source $sherpa_onnx_dir/scripts/utils.sh
# return true if the given file is a c++ source file
# return false otherwise
function is_source_code_file() {
case "$1" in
*.cc|*.h|*.cu)
echo true;;
*)
echo false;;
esac
}
function check_style() {
python3 $cpplint_src $1 || abort $1
}
function check_last_commit() {
files=$(git diff HEAD^1 --name-only --diff-filter=ACDMRUXB)
echo $files
}
function check_current_dir() {
files=$(git status -s -uno --porcelain | awk '{
if (NF == 4) {
# a file has been renamed
print $NF
} else {
print $2
}}')
echo $files
}
function do_check() {
case "$1" in
1)
echo "Check changed files"
files=$(check_current_dir)
;;
2)
echo "Check all files"
files=$(find $sherpa_onnx_dir/sherpa-onnx -name "*.h" -o -name "*.cc")
;;
*)
echo "Check last commit"
files=$(check_last_commit)
;;
esac
for f in $files; do
need_check=$(is_source_code_file $f)
if $need_check; then
[[ -f $f ]] && check_style $f
fi
done
}
function main() {
do_check $1
ok "Great! Style check passed!"
}
cd $sherpa_onnx_dir
main $1
... ...
#!/bin/bash
default='\033[0m'
bold='\033[1m'
red='\033[31m'
green='\033[32m'
function ok() {
printf "${bold}${green}[OK]${default} $1\n"
}
function error() {
printf "${bold}${red}[FAILED]${default} $1\n"
}
function abort() {
printf "${bold}${red}[FAILED]${default} $1\n"
exit 1
}
... ...
include_directories(${CMAKE_SOURCE_DIR})
add_executable(sherpa-onnx main.cpp)
add_executable(sherpa-onnx
decode.cc
rnnt-model.cc
sherpa-onnx.cc
symbol-table.cc
wave-reader.cc
)
target_link_libraries(sherpa-onnx
onnxruntime
kaldi-native-fbank-core
kaldi_native_io_core
)
add_executable(sherpa-show-onnx-info show-onnx-info.cc)
target_link_libraries(sherpa-show-onnx-info onnxruntime)
... ...
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sherpa-onnx/csrc/decode.h"
#include <assert.h>
#include <algorithm>
#include <vector>
namespace sherpa_onnx {
std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
const Ort::Value &encoder_out) {
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented");
Ort::Value projected_encoder_out =
model.RunJoinerEncoderProj(encoder_out.GetTensorData<float>(),
encoder_out_shape[1], encoder_out_shape[2]);
const float *p_projected_encoder_out =
projected_encoder_out.GetTensorData<float>();
int32_t context_size = 2; // hard-code it to 2
int32_t blank_id = 0; // hard-code it to 0
std::vector<int32_t> hyp(context_size, blank_id);
std::array<int64_t, 2> decoder_input{blank_id, blank_id};
Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size);
std::vector<int64_t> decoder_out_shape =
decoder_out.GetTensorTypeAndShapeInfo().GetShape();
Ort::Value projected_decoder_out = model.RunJoinerDecoderProj(
decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
int32_t joiner_dim =
projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1];
int32_t T = encoder_out_shape[1];
for (int32_t t = 0; t != T; ++t) {
Ort::Value logit = model.RunJoiner(
p_projected_encoder_out + t * joiner_dim,
projected_decoder_out.GetTensorData<float>(), joiner_dim);
int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1];
const float *p_logit = logit.GetTensorData<float>();
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
if (y != blank_id) {
decoder_input[0] = hyp.back();
decoder_input[1] = y;
hyp.push_back(y);
decoder_out = model.RunDecoder(decoder_input.data(), context_size);
projected_decoder_out = model.RunJoinerDecoderProj(
decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
}
}
return {hyp.begin() + context_size, hyp.end()};
}
} // namespace sherpa_onnx
... ...
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SHERPA_ONNX_CSRC_DECODE_H_
#define SHERPA_ONNX_CSRC_DECODE_H_
#include <vector>
#include "sherpa-onnx/csrc/rnnt-model.h"
namespace sherpa_onnx {
/** Greedy search for non-streaming ASR.
*
* @TODO(fangjun) Support batch size > 1
*
* @param model The RnntModel
* @param encoder_out Its shape is (1, num_frames, encoder_out_dim).
*/
std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
const Ort::Value &encoder_out);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_DECODE_H_
... ...
#include <iostream>
#include "kaldi_native_io/csrc/kaldi-io.h"
#include "kaldi_native_io/csrc/wave-reader.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
kaldiio::Matrix<float> readWav(std::string filename, bool log = false){
if (log)
std::cout << "reading " << filename << std::endl;
bool binary = true;
kaldiio::Input ki(filename, &binary);
kaldiio::WaveHolder wh;
if (!wh.Read(ki.Stream())) {
std::cerr << "Failed to read " << filename;
exit(EXIT_FAILURE);
}
auto &wave_data = wh.Value();
auto &d = wave_data.Data();
if (log)
std::cout << "wav shape: " << "(" << d.NumRows() << "," << d.NumCols() << ")" << std::endl;
return d;
}
std::vector<float> ComputeFeatures(knf::OnlineFbank &fbank, knf::FbankOptions opts, kaldiio::Matrix<float> samples, bool log = false){
int numSamples = samples.NumCols();
for (int i = 0; i < numSamples; i++)
{
float currentSample = samples.Row(0).Data()[i] / 32768;
fbank.AcceptWaveform(opts.frame_opts.samp_freq, &currentSample, 1);
}
std::vector<float> features;
int32_t num_frames = fbank.NumFramesReady();
for (int32_t i = 0; i != num_frames; ++i) {
const float *frame = fbank.GetFrame(i);
for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) {
features.push_back(frame[k]);
}
}
if (log){
std::cout << "done feature extraction" << std::endl;
std::cout << "extracted fbank shape " << "(" << num_frames << "," << opts.mel_opts.num_bins << ")" << std::endl;
for (int i=0; i< 20; i++)
std::cout << features.at(i) << std::endl;
}
return features;
}
\ No newline at end of file
#include <algorithm>
#include <fstream>
#include <iostream>
#include <math.h>
#include <time.h>
#include <vector>
#include "sherpa-onnx/csrc/fbank_features.h"
#include "sherpa-onnx/csrc/rnnt_beam_search.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
int main(int argc, char *argv[]) {
char *encoder_path = argv[1];
char *decoder_path = argv[2];
char *joiner_path = argv[3];
char *joiner_encoder_proj_path = argv[4];
char *joiner_decoder_proj_path = argv[5];
char *token_path = argv[6];
std::string search_method = argv[7];
char *filename = argv[8];
// General parameters
int numberOfThreads = 16;
// Initialize fbanks
knf::FbankOptions opts;
opts.frame_opts.dither = 0;
opts.frame_opts.samp_freq = 16000;
opts.frame_opts.frame_shift_ms = 10.0f;
opts.frame_opts.frame_length_ms = 25.0f;
opts.mel_opts.num_bins = 80;
opts.frame_opts.window_type = "povey";
opts.frame_opts.snip_edges = false;
knf::OnlineFbank fbank(opts);
// set session opts
// https://onnxruntime.ai/docs/performance/tune-performance.html
session_options.SetIntraOpNumThreads(numberOfThreads);
session_options.SetInterOpNumThreads(numberOfThreads);
session_options.SetGraphOptimizationLevel(
GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
session_options.SetLogSeverityLevel(4);
session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
api.CreateTensorRTProviderOptions(&tensorrt_options);
std::unique_ptr<OrtTensorRTProviderOptionsV2,
decltype(api.ReleaseTensorRTProviderOptions)>
rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);
api.SessionOptionsAppendExecutionProvider_TensorRT_V2(
static_cast<OrtSessionOptions *>(session_options), rel_trt_options.get());
// Define model
auto model =
get_model(encoder_path, decoder_path, joiner_path,
joiner_encoder_proj_path, joiner_decoder_proj_path, token_path);
std::vector<std::string> filename_list{filename};
for (auto filename : filename_list) {
std::cout << filename << std::endl;
auto samples = readWav(filename, true);
int numSamples = samples.NumCols();
auto features = ComputeFeatures(fbank, opts, samples);
auto tic = std::chrono::high_resolution_clock::now();
// # === Encoder Out === #
int num_frames = features.size() / opts.mel_opts.num_bins;
auto encoder_out =
model.encoder_forward(features, std::vector<int64_t>{num_frames},
std::vector<int64_t>{1, num_frames, 80},
std::vector<int64_t>{1}, memory_info);
// # === Search === #
std::vector<std::vector<int32_t>> hyps;
if (search_method == "greedy")
hyps = GreedySearch(&model, &encoder_out);
else {
std::cout << "wrong search method!" << std::endl;
exit(0);
}
auto results = hyps2result(model.tokens_map, hyps);
// # === Print Elapsed Time === #
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - tic);
std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds"
<< std::endl;
std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000)
<< std::endl;
print_hyps(hyps);
std::cout << results[0] << std::endl;
}
return 0;
}
#include <map>
#include <vector>
#include <iostream>
#include <algorithm>
#include <sys/stat.h>
#include "utils_onnx.h"
struct Model
{
public:
const char* encoder_path;
const char* decoder_path;
const char* joiner_path;
const char* joiner_encoder_proj_path;
const char* joiner_decoder_proj_path;
const char* tokens_path;
Ort::Session encoder = load_model(encoder_path);
Ort::Session decoder = load_model(decoder_path);
Ort::Session joiner = load_model(joiner_path);
Ort::Session joiner_encoder_proj = load_model(joiner_encoder_proj_path);
Ort::Session joiner_decoder_proj = load_model(joiner_decoder_proj_path);
std::map<int, std::string> tokens_map = get_token_map(tokens_path);
int32_t blank_id;
int32_t unk_id;
int32_t context_size;
std::vector<Ort::Value> encoder_forward(std::vector<float> in_vector,
std::vector<int64_t> in_vector_length,
std::vector<int64_t> feature_dims,
std::vector<int64_t> feature_length_dims,
Ort::MemoryInfo &memory_info){
std::vector<Ort::Value> encoder_inputTensors;
encoder_inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), feature_dims.data(), feature_dims.size()));
encoder_inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector_length.data(), in_vector_length.size(), feature_length_dims.data(), feature_length_dims.size()));
std::vector<const char*> encoder_inputNames = {encoder.GetInputName(0, allocator), encoder.GetInputName(1, allocator)};
std::vector<const char*> encoder_outputNames = {encoder.GetOutputName(0, allocator)};
auto out = encoder.Run(Ort::RunOptions{nullptr},
encoder_inputNames.data(),
encoder_inputTensors.data(),
encoder_inputTensors.size(),
encoder_outputNames.data(),
encoder_outputNames.size());
return out;
}
std::vector<Ort::Value> decoder_forward(std::vector<int64_t> in_vector,
std::vector<int64_t> dims,
Ort::MemoryInfo &memory_info){
std::vector<Ort::Value> inputTensors;
inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));
std::vector<const char*> inputNames {decoder.GetInputName(0, allocator)};
std::vector<const char*> outputNames {decoder.GetOutputName(0, allocator)};
auto out = decoder.Run(Ort::RunOptions{nullptr},
inputNames.data(),
inputTensors.data(),
inputTensors.size(),
outputNames.data(),
outputNames.size());
return out;
}
std::vector<Ort::Value> joiner_forward(std::vector<float> projected_encoder_out,
std::vector<float> decoder_out,
std::vector<int64_t> projected_encoder_out_dims,
std::vector<int64_t> decoder_out_dims,
Ort::MemoryInfo &memory_info){
std::vector<Ort::Value> inputTensors;
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, projected_encoder_out.data(), projected_encoder_out.size(), projected_encoder_out_dims.data(), projected_encoder_out_dims.size()));
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, decoder_out.data(), decoder_out.size(), decoder_out_dims.data(), decoder_out_dims.size()));
std::vector<const char*> inputNames = {joiner.GetInputName(0, allocator), joiner.GetInputName(1, allocator)};
std::vector<const char*> outputNames = {joiner.GetOutputName(0, allocator)};
auto out = joiner.Run(Ort::RunOptions{nullptr},
inputNames.data(),
inputTensors.data(),
inputTensors.size(),
outputNames.data(),
outputNames.size());
return out;
}
std::vector<Ort::Value> joiner_encoder_proj_forward(std::vector<float> in_vector,
std::vector<int64_t> dims,
Ort::MemoryInfo &memory_info){
std::vector<Ort::Value> inputTensors;
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));
std::vector<const char*> inputNames {joiner_encoder_proj.GetInputName(0, allocator)};
std::vector<const char*> outputNames {joiner_encoder_proj.GetOutputName(0, allocator)};
auto out = joiner_encoder_proj.Run(Ort::RunOptions{nullptr},
inputNames.data(),
inputTensors.data(),
inputTensors.size(),
outputNames.data(),
outputNames.size());
return out;
}
std::vector<Ort::Value> joiner_decoder_proj_forward(std::vector<float> in_vector,
std::vector<int64_t> dims,
Ort::MemoryInfo &memory_info){
std::vector<Ort::Value> inputTensors;
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));
std::vector<const char*> inputNames {joiner_decoder_proj.GetInputName(0, allocator)};
std::vector<const char*> outputNames {joiner_decoder_proj.GetOutputName(0, allocator)};
auto out = joiner_decoder_proj.Run(Ort::RunOptions{nullptr},
inputNames.data(),
inputTensors.data(),
inputTensors.size(),
outputNames.data(),
outputNames.size());
return out;
}
Ort::Session load_model(const char* path){
struct stat buffer;
if (stat(path, &buffer) != 0){
std::cout << "File does not exist!: " << path << std::endl;
exit(0);
}
std::cout << "loading " << path << std::endl;
Ort::Session onnx_model(env, path, session_options);
return onnx_model;
}
void extract_constant_lm_parameters(){
/*
all_in_one contains these params. We should trace all_in_one and find 'constants_lm' nodes to extract these params
For now, these params are set staticaly.
in: Ort::Session &all_in_one
out: {blank_id, unk_id, context_size}
should return std::vector<int32_t>
*/
blank_id = 0;
unk_id = 0;
context_size = 2;
}
std::map<int, std::string> get_token_map(const char* token_path){
std::ifstream inFile;
inFile.open(token_path);
if (inFile.fail())
std::cerr << "Could not find token file" << std::endl;
std::map<int, std::string> token_map;
std::string line;
while (std::getline(inFile, line))
{
int id;
std::string token;
std::istringstream iss(line);
iss >> token;
iss >> id;
token_map[id] = token;
}
return token_map;
}
};
Model get_model(std::string exp_path, char* tokens_path){
Model model{
(exp_path + "/encoder_simp.onnx").c_str(),
(exp_path + "/decoder_simp.onnx").c_str(),
(exp_path + "/joiner_simp.onnx").c_str(),
(exp_path + "/joiner_encoder_proj_simp.onnx").c_str(),
(exp_path + "/joiner_decoder_proj_simp.onnx").c_str(),
tokens_path,
};
model.extract_constant_lm_parameters();
return model;
}
Model get_model(char* encoder_path,
char* decoder_path,
char* joiner_path,
char* joiner_encoder_proj_path,
char* joiner_decoder_proj_path,
char* tokens_path){
Model model{
encoder_path,
decoder_path,
joiner_path,
joiner_encoder_proj_path,
joiner_decoder_proj_path,
tokens_path,
};
model.extract_constant_lm_parameters();
return model;
}
void doWarmup(Model *model, int numWarmup = 5){
std::cout << "Warmup is started" << std::endl;
std::vector<float> encoder_warmup_sample (500 * 80, 1.0);
for (int i=0; i<numWarmup; i++)
auto encoder_out = model->encoder_forward(encoder_warmup_sample,
std::vector<int64_t> {500},
std::vector<int64_t> {1, 500, 80},
std::vector<int64_t> {1},
memory_info);
std::vector<int64_t> decoder_warmup_sample {1, 1};
for (int i=0; i<numWarmup; i++)
auto decoder_out = model->decoder_forward(decoder_warmup_sample,
std::vector<int64_t> {1, 2},
memory_info);
std::vector<float> joiner_warmup_sample1 (512, 1.0);
std::vector<float> joiner_warmup_sample2 (512, 1.0);
for (int i=0; i<numWarmup; i++)
auto logits = model->joiner_forward(joiner_warmup_sample1,
joiner_warmup_sample2,
std::vector<int64_t> {1, 1, 1, 512},
std::vector<int64_t> {1, 1, 1, 512},
memory_info);
std::vector<float> joiner_encoder_proj_warmup_sample (100 * 512, 1.0);
for (int i=0; i<numWarmup; i++)
auto projected_encoder_out = model->joiner_encoder_proj_forward(joiner_encoder_proj_warmup_sample,
std::vector<int64_t> {100, 512},
memory_info);
std::vector<float> joiner_decoder_proj_warmup_sample (512, 1.0);
for (int i=0; i<numWarmup; i++)
auto projected_decoder_out = model->joiner_decoder_proj_forward(joiner_decoder_proj_warmup_sample,
std::vector<int64_t> {1, 512},
memory_info);
std::cout << "Warmup is done" << std::endl;
}
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sherpa-onnx/csrc/rnnt-model.h"
#include <array>
#include <utility>
#include <vector>
namespace sherpa_onnx {
/**
* Get the input names of a model.
*
* @param sess An onnxruntime session.
* @param input_names. On return, it contains the input names of the model.
* @param input_names_ptr. On return, input_names_ptr[i] contains
* input_names[i].c_str()
*/
static void GetInputNames(Ort::Session *sess,
std::vector<std::string> *input_names,
std::vector<const char *> *input_names_ptr) {
Ort::AllocatorWithDefaultOptions allocator;
size_t node_count = sess->GetInputCount();
input_names->resize(node_count);
input_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetInputNameAllocated(i, allocator);
(*input_names)[i] = tmp.get();
(*input_names_ptr)[i] = (*input_names)[i].c_str();
}
}
/**
* Get the output names of a model.
*
* @param sess An onnxruntime session.
* @param output_names. On return, it contains the output names of the model.
* @param output_names_ptr. On return, output_names_ptr[i] contains
* output_names[i].c_str()
*/
static void GetOutputNames(Ort::Session *sess,
std::vector<std::string> *output_names,
std::vector<const char *> *output_names_ptr) {
Ort::AllocatorWithDefaultOptions allocator;
size_t node_count = sess->GetOutputCount();
output_names->resize(node_count);
output_names_ptr->resize(node_count);
for (size_t i = 0; i != node_count; ++i) {
auto tmp = sess->GetOutputNameAllocated(i, allocator);
(*output_names)[i] = tmp.get();
(*output_names_ptr)[i] = (*output_names)[i].c_str();
}
}
RnntModel::RnntModel(const std::string &encoder_filename,
const std::string &decoder_filename,
const std::string &joiner_filename,
const std::string &joiner_encoder_proj_filename,
const std::string &joiner_decoder_proj_filename,
int32_t num_threads)
: env_(ORT_LOGGING_LEVEL_WARNING) {
sess_opts_.SetIntraOpNumThreads(num_threads);
sess_opts_.SetInterOpNumThreads(num_threads);
InitEncoder(encoder_filename);
InitDecoder(decoder_filename);
InitJoiner(joiner_filename);
InitJoinerEncoderProj(joiner_encoder_proj_filename);
InitJoinerDecoderProj(joiner_decoder_proj_filename);
}
void RnntModel::InitEncoder(const std::string &filename) {
encoder_sess_ =
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
}
void RnntModel::InitDecoder(const std::string &filename) {
decoder_sess_ =
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
void RnntModel::InitJoiner(const std::string &filename) {
joiner_sess_ =
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_);
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
&joiner_output_names_ptr_);
}
void RnntModel::InitJoinerEncoderProj(const std::string &filename) {
joiner_encoder_proj_sess_ =
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
GetInputNames(joiner_encoder_proj_sess_.get(),
&joiner_encoder_proj_input_names_,
&joiner_encoder_proj_input_names_ptr_);
GetOutputNames(joiner_encoder_proj_sess_.get(),
&joiner_encoder_proj_output_names_,
&joiner_encoder_proj_output_names_ptr_);
}
void RnntModel::InitJoinerDecoderProj(const std::string &filename) {
joiner_decoder_proj_sess_ =
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
GetInputNames(joiner_decoder_proj_sess_.get(),
&joiner_decoder_proj_input_names_,
&joiner_decoder_proj_input_names_ptr_);
GetOutputNames(joiner_decoder_proj_sess_.get(),
&joiner_decoder_proj_output_names_,
&joiner_decoder_proj_output_names_ptr_);
}
Ort::Value RnntModel::RunEncoder(const float *features, int32_t T,
int32_t feature_dim) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{1, T, feature_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, const_cast<float *>(features),
T * feature_dim, x_shape.data(), x_shape.size());
std::array<int64_t, 1> x_lens_shape{1};
int64_t x_lens_tmp = T;
Ort::Value x_lens = Ort::Value::CreateTensor(
memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size());
std::array<Ort::Value, 2> encoder_inputs{std::move(x), std::move(x_lens)};
// Note: We discard encoder_out_lens since we only implement
// batch==1.
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
return std::move(encoder_out[0]);
}
Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T,
int32_t encoder_out_dim) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> in_shape{T, encoder_out_dim};
Ort::Value in = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(encoder_out), T * encoder_out_dim,
in_shape.data(), in_shape.size());
auto encoder_proj_out = joiner_encoder_proj_sess_->Run(
{}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1,
joiner_encoder_proj_output_names_ptr_.data(),
joiner_encoder_proj_output_names_ptr_.size());
return std::move(encoder_proj_out[0]);
}
Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input,
int32_t context_size) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
std::array<int64_t, 2> shape{batch_size, context_size};
Ort::Value in = Ort::Value::CreateTensor(
memory_info, const_cast<int64_t *>(decoder_input),
batch_size * context_size, shape.data(), shape.size());
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), &in, 1,
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
return std::move(decoder_out[0]);
}
Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out,
int32_t decoder_out_dim) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
std::array<int64_t, 2> shape{batch_size, decoder_out_dim};
Ort::Value in = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(decoder_out),
batch_size * decoder_out_dim, shape.data(), shape.size());
auto decoder_proj_out = joiner_decoder_proj_sess_->Run(
{}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1,
joiner_decoder_proj_output_names_ptr_.data(),
joiner_decoder_proj_output_names_ptr_.size());
return std::move(decoder_proj_out[0]);
}
Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out,
const float *projected_decoder_out,
int32_t joiner_dim) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
std::array<int64_t, 2> shape{batch_size, joiner_dim};
Ort::Value enc = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(projected_encoder_out),
batch_size * joiner_dim, shape.data(), shape.size());
Ort::Value dec = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(projected_decoder_out),
batch_size * joiner_dim, shape.data(), shape.size());
std::array<Ort::Value, 2> inputs{std::move(enc), std::move(dec)};
auto logit = joiner_sess_->Run(
{}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(),
joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size());
return std::move(logit[0]);
}
} // namespace sherpa_onnx
... ...
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_
#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_
#include <memory>
#include <string>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
class RnntModel {
public:
/**
* @param encoder_filename Path to the encoder model
* @param decoder_filename Path to the decoder model
* @param joiner_filename Path to the joiner model
* @param joiner_encoder_proj_filename Path to the joiner encoder_proj model
* @param joiner_decoder_proj_filename Path to the joiner decoder_proj model
* @param num_threads Number of threads to use to run the models
*/
RnntModel(const std::string &encoder_filename,
const std::string &decoder_filename,
const std::string &joiner_filename,
const std::string &joiner_encoder_proj_filename,
const std::string &joiner_decoder_proj_filename,
int32_t num_threads);
/** Run the encoder model.
*
* @TODO(fangjun): Support batch_size > 1
*
* @param features A tensor of shape (batch_size, T, feature_dim)
* @param T Number of feature frames
* @param feature_dim Dimension of the feature.
*
* @return Return a tensor of shape (batch_size, T', encoder_out_dim)
*/
Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim);
/** Run the joiner encoder_proj model.
*
* @param encoder_out A tensor of shape (T, encoder_out_dim)
* @param T Number of frames in encoder_out.
* @param encoder_out_dim Dimension of encoder_out.
*
* @return Return a tensor of shape (T, joiner_dim)
*
*/
Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T,
int32_t encoder_out_dim);
/** Run the decoder model.
*
* @TODO(fangjun): Support batch_size > 1
*
* @param decoder_input A tensor of shape (batch_size, context_size).
* @return Return a tensor of shape (batch_size, 1, decoder_out_dim)
*/
Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size);
/** Run joiner decoder_proj model.
*
* @TODO(fangjun): Support batch_size > 1
*
* @param decoder_out A tensor of shape (batch_size, decoder_out_dim)
* @param decoder_out_dim Output dimension of the decoder_out.
*
* @return Return a tensor of shape (batch_size, joiner_dim);
*/
Ort::Value RunJoinerDecoderProj(const float *decoder_out,
int32_t decoder_out_dim);
/** Run the joiner model.
*
* @TODO(fangjun): Support batch_size > 1
*
* @param projected_encoder_out A tensor of shape (batch_size, joiner_dim).
* @param projected_decoder_out A tensor of shape (batch_size, joiner_dim).
*
* @return Return a tensor of shape (batch_size, vocab_size)
*/
Ort::Value RunJoiner(const float *projected_encoder_out,
const float *projected_decoder_out, int32_t joiner_dim);
private:
void InitEncoder(const std::string &encoder_filename);
void InitDecoder(const std::string &decoder_filename);
void InitJoiner(const std::string &joiner_filename);
void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename);
void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename);
private:
Ort::Env env_;
Ort::SessionOptions sess_opts_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::unique_ptr<Ort::Session> joiner_sess_;
std::unique_ptr<Ort::Session> joiner_encoder_proj_sess_;
std::unique_ptr<Ort::Session> joiner_decoder_proj_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<std::string> joiner_input_names_;
std::vector<const char *> joiner_input_names_ptr_;
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
std::vector<std::string> joiner_encoder_proj_input_names_;
std::vector<const char *> joiner_encoder_proj_input_names_ptr_;
std::vector<std::string> joiner_encoder_proj_output_names_;
std::vector<const char *> joiner_encoder_proj_output_names_ptr_;
std::vector<std::string> joiner_decoder_proj_input_names_;
std::vector<const char *> joiner_decoder_proj_input_names_ptr_;
std::vector<std::string> joiner_decoder_proj_output_names_;
std::vector<const char *> joiner_decoder_proj_output_names_ptr_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_
... ...
#include <vector>
#include <iostream>
#include <algorithm>
#include <time.h>
#include "models.h"
#include "utils.h"
std::vector<float> getEncoderCol(Ort::Value &tensor, int start, int length){
float* floatarr = tensor.GetTensorMutableData<float>();
std::vector<float> vector {floatarr + start, floatarr + length};
return vector;
}
/**
* Assume batch size = 1
*/
std::vector<int64_t> BuildDecoderInput(const std::vector<std::vector<int32_t>> &hyps,
std::vector<int64_t> &decoder_input) {
int32_t context_size = decoder_input.size();
int32_t hyps_length = hyps[0].size();
for (int i=0; i < context_size; i++)
decoder_input[i] = hyps[0][hyps_length-context_size+i];
return decoder_input;
}
std::vector<std::vector<int32_t>> GreedySearch(
Model *model, // NOLINT
std::vector<Ort::Value> *encoder_out){
Ort::Value &encoder_out_tensor = encoder_out->at(0);
int encoder_out_dim1 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
int encoder_out_dim2 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2];
auto encoder_out_vector = ortVal2Vector(encoder_out_tensor, encoder_out_dim1 * encoder_out_dim2);
// # === Greedy Search === #
int32_t batch_size = 1;
std::vector<int32_t> blanks(model->context_size, model->blank_id);
std::vector<std::vector<int32_t>> hyps(batch_size, blanks);
std::vector<int64_t> decoder_input(model->context_size, model->blank_id);
auto decoder_out = model->decoder_forward(decoder_input,
std::vector<int64_t> {batch_size, model->context_size},
memory_info);
Ort::Value &decoder_out_tensor = decoder_out[0];
int decoder_out_dim = decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2];
auto decoder_out_vector = ortVal2Vector(decoder_out_tensor, decoder_out_dim);
decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector,
std::vector<int64_t> {1, decoder_out_dim},
memory_info);
Ort::Value &projected_decoder_out_tensor = decoder_out[0];
auto projected_decoder_out_dim = projected_decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
auto projected_decoder_out_vector = ortVal2Vector(projected_decoder_out_tensor, projected_decoder_out_dim);
auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector,
std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2},
memory_info);
Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0];
int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0];
int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
auto projected_encoder_out_vector = ortVal2Vector(projected_encoder_out_tensor, projected_encoder_out_dim1 * projected_encoder_out_dim2);
int32_t offset = 0;
for (int i=0; i< projected_encoder_out_dim1; i++){
int32_t cur_batch_size = 1;
int32_t start = offset;
int32_t end = start + cur_batch_size;
offset = end;
auto cur_encoder_out = getEncoderCol(projected_encoder_out_tensor, start * projected_encoder_out_dim2, end * projected_encoder_out_dim2);
auto logits = model->joiner_forward(cur_encoder_out,
projected_decoder_out_vector,
std::vector<int64_t> {1, projected_encoder_out_dim2},
std::vector<int64_t> {1, projected_decoder_out_dim},
memory_info);
Ort::Value &logits_tensor = logits[0];
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
auto logits_vector = ortVal2Vector(logits_tensor, logits_dim);
int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end())));
bool emitted = false;
for (int32_t k = 0; k != cur_batch_size; ++k) {
auto index = max_indices;
if (index != model->blank_id && index != model->unk_id) {
emitted = true;
hyps[k].push_back(index);
}
}
if (emitted) {
decoder_input = BuildDecoderInput(hyps, decoder_input);
decoder_out = model->decoder_forward(decoder_input,
std::vector<int64_t> {batch_size, model->context_size},
memory_info);
decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[2];
decoder_out_vector = ortVal2Vector(decoder_out[0], decoder_out_dim);
decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector,
std::vector<int64_t> {1, decoder_out_dim},
memory_info);
projected_decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[1];
projected_decoder_out_vector = ortVal2Vector(decoder_out[0], projected_decoder_out_dim);
}
}
return hyps;
}
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <string>
#include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "sherpa-onnx/csrc/decode.h"
#include "sherpa-onnx/csrc/rnnt-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
/** Compute fbank features of the input wave filename.
*
* @param wav_filename. Path to a mono wave file.
* @param expected_sampling_rate Expected sampling rate of the input wave file.
* @param num_frames On return, it contains the number of feature frames.
* @return Return the computed feature of shape (num_frames, feature_dim)
* stored in row-major.
*/
static std::vector<float> ComputeFeatures(const std::string &wav_filename,
float expected_sampling_rate,
int32_t *num_frames) {
std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate);
float duration = samples.size() / expected_sampling_rate;
std::cout << "wav filename: " << wav_filename << "\n";
std::cout << "wav duration (s): " << duration << "\n";
knf::FbankOptions opts;
opts.frame_opts.dither = 0;
opts.frame_opts.snip_edges = false;
opts.frame_opts.samp_freq = expected_sampling_rate;
int32_t feature_dim = 80;
opts.mel_opts.num_bins = feature_dim;
knf::OnlineFbank fbank(opts);
fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
fbank.InputFinished();
*num_frames = fbank.NumFramesReady();
std::vector<float> features(*num_frames * feature_dim);
float *p = features.data();
for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) {
const float *f = fbank.GetFrame(i);
std::copy(f, f + feature_dim, p);
}
return features;
}
int main(int32_t argc, char *argv[]) {
if (argc < 8 || argc > 9) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx \
/path/to/tokens.txt \
/path/to/encoder.onnx \
/path/to/decoder.onnx \
/path/to/joiner.onnx \
/path/to/joiner_encoder_proj.ncnn.param \
/path/to/joiner_decoder_proj.ncnn.param \
/path/to/foo.wav [num_threads]
You can download pre-trained models from the following repository:
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
)usage";
std::cerr << usage << "\n";
return 0;
}
std::string tokens = argv[1];
std::string encoder = argv[2];
std::string decoder = argv[3];
std::string joiner = argv[4];
std::string joiner_encoder_proj = argv[5];
std::string joiner_decoder_proj = argv[6];
std::string wav_filename = argv[7];
int32_t num_threads = 4;
if (argc == 9) {
num_threads = atoi(argv[8]);
}
sherpa_onnx::SymbolTable sym(tokens);
int32_t num_frames;
auto features = ComputeFeatures(wav_filename, 16000, &num_frames);
int32_t feature_dim = features.size() / num_frames;
sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj,
joiner_decoder_proj, num_threads);
Ort::Value encoder_out =
model.RunEncoder(features.data(), num_frames, feature_dim);
auto hyp = sherpa_onnx::GreedySearch(model, encoder_out);
std::string text;
for (auto i : hyp) {
text += sym[i];
}
std::cout << "Recognition result for " << wav_filename << "\n"
<< text << "\n";
return 0;
}
... ...
... ... @@ -15,34 +15,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <sstream>
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "onnxruntime_cxx_api.h" // NOLINT
int main() {
knf::FbankOptions opts;
opts.frame_opts.dither = 0;
opts.mel_opts.num_bins = 10;
knf::OnlineFbank fbank(opts);
for (int32_t i = 0; i < 1600; ++i) {
float s = (i * i - i / 2) / 32767.;
fbank.AcceptWaveform(16000, &s, 1);
}
std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n";
std::vector<std::string> providers = Ort::GetAvailableProviders();
std::ostringstream os;
int32_t n = fbank.NumFramesReady();
for (int32_t i = 0; i != n; ++i) {
const float *frame = fbank.GetFrame(i);
for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) {
os << frame[k] << ", ";
}
os << "\n";
os << "Available providers: ";
std::string sep = "";
for (const auto &p : providers) {
os << sep << p;
sep = ", ";
}
std::cout << os.str() << "\n";
return 0;
}
... ...
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sherpa-onnx/csrc/symbol-table.h"
#include <cassert>
#include <fstream>
#include <sstream>
namespace sherpa_onnx {
SymbolTable::SymbolTable(const std::string &filename) {
std::ifstream is(filename);
std::string sym;
int32_t id;
while (is >> sym >> id) {
if (sym.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
sym = sym.replace(0, 3, " ");
}
}
assert(!sym.empty());
assert(sym2id_.count(sym) == 0);
assert(id2sym_.count(id) == 0);
sym2id_.insert({sym, id});
id2sym_.insert({id, sym});
}
assert(is.eof());
}
std::string SymbolTable::ToString() const {
std::ostringstream os;
char sep = ' ';
for (const auto &p : sym2id_) {
os << p.first << sep << p.second << "\n";
}
return os.str();
}
const std::string &SymbolTable::operator[](int32_t id) const {
return id2sym_.at(id);
}
int32_t SymbolTable::operator[](const std::string &sym) const {
return sym2id_.at(sym);
}
bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; }
bool SymbolTable::contains(const std::string &sym) const {
return sym2id_.count(sym) != 0;
}
std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
return os << symbol_table.ToString();
}
} // namespace sherpa_onnx
... ...
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
#include <string>
#include <unordered_map>
namespace sherpa_onnx {
/// It manages mapping between symbols and integer IDs.
class SymbolTable {
public:
SymbolTable() = default;
/// Construct a symbol table from a file.
/// Each line in the file contains two fields:
///
/// sym ID
///
/// Fields are separated by space(s).
explicit SymbolTable(const std::string &filename);
/// Return a string representation of this symbol table
std::string ToString() const;
/// Return the symbol corresponding to the given ID.
const std::string &operator[](int32_t id) const;
/// Return the ID corresponding to the given symbol.
int32_t operator[](const std::string &sym) const;
/// Return true if there is a symbol with the given ID.
bool contains(int32_t id) const;
/// Return true if there is a given symbol in the symbol table.
bool contains(const std::string &sym) const;
private:
std::unordered_map<std::string, int32_t> sym2id_;
std::unordered_map<int32_t, std::string> id2sym_;
};
std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
... ...
#include <iostream>
#include <fstream>
void vector2file(std::vector<float> vector, std::string saveFileName){
std::ofstream f(saveFileName);
for(std::vector<float>::const_iterator i = vector.begin(); i != vector.end(); ++i) {
f << *i << '\n';
}
}
std::vector<std::string> hyps2result(std::map<int, std::string> token_map, std::vector<std::vector<int32_t>> hyps, int context_size = 2){
std::vector<std::string> results;
for (int k=0; k < hyps.size(); k++){
std::string result = token_map[hyps[k][context_size]];
for (int i=context_size+1; i < hyps[k].size(); i++){
std::string token = token_map[hyps[k][i]];
// TODO: recognising '_' is not working
if (token.at(0) == '_')
result += " " + token;
else
result += token;
}
results.push_back(result);
}
return results;
}
void print_hyps(std::vector<std::vector<int32_t>> hyps, int context_size = 2){
std::cout << "Hyps:" << std::endl;
for (int i=context_size; i<hyps[0].size(); i++)
std::cout << hyps[0][i] << "-";
std::cout << "|" << std::endl;
}
#include <iostream>
#include "onnxruntime_cxx_api.h"
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
const auto& api = Ort::GetApi();
OrtTensorRTProviderOptionsV2* tensorrt_options;
Ort::SessionOptions session_options;
Ort::AllocatorWithDefaultOptions allocator;
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<float> ortVal2Vector(Ort::Value &tensor, int tensor_length){
/**
* convert ort tensor to vector
*/
float* floatarr = tensor.GetTensorMutableData<float>();
std::vector<float> vector {floatarr, floatarr + tensor_length};
return vector;
}
void print_onnx_forward_output(std::vector<Ort::Value> &output_tensors, int num){
float* floatarr = output_tensors.front().GetTensorMutableData<float>();
for (int i = 0; i < num; i++)
printf("[%d] = %f\n", i, floatarr[i]);
}
void print_shape_of_ort_val(std::vector<Ort::Value> &tensor){
auto out_shape = tensor.front().GetTensorTypeAndShapeInfo().GetShape();
auto out_size = out_shape.size();
std::cout << "(";
for (int i=0; i<out_size; i++){
std::cout << out_shape[i];
if (i < out_size-1)
std::cout << ",";
}
std::cout << ")" << std::endl;
}
void print_model_info(Ort::Session &session, std::string title){
std::cout << "=== Printing '" << title << "' model ===" << std::endl;
Ort::AllocatorWithDefaultOptions allocator;
// print number of model input nodes
size_t num_input_nodes = session.GetInputCount();
std::vector<const char*> input_node_names(num_input_nodes);
std::vector<int64_t> input_node_dims;
printf("Number of inputs = %zu\n", num_input_nodes);
char* output_name = session.GetOutputName(0, allocator);
printf("output name: %s\n", output_name);
// iterate over all input nodes
for (int i = 0; i < num_input_nodes; i++) {
// print input node names
char* input_name = session.GetInputName(i, allocator);
printf("Input %d : name=%s\n", i, input_name);
input_node_names[i] = input_name;
// print input node types
Ort::TypeInfo type_info = session.GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
printf("Input %d : type=%d\n", i, type);
// print input shapes/dims
input_node_dims = tensor_info.GetShape();
printf("Input %d : num_dims=%zu\n", i, input_node_dims.size());
for (size_t j = 0; j < input_node_dims.size(); j++)
printf("Input %d : dim %zu=%jd\n", i, j, input_node_dims[j]);
}
std::cout << "=======================================" << std::endl;
}
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sherpa-onnx/csrc/wave-reader.h"
#include <cassert>
#include <fstream>
#include <iostream>
#include <utility>
#include <vector>
namespace sherpa_onnx {
namespace {
// see http://soundfile.sapp.org/doc/WaveFormat/
//
// Note: We assume little endian here
// TODO(fangjun): Support big endian
struct WaveHeader {
void Validate() const {
// F F I R
assert(chunk_id == 0x46464952);
assert(chunk_size == 36 + subchunk2_size);
// E V A W
assert(format == 0x45564157);
assert(subchunk1_id == 0x20746d66);
assert(subchunk1_size == 16); // 16 for PCM
assert(audio_format == 1); // 1 for PCM
assert(num_channels == 1); // we support only single channel for now
assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);
assert(block_align == num_channels * bits_per_sample / 8);
assert(bits_per_sample == 16); // we support only 16 bits per sample
}
int32_t chunk_id;
int32_t chunk_size;
int32_t format;
int32_t subchunk1_id;
int32_t subchunk1_size;
int16_t audio_format;
int16_t num_channels;
int32_t sample_rate;
int32_t byte_rate;
int16_t block_align;
int16_t bits_per_sample;
int32_t subchunk2_id;
int32_t subchunk2_size;
};
static_assert(sizeof(WaveHeader) == 44, "");
// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) {
WaveHeader header;
is.read(reinterpret_cast<char *>(&header), sizeof(header));
assert(static_cast<bool>(is));
header.Validate();
*sample_rate = header.sample_rate;
// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains two bytes, so it is divided by 2 here
std::vector<int16_t> samples(header.subchunk2_size / 2);
is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
assert(static_cast<bool>(is));
std::vector<float> ans(samples.size());
for (int32_t i = 0; i != ans.size(); ++i) {
ans[i] = samples[i] / 32768.;
}
return ans;
}
} // namespace
std::vector<float> ReadWave(const std::string &filename,
float expected_sample_rate) {
std::ifstream is(filename, std::ifstream::binary);
float sample_rate;
auto samples = ReadWaveImpl(is, &sample_rate);
if (expected_sample_rate != sample_rate) {
std::cerr << "Expected sample rate: " << expected_sample_rate
<< ". Given: " << sample_rate << ".\n";
exit(-1);
}
return samples;
}
} // namespace sherpa_onnx
... ...
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_
#define SHERPA_ONNX_CSRC_WAVE_READER_H_
#include <istream>
#include <string>
#include <vector>
namespace sherpa_onnx {
/** Read a wave file with expected sample rate.
@param filename Path to a wave file. It MUST be single channel, PCM encoded.
@param expected_sample_rate Expected sample rate of the wave file. If the
sample rate don't match, it throws an exception.
@return Return wave samples normalized to the range [-1, 1).
*/
std::vector<float> ReadWave(const std::string &filename,
float expected_sample_rate);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_WAVE_READER_H_
... ...