正在显示
9 个修改的文件
包含
683 行增加
和
0 行删除
| @@ -38,6 +38,7 @@ set(CMAKE_CXX_EXTENSIONS OFF) | @@ -38,6 +38,7 @@ set(CMAKE_CXX_EXTENSIONS OFF) | ||
| 38 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) | 38 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) |
| 39 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) | 39 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) |
| 40 | 40 | ||
| 41 | +include(cmake/kaldi_native_io.cmake) | ||
| 41 | include(cmake/kaldi-native-fbank.cmake) | 42 | include(cmake/kaldi-native-fbank.cmake) |
| 42 | 43 | ||
| 43 | add_subdirectory(sherpa-onnx) | 44 | add_subdirectory(sherpa-onnx) |
cmake/kaldi_native_io.cmake
0 → 100644
| 1 | +if(DEFINED ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}) | ||
| 2 | + message(STATUS "Using environment variable KALDI_NATIVE_IO_INSTALL_PREFIX: $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}") | ||
| 3 | + set(KALDI_NATIVE_IO_CMAKE_PREFIX_PATH $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}) | ||
| 4 | +else() | ||
| 5 | + # PYTHON_EXECUTABLE is set by cmake/pybind11.cmake | ||
| 6 | + message(STATUS "Python executable: ${PYTHON_EXECUTABLE}") | ||
| 7 | + | ||
| 8 | + execute_process( | ||
| 9 | + COMMAND "${PYTHON_EXECUTABLE}" -c "import kaldi_native_io; print(kaldi_native_io.cmake_prefix_path)" | ||
| 10 | + OUTPUT_STRIP_TRAILING_WHITESPACE | ||
| 11 | + OUTPUT_VARIABLE KALDI_NATIVE_IO_CMAKE_PREFIX_PATH | ||
| 12 | + ) | ||
| 13 | +endif() | ||
| 14 | + | ||
| 15 | +message(STATUS "KALDI_NATIVE_IO_CMAKE_PREFIX_PATH: ${KALDI_NATIVE_IO_CMAKE_PREFIX_PATH}") | ||
| 16 | +list(APPEND CMAKE_PREFIX_PATH "${KALDI_NATIVE_IO_CMAKE_PREFIX_PATH}") | ||
| 17 | + | ||
| 18 | +find_package(kaldi_native_io REQUIRED) | ||
| 19 | + | ||
| 20 | +message(STATUS "KALDI_NATIVE_IO_FOUND: ${KALDI_NATIVE_IO_FOUND}") | ||
| 21 | +message(STATUS "KALDI_NATIVE_IO_VERSION: ${KALDI_NATIVE_IO_VERSION}") | ||
| 22 | +message(STATUS "KALDI_NATIVE_IO_INCLUDE_DIRS: ${KALDI_NATIVE_IO_INCLUDE_DIRS}") | ||
| 23 | +message(STATUS "KALDI_NATIVE_IO_CXX_FLAGS: ${KALDI_NATIVE_IO_CXX_FLAGS}") | ||
| 24 | +message(STATUS "KALDI_NATIVE_IO_LIBRARIES: ${KALDI_NATIVE_IO_LIBRARIES}") | ||
| 25 | + | ||
| 26 | +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${KALDI_NATIVE_IO_CXX_FLAGS}") | ||
| 27 | +message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") |
| 1 | add_executable(online-fbank-test online-fbank-test.cc) | 1 | add_executable(online-fbank-test online-fbank-test.cc) |
| 2 | target_link_libraries(online-fbank-test kaldi-native-fbank-core) | 2 | target_link_libraries(online-fbank-test kaldi-native-fbank-core) |
| 3 | + | ||
| 4 | +include_directories( | ||
| 5 | + ${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/session/ | ||
| 6 | + ${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/providers/tensorrt/ | ||
| 7 | +) | ||
| 8 | + | ||
| 9 | +include_directories( | ||
| 10 | + ${KALDINATIVEIO} | ||
| 11 | +) | ||
| 12 | +add_executable(capi_test main.cpp) | ||
| 13 | +target_link_libraries(capi_test onnxruntime kaldi-native-fbank-core kaldi_native_io_core) |
sherpa-onnx/csrc/fbank_features.h
0 → 100644
| 1 | +#include <iostream> | ||
| 2 | + | ||
| 3 | +#include "kaldi_native_io/csrc/kaldi-io.h" | ||
| 4 | +#include "kaldi_native_io/csrc/wave-reader.h" | ||
| 5 | +#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 6 | + | ||
| 7 | + | ||
| 8 | +kaldiio::Matrix<float> readWav(std::string filename, bool log = false){ | ||
| 9 | + if (log) | ||
| 10 | + std::cout << "reading " << filename << std::endl; | ||
| 11 | + | ||
| 12 | + bool binary = true; | ||
| 13 | + kaldiio::Input ki(filename, &binary); | ||
| 14 | + kaldiio::WaveHolder wh; | ||
| 15 | + | ||
| 16 | + if (!wh.Read(ki.Stream())) { | ||
| 17 | + std::cerr << "Failed to read " << filename; | ||
| 18 | + exit(EXIT_FAILURE); | ||
| 19 | + } | ||
| 20 | + | ||
| 21 | + auto &wave_data = wh.Value(); | ||
| 22 | + auto &d = wave_data.Data(); | ||
| 23 | + | ||
| 24 | + if (log) | ||
| 25 | + std::cout << "wav shape: " << "(" << d.NumRows() << "," << d.NumCols() << ")" << std::endl; | ||
| 26 | + | ||
| 27 | + return d; | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | + | ||
| 31 | +std::vector<float> ComputeFeatures(knf::OnlineFbank &fbank, knf::FbankOptions opts, kaldiio::Matrix<float> samples, bool log = false){ | ||
| 32 | + int numSamples = samples.NumCols(); | ||
| 33 | + | ||
| 34 | + for (int i = 0; i < numSamples; i++) | ||
| 35 | + { | ||
| 36 | + float currentSample = samples.Row(0).Data()[i] / 32768; | ||
| 37 | + fbank.AcceptWaveform(opts.frame_opts.samp_freq, ¤tSample, 1); | ||
| 38 | + } | ||
| 39 | + | ||
| 40 | + std::vector<float> features; | ||
| 41 | + int32_t num_frames = fbank.NumFramesReady(); | ||
| 42 | + for (int32_t i = 0; i != num_frames; ++i) { | ||
| 43 | + const float *frame = fbank.GetFrame(i); | ||
| 44 | + for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) { | ||
| 45 | + features.push_back(frame[k]); | ||
| 46 | + } | ||
| 47 | + } | ||
| 48 | + if (log){ | ||
| 49 | + std::cout << "done feature extraction" << std::endl; | ||
| 50 | + std::cout << "extracted fbank shape " << "(" << num_frames << "," << opts.mel_opts.num_bins << ")" << std::endl; | ||
| 51 | + | ||
| 52 | + for (int i=0; i< 20; i++) | ||
| 53 | + std::cout << features.at(i) << std::endl; | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + return features; | ||
| 57 | +} |
sherpa-onnx/csrc/main.cpp
0 → 100644
| 1 | +#include <vector> | ||
| 2 | +#include <iostream> | ||
| 3 | +#include <algorithm> | ||
| 4 | +#include <time.h> | ||
| 5 | +#include <math.h> | ||
| 6 | +#include <fstream> | ||
| 7 | + | ||
| 8 | +#include "fbank_features.h" | ||
| 9 | +#include "rnnt_beam_search.h" | ||
| 10 | + | ||
| 11 | +#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 12 | + | ||
| 13 | + | ||
| 14 | +int main(int argc, char* argv[]) { | ||
| 15 | + char* filename = argv[1]; | ||
| 16 | + std::string search_method = argv[2]; | ||
| 17 | + int num_active_paths = atoi(argv[3]); | ||
| 18 | + | ||
| 19 | + // General parameters | ||
| 20 | + int numberOfThreads = 16; | ||
| 21 | + | ||
| 22 | + // Initialize fbanks | ||
| 23 | + knf::FbankOptions opts; | ||
| 24 | + opts.frame_opts.dither = 0; | ||
| 25 | + opts.frame_opts.samp_freq = 16000; | ||
| 26 | + opts.frame_opts.frame_shift_ms = 10.0f; | ||
| 27 | + opts.frame_opts.frame_length_ms = 25.0f; | ||
| 28 | + opts.mel_opts.num_bins = 80; | ||
| 29 | + opts.frame_opts.window_type = "povey"; | ||
| 30 | + opts.frame_opts.snip_edges = false; | ||
| 31 | + knf::OnlineFbank fbank(opts); | ||
| 32 | + | ||
| 33 | + // set session opts | ||
| 34 | + // https://onnxruntime.ai/docs/performance/tune-performance.html | ||
| 35 | + session_options.SetIntraOpNumThreads(numberOfThreads); | ||
| 36 | + session_options.SetInterOpNumThreads(numberOfThreads); | ||
| 37 | + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); | ||
| 38 | + session_options.SetLogSeverityLevel(4); | ||
| 39 | + session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); | ||
| 40 | + | ||
| 41 | + api.CreateTensorRTProviderOptions(&tensorrt_options); | ||
| 42 | + std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(api.ReleaseTensorRTProviderOptions)> rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions); | ||
| 43 | + api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast<OrtSessionOptions*>(session_options), rel_trt_options.get()); | ||
| 44 | + | ||
| 45 | + // Define model | ||
| 46 | + auto model = get_model( | ||
| 47 | + "/mnt/local4/sr/k2_sherpa/models/exp_en2/encoder_simp.onnx", | ||
| 48 | + "/mnt/local4/sr/k2_sherpa/models/exp_en2/decoder_simp.onnx", | ||
| 49 | + "/mnt/local4/sr/k2_sherpa/models/exp_en2/joiner_simp.onnx", | ||
| 50 | + "/mnt/local4/sr/k2_sherpa/models/exp_en2/joiner_encoder_proj_simp.onnx", | ||
| 51 | + "/mnt/local4/sr/k2_sherpa/models/exp_en2/joiner_decoder_proj_simp.onnx", | ||
| 52 | + "/mnt/local4/sr/k2_sherpa/models/exp_en2/enUS_tokens.txt" | ||
| 53 | + ); | ||
| 54 | + | ||
| 55 | + std::vector<std::string> filename_list { | ||
| 56 | + "/mnt/local4/sr/k2_sherpa/test_wavs/cnn_15sec.wav", | ||
| 57 | + //"/mnt/local4/sr/k2_sherpa/test_wavs/1089-134686-0001.wav" | ||
| 58 | + }; | ||
| 59 | + | ||
| 60 | + for (auto filename : filename_list){ | ||
| 61 | + std::cout << filename << std::endl; | ||
| 62 | + auto samples = readWav(filename, true); | ||
| 63 | + int numSamples = samples.NumCols(); | ||
| 64 | + | ||
| 65 | + auto features = ComputeFeatures(fbank, opts, samples); | ||
| 66 | + | ||
| 67 | + auto tic = std::chrono::high_resolution_clock::now(); | ||
| 68 | + | ||
| 69 | + // # === Encoder Out === # | ||
| 70 | + int num_frames = features.size() / opts.mel_opts.num_bins; | ||
| 71 | + auto encoder_out = model.encoder_forward(features, | ||
| 72 | + std::vector<int64_t> {num_frames}, | ||
| 73 | + std::vector<int64_t> {1, num_frames, 80}, | ||
| 74 | + std::vector<int64_t> {1}, | ||
| 75 | + memory_info); | ||
| 76 | + | ||
| 77 | + // # === Search === # | ||
| 78 | + std::vector<std::vector<int32_t>> hyps; | ||
| 79 | + if (search_method == "greedy") | ||
| 80 | + hyps = GreedySearch(&model, &encoder_out); | ||
| 81 | + else{ | ||
| 82 | + std::cout << "wrong search method!" << std::endl; | ||
| 83 | + exit(0); | ||
| 84 | + } | ||
| 85 | + auto results = hyps2result(model.tokens_map, hyps); | ||
| 86 | + | ||
| 87 | + // # === Print Elapsed Time === # | ||
| 88 | + auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - tic); | ||
| 89 | + std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" << std::endl; | ||
| 90 | + std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) << std::endl; | ||
| 91 | + | ||
| 92 | + print_hyps(hyps); | ||
| 93 | + std::cout << results[0] << std::endl; | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + return 0; | ||
| 97 | +} |
sherpa-onnx/csrc/models.h
0 → 100644
| 1 | +#include <map> | ||
| 2 | +#include <vector> | ||
| 3 | +#include <iostream> | ||
| 4 | +#include <algorithm> | ||
| 5 | +#include <sys/stat.h> | ||
| 6 | + | ||
| 7 | +#include "utils_onnx.h" | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +struct Model | ||
| 11 | +{ | ||
| 12 | + public: | ||
| 13 | + const char* encoder_path; | ||
| 14 | + const char* decoder_path; | ||
| 15 | + const char* joiner_path; | ||
| 16 | + const char* joiner_encoder_proj_path; | ||
| 17 | + const char* joiner_decoder_proj_path; | ||
| 18 | + const char* tokens_path; | ||
| 19 | + | ||
| 20 | + Ort::Session encoder = load_model(encoder_path); | ||
| 21 | + Ort::Session decoder = load_model(decoder_path); | ||
| 22 | + Ort::Session joiner = load_model(joiner_path); | ||
| 23 | + Ort::Session joiner_encoder_proj = load_model(joiner_encoder_proj_path); | ||
| 24 | + Ort::Session joiner_decoder_proj = load_model(joiner_decoder_proj_path); | ||
| 25 | + std::map<int, std::string> tokens_map = get_token_map(tokens_path); | ||
| 26 | + | ||
| 27 | + int32_t blank_id; | ||
| 28 | + int32_t unk_id; | ||
| 29 | + int32_t context_size; | ||
| 30 | + | ||
| 31 | + std::vector<Ort::Value> encoder_forward(std::vector<float> in_vector, | ||
| 32 | + std::vector<int64_t> in_vector_length, | ||
| 33 | + std::vector<int64_t> feature_dims, | ||
| 34 | + std::vector<int64_t> feature_length_dims, | ||
| 35 | + Ort::MemoryInfo &memory_info){ | ||
| 36 | + std::vector<Ort::Value> encoder_inputTensors; | ||
| 37 | + encoder_inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), feature_dims.data(), feature_dims.size())); | ||
| 38 | + encoder_inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector_length.data(), in_vector_length.size(), feature_length_dims.data(), feature_length_dims.size())); | ||
| 39 | + | ||
| 40 | + std::vector<const char*> encoder_inputNames = {encoder.GetInputName(0, allocator), encoder.GetInputName(1, allocator)}; | ||
| 41 | + std::vector<const char*> encoder_outputNames = {encoder.GetOutputName(0, allocator)}; | ||
| 42 | + | ||
| 43 | + auto out = encoder.Run(Ort::RunOptions{nullptr}, | ||
| 44 | + encoder_inputNames.data(), | ||
| 45 | + encoder_inputTensors.data(), | ||
| 46 | + encoder_inputTensors.size(), | ||
| 47 | + encoder_outputNames.data(), | ||
| 48 | + encoder_outputNames.size()); | ||
| 49 | + return out; | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + std::vector<Ort::Value> decoder_forward(std::vector<int64_t> in_vector, | ||
| 53 | + std::vector<int64_t> dims, | ||
| 54 | + Ort::MemoryInfo &memory_info){ | ||
| 55 | + std::vector<Ort::Value> inputTensors; | ||
| 56 | + inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size())); | ||
| 57 | + | ||
| 58 | + std::vector<const char*> inputNames {decoder.GetInputName(0, allocator)}; | ||
| 59 | + std::vector<const char*> outputNames {decoder.GetOutputName(0, allocator)}; | ||
| 60 | + | ||
| 61 | + auto out = decoder.Run(Ort::RunOptions{nullptr}, | ||
| 62 | + inputNames.data(), | ||
| 63 | + inputTensors.data(), | ||
| 64 | + inputTensors.size(), | ||
| 65 | + outputNames.data(), | ||
| 66 | + outputNames.size()); | ||
| 67 | + | ||
| 68 | + return out; | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + std::vector<Ort::Value> joiner_forward(std::vector<float> projected_encoder_out, | ||
| 72 | + std::vector<float> decoder_out, | ||
| 73 | + std::vector<int64_t> projected_encoder_out_dims, | ||
| 74 | + std::vector<int64_t> decoder_out_dims, | ||
| 75 | + Ort::MemoryInfo &memory_info){ | ||
| 76 | + std::vector<Ort::Value> inputTensors; | ||
| 77 | + inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, projected_encoder_out.data(), projected_encoder_out.size(), projected_encoder_out_dims.data(), projected_encoder_out_dims.size())); | ||
| 78 | + inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, decoder_out.data(), decoder_out.size(), decoder_out_dims.data(), decoder_out_dims.size())); | ||
| 79 | + std::vector<const char*> inputNames = {joiner.GetInputName(0, allocator), joiner.GetInputName(1, allocator)}; | ||
| 80 | + std::vector<const char*> outputNames = {joiner.GetOutputName(0, allocator)}; | ||
| 81 | + | ||
| 82 | + auto out = joiner.Run(Ort::RunOptions{nullptr}, | ||
| 83 | + inputNames.data(), | ||
| 84 | + inputTensors.data(), | ||
| 85 | + inputTensors.size(), | ||
| 86 | + outputNames.data(), | ||
| 87 | + outputNames.size()); | ||
| 88 | + | ||
| 89 | + return out; | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + std::vector<Ort::Value> joiner_encoder_proj_forward(std::vector<float> in_vector, | ||
| 93 | + std::vector<int64_t> dims, | ||
| 94 | + Ort::MemoryInfo &memory_info){ | ||
| 95 | + std::vector<Ort::Value> inputTensors; | ||
| 96 | + inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size())); | ||
| 97 | + | ||
| 98 | + std::vector<const char*> inputNames {joiner_encoder_proj.GetInputName(0, allocator)}; | ||
| 99 | + std::vector<const char*> outputNames {joiner_encoder_proj.GetOutputName(0, allocator)}; | ||
| 100 | + | ||
| 101 | + auto out = joiner_encoder_proj.Run(Ort::RunOptions{nullptr}, | ||
| 102 | + inputNames.data(), | ||
| 103 | + inputTensors.data(), | ||
| 104 | + inputTensors.size(), | ||
| 105 | + outputNames.data(), | ||
| 106 | + outputNames.size()); | ||
| 107 | + | ||
| 108 | + return out; | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + std::vector<Ort::Value> joiner_decoder_proj_forward(std::vector<float> in_vector, | ||
| 112 | + std::vector<int64_t> dims, | ||
| 113 | + Ort::MemoryInfo &memory_info){ | ||
| 114 | + std::vector<Ort::Value> inputTensors; | ||
| 115 | + inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size())); | ||
| 116 | + | ||
| 117 | + std::vector<const char*> inputNames {joiner_decoder_proj.GetInputName(0, allocator)}; | ||
| 118 | + std::vector<const char*> outputNames {joiner_decoder_proj.GetOutputName(0, allocator)}; | ||
| 119 | + | ||
| 120 | + auto out = joiner_decoder_proj.Run(Ort::RunOptions{nullptr}, | ||
| 121 | + inputNames.data(), | ||
| 122 | + inputTensors.data(), | ||
| 123 | + inputTensors.size(), | ||
| 124 | + outputNames.data(), | ||
| 125 | + outputNames.size()); | ||
| 126 | + | ||
| 127 | + return out; | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + Ort::Session load_model(const char* path){ | ||
| 131 | + struct stat buffer; | ||
| 132 | + if (stat(path, &buffer) != 0){ | ||
| 133 | + std::cout << "File does not exist!: " << path << std::endl; | ||
| 134 | + exit(0); | ||
| 135 | + } | ||
| 136 | + std::cout << "loading " << path << std::endl; | ||
| 137 | + Ort::Session onnx_model(env, path, session_options); | ||
| 138 | + return onnx_model; | ||
| 139 | + } | ||
| 140 | + | ||
| 141 | + void extract_constant_lm_parameters(){ | ||
| 142 | + /* | ||
| 143 | + all_in_one contains these params. We should trace all_in_one and find 'constants_lm' nodes to extract these params | ||
| 144 | + For now, these params are set staticaly. | ||
| 145 | + in: Ort::Session &all_in_one | ||
| 146 | + out: {blank_id, unk_id, context_size} | ||
| 147 | + should return std::vector<int32_t> | ||
| 148 | + */ | ||
| 149 | + blank_id = 0; | ||
| 150 | + unk_id = 0; | ||
| 151 | + context_size = 2; | ||
| 152 | + } | ||
| 153 | + | ||
| 154 | + std::map<int, std::string> get_token_map(const char* token_path){ | ||
| 155 | + std::ifstream inFile; | ||
| 156 | + inFile.open(token_path); | ||
| 157 | + if (inFile.fail()) | ||
| 158 | + std::cerr << "Could not find token file" << std::endl; | ||
| 159 | + | ||
| 160 | + std::map<int, std::string> token_map; | ||
| 161 | + | ||
| 162 | + std::string line; | ||
| 163 | + while (std::getline(inFile, line)) | ||
| 164 | + { | ||
| 165 | + int id; | ||
| 166 | + std::string token; | ||
| 167 | + | ||
| 168 | + std::istringstream iss(line); | ||
| 169 | + iss >> token; | ||
| 170 | + iss >> id; | ||
| 171 | + | ||
| 172 | + token_map[id] = token; | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + return token_map; | ||
| 176 | + } | ||
| 177 | + | ||
| 178 | +}; | ||
| 179 | + | ||
| 180 | + | ||
| 181 | +Model get_model(std::string exp_path, char* tokens_path){ | ||
| 182 | + Model model{ | ||
| 183 | + (exp_path + "/encoder_simp.onnx").c_str(), | ||
| 184 | + (exp_path + "/decoder_simp.onnx").c_str(), | ||
| 185 | + (exp_path + "/joiner_simp.onnx").c_str(), | ||
| 186 | + (exp_path + "/joiner_encoder_proj_simp.onnx").c_str(), | ||
| 187 | + (exp_path + "/joiner_decoder_proj_simp.onnx").c_str(), | ||
| 188 | + tokens_path, | ||
| 189 | + }; | ||
| 190 | + model.extract_constant_lm_parameters(); | ||
| 191 | + | ||
| 192 | + return model; | ||
| 193 | +} | ||
| 194 | + | ||
| 195 | +Model get_model(char* encoder_path, | ||
| 196 | + char* decoder_path, | ||
| 197 | + char* joiner_path, | ||
| 198 | + char* joiner_encoder_proj_path, | ||
| 199 | + char* joiner_decoder_proj_path, | ||
| 200 | + char* tokens_path){ | ||
| 201 | + Model model{ | ||
| 202 | + encoder_path, | ||
| 203 | + decoder_path, | ||
| 204 | + joiner_path, | ||
| 205 | + joiner_encoder_proj_path, | ||
| 206 | + joiner_decoder_proj_path, | ||
| 207 | + tokens_path, | ||
| 208 | + }; | ||
| 209 | + model.extract_constant_lm_parameters(); | ||
| 210 | + | ||
| 211 | + return model; | ||
| 212 | +} | ||
| 213 | + | ||
| 214 | + | ||
| 215 | +void doWarmup(Model *model, int numWarmup = 5){ | ||
| 216 | + std::cout << "Warmup is started" << std::endl; | ||
| 217 | + | ||
| 218 | + std::vector<float> encoder_warmup_sample (500 * 80, 1.0); | ||
| 219 | + for (int i=0; i<numWarmup; i++) | ||
| 220 | + auto encoder_out = model->encoder_forward(encoder_warmup_sample, | ||
| 221 | + std::vector<int64_t> {500}, | ||
| 222 | + std::vector<int64_t> {1, 500, 80}, | ||
| 223 | + std::vector<int64_t> {1}, | ||
| 224 | + memory_info); | ||
| 225 | + | ||
| 226 | + std::vector<int64_t> decoder_warmup_sample {1, 1}; | ||
| 227 | + for (int i=0; i<numWarmup; i++) | ||
| 228 | + auto decoder_out = model->decoder_forward(decoder_warmup_sample, | ||
| 229 | + std::vector<int64_t> {1, 2}, | ||
| 230 | + memory_info); | ||
| 231 | + | ||
| 232 | + std::vector<float> joiner_warmup_sample1 (512, 1.0); | ||
| 233 | + std::vector<float> joiner_warmup_sample2 (512, 1.0); | ||
| 234 | + for (int i=0; i<numWarmup; i++) | ||
| 235 | + auto logits = model->joiner_forward(joiner_warmup_sample1, | ||
| 236 | + joiner_warmup_sample2, | ||
| 237 | + std::vector<int64_t> {1, 1, 1, 512}, | ||
| 238 | + std::vector<int64_t> {1, 1, 1, 512}, | ||
| 239 | + memory_info); | ||
| 240 | + | ||
| 241 | + std::vector<float> joiner_encoder_proj_warmup_sample (100 * 512, 1.0); | ||
| 242 | + for (int i=0; i<numWarmup; i++) | ||
| 243 | + auto projected_encoder_out = model->joiner_encoder_proj_forward(joiner_encoder_proj_warmup_sample, | ||
| 244 | + std::vector<int64_t> {100, 512}, | ||
| 245 | + memory_info); | ||
| 246 | + | ||
| 247 | + std::vector<float> joiner_decoder_proj_warmup_sample (512, 1.0); | ||
| 248 | + for (int i=0; i<numWarmup; i++) | ||
| 249 | + auto projected_decoder_out = model->joiner_decoder_proj_forward(joiner_decoder_proj_warmup_sample, | ||
| 250 | + std::vector<int64_t> {1, 512}, | ||
| 251 | + memory_info); | ||
| 252 | + std::cout << "Warmup is done" << std::endl; | ||
| 253 | +} |
sherpa-onnx/csrc/rnnt_beam_search.h
0 → 100644
| 1 | +#include <vector> | ||
| 2 | +#include <iostream> | ||
| 3 | +#include <algorithm> | ||
| 4 | +#include <time.h> | ||
| 5 | + | ||
| 6 | +#include "models.h" | ||
| 7 | +#include "utils.h" | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +std::vector<float> getEncoderCol(Ort::Value &tensor, int start, int length){ | ||
| 11 | + float* floatarr = tensor.GetTensorMutableData<float>(); | ||
| 12 | + std::vector<float> vector {floatarr + start, floatarr + length}; | ||
| 13 | + return vector; | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +/** | ||
| 18 | + * Assume batch size = 1 | ||
| 19 | + */ | ||
| 20 | +std::vector<int64_t> BuildDecoderInput(const std::vector<std::vector<int32_t>> &hyps, | ||
| 21 | + std::vector<int64_t> &decoder_input) { | ||
| 22 | + | ||
| 23 | + int32_t context_size = decoder_input.size(); | ||
| 24 | + int32_t hyps_length = hyps[0].size(); | ||
| 25 | + for (int i=0; i < context_size; i++) | ||
| 26 | + decoder_input[i] = hyps[0][hyps_length-context_size+i]; | ||
| 27 | + | ||
| 28 | + return decoder_input; | ||
| 29 | +} | ||
| 30 | + | ||
| 31 | + | ||
| 32 | +std::vector<std::vector<int32_t>> GreedySearch( | ||
| 33 | + Model *model, // NOLINT | ||
| 34 | + std::vector<Ort::Value> *encoder_out){ | ||
| 35 | + Ort::Value &encoder_out_tensor = encoder_out->at(0); | ||
| 36 | + int encoder_out_dim1 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 37 | + int encoder_out_dim2 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2]; | ||
| 38 | + auto encoder_out_vector = ortVal2Vector(encoder_out_tensor, encoder_out_dim1 * encoder_out_dim2); | ||
| 39 | + | ||
| 40 | + // # === Greedy Search === # | ||
| 41 | + int32_t batch_size = 1; | ||
| 42 | + std::vector<int32_t> blanks(model->context_size, model->blank_id); | ||
| 43 | + std::vector<std::vector<int32_t>> hyps(batch_size, blanks); | ||
| 44 | + std::vector<int64_t> decoder_input(model->context_size, model->blank_id); | ||
| 45 | + | ||
| 46 | + auto decoder_out = model->decoder_forward(decoder_input, | ||
| 47 | + std::vector<int64_t> {batch_size, model->context_size}, | ||
| 48 | + memory_info); | ||
| 49 | + | ||
| 50 | + Ort::Value &decoder_out_tensor = decoder_out[0]; | ||
| 51 | + int decoder_out_dim = decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2]; | ||
| 52 | + auto decoder_out_vector = ortVal2Vector(decoder_out_tensor, decoder_out_dim); | ||
| 53 | + | ||
| 54 | + decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector, | ||
| 55 | + std::vector<int64_t> {1, decoder_out_dim}, | ||
| 56 | + memory_info); | ||
| 57 | + Ort::Value &projected_decoder_out_tensor = decoder_out[0]; | ||
| 58 | + auto projected_decoder_out_dim = projected_decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 59 | + auto projected_decoder_out_vector = ortVal2Vector(projected_decoder_out_tensor, projected_decoder_out_dim); | ||
| 60 | + | ||
| 61 | + auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector, | ||
| 62 | + std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2}, | ||
| 63 | + memory_info); | ||
| 64 | + | ||
| 65 | + Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0]; | ||
| 66 | + int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0]; | ||
| 67 | + int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 68 | + auto projected_encoder_out_vector = ortVal2Vector(projected_encoder_out_tensor, projected_encoder_out_dim1 * projected_encoder_out_dim2); | ||
| 69 | + | ||
| 70 | + int32_t offset = 0; | ||
| 71 | + for (int i=0; i< projected_encoder_out_dim1; i++){ | ||
| 72 | + int32_t cur_batch_size = 1; | ||
| 73 | + int32_t start = offset; | ||
| 74 | + int32_t end = start + cur_batch_size; | ||
| 75 | + offset = end; | ||
| 76 | + | ||
| 77 | + auto cur_encoder_out = getEncoderCol(projected_encoder_out_tensor, start * projected_encoder_out_dim2, end * projected_encoder_out_dim2); | ||
| 78 | + | ||
| 79 | + auto logits = model->joiner_forward(cur_encoder_out, | ||
| 80 | + projected_decoder_out_vector, | ||
| 81 | + std::vector<int64_t> {1, 1, 1, projected_encoder_out_dim2}, | ||
| 82 | + std::vector<int64_t> {1, 1, 1, projected_decoder_out_dim}, | ||
| 83 | + memory_info); | ||
| 84 | + | ||
| 85 | + Ort::Value &logits_tensor = logits[0]; | ||
| 86 | + int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[3]; | ||
| 87 | + auto logits_vector = ortVal2Vector(logits_tensor, logits_dim); | ||
| 88 | + | ||
| 89 | + int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end()))); | ||
| 90 | + bool emitted = false; | ||
| 91 | + | ||
| 92 | + for (int32_t k = 0; k != cur_batch_size; ++k) { | ||
| 93 | + auto index = max_indices; | ||
| 94 | + if (index != model->blank_id && index != model->unk_id) { | ||
| 95 | + emitted = true; | ||
| 96 | + hyps[k].push_back(index); | ||
| 97 | + } | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + if (emitted) { | ||
| 101 | + decoder_input = BuildDecoderInput(hyps, decoder_input); | ||
| 102 | + | ||
| 103 | + decoder_out = model->decoder_forward(decoder_input, | ||
| 104 | + std::vector<int64_t> {batch_size, model->context_size}, | ||
| 105 | + memory_info); | ||
| 106 | + | ||
| 107 | + decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[2]; | ||
| 108 | + decoder_out_vector = ortVal2Vector(decoder_out[0], decoder_out_dim); | ||
| 109 | + | ||
| 110 | + decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector, | ||
| 111 | + std::vector<int64_t> {1, decoder_out_dim}, | ||
| 112 | + memory_info); | ||
| 113 | + | ||
| 114 | + projected_decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 115 | + projected_decoder_out_vector = ortVal2Vector(decoder_out[0], projected_decoder_out_dim); | ||
| 116 | + } | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + return hyps; | ||
| 120 | +} | ||
| 121 | + |
sherpa-onnx/csrc/utils.h
0 → 100644
| 1 | +#include <iostream> | ||
| 2 | +#include <fstream> | ||
| 3 | + | ||
| 4 | + | ||
| 5 | +void vector2file(std::vector<float> vector, std::string saveFileName){ | ||
| 6 | + std::ofstream f(saveFileName); | ||
| 7 | + for(std::vector<float>::const_iterator i = vector.begin(); i != vector.end(); ++i) { | ||
| 8 | + f << *i << '\n'; | ||
| 9 | + } | ||
| 10 | +} | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +std::vector<std::string> hyps2result(std::map<int, std::string> token_map, std::vector<std::vector<int32_t>> hyps, int context_size = 2){ | ||
| 14 | + std::vector<std::string> results; | ||
| 15 | + | ||
| 16 | + for (int k=0; k < hyps.size(); k++){ | ||
| 17 | + std::string result = token_map[hyps[k][context_size]]; | ||
| 18 | + | ||
| 19 | + for (int i=context_size+1; i < hyps[k].size(); i++){ | ||
| 20 | + std::string token = token_map[hyps[k][i]]; | ||
| 21 | + | ||
| 22 | + // TODO: recognising '_' is not working | ||
| 23 | + if (token.at(0) == '_') | ||
| 24 | + result += " " + token; | ||
| 25 | + else | ||
| 26 | + result += token; | ||
| 27 | + } | ||
| 28 | + results.push_back(result); | ||
| 29 | + } | ||
| 30 | + return results; | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | + | ||
| 34 | +void print_hyps(std::vector<std::vector<int32_t>> hyps, int context_size = 2){ | ||
| 35 | + std::cout << "Hyps:" << std::endl; | ||
| 36 | + for (int i=context_size; i<hyps[0].size(); i++) | ||
| 37 | + std::cout << hyps[0][i] << "-"; | ||
| 38 | + std::cout << "|" << std::endl; | ||
| 39 | +} |
sherpa-onnx/csrc/utils_onnx.h
0 → 100644
| 1 | +#include <iostream> | ||
| 2 | +#include <onnxruntime_cxx_api.h> | ||
| 3 | + | ||
| 4 | +Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); | ||
| 5 | +const auto& api = Ort::GetApi(); | ||
| 6 | +OrtTensorRTProviderOptionsV2* tensorrt_options; | ||
| 7 | +Ort::SessionOptions session_options; | ||
| 8 | +Ort::AllocatorWithDefaultOptions allocator; | ||
| 9 | +auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +std::vector<float> ortVal2Vector(Ort::Value &tensor, int tensor_length){ | ||
| 13 | + /** | ||
| 14 | + * convert ort tensor to vector | ||
| 15 | + */ | ||
| 16 | + float* floatarr = tensor.GetTensorMutableData<float>(); | ||
| 17 | + std::vector<float> vector {floatarr, floatarr + tensor_length}; | ||
| 18 | + return vector; | ||
| 19 | +} | ||
| 20 | + | ||
| 21 | + | ||
| 22 | +void print_onnx_forward_output(std::vector<Ort::Value> &output_tensors, int num){ | ||
| 23 | + float* floatarr = output_tensors.front().GetTensorMutableData<float>(); | ||
| 24 | + for (int i = 0; i < num; i++) | ||
| 25 | + printf("[%d] = %f\n", i, floatarr[i]); | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | + | ||
| 29 | +void print_shape_of_ort_val(std::vector<Ort::Value> &tensor){ | ||
| 30 | + auto out_shape = tensor.front().GetTensorTypeAndShapeInfo().GetShape(); | ||
| 31 | + auto out_size = out_shape.size(); | ||
| 32 | + std::cout << "("; | ||
| 33 | + for (int i=0; i<out_size; i++){ | ||
| 34 | + std::cout << out_shape[i]; | ||
| 35 | + if (i < out_size-1) | ||
| 36 | + std::cout << ","; | ||
| 37 | + } | ||
| 38 | + std::cout << ")" << std::endl; | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | + | ||
| 42 | +void print_model_info(Ort::Session &session, std::string title){ | ||
| 43 | + std::cout << "=== Printing '" << title << "' model ===" << std::endl; | ||
| 44 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 45 | + | ||
| 46 | + // print number of model input nodes | ||
| 47 | + size_t num_input_nodes = session.GetInputCount(); | ||
| 48 | + std::vector<const char*> input_node_names(num_input_nodes); | ||
| 49 | + std::vector<int64_t> input_node_dims; | ||
| 50 | + | ||
| 51 | + printf("Number of inputs = %zu\n", num_input_nodes); | ||
| 52 | + | ||
| 53 | + char* output_name = session.GetOutputName(0, allocator); | ||
| 54 | + printf("output name: %s\n", output_name); | ||
| 55 | + | ||
| 56 | + // iterate over all input nodes | ||
| 57 | + for (int i = 0; i < num_input_nodes; i++) { | ||
| 58 | + // print input node names | ||
| 59 | + char* input_name = session.GetInputName(i, allocator); | ||
| 60 | + printf("Input %d : name=%s\n", i, input_name); | ||
| 61 | + input_node_names[i] = input_name; | ||
| 62 | + | ||
| 63 | + // print input node types | ||
| 64 | + Ort::TypeInfo type_info = session.GetInputTypeInfo(i); | ||
| 65 | + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); | ||
| 66 | + | ||
| 67 | + ONNXTensorElementDataType type = tensor_info.GetElementType(); | ||
| 68 | + printf("Input %d : type=%d\n", i, type); | ||
| 69 | + | ||
| 70 | + // print input shapes/dims | ||
| 71 | + input_node_dims = tensor_info.GetShape(); | ||
| 72 | + printf("Input %d : num_dims=%zu\n", i, input_node_dims.size()); | ||
| 73 | + for (size_t j = 0; j < input_node_dims.size(); j++) | ||
| 74 | + printf("Input %d : dim %zu=%jd\n", i, j, input_node_dims[j]); | ||
| 75 | + } | ||
| 76 | + std::cout << "=======================================" << std::endl; | ||
| 77 | +} |
-
请 注册 或 登录 后发表评论