Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2022-10-12 11:27:05 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2022-10-12 11:27:05 +0800
Commit
77ccd625b83200827e8f0dd4cd93dc56f224af62
77ccd625
1 parent
d9b84d55
code refactoring and add CI (#11)
隐藏空白字符变更
内嵌
并排对比
正在显示
9 个修改的文件
包含
278 行增加
和
132 行删除
.github/workflows/test-linux.yaml
.gitignore
CMakeLists.txt
cmake/kaldi_native_io.cmake
cmake/onnxruntime.cmake
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/main.cpp
sherpa-onnx/csrc/rnnt_beam_search.h
sherpa-onnx/csrc/utils_onnx.h
.github/workflows/test-linux.yaml
0 → 100644
查看文件 @
77ccd62
name
:
test-linux
on
:
push
:
branches
:
-
master
paths
:
-
'
.github/workflows/test-linux.yaml'
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
pull_request
:
branches
:
-
master
paths
:
-
'
.github/workflows/test-linux.yaml'
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
concurrency
:
group
:
test-linux-${{ github.ref }}
cancel-in-progress
:
true
permissions
:
contents
:
read
jobs
:
test-linux
:
runs-on
:
${{ matrix.os }}
strategy
:
fail-fast
:
false
matrix
:
os
:
[
ubuntu-latest
]
steps
:
-
uses
:
actions/checkout@v2
with
:
fetch-depth
:
0
-
name
:
Download pretrained model and test-data (English)
shell
:
bash
run
:
|
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
-
name
:
Configure Cmake
shell
:
bash
run
:
|
mkdir build
cd build
cmake -D CMAKE_BUILD_TYPE=Release ..
-
name
:
Build sherpa-onnx for ubuntu
run
:
|
cd build
make VERBOSE=1 -j3
-
name
:
Run tests for ubuntu (English)
run
:
|
time ./build/bin/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
time ./build/bin/sherpa-onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
...
...
.gitignore
0 → 100644
查看文件 @
77ccd62
build
...
...
CMakeLists.txt
查看文件 @
77ccd62
...
...
@@ -38,7 +38,8 @@ set(CMAKE_CXX_EXTENSIONS OFF)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_SOURCE_DIR
}
/cmake/Modules
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_SOURCE_DIR
}
/cmake
)
include
(
cmake/kaldi_native_io.cmake
)
include
(
cmake/kaldi-native-fbank.cmake
)
include
(
kaldi_native_io
)
include
(
kaldi-native-fbank
)
include
(
onnxruntime
)
add_subdirectory
(
sherpa-onnx
)
...
...
cmake/kaldi_native_io.cmake
查看文件 @
77ccd62
if
(
DEFINED ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}
)
message
(
STATUS
"Using environment variable KALDI_NATIVE_IO_INSTALL_PREFIX: $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}"
)
set
(
KALDI_NATIVE_IO_CMAKE_PREFIX_PATH $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}
)
else
()
# PYTHON_EXECUTABLE is set by cmake/pybind11.cmake
message
(
STATUS
"Python executable:
${
PYTHON_EXECUTABLE
}
"
)
execute_process
(
COMMAND
"
${
PYTHON_EXECUTABLE
}
"
-c
"import kaldi_native_io; print(kaldi_native_io.cmake_prefix_path)"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE KALDI_NATIVE_IO_CMAKE_PREFIX_PATH
function
(
download_kaldi_native_io
)
if
(
CMAKE_VERSION VERSION_LESS 3.11
)
# FetchContent is available since 3.11,
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
# so that it can be used in lower CMake versions.
message
(
STATUS
"Use FetchContent provided by sherpa-onnx"
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_SOURCE_DIR
}
/cmake/Modules
)
endif
()
include
(
FetchContent
)
set
(
kaldi_native_io_URL
"https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz"
)
set
(
kaldi_native_io_HASH
"SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6"
)
set
(
KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL
""
FORCE
)
set
(
KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL
""
FORCE
)
FetchContent_Declare
(
kaldi_native_io
URL
${
kaldi_native_io_URL
}
URL_HASH
${
kaldi_native_io_HASH
}
)
endif
()
message
(
STATUS
"KALDI_NATIVE_IO_CMAKE_PREFIX_PATH:
${
KALDI_NATIVE_IO_CMAKE_PREFIX_PATH
}
"
)
list
(
APPEND CMAKE_PREFIX_PATH
"
${
KALDI_NATIVE_IO_CMAKE_PREFIX_PATH
}
"
)
FetchContent_GetProperties
(
kaldi_native_io
)
if
(
NOT kaldi_native_io_POPULATED
)
message
(
STATUS
"Downloading kaldi_native_io
${
kaldi_native_io_URL
}
"
)
FetchContent_Populate
(
kaldi_native_io
)
endif
()
message
(
STATUS
"kaldi_native_io is downloaded to
${
kaldi_native_io_SOURCE_DIR
}
"
)
message
(
STATUS
"kaldi_native_io's binary dir is
${
kaldi_native_io_BINARY_DIR
}
"
)
find_package
(
kaldi_native_io REQUIRED
)
add_subdirectory
(
${
kaldi_native_io_SOURCE_DIR
}
${
kaldi_native_io_BINARY_DIR
}
EXCLUDE_FROM_ALL
)
message
(
STATUS
"KALDI_NATIVE_IO_FOUND:
${
KALDI_NATIVE_IO_FOUND
}
"
)
message
(
STATUS
"KALDI_NATIVE_IO_VERSION:
${
KALDI_NATIVE_IO_VERSION
}
"
)
message
(
STATUS
"KALDI_NATIVE_IO_INCLUDE_DIRS:
${
KALDI_NATIVE_IO_INCLUDE_DIRS
}
"
)
message
(
STATUS
"KALDI_NATIVE_IO_CXX_FLAGS:
${
KALDI_NATIVE_IO_CXX_FLAGS
}
"
)
message
(
STATUS
"KALDI_NATIVE_IO_LIBRARIES:
${
KALDI_NATIVE_IO_LIBRARIES
}
"
)
target_include_directories
(
kaldi_native_io_core
PUBLIC
${
kaldi_native_io_SOURCE_DIR
}
/
)
endfunction
()
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
KALDI_NATIVE_IO_CXX_FLAGS
}
"
)
message
(
STATUS
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
\ No newline at end of file
download_kaldi_native_io
()
...
...
cmake/onnxruntime.cmake
0 → 100644
查看文件 @
77ccd62
function
(
download_onnxruntime
)
if
(
CMAKE_VERSION VERSION_LESS 3.11
)
# FetchContent is available since 3.11,
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
# so that it can be used in lower CMake versions.
message
(
STATUS
"Use FetchContent provided by sherpa-onnx"
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_SOURCE_DIR
}
/cmake/Modules
)
endif
()
include
(
FetchContent
)
if
(
UNIX AND NOT APPLE
)
# set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz")
# If you don't have access to the internet, you can first download onnxruntime to some directory, and the use
# set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz")
set
(
onnxruntime_HASH
"SHA256=8f6eb9e2da9cf74e7905bf3fc687ef52e34cc566af7af2f92dafe5a5d106aa3d"
)
# After downloading, it contains:
# ./lib/libonnxruntime.so.1.12.1
# ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.12.1
#
# ./include
# It contains all the needed header files
else
()
message
(
FATAL_ERROR
"Only support Linux at present. Will support other OSes later"
)
endif
()
FetchContent_Declare
(
onnxruntime
URL
${
onnxruntime_URL
}
URL_HASH
${
onnxruntime_HASH
}
)
FetchContent_GetProperties
(
onnxruntime
)
if
(
NOT onnxruntime_POPULATED
)
message
(
STATUS
"Downloading onnxruntime
${
onnxruntime_URL
}
"
)
FetchContent_Populate
(
onnxruntime
)
endif
()
message
(
STATUS
"onnxruntime is downloaded to
${
onnxruntime_SOURCE_DIR
}
"
)
find_library
(
location_onnxruntime onnxruntime
PATHS
"
${
onnxruntime_SOURCE_DIR
}
/lib"
)
message
(
STATUS
"location_onnxruntime:
${
location_onnxruntime
}
"
)
add_library
(
onnxruntime SHARED IMPORTED
)
set_target_properties
(
onnxruntime PROPERTIES
IMPORTED_LOCATION
${
location_onnxruntime
}
INTERFACE_INCLUDE_DIRECTORIES
"
${
onnxruntime_SOURCE_DIR
}
/include"
)
endfunction
()
download_onnxruntime
()
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
77ccd62
add_executable
(
online-fbank-test online-fbank-test.cc
)
target_link_libraries
(
online-fbank-test kaldi-native-fbank-core
)
include_directories
(
${
ONNXRUNTIME_ROOTDIR
}
/include/onnxruntime/core/session/
${
ONNXRUNTIME_ROOTDIR
}
/include/onnxruntime/core/providers/tensorrt/
)
include_directories
(
${
CMAKE_SOURCE_DIR
}
)
add_executable
(
sherpa-onnx main.cpp
)
include_directories
(
${
KALDINATIVEIO
}
target_link_libraries
(
sherpa-onnx
onnxruntime
kaldi-native-fbank-core
kaldi_native_io_core
)
add_executable
(
sherpa-onnx main.cpp
)
target_link_libraries
(
sherpa-onnx onnxruntime kaldi-native-fbank-core kaldi_native_io_core
)
...
...
sherpa-onnx/csrc/main.cpp
查看文件 @
77ccd62
#include <vector>
#include <iostream>
#include <algorithm>
#include <time.h>
#include <math.h>
#include <fstream>
#include <iostream>
#include <math.h>
#include <time.h>
#include <vector>
#include "fbank_features.h"
#include "rnnt_beam_search.h"
#include "sherpa-onnx/csrc/fbank_features.h"
#include "sherpa-onnx/csrc/rnnt_beam_search.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
int
main
(
int
argc
,
char
*
argv
[])
{
char
*
encoder_path
=
argv
[
1
];
char
*
decoder_path
=
argv
[
2
];
char
*
joiner_path
=
argv
[
3
];
char
*
joiner_encoder_proj_path
=
argv
[
4
];
char
*
joiner_decoder_proj_path
=
argv
[
5
];
char
*
token_path
=
argv
[
6
];
std
::
string
search_method
=
argv
[
7
];
char
*
filename
=
argv
[
8
];
// General parameters
int
numberOfThreads
=
16
;
// Initialize fbanks
knf
::
FbankOptions
opts
;
opts
.
frame_opts
.
dither
=
0
;
opts
.
frame_opts
.
samp_freq
=
16000
;
opts
.
frame_opts
.
frame_shift_ms
=
10.0
f
;
opts
.
frame_opts
.
frame_length_ms
=
25.0
f
;
opts
.
mel_opts
.
num_bins
=
80
;
opts
.
frame_opts
.
window_type
=
"povey"
;
opts
.
frame_opts
.
snip_edges
=
false
;
knf
::
OnlineFbank
fbank
(
opts
);
// set session opts
// https://onnxruntime.ai/docs/performance/tune-performance.html
session_options
.
SetIntraOpNumThreads
(
numberOfThreads
);
session_options
.
SetInterOpNumThreads
(
numberOfThreads
);
session_options
.
SetGraphOptimizationLevel
(
GraphOptimizationLevel
::
ORT_ENABLE_EXTENDED
);
session_options
.
SetLogSeverityLevel
(
4
);
session_options
.
SetExecutionMode
(
ExecutionMode
::
ORT_SEQUENTIAL
);
api
.
CreateTensorRTProviderOptions
(
&
tensorrt_options
);
std
::
unique_ptr
<
OrtTensorRTProviderOptionsV2
,
decltype
(
api
.
ReleaseTensorRTProviderOptions
)
>
rel_trt_options
(
tensorrt_options
,
api
.
ReleaseTensorRTProviderOptions
);
api
.
SessionOptionsAppendExecutionProvider_TensorRT_V2
(
static_cast
<
OrtSessionOptions
*>
(
session_options
),
rel_trt_options
.
get
());
// Define model
auto
model
=
get_model
(
encoder_path
,
decoder_path
,
joiner_path
,
joiner_encoder_proj_path
,
joiner_decoder_proj_path
,
token_path
);
std
::
vector
<
std
::
string
>
filename_list
{
filename
};
for
(
auto
filename
:
filename_list
){
std
::
cout
<<
filename
<<
std
::
endl
;
auto
samples
=
readWav
(
filename
,
true
);
int
numSamples
=
samples
.
NumCols
();
auto
features
=
ComputeFeatures
(
fbank
,
opts
,
samples
);
auto
tic
=
std
::
chrono
::
high_resolution_clock
::
now
();
// # === Encoder Out === #
int
num_frames
=
features
.
size
()
/
opts
.
mel_opts
.
num_bins
;
auto
encoder_out
=
model
.
encoder_forward
(
features
,
std
::
vector
<
int64_t
>
{
num_frames
},
std
::
vector
<
int64_t
>
{
1
,
num_frames
,
80
},
std
::
vector
<
int64_t
>
{
1
},
memory_info
);
// # === Search === #
std
::
vector
<
std
::
vector
<
int32_t
>>
hyps
;
if
(
search_method
==
"greedy"
)
hyps
=
GreedySearch
(
&
model
,
&
encoder_out
);
else
{
std
::
cout
<<
"wrong search method!"
<<
std
::
endl
;
exit
(
0
);
}
auto
results
=
hyps2result
(
model
.
tokens_map
,
hyps
);
// # === Print Elapsed Time === #
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
std
::
chrono
::
high_resolution_clock
::
now
()
-
tic
);
std
::
cout
<<
"Elapsed: "
<<
float
(
elapsed
.
count
())
/
1000
<<
" seconds"
<<
std
::
endl
;
std
::
cout
<<
"rtf: "
<<
float
(
elapsed
.
count
())
/
1000
/
(
numSamples
/
16000
)
<<
std
::
endl
;
print_hyps
(
hyps
);
std
::
cout
<<
results
[
0
]
<<
std
::
endl
;
int
main
(
int
argc
,
char
*
argv
[])
{
char
*
encoder_path
=
argv
[
1
];
char
*
decoder_path
=
argv
[
2
];
char
*
joiner_path
=
argv
[
3
];
char
*
joiner_encoder_proj_path
=
argv
[
4
];
char
*
joiner_decoder_proj_path
=
argv
[
5
];
char
*
token_path
=
argv
[
6
];
std
::
string
search_method
=
argv
[
7
];
char
*
filename
=
argv
[
8
];
// General parameters
int
numberOfThreads
=
16
;
// Initialize fbanks
knf
::
FbankOptions
opts
;
opts
.
frame_opts
.
dither
=
0
;
opts
.
frame_opts
.
samp_freq
=
16000
;
opts
.
frame_opts
.
frame_shift_ms
=
10.0
f
;
opts
.
frame_opts
.
frame_length_ms
=
25.0
f
;
opts
.
mel_opts
.
num_bins
=
80
;
opts
.
frame_opts
.
window_type
=
"povey"
;
opts
.
frame_opts
.
snip_edges
=
false
;
knf
::
OnlineFbank
fbank
(
opts
);
// set session opts
// https://onnxruntime.ai/docs/performance/tune-performance.html
session_options
.
SetIntraOpNumThreads
(
numberOfThreads
);
session_options
.
SetInterOpNumThreads
(
numberOfThreads
);
session_options
.
SetGraphOptimizationLevel
(
GraphOptimizationLevel
::
ORT_ENABLE_EXTENDED
);
session_options
.
SetLogSeverityLevel
(
4
);
session_options
.
SetExecutionMode
(
ExecutionMode
::
ORT_SEQUENTIAL
);
api
.
CreateTensorRTProviderOptions
(
&
tensorrt_options
);
std
::
unique_ptr
<
OrtTensorRTProviderOptionsV2
,
decltype
(
api
.
ReleaseTensorRTProviderOptions
)
>
rel_trt_options
(
tensorrt_options
,
api
.
ReleaseTensorRTProviderOptions
);
api
.
SessionOptionsAppendExecutionProvider_TensorRT_V2
(
static_cast
<
OrtSessionOptions
*>
(
session_options
),
rel_trt_options
.
get
());
// Define model
auto
model
=
get_model
(
encoder_path
,
decoder_path
,
joiner_path
,
joiner_encoder_proj_path
,
joiner_decoder_proj_path
,
token_path
);
std
::
vector
<
std
::
string
>
filename_list
{
filename
};
for
(
auto
filename
:
filename_list
)
{
std
::
cout
<<
filename
<<
std
::
endl
;
auto
samples
=
readWav
(
filename
,
true
);
int
numSamples
=
samples
.
NumCols
();
auto
features
=
ComputeFeatures
(
fbank
,
opts
,
samples
);
auto
tic
=
std
::
chrono
::
high_resolution_clock
::
now
();
// # === Encoder Out === #
int
num_frames
=
features
.
size
()
/
opts
.
mel_opts
.
num_bins
;
auto
encoder_out
=
model
.
encoder_forward
(
features
,
std
::
vector
<
int64_t
>
{
num_frames
},
std
::
vector
<
int64_t
>
{
1
,
num_frames
,
80
},
std
::
vector
<
int64_t
>
{
1
},
memory_info
);
// # === Search === #
std
::
vector
<
std
::
vector
<
int32_t
>>
hyps
;
if
(
search_method
==
"greedy"
)
hyps
=
GreedySearch
(
&
model
,
&
encoder_out
);
else
{
std
::
cout
<<
"wrong search method!"
<<
std
::
endl
;
exit
(
0
);
}
auto
results
=
hyps2result
(
model
.
tokens_map
,
hyps
);
// # === Print Elapsed Time === #
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
std
::
chrono
::
high_resolution_clock
::
now
()
-
tic
);
std
::
cout
<<
"Elapsed: "
<<
float
(
elapsed
.
count
())
/
1000
<<
" seconds"
<<
std
::
endl
;
std
::
cout
<<
"rtf: "
<<
float
(
elapsed
.
count
())
/
1000
/
(
numSamples
/
16000
)
<<
std
::
endl
;
print_hyps
(
hyps
);
std
::
cout
<<
results
[
0
]
<<
std
::
endl
;
}
return
0
;
return
0
;
}
...
...
sherpa-onnx/csrc/rnnt_beam_search.h
查看文件 @
77ccd62
...
...
@@ -61,7 +61,6 @@ std::vector<std::vector<int32_t>> GreedySearch(
auto
projected_encoder_out
=
model
->
joiner_encoder_proj_forward
(
encoder_out_vector
,
std
::
vector
<
int64_t
>
{
encoder_out_dim1
,
encoder_out_dim2
},
memory_info
);
Ort
::
Value
&
projected_encoder_out_tensor
=
projected_encoder_out
[
0
];
int
projected_encoder_out_dim1
=
projected_encoder_out_tensor
.
GetTensorTypeAndShapeInfo
().
GetShape
()[
0
];
int
projected_encoder_out_dim2
=
projected_encoder_out_tensor
.
GetTensorTypeAndShapeInfo
().
GetShape
()[
1
];
...
...
@@ -78,12 +77,12 @@ std::vector<std::vector<int32_t>> GreedySearch(
auto
logits
=
model
->
joiner_forward
(
cur_encoder_out
,
projected_decoder_out_vector
,
std
::
vector
<
int64_t
>
{
1
,
1
,
1
,
projected_encoder_out_dim2
},
std
::
vector
<
int64_t
>
{
1
,
1
,
1
,
projected_decoder_out_dim
},
std
::
vector
<
int64_t
>
{
1
,
projected_encoder_out_dim2
},
std
::
vector
<
int64_t
>
{
1
,
projected_decoder_out_dim
},
memory_info
);
Ort
::
Value
&
logits_tensor
=
logits
[
0
];
int
logits_dim
=
logits_tensor
.
GetTensorTypeAndShapeInfo
().
GetShape
()[
3
];
int
logits_dim
=
logits_tensor
.
GetTensorTypeAndShapeInfo
().
GetShape
()[
1
];
auto
logits_vector
=
ortVal2Vector
(
logits_tensor
,
logits_dim
);
int
max_indices
=
static_cast
<
int
>
(
std
::
distance
(
logits_vector
.
begin
(),
std
::
max_element
(
logits_vector
.
begin
(),
logits_vector
.
end
())));
...
...
sherpa-onnx/csrc/utils_onnx.h
查看文件 @
77ccd62
#include <iostream>
#include
<onnxruntime_cxx_api.h>
#include
"onnxruntime_cxx_api.h"
Ort
::
Env
env
(
ORT_LOGGING_LEVEL_WARNING
,
"test"
);
const
auto
&
api
=
Ort
::
GetApi
();
...
...
请
注册
或
登录
后发表评论