Fangjun Kuang
Committed by GitHub

code refactoring and add CI (#11)

  1 +name: test-linux
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + paths:
  8 + - '.github/workflows/test-linux.yaml'
  9 + - 'CMakeLists.txt'
  10 + - 'cmake/**'
  11 + - 'sherpa-onnx/csrc/*'
  12 + pull_request:
  13 + branches:
  14 + - master
  15 + paths:
  16 + - '.github/workflows/test-linux.yaml'
  17 + - 'CMakeLists.txt'
  18 + - 'cmake/**'
  19 + - 'sherpa-onnx/csrc/*'
  20 +
  21 +concurrency:
  22 + group: test-linux-${{ github.ref }}
  23 + cancel-in-progress: true
  24 +
  25 +permissions:
  26 + contents: read
  27 +
  28 +jobs:
  29 + test-linux:
  30 + runs-on: ${{ matrix.os }}
  31 + strategy:
  32 + fail-fast: false
  33 + matrix:
  34 + os: [ubuntu-latest]
  35 +
  36 + steps:
  37 + - uses: actions/checkout@v2
  38 + with:
  39 + fetch-depth: 0
  40 +
  41 + - name: Download pretrained model and test-data (English)
  42 + shell: bash
  43 + run: |
  44 + git lfs install
  45 + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
  46 +
  47 + - name: Configure Cmake
  48 + shell: bash
  49 + run: |
  50 + mkdir build
  51 + cd build
  52 + cmake -D CMAKE_BUILD_TYPE=Release ..
  53 +
  54 + - name: Build sherpa-onnx for ubuntu
  55 + run: |
  56 + cd build
  57 + make VERBOSE=1 -j3
  58 +
  59 + - name: Run tests for ubuntu (English)
  60 + run: |
  61 + time ./build/bin/sherpa-onnx \
  62 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
  63 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
  64 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
  65 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
  66 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
  67 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
  68 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
  69 +
  70 + time ./build/bin/sherpa-onnx \
  71 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
  72 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
  73 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
  74 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
  75 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
  76 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
  77 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
  78 +
  79 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
  80 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
  81 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
  82 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
  83 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
  84 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
  85 + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
  1 +build
@@ -38,7 +38,8 @@ set(CMAKE_CXX_EXTENSIONS OFF) @@ -38,7 +38,8 @@ set(CMAKE_CXX_EXTENSIONS OFF)
38 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) 38 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
39 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) 39 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
40 40
41 -include(cmake/kaldi_native_io.cmake)  
42 -include(cmake/kaldi-native-fbank.cmake) 41 +include(kaldi_native_io)
  42 +include(kaldi-native-fbank)
  43 +include(onnxruntime)
