正在显示
9 个修改的文件
包含
278 行增加
和
132 行删除
.github/workflows/test-linux.yaml
0 → 100644
| 1 | +name: test-linux | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - master | ||
| 7 | + paths: | ||
| 8 | + - '.github/workflows/test-linux.yaml' | ||
| 9 | + - 'CMakeLists.txt' | ||
| 10 | + - 'cmake/**' | ||
| 11 | + - 'sherpa-onnx/csrc/*' | ||
| 12 | + pull_request: | ||
| 13 | + branches: | ||
| 14 | + - master | ||
| 15 | + paths: | ||
| 16 | + - '.github/workflows/test-linux.yaml' | ||
| 17 | + - 'CMakeLists.txt' | ||
| 18 | + - 'cmake/**' | ||
| 19 | + - 'sherpa-onnx/csrc/*' | ||
| 20 | + | ||
| 21 | +concurrency: | ||
| 22 | + group: test-linux-${{ github.ref }} | ||
| 23 | + cancel-in-progress: true | ||
| 24 | + | ||
| 25 | +permissions: | ||
| 26 | + contents: read | ||
| 27 | + | ||
| 28 | +jobs: | ||
| 29 | + test-linux: | ||
| 30 | + runs-on: ${{ matrix.os }} | ||
| 31 | + strategy: | ||
| 32 | + fail-fast: false | ||
| 33 | + matrix: | ||
| 34 | + os: [ubuntu-latest] | ||
| 35 | + | ||
| 36 | + steps: | ||
| 37 | + - uses: actions/checkout@v2 | ||
| 38 | + with: | ||
| 39 | + fetch-depth: 0 | ||
| 40 | + | ||
| 41 | + - name: Download pretrained model and test-data (English) | ||
| 42 | + shell: bash | ||
| 43 | + run: | | ||
| 44 | + git lfs install | ||
| 45 | + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 | ||
| 46 | + | ||
| 47 | + - name: Configure Cmake | ||
| 48 | + shell: bash | ||
| 49 | + run: | | ||
| 50 | + mkdir build | ||
| 51 | + cd build | ||
| 52 | + cmake -D CMAKE_BUILD_TYPE=Release .. | ||
| 53 | + | ||
| 54 | + - name: Build sherpa-onnx for ubuntu | ||
| 55 | + run: | | ||
| 56 | + cd build | ||
| 57 | + make VERBOSE=1 -j3 | ||
| 58 | + | ||
| 59 | + - name: Run tests for ubuntu (English) | ||
| 60 | + run: | | ||
| 61 | + time ./build/bin/sherpa-onnx \ | ||
| 62 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | ||
| 63 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | ||
| 64 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | ||
| 65 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | ||
| 66 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | ||
| 67 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 68 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav | ||
| 69 | + | ||
| 70 | + time ./build/bin/sherpa-onnx \ | ||
| 71 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | ||
| 72 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | ||
| 73 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | ||
| 74 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | ||
| 75 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | ||
| 76 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 77 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav | ||
| 78 | + | ||
| 79 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ | ||
| 80 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ | ||
| 81 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ | ||
| 82 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ | ||
| 83 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ | ||
| 84 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ | ||
| 85 | + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav |
.gitignore
0 → 100644
| 1 | +build |
| @@ -38,7 +38,8 @@ set(CMAKE_CXX_EXTENSIONS OFF) | @@ -38,7 +38,8 @@ set(CMAKE_CXX_EXTENSIONS OFF) | ||
| 38 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) | 38 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) |
| 39 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) | 39 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) |
| 40 | 40 | ||
| 41 | -include(cmake/kaldi_native_io.cmake) | ||
| 42 | -include(cmake/kaldi-native-fbank.cmake) | 41 | +include(kaldi_native_io) |
| 42 | +include(kaldi-native-fbank) | ||
| 43 | +include(onnxruntime) | ||
| 43 | 44 | ||
| 44 | add_subdirectory(sherpa-onnx) | 45 | add_subdirectory(sherpa-onnx) |
| 1 | -if(DEFINED ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}) | ||
| 2 | - message(STATUS "Using environment variable KALDI_NATIVE_IO_INSTALL_PREFIX: $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}") | ||
| 3 | - set(KALDI_NATIVE_IO_CMAKE_PREFIX_PATH $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}) | ||
| 4 | -else() | ||
| 5 | - # PYTHON_EXECUTABLE is set by cmake/pybind11.cmake | ||
| 6 | - message(STATUS "Python executable: ${PYTHON_EXECUTABLE}") | ||
| 7 | - | ||
| 8 | - execute_process( | ||
| 9 | - COMMAND "${PYTHON_EXECUTABLE}" -c "import kaldi_native_io; print(kaldi_native_io.cmake_prefix_path)" | ||
| 10 | - OUTPUT_STRIP_TRAILING_WHITESPACE | ||
| 11 | - OUTPUT_VARIABLE KALDI_NATIVE_IO_CMAKE_PREFIX_PATH | 1 | +function(download_kaldi_native_io) |
| 2 | + if(CMAKE_VERSION VERSION_LESS 3.11) | ||
| 3 | + # FetchContent is available since 3.11, | ||
| 4 | + # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules | ||
| 5 | + # so that it can be used in lower CMake versions. | ||
| 6 | + message(STATUS "Use FetchContent provided by sherpa-onnx") | ||
| 7 | + list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) | ||
| 8 | + endif() | ||
| 9 | + | ||
| 10 | + include(FetchContent) | ||
| 11 | + | ||
| 12 | + set(kaldi_native_io_URL "https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz") | ||
| 13 | + set(kaldi_native_io_HASH "SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6") | ||
| 14 | + | ||
| 15 | + set(KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL "" FORCE) | ||
| 16 | + set(KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | ||
| 17 | + | ||
| 18 | + FetchContent_Declare(kaldi_native_io | ||
| 19 | + URL ${kaldi_native_io_URL} | ||
| 20 | + URL_HASH ${kaldi_native_io_HASH} | ||
| 12 | ) | 21 | ) |
| 13 | -endif() | ||
| 14 | 22 | ||
| 15 | -message(STATUS "KALDI_NATIVE_IO_CMAKE_PREFIX_PATH: ${KALDI_NATIVE_IO_CMAKE_PREFIX_PATH}") | ||
| 16 | -list(APPEND CMAKE_PREFIX_PATH "${KALDI_NATIVE_IO_CMAKE_PREFIX_PATH}") | 23 | + FetchContent_GetProperties(kaldi_native_io) |
| 24 | + if(NOT kaldi_native_io_POPULATED) | ||
| 25 | + message(STATUS "Downloading kaldi_native_io ${kaldi_native_io_URL}") | ||
| 26 | + FetchContent_Populate(kaldi_native_io) | ||
| 27 | + endif() | ||
| 28 | + message(STATUS "kaldi_native_io is downloaded to ${kaldi_native_io_SOURCE_DIR}") | ||
| 29 | + message(STATUS "kaldi_native_io's binary dir is ${kaldi_native_io_BINARY_DIR}") | ||
| 17 | 30 | ||
| 18 | -find_package(kaldi_native_io REQUIRED) | 31 | + add_subdirectory(${kaldi_native_io_SOURCE_DIR} ${kaldi_native_io_BINARY_DIR} EXCLUDE_FROM_ALL) |
| 19 | 32 | ||
| 20 | -message(STATUS "KALDI_NATIVE_IO_FOUND: ${KALDI_NATIVE_IO_FOUND}") | ||
| 21 | -message(STATUS "KALDI_NATIVE_IO_VERSION: ${KALDI_NATIVE_IO_VERSION}") | ||
| 22 | -message(STATUS "KALDI_NATIVE_IO_INCLUDE_DIRS: ${KALDI_NATIVE_IO_INCLUDE_DIRS}") | ||
| 23 | -message(STATUS "KALDI_NATIVE_IO_CXX_FLAGS: ${KALDI_NATIVE_IO_CXX_FLAGS}") | ||
| 24 | -message(STATUS "KALDI_NATIVE_IO_LIBRARIES: ${KALDI_NATIVE_IO_LIBRARIES}") | 33 | + target_include_directories(kaldi_native_io_core |
| 34 | + PUBLIC | ||
| 35 | + ${kaldi_native_io_SOURCE_DIR}/ | ||
| 36 | + ) | ||
| 37 | +endfunction() | ||
| 25 | 38 | ||
| 26 | -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${KALDI_NATIVE_IO_CXX_FLAGS}") | ||
| 27 | -message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") | ||
| 39 | +download_kaldi_native_io() |
cmake/onnxruntime.cmake
0 → 100644
| 1 | +function(download_onnxruntime) | ||
| 2 | + if(CMAKE_VERSION VERSION_LESS 3.11) | ||
| 3 | + # FetchContent is available since 3.11, | ||
| 4 | + # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules | ||
| 5 | + # so that it can be used in lower CMake versions. | ||
| 6 | + message(STATUS "Use FetchContent provided by sherpa-onnx") | ||
| 7 | + list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) | ||
| 8 | + endif() | ||
| 9 | + | ||
| 10 | + include(FetchContent) | ||
| 11 | + | ||
| 12 | + if(UNIX AND NOT APPLE) | ||
| 13 | + # set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz") | ||
| 14 | + | ||
| 15 | + # If you don't have access to the internet, you can first download onnxruntime to some directory, and the use | ||
| 16 | + # set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz") | ||
| 17 | + | ||
| 18 | + set(onnxruntime_HASH "SHA256=8f6eb9e2da9cf74e7905bf3fc687ef52e34cc566af7af2f92dafe5a5d106aa3d") | ||
| 19 | + # After downloading, it contains: | ||
| 20 | + # ./lib/libonnxruntime.so.1.12.1 | ||
| 21 | + # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.12.1 | ||
| 22 | + # | ||
| 23 | + # ./include | ||
| 24 | + # It contains all the needed header files | ||
| 25 | + else() | ||
| 26 | + message(FATAL_ERROR "Only support Linux at present. Will support other OSes later") | ||
| 27 | + endif() | ||
| 28 | + | ||
| 29 | + FetchContent_Declare(onnxruntime | ||
| 30 | + URL ${onnxruntime_URL} | ||
| 31 | + URL_HASH ${onnxruntime_HASH} | ||
| 32 | + ) | ||
| 33 | + | ||
| 34 | + FetchContent_GetProperties(onnxruntime) | ||
| 35 | + if(NOT onnxruntime_POPULATED) | ||
| 36 | + message(STATUS "Downloading onnxruntime ${onnxruntime_URL}") | ||
| 37 | + FetchContent_Populate(onnxruntime) | ||
| 38 | + endif() | ||
| 39 | + message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") | ||
| 40 | + | ||
| 41 | + find_library(location_onnxruntime onnxruntime | ||
| 42 | + PATHS | ||
| 43 | + "${onnxruntime_SOURCE_DIR}/lib" | ||
| 44 | + ) | ||
| 45 | + | ||
| 46 | + message(STATUS "location_onnxruntime: ${location_onnxruntime}") | ||
| 47 | + | ||
| 48 | + add_library(onnxruntime SHARED IMPORTED) | ||
| 49 | + set_target_properties(onnxruntime PROPERTIES | ||
| 50 | + IMPORTED_LOCATION ${location_onnxruntime} | ||
| 51 | + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" | ||
| 52 | + ) | ||
| 53 | +endfunction() | ||
| 54 | + | ||
| 55 | +download_onnxruntime() |
| 1 | -add_executable(online-fbank-test online-fbank-test.cc) | ||
| 2 | -target_link_libraries(online-fbank-test kaldi-native-fbank-core) | ||
| 3 | - | ||
| 4 | -include_directories( | ||
| 5 | - ${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/session/ | ||
| 6 | - ${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/providers/tensorrt/ | ||
| 7 | -) | 1 | +include_directories(${CMAKE_SOURCE_DIR}) |
| 2 | +add_executable(sherpa-onnx main.cpp) | ||
| 8 | 3 | ||
| 9 | -include_directories( | ||
| 10 | - ${KALDINATIVEIO} | 4 | +target_link_libraries(sherpa-onnx |
| 5 | + onnxruntime | ||
| 6 | + kaldi-native-fbank-core | ||
| 7 | + kaldi_native_io_core | ||
| 11 | ) | 8 | ) |
| 12 | -add_executable(sherpa-onnx main.cpp) | ||
| 13 | -target_link_libraries(sherpa-onnx onnxruntime kaldi-native-fbank-core kaldi_native_io_core) |
| 1 | -#include <vector> | ||
| 2 | -#include <iostream> | ||
| 3 | #include <algorithm> | 1 | #include <algorithm> |
| 4 | -#include <time.h> | ||
| 5 | -#include <math.h> | ||
| 6 | #include <fstream> | 2 | #include <fstream> |
| 3 | +#include <iostream> | ||
| 4 | +#include <math.h> | ||
| 5 | +#include <time.h> | ||
| 6 | +#include <vector> | ||
| 7 | 7 | ||
| 8 | -#include "fbank_features.h" | ||
| 9 | -#include "rnnt_beam_search.h" | 8 | +#include "sherpa-onnx/csrc/fbank_features.h" |
| 9 | +#include "sherpa-onnx/csrc/rnnt_beam_search.h" | ||
| 10 | 10 | ||
| 11 | #include "kaldi-native-fbank/csrc/online-feature.h" | 11 | #include "kaldi-native-fbank/csrc/online-feature.h" |
| 12 | 12 | ||
| 13 | - | ||
| 14 | -int main(int argc, char* argv[]) { | ||
| 15 | - char* encoder_path = argv[1]; | ||
| 16 | - char* decoder_path = argv[2]; | ||
| 17 | - char* joiner_path = argv[3]; | ||
| 18 | - char* joiner_encoder_proj_path = argv[4]; | ||
| 19 | - char* joiner_decoder_proj_path = argv[5]; | ||
| 20 | - char* token_path = argv[6]; | ||
| 21 | - std::string search_method = argv[7]; | ||
| 22 | - char* filename = argv[8]; | ||
| 23 | - | ||
| 24 | - // General parameters | ||
| 25 | - int numberOfThreads = 16; | ||
| 26 | - | ||
| 27 | - // Initialize fbanks | ||
| 28 | - knf::FbankOptions opts; | ||
| 29 | - opts.frame_opts.dither = 0; | ||
| 30 | - opts.frame_opts.samp_freq = 16000; | ||
| 31 | - opts.frame_opts.frame_shift_ms = 10.0f; | ||
| 32 | - opts.frame_opts.frame_length_ms = 25.0f; | ||
| 33 | - opts.mel_opts.num_bins = 80; | ||
| 34 | - opts.frame_opts.window_type = "povey"; | ||
| 35 | - opts.frame_opts.snip_edges = false; | ||
| 36 | - knf::OnlineFbank fbank(opts); | ||
| 37 | - | ||
| 38 | - // set session opts | ||
| 39 | - // https://onnxruntime.ai/docs/performance/tune-performance.html | ||
| 40 | - session_options.SetIntraOpNumThreads(numberOfThreads); | ||
| 41 | - session_options.SetInterOpNumThreads(numberOfThreads); | ||
| 42 | - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); | ||
| 43 | - session_options.SetLogSeverityLevel(4); | ||
| 44 | - session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); | ||
| 45 | - | ||
| 46 | - api.CreateTensorRTProviderOptions(&tensorrt_options); | ||
| 47 | - std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(api.ReleaseTensorRTProviderOptions)> rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions); | ||
| 48 | - api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast<OrtSessionOptions*>(session_options), rel_trt_options.get()); | ||
| 49 | - | ||
| 50 | - // Define model | ||
| 51 | - auto model = get_model( | ||
| 52 | - encoder_path, | ||
| 53 | - decoder_path, | ||
| 54 | - joiner_path, | ||
| 55 | - joiner_encoder_proj_path, | ||
| 56 | - joiner_decoder_proj_path, | ||
| 57 | - token_path | ||
| 58 | - ); | ||
| 59 | - | ||
| 60 | - std::vector<std::string> filename_list { | ||
| 61 | - filename | ||
| 62 | - }; | ||
| 63 | - | ||
| 64 | - for (auto filename : filename_list){ | ||
| 65 | - std::cout << filename << std::endl; | ||
| 66 | - auto samples = readWav(filename, true); | ||
| 67 | - int numSamples = samples.NumCols(); | ||
| 68 | - | ||
| 69 | - auto features = ComputeFeatures(fbank, opts, samples); | ||
| 70 | - | ||
| 71 | - auto tic = std::chrono::high_resolution_clock::now(); | ||
| 72 | - | ||
| 73 | - // # === Encoder Out === # | ||
| 74 | - int num_frames = features.size() / opts.mel_opts.num_bins; | ||
| 75 | - auto encoder_out = model.encoder_forward(features, | ||
| 76 | - std::vector<int64_t> {num_frames}, | ||
| 77 | - std::vector<int64_t> {1, num_frames, 80}, | ||
| 78 | - std::vector<int64_t> {1}, | ||
| 79 | - memory_info); | ||
| 80 | - | ||
| 81 | - // # === Search === # | ||
| 82 | - std::vector<std::vector<int32_t>> hyps; | ||
| 83 | - if (search_method == "greedy") | ||
| 84 | - hyps = GreedySearch(&model, &encoder_out); | ||
| 85 | - else{ | ||
| 86 | - std::cout << "wrong search method!" << std::endl; | ||
| 87 | - exit(0); | ||
| 88 | - } | ||
| 89 | - auto results = hyps2result(model.tokens_map, hyps); | ||
| 90 | - | ||
| 91 | - // # === Print Elapsed Time === # | ||
| 92 | - auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - tic); | ||
| 93 | - std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" << std::endl; | ||
| 94 | - std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) << std::endl; | ||
| 95 | - | ||
| 96 | - print_hyps(hyps); | ||
| 97 | - std::cout << results[0] << std::endl; | 13 | +int main(int argc, char *argv[]) { |
| 14 | + char *encoder_path = argv[1]; | ||
| 15 | + char *decoder_path = argv[2]; | ||
| 16 | + char *joiner_path = argv[3]; | ||
| 17 | + char *joiner_encoder_proj_path = argv[4]; | ||
| 18 | + char *joiner_decoder_proj_path = argv[5]; | ||
| 19 | + char *token_path = argv[6]; | ||
| 20 | + std::string search_method = argv[7]; | ||
| 21 | + char *filename = argv[8]; | ||
| 22 | + | ||
| 23 | + // General parameters | ||
| 24 | + int numberOfThreads = 16; | ||
| 25 | + | ||
| 26 | + // Initialize fbanks | ||
| 27 | + knf::FbankOptions opts; | ||
| 28 | + opts.frame_opts.dither = 0; | ||
| 29 | + opts.frame_opts.samp_freq = 16000; | ||
| 30 | + opts.frame_opts.frame_shift_ms = 10.0f; | ||
| 31 | + opts.frame_opts.frame_length_ms = 25.0f; | ||
| 32 | + opts.mel_opts.num_bins = 80; | ||
| 33 | + opts.frame_opts.window_type = "povey"; | ||
| 34 | + opts.frame_opts.snip_edges = false; | ||
| 35 | + knf::OnlineFbank fbank(opts); | ||
| 36 | + | ||
| 37 | + // set session opts | ||
| 38 | + // https://onnxruntime.ai/docs/performance/tune-performance.html | ||
| 39 | + session_options.SetIntraOpNumThreads(numberOfThreads); | ||
| 40 | + session_options.SetInterOpNumThreads(numberOfThreads); | ||
| 41 | + session_options.SetGraphOptimizationLevel( | ||
| 42 | + GraphOptimizationLevel::ORT_ENABLE_EXTENDED); | ||
| 43 | + session_options.SetLogSeverityLevel(4); | ||
| 44 | + session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); | ||
| 45 | + | ||
| 46 | + api.CreateTensorRTProviderOptions(&tensorrt_options); | ||
| 47 | + std::unique_ptr<OrtTensorRTProviderOptionsV2, | ||
| 48 | + decltype(api.ReleaseTensorRTProviderOptions)> | ||
| 49 | + rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions); | ||
| 50 | + api.SessionOptionsAppendExecutionProvider_TensorRT_V2( | ||
| 51 | + static_cast<OrtSessionOptions *>(session_options), rel_trt_options.get()); | ||
| 52 | + | ||
| 53 | + // Define model | ||
| 54 | + auto model = | ||
| 55 | + get_model(encoder_path, decoder_path, joiner_path, | ||
| 56 | + joiner_encoder_proj_path, joiner_decoder_proj_path, token_path); | ||
| 57 | + | ||
| 58 | + std::vector<std::string> filename_list{filename}; | ||
| 59 | + | ||
| 60 | + for (auto filename : filename_list) { | ||
| 61 | + std::cout << filename << std::endl; | ||
| 62 | + auto samples = readWav(filename, true); | ||
| 63 | + int numSamples = samples.NumCols(); | ||
| 64 | + | ||
| 65 | + auto features = ComputeFeatures(fbank, opts, samples); | ||
| 66 | + | ||
| 67 | + auto tic = std::chrono::high_resolution_clock::now(); | ||
| 68 | + | ||
| 69 | + // # === Encoder Out === # | ||
| 70 | + int num_frames = features.size() / opts.mel_opts.num_bins; | ||
| 71 | + auto encoder_out = | ||
| 72 | + model.encoder_forward(features, std::vector<int64_t>{num_frames}, | ||
| 73 | + std::vector<int64_t>{1, num_frames, 80}, | ||
| 74 | + std::vector<int64_t>{1}, memory_info); | ||
| 75 | + | ||
| 76 | + // # === Search === # | ||
| 77 | + std::vector<std::vector<int32_t>> hyps; | ||
| 78 | + if (search_method == "greedy") | ||
| 79 | + hyps = GreedySearch(&model, &encoder_out); | ||
| 80 | + else { | ||
| 81 | + std::cout << "wrong search method!" << std::endl; | ||
| 82 | + exit(0); | ||
| 98 | } | 83 | } |
| 84 | + auto results = hyps2result(model.tokens_map, hyps); | ||
| 85 | + | ||
| 86 | + // # === Print Elapsed Time === # | ||
| 87 | + auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>( | ||
| 88 | + std::chrono::high_resolution_clock::now() - tic); | ||
| 89 | + std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" | ||
| 90 | + << std::endl; | ||
| 91 | + std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) | ||
| 92 | + << std::endl; | ||
| 93 | + | ||
| 94 | + print_hyps(hyps); | ||
| 95 | + std::cout << results[0] << std::endl; | ||
| 96 | + } | ||
| 99 | 97 | ||
| 100 | - return 0; | 98 | + return 0; |
| 101 | } | 99 | } |
| @@ -61,7 +61,6 @@ std::vector<std::vector<int32_t>> GreedySearch( | @@ -61,7 +61,6 @@ std::vector<std::vector<int32_t>> GreedySearch( | ||
| 61 | auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector, | 61 | auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector, |
| 62 | std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2}, | 62 | std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2}, |
| 63 | memory_info); | 63 | memory_info); |
| 64 | - | ||
| 65 | Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0]; | 64 | Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0]; |
| 66 | int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0]; | 65 | int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0]; |
| 67 | int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; | 66 | int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; |
| @@ -78,12 +77,12 @@ std::vector<std::vector<int32_t>> GreedySearch( | @@ -78,12 +77,12 @@ std::vector<std::vector<int32_t>> GreedySearch( | ||
| 78 | 77 | ||
| 79 | auto logits = model->joiner_forward(cur_encoder_out, | 78 | auto logits = model->joiner_forward(cur_encoder_out, |
| 80 | projected_decoder_out_vector, | 79 | projected_decoder_out_vector, |
| 81 | - std::vector<int64_t> {1, 1, 1, projected_encoder_out_dim2}, | ||
| 82 | - std::vector<int64_t> {1, 1, 1, projected_decoder_out_dim}, | 80 | + std::vector<int64_t> {1, projected_encoder_out_dim2}, |
| 81 | + std::vector<int64_t> {1, projected_decoder_out_dim}, | ||
| 83 | memory_info); | 82 | memory_info); |
| 84 | 83 | ||
| 85 | Ort::Value &logits_tensor = logits[0]; | 84 | Ort::Value &logits_tensor = logits[0]; |
| 86 | - int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[3]; | 85 | + int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; |
| 87 | auto logits_vector = ortVal2Vector(logits_tensor, logits_dim); | 86 | auto logits_vector = ortVal2Vector(logits_tensor, logits_dim); |
| 88 | 87 | ||
| 89 | int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end()))); | 88 | int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end()))); |
-
请 注册 或 登录 后发表评论