43 44
44 add_subdirectory(sherpa-onnx) 45 add_subdirectory(sherpa-onnx)
1 -if(DEFINED ENV{KALDI_NATIVE_IO_INSTALL_PREFIX})  
2 - message(STATUS "Using environment variable KALDI_NATIVE_IO_INSTALL_PREFIX: $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}")  
3 - set(KALDI_NATIVE_IO_CMAKE_PREFIX_PATH $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX})  
4 -else()  
5 - # PYTHON_EXECUTABLE is set by cmake/pybind11.cmake  
6 - message(STATUS "Python executable: ${PYTHON_EXECUTABLE}")  
7 -  
8 - execute_process(  
9 - COMMAND "${PYTHON_EXECUTABLE}" -c "import kaldi_native_io; print(kaldi_native_io.cmake_prefix_path)"  
10 - OUTPUT_STRIP_TRAILING_WHITESPACE  
11 - OUTPUT_VARIABLE KALDI_NATIVE_IO_CMAKE_PREFIX_PATH 1 +function(download_kaldi_native_io)
  2 + if(CMAKE_VERSION VERSION_LESS 3.11)
  3 + # FetchContent is available since 3.11,
  4 + # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
  5 + # so that it can be used in lower CMake versions.
  6 + message(STATUS "Use FetchContent provided by sherpa-onnx")
  7 + list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
  8 + endif()
  9 +
  10 + include(FetchContent)
  11 +
  12 + set(kaldi_native_io_URL "https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz")
  13 + set(kaldi_native_io_HASH "SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6")
  14 +
  15 + set(KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL "" FORCE)
  16 + set(KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
  17 +
  18 + FetchContent_Declare(kaldi_native_io
  19 + URL ${kaldi_native_io_URL}
  20 + URL_HASH ${kaldi_native_io_HASH}
12 ) 21 )
13 -endif()  
14 22
15 -message(STATUS "KALDI_NATIVE_IO_CMAKE_PREFIX_PATH: ${KALDI_NATIVE_IO_CMAKE_PREFIX_PATH}")  
16 -list(APPEND CMAKE_PREFIX_PATH "${KALDI_NATIVE_IO_CMAKE_PREFIX_PATH}") 23 + FetchContent_GetProperties(kaldi_native_io)
  24 + if(NOT kaldi_native_io_POPULATED)
  25 + message(STATUS "Downloading kaldi_native_io ${kaldi_native_io_URL}")
  26 + FetchContent_Populate(kaldi_native_io)
  27 + endif()
  28 + message(STATUS "kaldi_native_io is downloaded to ${kaldi_native_io_SOURCE_DIR}")
  29 + message(STATUS "kaldi_native_io's binary dir is ${kaldi_native_io_BINARY_DIR}")
17 30
18 -find_package(kaldi_native_io REQUIRED) 31 + add_subdirectory(${kaldi_native_io_SOURCE_DIR} ${kaldi_native_io_BINARY_DIR} EXCLUDE_FROM_ALL)
19 32
20 -message(STATUS "KALDI_NATIVE_IO_FOUND: ${KALDI_NATIVE_IO_FOUND}")  
21 -message(STATUS "KALDI_NATIVE_IO_VERSION: ${KALDI_NATIVE_IO_VERSION}")  
22 -message(STATUS "KALDI_NATIVE_IO_INCLUDE_DIRS: ${KALDI_NATIVE_IO_INCLUDE_DIRS}")  
23 -message(STATUS "KALDI_NATIVE_IO_CXX_FLAGS: ${KALDI_NATIVE_IO_CXX_FLAGS}")  
24 -message(STATUS "KALDI_NATIVE_IO_LIBRARIES: ${KALDI_NATIVE_IO_LIBRARIES}") 33 + target_include_directories(kaldi_native_io_core
  34 + PUBLIC
  35 + ${kaldi_native_io_SOURCE_DIR}/
  36 + )
  37 +endfunction()
25 38
26 -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${KALDI_NATIVE_IO_CXX_FLAGS}")  
27 -message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")  
  39 +download_kaldi_native_io()
  1 +function(download_onnxruntime)
  2 + if(CMAKE_VERSION VERSION_LESS 3.11)
  3 + # FetchContent is available since 3.11,
  4 + # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
  5 + # so that it can be used in lower CMake versions.
  6 + message(STATUS "Use FetchContent provided by sherpa-onnx")
  7 + list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
  8 + endif()
  9 +
  10 + include(FetchContent)
  11 +
  12 + if(UNIX AND NOT APPLE)
  13 + # set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz")
  14 +
  15 + # If you don't have access to the internet, you can first download onnxruntime to some directory, and the use
  16 + # set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz")
  17 +
  18 + set(onnxruntime_HASH "SHA256=8f6eb9e2da9cf74e7905bf3fc687ef52e34cc566af7af2f92dafe5a5d106aa3d")
  19 + # After downloading, it contains:
  20 + # ./lib/libonnxruntime.so.1.12.1
  21 + # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.12.1
  22 + #
  23 + # ./include
  24 + # It contains all the needed header files
  25 + else()
  26 + message(FATAL_ERROR "Only support Linux at present. Will support other OSes later")
  27 + endif()
  28 +
  29 + FetchContent_Declare(onnxruntime
  30 + URL ${onnxruntime_URL}
  31 + URL_HASH ${onnxruntime_HASH}
  32 + )
  33 +
  34 + FetchContent_GetProperties(onnxruntime)
  35 + if(NOT onnxruntime_POPULATED)
  36 + message(STATUS "Downloading onnxruntime ${onnxruntime_URL}")
  37 + FetchContent_Populate(onnxruntime)
  38 + endif()
  39 + message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
  40 +
  41 + find_library(location_onnxruntime onnxruntime
  42 + PATHS
  43 + "${onnxruntime_SOURCE_DIR}/lib"
  44 + )
  45 +
  46 + message(STATUS "location_onnxruntime: ${location_onnxruntime}")
  47 +
  48 + add_library(onnxruntime SHARED IMPORTED)
  49 + set_target_properties(onnxruntime PROPERTIES
  50 + IMPORTED_LOCATION ${location_onnxruntime}
  51 + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
  52 + )
  53 +endfunction()
  54 +
  55 +download_onnxruntime()
1 -add_executable(online-fbank-test online-fbank-test.cc)  
2 -target_link_libraries(online-fbank-test kaldi-native-fbank-core)  
3 -  
4 -include_directories(  
5 - ${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/session/  
6 - ${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/providers/tensorrt/  
7 -) 1 +include_directories(${CMAKE_SOURCE_DIR})
  2 +add_executable(sherpa-onnx main.cpp)
8 3
9 -include_directories(  
10 - ${KALDINATIVEIO} 4 +target_link_libraries(sherpa-onnx
  5 + onnxruntime
  6 + kaldi-native-fbank-core
  7 + kaldi_native_io_core
11 ) 8 )
12 -add_executable(sherpa-onnx main.cpp)  
13 -target_link_libraries(sherpa-onnx onnxruntime kaldi-native-fbank-core kaldi_native_io_core)  
1 -#include <vector>  
2 -#include <iostream>  
3 #include <algorithm> 1 #include <algorithm>
4 -#include <time.h>  
5 -#include <math.h>  
6 #include <fstream> 2 #include <fstream>
  3 +#include <iostream>
  4 +#include <math.h>
  5 +#include <time.h>
  6 +#include <vector>
7 7
8 -#include "fbank_features.h"  
9 -#include "rnnt_beam_search.h" 8 +#include "sherpa-onnx/csrc/fbank_features.h"
  9 +#include "sherpa-onnx/csrc/rnnt_beam_search.h"
10 10
11 #include "kaldi-native-fbank/csrc/online-feature.h" 11 #include "kaldi-native-fbank/csrc/online-feature.h"
12 12
13 -  
14 -int main(int argc, char* argv[]) {  
15 - char* encoder_path = argv[1];  
16 - char* decoder_path = argv[2];  
17 - char* joiner_path = argv[3];  
18 - char* joiner_encoder_proj_path = argv[4];  
19 - char* joiner_decoder_proj_path = argv[5];  
20 - char* token_path = argv[6];  
21 - std::string search_method = argv[7];  
22 - char* filename = argv[8];  
23 -  
24 - // General parameters  
25 - int numberOfThreads = 16;  
26 -  
27 - // Initialize fbanks  
28 - knf::FbankOptions opts;  
29 - opts.frame_opts.dither = 0;  
30 - opts.frame_opts.samp_freq = 16000;  
31 - opts.frame_opts.frame_shift_ms = 10.0f;  
32 - opts.frame_opts.frame_length_ms = 25.0f;  
33 - opts.mel_opts.num_bins = 80;  
34 - opts.frame_opts.window_type = "povey";  
35 - opts.frame_opts.snip_edges = false;  
36 - knf::OnlineFbank fbank(opts);  
37 -  
38 - // set session opts  
39 - // https://onnxruntime.ai/docs/performance/tune-performance.html  
40 - session_options.SetIntraOpNumThreads(numberOfThreads);  
41 - session_options.SetInterOpNumThreads(numberOfThreads);  
42 - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);  
43 - session_options.SetLogSeverityLevel(4);  
44 - session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);  
45 -  
46 - api.CreateTensorRTProviderOptions(&tensorrt_options);  
47 - std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(api.ReleaseTensorRTProviderOptions)> rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);  
48 - api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast<OrtSessionOptions*>(session_options), rel_trt_options.get());  
49 -  
50 - // Define model  
51 - auto model = get_model(  
52 - encoder_path,  
53 - decoder_path,  
54 - joiner_path,  
55 - joiner_encoder_proj_path,  
56 - joiner_decoder_proj_path,  
57 - token_path  
58 - );  
59 -  
60 - std::vector<std::string> filename_list {  
61 - filename  
62 - };  
63 -  
64 - for (auto filename : filename_list){  
65 - std::cout << filename << std::endl;  
66 - auto samples = readWav(filename, true);  
67 - int numSamples = samples.NumCols();  
68 -  
69 - auto features = ComputeFeatures(fbank, opts, samples);  
70 -  
71 - auto tic = std::chrono::high_resolution_clock::now();  
72 -  
73 - // # === Encoder Out === #  
74 - int num_frames = features.size() / opts.mel_opts.num_bins;  
75 - auto encoder_out = model.encoder_forward(features,  
76 - std::vector<int64_t> {num_frames},  
77 - std::vector<int64_t> {1, num_frames, 80},  
78 - std::vector<int64_t> {1},  
79 - memory_info);  
80 -  
81 - // # === Search === #  
82 - std::vector<std::vector<int32_t>> hyps;  
83 - if (search_method == "greedy")  
84 - hyps = GreedySearch(&model, &encoder_out);  
85 - else{  
86 - std::cout << "wrong search method!" << std::endl;  
87 - exit(0);  
88 - }  
89 - auto results = hyps2result(model.tokens_map, hyps);  
90 -  
91 - // # === Print Elapsed Time === #  
92 - auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - tic);  
93 - std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" << std::endl;  
94 - std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) << std::endl;  
95 -  
96 - print_hyps(hyps);  
97 - std::cout << results[0] << std::endl; 13 +int main(int argc, char *argv[]) {
  14 + char *encoder_path = argv[1];
  15 + char *decoder_path = argv[2];
  16 + char *joiner_path = argv[3];
  17 + char *joiner_encoder_proj_path = argv[4];
  18 + char *joiner_decoder_proj_path = argv[5];
  19 + char *token_path = argv[6];
  20 + std::string search_method = argv[7];
  21 + char *filename = argv[8];
  22 +
  23 + // General parameters
  24 + int numberOfThreads = 16;
  25 +
  26 + // Initialize fbanks
  27 + knf::FbankOptions opts;
  28 + opts.frame_opts.dither = 0;
  29 + opts.frame_opts.samp_freq = 16000;
  30 + opts.frame_opts.frame_shift_ms = 10.0f;
  31 + opts.frame_opts.frame_length_ms = 25.0f;
  32 + opts.mel_opts.num_bins = 80;
  33 + opts.frame_opts.window_type = "povey";
  34 + opts.frame_opts.snip_edges = false;
  35 + knf::OnlineFbank fbank(opts);
  36 +
  37 + // set session opts
  38 + // https://onnxruntime.ai/docs/performance/tune-performance.html
  39 + session_options.SetIntraOpNumThreads(numberOfThreads);
  40 + session_options.SetInterOpNumThreads(numberOfThreads);
  41 + session_options.SetGraphOptimizationLevel(
  42 + GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
  43 + session_options.SetLogSeverityLevel(4);
  44 + session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
  45 +
  46 + api.CreateTensorRTProviderOptions(&tensorrt_options);
  47 + std::unique_ptr<OrtTensorRTProviderOptionsV2,
  48 + decltype(api.ReleaseTensorRTProviderOptions)>
  49 + rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);
  50 + api.SessionOptionsAppendExecutionProvider_TensorRT_V2(
  51 + static_cast<OrtSessionOptions *>(session_options), rel_trt_options.get());
  52 +
  53 + // Define model
  54 + auto model =
  55 + get_model(encoder_path, decoder_path, joiner_path,
  56 + joiner_encoder_proj_path, joiner_decoder_proj_path, token_path);
  57 +
  58 + std::vector<std::string> filename_list{filename};
  59 +
  60 + for (auto filename : filename_list) {
  61 + std::cout << filename << std::endl;
  62 + auto samples = readWav(filename, true);
  63 + int numSamples = samples.NumCols();
  64 +
  65 + auto features = ComputeFeatures(fbank, opts, samples);
  66 +
  67 + auto tic = std::chrono::high_resolution_clock::now();
  68 +
  69 + // # === Encoder Out === #
  70 + int num_frames = features.size() / opts.mel_opts.num_bins;
  71 + auto encoder_out =
  72 + model.encoder_forward(features, std::vector<int64_t>{num_frames},
  73 + std::vector<int64_t>{1, num_frames, 80},
  74 + std::vector<int64_t>{1}, memory_info);
  75 +
  76 + // # === Search === #
  77 + std::vector<std::vector<int32_t>> hyps;
  78 + if (search_method == "greedy")
  79 + hyps = GreedySearch(&model, &encoder_out);
  80 + else {
  81 + std::cout << "wrong search method!" << std::endl;
  82 + exit(0);
98 } 83 }
  84 + auto results = hyps2result(model.tokens_map, hyps);
  85 +
  86 + // # === Print Elapsed Time === #
  87 + auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
  88 + std::chrono::high_resolution_clock::now() - tic);
  89 + std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds"
  90 + << std::endl;
  91 + std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000)
  92 + << std::endl;
  93 +
  94 + print_hyps(hyps);
  95 + std::cout << results[0] << std::endl;
  96 + }
99 97
100 - return 0; 98 + return 0;
101 } 99 }
@@ -61,7 +61,6 @@ std::vector<std::vector<int32_t>> GreedySearch( @@ -61,7 +61,6 @@ std::vector<std::vector<int32_t>> GreedySearch(
61 auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector, 61 auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector,
62 std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2}, 62 std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2},
63 memory_info); 63 memory_info);
64 -  
65 Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0]; 64 Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0];
66 int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0]; 65 int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0];
67 int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; 66 int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
@@ -78,12 +77,12 @@ std::vector<std::vector<int32_t>> GreedySearch( @@ -78,12 +77,12 @@ std::vector<std::vector<int32_t>> GreedySearch(
78 77
79 auto logits = model->joiner_forward(cur_encoder_out, 78 auto logits = model->joiner_forward(cur_encoder_out,
80 projected_decoder_out_vector, 79 projected_decoder_out_vector,
81 - std::vector<int64_t> {1, 1, 1, projected_encoder_out_dim2},  
82 - std::vector<int64_t> {1, 1, 1, projected_decoder_out_dim}, 80 + std::vector<int64_t> {1, projected_encoder_out_dim2},
  81 + std::vector<int64_t> {1, projected_decoder_out_dim},
83 memory_info); 82 memory_info);
84 83
85 Ort::Value &logits_tensor = logits[0]; 84 Ort::Value &logits_tensor = logits[0];
86 - int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[3]; 85 + int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
87 auto logits_vector = ortVal2Vector(logits_tensor, logits_dim); 86 auto logits_vector = ortVal2Vector(logits_tensor, logits_dim);
88 87
89 int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end()))); 88 int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end())));
1 #include <iostream> 1 #include <iostream>
2 -#include <onnxruntime_cxx_api.h> 2 +#include "onnxruntime_cxx_api.h"
3 3
4 Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); 4 Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
5 const auto& api = Ort::GetApi(); 5 const auto& api = Ort::GetApi();