Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-06-19 20:51:57 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-06-19 20:51:57 +0800
Commit
a11c8599710dbbd2484e93b452acf2bb7196354f
a11c8599
1 parent
656b9fa1
Support clang-tidy (#1034)
显示空白字符变更
内嵌
并排对比
正在显示
63 个修改的文件
包含
362 行增加
和
218 行删除
.clang-tidy
.github/workflows/clang-tidy.yaml
.github/workflows/flutter-macos.yaml
.github/workflows/flutter-windows-x64.yaml
CMakeLists.txt
cmake/openfst.cmake
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/audio-tagging-label-file.cc
sherpa-onnx/csrc/base64-decode.cc
sherpa-onnx/csrc/cat.cc
sherpa-onnx/csrc/circular-buffer.cc
sherpa-onnx/csrc/context-graph.cc
sherpa-onnx/csrc/endpoint.cc
sherpa-onnx/csrc/endpoint.h
sherpa-onnx/csrc/jieba-lexicon.cc
sherpa-onnx/csrc/keyword-spotter.cc
sherpa-onnx/csrc/lexicon.cc
sherpa-onnx/csrc/lexicon.h
sherpa-onnx/csrc/offline-ct-transformer-model.cc
sherpa-onnx/csrc/offline-ctc-model.cc
sherpa-onnx/csrc/offline-stream.cc
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/offline-tts-character-frontend.cc
sherpa-onnx/csrc/offline-tts.cc
sherpa-onnx/csrc/offline-whisper-model.cc
sherpa-onnx/csrc/online-conformer-transducer-model.cc
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
sherpa-onnx/csrc/online-lstm-transducer-model.cc
sherpa-onnx/csrc/online-nemo-ctc-model.cc
sherpa-onnx/csrc/online-recognizer.cc
sherpa-onnx/csrc/online-stream.cc
sherpa-onnx/csrc/online-stream.h
sherpa-onnx/csrc/online-transducer-decoder.cc
sherpa-onnx/csrc/online-transducer-decoder.h
sherpa-onnx/csrc/online-transducer-model.cc
sherpa-onnx/csrc/online-transducer-nemo-model.cc
sherpa-onnx/csrc/online-wenet-ctc-model.cc
sherpa-onnx/csrc/online-zipformer-transducer-model.cc
sherpa-onnx/csrc/online-zipformer2-ctc-model.cc
sherpa-onnx/csrc/online-zipformer2-transducer-model.cc
sherpa-onnx/csrc/onnx-utils.cc
sherpa-onnx/csrc/onnx-utils.h
sherpa-onnx/csrc/packed-sequence.cc
sherpa-onnx/csrc/pad-sequence.cc
sherpa-onnx/csrc/parse-options.cc
sherpa-onnx/csrc/piper-phonemize-lexicon.cc
sherpa-onnx/csrc/resample.cc
sherpa-onnx/csrc/resample.h
sherpa-onnx/csrc/session.cc
sherpa-onnx/csrc/silero-vad-model.cc
sherpa-onnx/csrc/slice.cc
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
sherpa-onnx/csrc/speaker-embedding-manager.cc
sherpa-onnx/csrc/spoken-language-identification-impl.cc
sherpa-onnx/csrc/stack.cc
sherpa-onnx/csrc/symbol-table.cc
sherpa-onnx/csrc/text-utils.cc
sherpa-onnx/csrc/transducer-keyword-decoder.cc
sherpa-onnx/csrc/transpose.cc
sherpa-onnx/csrc/unbind.cc
sherpa-onnx/csrc/utils.cc
sherpa-onnx/csrc/wave-reader.cc
sherpa-onnx/csrc/wave-writer.cc
.clang-tidy
0 → 100644
查看文件 @
a11c859
---
# NOTE there must be no spaces before the '-', so put the comma last.
# The check bugprone-unchecked-optional-access is also turned off atm
# because it causes clang-tidy to hang randomly. The tracking issue
# can be found at https://github.com/llvm/llvm-project/issues/69369.
#
# Modified from
# https://github.com/pytorch/pytorch/blob/main/.clang-tidy
InheritParentConfig: true
Checks: '
bugprone-*,
-bugprone-easily-swappable-parameters,
-bugprone-forward-declaration-namespace,
-bugprone-implicit-widening-of-multiplication-result,
-bugprone-macro-parentheses,
-bugprone-lambda-function-name,
-bugprone-narrowing-conversions,
-bugprone-reserved-identifier,
-bugprone-swapped-arguments,
-bugprone-unchecked-optional-access,
clang-diagnostic-missing-prototypes,
cppcoreguidelines-*,
-cppcoreguidelines-avoid-const-or-ref-data-members,
-cppcoreguidelines-avoid-do-while,
-cppcoreguidelines-avoid-magic-numbers,
-cppcoreguidelines-avoid-non-const-global-variables,
-cppcoreguidelines-interfaces-global-init,
-cppcoreguidelines-macro-usage,
-cppcoreguidelines-narrowing-conversions,
-cppcoreguidelines-owning-memory,
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
-cppcoreguidelines-pro-bounds-constant-array-index,
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
-cppcoreguidelines-pro-type-const-cast,
-cppcoreguidelines-pro-type-cstyle-cast,
-cppcoreguidelines-pro-type-reinterpret-cast,
-cppcoreguidelines-pro-type-static-cast-downcast,
-cppcoreguidelines-pro-type-union-access,
-cppcoreguidelines-pro-type-vararg,
-cppcoreguidelines-special-member-functions,
-cppcoreguidelines-non-private-member-variables-in-classes,
-facebook-hte-RelativeInclude,
hicpp-exception-baseclass,
hicpp-avoid-goto,
misc-*,
-misc-const-correctness,
-misc-include-cleaner,
-misc-use-anonymous-namespace,
-misc-unused-parameters,
-misc-no-recursion,
-misc-non-private-member-variables-in-classes,
-misc-confusable-identifiers,
modernize-*,
-modernize-macro-to-enum,
-modernize-pass-by-value,
-modernize-return-braced-init-list,
-modernize-use-auto,
-modernize-use-default-member-init,
-modernize-use-using,
-modernize-use-trailing-return-type,
-modernize-use-nodiscard,
performance-*,
readability-container-size-empty,
readability-delete-null-pointer,
readability-duplicate-include
readability-misplaced-array-index,
readability-redundant-function-ptr-dereference,
readability-redundant-smartptr-get,
readability-simplify-subscript-expr,
readability-string-compare,
'
WarningsAsErrors: '*'
...
...
...
.github/workflows/clang-tidy.yaml
0 → 100644
查看文件 @
a11c859
name
:
clang-tidy
on
:
push
:
branches
:
-
master
-
clang-tidy
paths
:
-
'
sherpa-onnx/csrc/**'
pull_request
:
branches
:
-
master
paths
:
-
'
sherpa-onnx/csrc/**'
workflow_dispatch
:
concurrency
:
group
:
clang-tidy-${{ github.ref }}
cancel-in-progress
:
true
jobs
:
clang-tidy
:
runs-on
:
ubuntu-latest
strategy
:
matrix
:
python-version
:
[
3.8
]
fail-fast
:
false
steps
:
-
uses
:
actions/checkout@v4
with
:
fetch-depth
:
0
-
name
:
Setup Python ${{ matrix.python-version }}
uses
:
actions/setup-python@v5
with
:
python-version
:
${{ matrix.python-version }}
-
name
:
Install clang-tidy
shell
:
bash
run
:
|
pip install clang-tidy
-
name
:
Configure
shell
:
bash
run
:
|
mkdir build
cd build
cmake -DSHERPA_ONNX_ENABLE_PYTHON=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON ..
-
name
:
Check with clang-tidy
shell
:
bash
run
:
|
cd build
make check
...
...
.github/workflows/flutter-macos.yaml
查看文件 @
a11c859
...
...
@@ -184,6 +184,7 @@ jobs:
path
:
./*.tar.bz2
-
name
:
Publish to huggingface
if
:
(github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && github.event_name == 'push' && contains(github.ref, 'refs/tags/') && matrix.build_type == 'Release'
env
:
HF_TOKEN
:
${{ secrets.HF_TOKEN }}
uses
:
nick-fields/retry@v3
...
...
.github/workflows/flutter-windows-x64.yaml
查看文件 @
a11c859
...
...
@@ -133,6 +133,7 @@ jobs:
shell
:
bash
run
:
|
d=$PWD
SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
pushd sherpa-onnx/flutter
dart pub get
...
...
@@ -159,6 +160,7 @@ jobs:
path
:
./*.tar.bz2
-
name
:
Publish to huggingface
if
:
(github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && github.event_name == 'push' && contains(github.ref, 'refs/tags/') && matrix.build_type == 'Release'
env
:
HF_TOKEN
:
${{ secrets.HF_TOKEN }}
uses
:
nick-fields/retry@v3
...
...
CMakeLists.txt
查看文件 @
a11c859
...
...
@@ -167,7 +167,7 @@ if(SHERPA_ONNX_ENABLE_WASM_KWS)
endif
()
if
(
NOT CMAKE_CXX_STANDARD
)
set
(
CMAKE_CXX_STANDARD 1
4
CACHE STRING
"The C++ version to be used."
)
set
(
CMAKE_CXX_STANDARD 1
7
CACHE STRING
"The C++ version to be used."
)
endif
()
set
(
CMAKE_CXX_EXTENSIONS OFF
)
message
(
STATUS
"C++ Standard version:
${
CMAKE_CXX_STANDARD
}
"
)
...
...
cmake/openfst.cmake
查看文件 @
a11c859
...
...
@@ -3,18 +3,18 @@
function
(
download_openfst
)
include
(
FetchContent
)
set
(
openfst_URL
"https://github.com/csukuangfj/openfst/archive/refs/tags/sherpa-onnx-2024-06-13.tar.gz"
)
set
(
openfst_URL2
"https://hub.nuaa.cf/csukuangfj/openfst/archive/refs/tags/sherpa-onnx-2024-06-13.tar.gz"
)
set
(
openfst_HASH
"SHA256=f10a71c6b64d89eabdc316d372b956c30c825c7c298e2f20c780320e8181ffb6"
)
set
(
openfst_URL
"https://github.com/csukuangfj/openfst/archive/refs/tags/sherpa-onnx-2024-06-19.tar.gz"
)
set
(
openfst_URL2
"https://hub.nuaa.cf/csukuangfj/openfst/archive/refs/tags/sherpa-onnx-2024-06-19.tar.gz"
)
set
(
openfst_HASH
"SHA256=5c98e82cc509c5618502dde4860b8ea04d843850ed57e6d6b590b644b268853d"
)
# If you don't have access to the Internet,
# please pre-download it
set
(
possible_file_locations
$ENV{HOME}/Downloads/openfst-sherpa-onnx-2024-06-13.tar.gz
${
CMAKE_SOURCE_DIR
}
/openfst-sherpa-onnx-2024-06-13.tar.gz
${
CMAKE_BINARY_DIR
}
/openfst-sherpa-onnx-2024-06-13.tar.gz
/tmp/openfst-sherpa-onnx-2024-06-13.tar.gz
/star-fj/fangjun/download/github/openfst-sherpa-onnx-2024-06-13.tar.gz
$ENV{HOME}/Downloads/openfst-sherpa-onnx-2024-06-19.tar.gz
${
CMAKE_SOURCE_DIR
}
/openfst-sherpa-onnx-2024-06-19.tar.gz
${
CMAKE_BINARY_DIR
}
/openfst-sherpa-onnx-2024-06-19.tar.gz
/tmp/openfst-sherpa-onnx-2024-06-19.tar.gz
/star-fj/fangjun/download/github/openfst-sherpa-onnx-2024-06-19.tar.gz
)
foreach
(
f IN LISTS possible_file_locations
)
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
a11c859
...
...
@@ -534,3 +534,17 @@ if(SHERPA_ONNX_ENABLE_TESTS)
sherpa_onnx_add_test
(
${
source
}
)
endforeach
()
endif
()
set
(
srcs_to_check
)
foreach
(
s IN LISTS sources
)
list
(
APPEND srcs_to_check
${
CMAKE_CURRENT_LIST_DIR
}
/
${
s
}
)
endforeach
()
# For clang-tidy
add_custom_target
(
clang-tidy-check
clang-tidy -p
${
CMAKE_BINARY_DIR
}
/compile_commands.json --config-file
${
CMAKE_SOURCE_DIR
}
/.clang-tidy
${
srcs_to_check
}
DEPENDS
${
sources
}
)
add_custom_target
(
check DEPENDS clang-tidy-check
)
...
...
sherpa-onnx/csrc/audio-tagging-label-file.cc
查看文件 @
a11c859
...
...
@@ -60,7 +60,7 @@ void AudioTaggingLabels::Init(std::istream &is) {
std
::
size_t
pos
{};
int32_t
i
=
std
::
stoi
(
index
,
&
pos
);
if
(
index
.
size
()
==
0
||
pos
!=
index
.
size
())
{
if
(
index
.
empty
()
||
pos
!=
index
.
size
())
{
SHERPA_ONNX_LOGE
(
"Invalid line: %s"
,
line
.
c_str
());
exit
(
-
1
);
}
...
...
sherpa-onnx/csrc/base64-decode.cc
查看文件 @
a11c859
...
...
@@ -34,7 +34,7 @@ std::string Base64Decode(const std::string &s) {
exit
(
-
1
);
}
int32_t
n
=
s
.
size
(
)
/
4
*
3
;
int32_t
n
=
s
tatic_cast
<
int32_t
>
(
s
.
size
()
)
/
4
*
3
;
std
::
string
ans
;
ans
.
reserve
(
n
);
...
...
@@ -46,16 +46,16 @@ std::string Base64Decode(const std::string &s) {
}
int32_t
first
=
(
Ord
(
s
[
i
])
<<
2
)
+
((
Ord
(
s
[
i
+
1
])
&
0x30
)
>>
4
);
ans
.
push_back
(
first
);
ans
.
push_back
(
static_cast
<
char
>
(
first
)
);
if
(
i
+
2
<
static_cast
<
int32_t
>
(
s
.
size
())
&&
s
[
i
+
2
]
!=
'='
)
{
int32_t
second
=
((
Ord
(
s
[
i
+
1
])
&
0x0f
)
<<
4
)
+
((
Ord
(
s
[
i
+
2
])
&
0x3c
)
>>
2
);
ans
.
push_back
(
s
econd
);
ans
.
push_back
(
s
tatic_cast
<
char
>
(
second
)
);
if
(
i
+
3
<
static_cast
<
int32_t
>
(
s
.
size
())
&&
s
[
i
+
3
]
!=
'='
)
{
int32_t
third
=
((
Ord
(
s
[
i
+
2
])
&
0x03
)
<<
6
)
+
Ord
(
s
[
i
+
3
]);
ans
.
push_back
(
third
);
ans
.
push_back
(
static_cast
<
char
>
(
third
)
);
}
}
i
+=
4
;
...
...
sherpa-onnx/csrc/cat.cc
查看文件 @
a11c859
...
...
@@ -82,9 +82,9 @@ Ort::Value Cat(OrtAllocator *allocator,
T
*
dst
=
ans
.
GetTensorMutableData
<
T
>
();
for
(
int32_t
i
=
0
;
i
!=
leading_size
;
++
i
)
{
for
(
int32_t
n
=
0
;
n
!=
static_cast
<
int32_t
>
(
values
.
size
());
++
n
)
{
auto
this_dim
=
values
[
n
]
->
GetTensorTypeAndShapeInfo
().
GetShape
()[
dim
];
const
T
*
src
=
values
[
n
]
->
GetTensorData
<
T
>
();
for
(
auto
value
:
values
)
{
auto
this_dim
=
value
->
GetTensorTypeAndShapeInfo
().
GetShape
()[
dim
];
const
T
*
src
=
value
->
GetTensorData
<
T
>
();
src
+=
i
*
this_dim
*
trailing_size
;
std
::
copy
(
src
,
src
+
this_dim
*
trailing_size
,
dst
);
...
...
sherpa-onnx/csrc/circular-buffer.cc
查看文件 @
a11c859
...
...
@@ -20,7 +20,7 @@ CircularBuffer::CircularBuffer(int32_t capacity) {
}
void
CircularBuffer
::
Resize
(
int32_t
new_capacity
)
{
int32_t
capacity
=
buffer_
.
size
(
);
int32_t
capacity
=
static_cast
<
int32_t
>
(
buffer_
.
size
()
);
if
(
new_capacity
<=
capacity
)
{
SHERPA_ONNX_LOGE
(
"new_capacity (%d) <= original capacity (%d). Skip it."
,
new_capacity
,
capacity
);
...
...
@@ -86,7 +86,7 @@ void CircularBuffer::Resize(int32_t new_capacity) {
}
void
CircularBuffer
::
Push
(
const
float
*
p
,
int32_t
n
)
{
int32_t
capacity
=
buffer_
.
size
(
);
int32_t
capacity
=
static_cast
<
int32_t
>
(
buffer_
.
size
()
);
int32_t
size
=
Size
();
if
(
n
+
size
>
capacity
)
{
int32_t
new_capacity
=
std
::
max
(
capacity
*
2
,
n
+
size
);
...
...
@@ -126,7 +126,7 @@ std::vector<float> CircularBuffer::Get(int32_t start_index, int32_t n) const {
return
{};
}
int32_t
capacity
=
buffer_
.
size
(
);
int32_t
capacity
=
static_cast
<
int32_t
>
(
buffer_
.
size
()
);
if
(
start_index
-
head_
+
n
>
size
)
{
SHERPA_ONNX_LOGE
(
"Invalid start_index: %d and n: %d. head_: %d, size: %d"
,
...
...
sherpa-onnx/csrc/context-graph.cc
查看文件 @
a11c859
...
...
@@ -67,8 +67,8 @@ void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
std
::
tuple
<
float
,
const
ContextState
*
,
const
ContextState
*>
ContextGraph
::
ForwardOneStep
(
const
ContextState
*
state
,
int32_t
token
,
bool
strict_mode
/*= true*/
)
const
{
const
ContextState
*
node
;
float
score
;
const
ContextState
*
node
=
nullptr
;
float
score
=
0
;
if
(
1
==
state
->
next
.
count
(
token
))
{
node
=
state
->
next
.
at
(
token
).
get
();
score
=
node
->
token_score
;
...
...
@@ -84,7 +84,10 @@ ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
score
=
node
->
node_score
-
state
->
node_score
;
}
SHERPA_ONNX_CHECK
(
nullptr
!=
node
);
if
(
!
node
)
{
SHERPA_ONNX_LOGE
(
"Some bad things happened."
);
exit
(
-
1
);
}
const
ContextState
*
matched_node
=
node
->
is_end
?
node
:
(
node
->
output
!=
nullptr
?
node
->
output
:
nullptr
);
...
...
sherpa-onnx/csrc/endpoint.cc
查看文件 @
a11c859
...
...
@@ -73,10 +73,15 @@ std::string EndpointConfig::ToString() const {
return
os
.
str
();
}
bool
Endpoint
::
IsEndpoint
(
int
num_frames_decoded
,
int
trailing_silence_frames
,
bool
Endpoint
::
IsEndpoint
(
int32_t
num_frames_decoded
,
int32_t
trailing_silence_frames
,
float
frame_shift_in_seconds
)
const
{
float
utterance_length
=
num_frames_decoded
*
frame_shift_in_seconds
;
float
trailing_silence
=
trailing_silence_frames
*
frame_shift_in_seconds
;
float
utterance_length
=
static_cast
<
float
>
(
num_frames_decoded
)
*
frame_shift_in_seconds
;
float
trailing_silence
=
static_cast
<
float
>
(
trailing_silence_frames
)
*
frame_shift_in_seconds
;
if
(
RuleActivated
(
config_
.
rule1
,
"rule1"
,
trailing_silence
,
utterance_length
)
||
RuleActivated
(
config_
.
rule2
,
"rule2"
,
trailing_silence
,
...
...
sherpa-onnx/csrc/endpoint.h
查看文件 @
a11c859
...
...
@@ -64,7 +64,7 @@ class Endpoint {
/// This function returns true if this set of endpointing rules thinks we
/// should terminate decoding.
bool
IsEndpoint
(
int
num_frames_decoded
,
in
t
trailing_silence_frames
,
bool
IsEndpoint
(
int
32_t
num_frames_decoded
,
int32_
t
trailing_silence_frames
,
float
frame_shift_in_seconds
)
const
;
private
:
...
...
sherpa-onnx/csrc/jieba-lexicon.cc
查看文件 @
a11c859
...
...
@@ -103,6 +103,7 @@ class JiebaLexicon::Impl {
if
(
w
==
"。"
||
w
==
"!"
||
w
==
"?"
||
w
==
","
)
{
ans
.
push_back
(
std
::
move
(
this_sentence
));
this_sentence
=
{};
}
}
// for (const auto &w : words)
...
...
sherpa-onnx/csrc/keyword-spotter.cc
查看文件 @
a11c859
...
...
@@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <fstream>
#include <iomanip>
#include <memory>
...
...
sherpa-onnx/csrc/lexicon.cc
查看文件 @
a11c859
...
...
@@ -82,7 +82,7 @@ std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
std
::
string
line
;
std
::
string
sym
;
int32_t
id
;
int32_t
id
=
-
1
;
while
(
std
::
getline
(
is
,
line
))
{
std
::
istringstream
iss
(
line
);
iss
>>
sym
;
...
...
@@ -254,6 +254,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
this_sentence
.
push_back
(
eos
);
}
ans
.
push_back
(
std
::
move
(
this_sentence
));
this_sentence
=
{};
if
(
sil
!=
-
1
)
{
this_sentence
.
push_back
(
sil
);
...
...
@@ -324,6 +325,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
if
(
w
!=
","
)
{
this_sentence
.
push_back
(
blank
);
ans
.
push_back
(
std
::
move
(
this_sentence
));
this_sentence
=
{};
}
continue
;
...
...
sherpa-onnx/csrc/lexicon.h
查看文件 @
a11c859
...
...
@@ -62,8 +62,8 @@ class Lexicon : public OfflineTtsFrontend {
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int32_t
>>
word2ids_
;
std
::
unordered_set
<
std
::
string
>
punctuations_
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
token2id_
;
Language
language_
;
bool
debug_
;
Language
language_
=
Language
::
kUnknown
;
bool
debug_
=
false
;
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-ct-transformer-model.cc
查看文件 @
a11c859
...
...
@@ -67,7 +67,7 @@ class OfflineCtTransformerModel::Impl {
std
::
vector
<
std
::
string
>
tokens
;
SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP
(
tokens
,
"tokens"
,
"|"
);
int32_t
vocab_size
;
int32_t
vocab_size
=
0
;
SHERPA_ONNX_READ_META_DATA
(
vocab_size
,
"vocab_size"
);
if
(
static_cast
<
int32_t
>
(
tokens
.
size
())
!=
vocab_size
)
{
SHERPA_ONNX_LOGE
(
"tokens.size() %d != vocab_size %d"
,
...
...
sherpa-onnx/csrc/offline-ctc-model.cc
查看文件 @
a11c859
...
...
@@ -19,7 +19,7 @@
namespace
{
enum
class
ModelType
{
enum
class
ModelType
:
std
::
uint8_t
{
kEncDecCTCModelBPE
,
kEncDecHybridRNNTCTCBPEModel
,
kTdnn
,
...
...
sherpa-onnx/csrc/offline-stream.cc
查看文件 @
a11c859
...
...
@@ -4,11 +4,11 @@
#include "sherpa-onnx/csrc/offline-stream.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <iomanip>
#include <utility>
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "sherpa-onnx/csrc/macros.h"
...
...
@@ -56,7 +56,7 @@ class OfflineStream::Impl {
public
:
explicit
Impl
(
const
FeatureExtractorConfig
&
config
,
ContextGraphPtr
context_graph
)
:
config_
(
config
),
context_graph_
(
context_graph
)
{
:
config_
(
config
),
context_graph_
(
std
::
move
(
context_graph
)
)
{
if
(
config
.
is_mfcc
)
{
mfcc_opts_
.
frame_opts
.
dither
=
config_
.
dither
;
mfcc_opts_
.
frame_opts
.
snip_edges
=
config_
.
snip_edges
;
...
...
@@ -266,7 +266,7 @@ class OfflineStream::Impl {
OfflineStream
::
OfflineStream
(
const
FeatureExtractorConfig
&
config
/*= {}*/
,
ContextGraphPtr
context_graph
/*= nullptr*/
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
,
context_graph
))
{}
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
,
std
::
move
(
context_graph
)
))
{}
OfflineStream
::
OfflineStream
(
WhisperTag
tag
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
tag
))
{}
...
...
sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
查看文件 @
a11c859
...
...
@@ -42,7 +42,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
std
::
vector
<
ContextGraphPtr
>
context_graphs
(
batch_size
,
nullptr
);
for
(
int32_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
const
ContextState
*
context_state
;
const
ContextState
*
context_state
=
nullptr
;
if
(
ss
!=
nullptr
)
{
context_graphs
[
i
]
=
ss
[
packed_encoder_out
.
sorted_indexes
[
i
]]
->
GetContextGraph
();
...
...
sherpa-onnx/csrc/offline-tts-character-frontend.cc
查看文件 @
a11c859
...
...
@@ -30,7 +30,7 @@ static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) {
std
::
string
sym
;
std
::
u32string
s
;
int32_t
id
;
int32_t
id
=
0
;
while
(
std
::
getline
(
is
,
line
))
{
std
::
istringstream
iss
(
line
);
iss
>>
sym
;
...
...
@@ -138,6 +138,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
}
ans
.
push_back
(
std
::
move
(
this_sentence
));
this_sentence
=
{};
// re-initialize this_sentence
if
(
use_eos_bos
)
{
...
...
@@ -172,6 +173,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
}
ans
.
push_back
(
std
::
move
(
this_sentence
));
this_sentence
=
{};
// re-initialize this_sentence
if
(
use_eos_bos
)
{
...
...
sherpa-onnx/csrc/offline-tts.cc
查看文件 @
a11c859
...
...
@@ -5,6 +5,7 @@
#include "sherpa-onnx/csrc/offline-tts.h"
#include <string>
#include <utility>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
...
...
@@ -87,7 +88,7 @@ OfflineTts::~OfflineTts() = default;
GeneratedAudio
OfflineTts
::
Generate
(
const
std
::
string
&
text
,
int64_t
sid
/*=0*/
,
float
speed
/*= 1.0*/
,
GeneratedAudioCallback
callback
/*= nullptr*/
)
const
{
return
impl_
->
Generate
(
text
,
sid
,
speed
,
callback
);
return
impl_
->
Generate
(
text
,
sid
,
speed
,
std
::
move
(
callback
)
);
}
int32_t
OfflineTts
::
SampleRate
()
const
{
return
impl_
->
SampleRate
();
}
...
...
sherpa-onnx/csrc/offline-whisper-model.cc
查看文件 @
a11c859
...
...
@@ -22,9 +22,9 @@ class OfflineWhisperModel::Impl {
explicit
Impl
(
const
OfflineModelConfig
&
config
)
:
config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_ERROR
),
debug_
(
config
.
debug
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
debug_
=
config_
.
debug
;
{
auto
buf
=
ReadFile
(
config
.
whisper
.
encoder
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
...
...
@@ -39,9 +39,9 @@ class OfflineWhisperModel::Impl {
explicit
Impl
(
const
SpokenLanguageIdentificationConfig
&
config
)
:
lid_config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_ERROR
),
debug_
(
config_
.
debug
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
debug_
=
config_
.
debug
;
{
auto
buf
=
ReadFile
(
config
.
whisper
.
encoder
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
...
...
@@ -148,7 +148,6 @@ class OfflineWhisperModel::Impl {
cross_v
=
std
::
move
(
std
::
get
<
4
>
(
decoder_out
));
const
float
*
p_logits
=
std
::
get
<
0
>
(
decoder_out
).
GetTensorData
<
float
>
();
int32_t
vocab_size
=
VocabSize
();
const
auto
&
all_language_ids
=
GetAllLanguageIDs
();
int32_t
lang_id
=
all_language_ids
[
0
];
...
...
@@ -317,18 +316,18 @@ class OfflineWhisperModel::Impl {
std
::
unordered_map
<
int32_t
,
std
::
string
>
id2lang_
;
// model meta data
int32_t
n_text_layer_
;
int32_t
n_text_ctx_
;
int32_t
n_text_state_
;
int32_t
n_vocab_
;
int32_t
sot_
;
int32_t
eot_
;
int32_t
blank_
;
int32_t
translate_
;
int32_t
transcribe_
;
int32_t
no_timestamps_
;
int32_t
no_speech_
;
int32_t
is_multilingual_
;
int32_t
n_text_layer_
=
0
;
int32_t
n_text_ctx_
=
0
;
int32_t
n_text_state_
=
0
;
int32_t
n_vocab_
=
0
;
int32_t
sot_
=
0
;
int32_t
eot_
=
0
;
int32_t
blank_
=
0
;
int32_t
translate_
=
0
;
int32_t
transcribe_
=
0
;
int32_t
no_timestamps_
=
0
;
int32_t
no_speech_
=
0
;
int32_t
is_multilingual_
=
0
;
std
::
vector
<
int64_t
>
sot_sequence_
;
};
...
...
sherpa-onnx/csrc/online-conformer-transducer-model.cc
查看文件 @
a11c859
...
...
@@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <memory>
#include <sstream>
#include <string>
...
...
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
查看文件 @
a11c859
...
...
@@ -52,8 +52,9 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
if
(
ok
)
{
std
::
vector
<
int32_t
>
isymbols_out
;
std
::
vector
<
int32_t
>
osymbols_out
;
ok
=
fst
::
GetLinearSymbolSequence
(
fst_out
,
&
isymbols_out
,
&
osymbols_out
,
nullptr
);
/*ok =*/
fst
::
GetLinearSymbolSequence
(
fst_out
,
&
isymbols_out
,
&
osymbols_out
,
nullptr
);
// TODO(fangjun): handle ok is false
std
::
vector
<
int64_t
>
tokens
;
tokens
.
reserve
(
isymbols_out
.
size
());
...
...
sherpa-onnx/csrc/online-lstm-transducer-model.cc
查看文件 @
a11c859
...
...
@@ -3,9 +3,8 @@
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <memory>
#include <sstream>
#include <string>
...
...
sherpa-onnx/csrc/online-nemo-ctc-model.cc
查看文件 @
a11c859
...
...
@@ -265,16 +265,16 @@ class OnlineNeMoCtcModel::Impl {
std
::
vector
<
std
::
string
>
output_names_
;
std
::
vector
<
const
char
*>
output_names_ptr_
;
int32_t
window_size_
;
int32_t
chunk_shift_
;
int32_t
subsampling_factor_
;
int32_t
vocab_size_
;
int32_t
cache_last_channel_dim1_
;
int32_t
cache_last_channel_dim2_
;
int32_t
cache_last_channel_dim3_
;
int32_t
cache_last_time_dim1_
;
int32_t
cache_last_time_dim2_
;
int32_t
cache_last_time_dim3_
;
int32_t
window_size_
=
0
;
int32_t
chunk_shift_
=
0
;
int32_t
subsampling_factor_
=
0
;
int32_t
vocab_size_
=
0
;
int32_t
cache_last_channel_dim1_
=
0
;
int32_t
cache_last_channel_dim2_
=
0
;
int32_t
cache_last_channel_dim3_
=
0
;
int32_t
cache_last_time_dim1_
=
0
;
int32_t
cache_last_time_dim2_
=
0
;
int32_t
cache_last_time_dim3_
=
0
;
Ort
::
Value
cache_last_channel_
{
nullptr
};
Ort
::
Value
cache_last_time_
{
nullptr
};
...
...
sherpa-onnx/csrc/online-recognizer.cc
查看文件 @
a11c859
...
...
@@ -5,9 +5,8 @@
#include "sherpa-onnx/csrc/online-recognizer.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <iomanip>
#include <memory>
#include <sstream>
...
...
sherpa-onnx/csrc/online-stream.cc
查看文件 @
a11c859
...
...
@@ -8,6 +8,7 @@
#include <vector>
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
namespace
sherpa_onnx
{
...
...
@@ -15,7 +16,7 @@ class OnlineStream::Impl {
public
:
explicit
Impl
(
const
FeatureExtractorConfig
&
config
,
ContextGraphPtr
context_graph
)
:
feat_extractor_
(
config
),
context_graph_
(
context_graph
)
{}
:
feat_extractor_
(
config
),
context_graph_
(
std
::
move
(
context_graph
)
)
{}
void
AcceptWaveform
(
int32_t
sampling_rate
,
const
float
*
waveform
,
int32_t
n
)
{
feat_extractor_
.
AcceptWaveform
(
sampling_rate
,
waveform
,
n
);
...
...
@@ -146,7 +147,7 @@ class OnlineStream::Impl {
OnlineStream
::
OnlineStream
(
const
FeatureExtractorConfig
&
config
/*= {}*/
,
ContextGraphPtr
context_graph
/*= nullptr */
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
,
context_graph
))
{}
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
,
std
::
move
(
context_graph
)
))
{}
OnlineStream
::~
OnlineStream
()
=
default
;
...
...
sherpa-onnx/csrc/online-stream.h
查看文件 @
a11c859
...
...
@@ -15,7 +15,6 @@
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
namespace
sherpa_onnx
{
...
...
sherpa-onnx/csrc/online-transducer-decoder.cc
查看文件 @
a11c859
...
...
@@ -45,13 +45,13 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
}
OnlineTransducerDecoderResult
::
OnlineTransducerDecoderResult
(
OnlineTransducerDecoderResult
&&
other
)
OnlineTransducerDecoderResult
&&
other
)
noexcept
:
OnlineTransducerDecoderResult
()
{
*
this
=
std
::
move
(
other
);
}
OnlineTransducerDecoderResult
&
OnlineTransducerDecoderResult
::
operator
=
(
OnlineTransducerDecoderResult
&&
other
)
{
OnlineTransducerDecoderResult
&&
other
)
noexcept
{
if
(
this
==
&
other
)
{
return
*
this
;
}
...
...
sherpa-onnx/csrc/online-transducer-decoder.h
查看文件 @
a11c859
...
...
@@ -44,10 +44,10 @@ struct OnlineTransducerDecoderResult {
OnlineTransducerDecoderResult
&
operator
=
(
const
OnlineTransducerDecoderResult
&
other
);
OnlineTransducerDecoderResult
(
OnlineTransducerDecoderResult
&&
other
);
OnlineTransducerDecoderResult
(
OnlineTransducerDecoderResult
&&
other
)
noexcept
;
OnlineTransducerDecoderResult
&
operator
=
(
OnlineTransducerDecoderResult
&&
other
);
OnlineTransducerDecoderResult
&&
other
)
noexcept
;
};
class
OnlineStream
;
...
...
sherpa-onnx/csrc/online-transducer-model.cc
查看文件 @
a11c859
...
...
@@ -23,7 +23,7 @@
namespace
{
enum
class
ModelType
{
enum
class
ModelType
:
std
::
uint8_t
{
kConformer
,
kLstm
,
kZipformer
,
...
...
sherpa-onnx/csrc/online-transducer-nemo-model.cc
查看文件 @
a11c859
...
...
@@ -5,10 +5,9 @@
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
#include <assert.h>
#include <math.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <memory>
#include <numeric>
#include <sstream>
...
...
@@ -429,8 +428,8 @@ class OnlineTransducerNeMoModel::Impl {
std
::
vector
<
std
::
string
>
joiner_output_names_
;
std
::
vector
<
const
char
*>
joiner_output_names_ptr_
;
int32_t
window_size_
;
int32_t
chunk_shift_
;
int32_t
window_size_
=
0
;
int32_t
chunk_shift_
=
0
;
int32_t
vocab_size_
=
0
;
int32_t
subsampling_factor_
=
8
;
std
::
string
normalize_type_
;
...
...
@@ -438,12 +437,12 @@ class OnlineTransducerNeMoModel::Impl {
int32_t
pred_hidden_
=
-
1
;
// encoder states
int32_t
cache_last_channel_dim1_
;
int32_t
cache_last_channel_dim2_
;
int32_t
cache_last_channel_dim3_
;
int32_t
cache_last_time_dim1_
;
int32_t
cache_last_time_dim2_
;
int32_t
cache_last_time_dim3_
;
int32_t
cache_last_channel_dim1_
=
0
;
int32_t
cache_last_channel_dim2_
=
0
;
int32_t
cache_last_channel_dim3_
=
0
;
int32_t
cache_last_time_dim1_
=
0
;
int32_t
cache_last_time_dim2_
=
0
;
int32_t
cache_last_time_dim3_
=
0
;
// init encoder states
Ort
::
Value
cache_last_channel_
{
nullptr
};
...
...
sherpa-onnx/csrc/online-wenet-ctc-model.cc
查看文件 @
a11c859
...
...
@@ -192,15 +192,15 @@ class OnlineWenetCtcModel::Impl {
std
::
vector
<
std
::
string
>
output_names_
;
std
::
vector
<
const
char
*>
output_names_ptr_
;
int32_t
head_
;
int32_t
num_blocks_
;
int32_t
output_size_
;
int32_t
cnn_module_kernel_
;
int32_t
right_context_
;
int32_t
subsampling_factor_
;
int32_t
vocab_size_
;
int32_t
required_cache_size_
;
int32_t
head_
=
0
;
int32_t
num_blocks_
=
0
;
int32_t
output_size_
=
0
;
int32_t
cnn_module_kernel_
=
0
;
int32_t
right_context_
=
0
;
int32_t
subsampling_factor_
=
0
;
int32_t
vocab_size_
=
0
;
int32_t
required_cache_size_
=
0
;
Ort
::
Value
attn_cache_
{
nullptr
};
Ort
::
Value
conv_cache_
{
nullptr
};
...
...
sherpa-onnx/csrc/online-zipformer-transducer-model.cc
查看文件 @
a11c859
...
...
@@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <memory>
#include <sstream>
#include <string>
...
...
sherpa-onnx/csrc/online-zipformer2-ctc-model.cc
查看文件 @
a11c859
...
...
@@ -4,10 +4,8 @@
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
#include <assert.h>
#include <math.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <string>
...
...
@@ -90,7 +88,6 @@ class OnlineZipformer2CtcModel::Impl {
std
::
vector
<
Ort
::
Value
>
StackStates
(
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
states
)
const
{
int32_t
batch_size
=
static_cast
<
int32_t
>
(
states
.
size
());
int32_t
num_encoders
=
static_cast
<
int32_t
>
(
num_encoder_layers_
.
size
());
std
::
vector
<
const
Ort
::
Value
*>
buf
(
batch_size
);
...
...
@@ -168,7 +165,6 @@ class OnlineZipformer2CtcModel::Impl {
assert
(
states
.
size
()
==
m
*
6
+
2
);
int32_t
batch_size
=
states
[
0
].
GetTensorTypeAndShapeInfo
().
GetShape
()[
1
];
int32_t
num_encoders
=
num_encoder_layers_
.
size
();
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
ans
;
ans
.
resize
(
batch_size
);
...
...
sherpa-onnx/csrc/online-zipformer2-transducer-model.cc
查看文件 @
a11c859
...
...
@@ -4,10 +4,9 @@
#include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h"
#include <assert.h>
#include <math.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <memory>
#include <numeric>
#include <sstream>
...
...
sherpa-onnx/csrc/onnx-utils.cc
查看文件 @
a11c859
...
...
@@ -281,11 +281,12 @@ CopyableOrtValue &CopyableOrtValue::operator=(const CopyableOrtValue &other) {
return
*
this
;
}
CopyableOrtValue
::
CopyableOrtValue
(
CopyableOrtValue
&&
other
)
{
CopyableOrtValue
::
CopyableOrtValue
(
CopyableOrtValue
&&
other
)
noexcept
{
*
this
=
std
::
move
(
other
);
}
CopyableOrtValue
&
CopyableOrtValue
::
operator
=
(
CopyableOrtValue
&&
other
)
{
CopyableOrtValue
&
CopyableOrtValue
::
operator
=
(
CopyableOrtValue
&&
other
)
noexcept
{
if
(
this
==
&
other
)
{
return
*
this
;
}
...
...
sherpa-onnx/csrc/onnx-utils.h
查看文件 @
a11c859
...
...
@@ -110,9 +110,9 @@ struct CopyableOrtValue {
CopyableOrtValue
&
operator
=
(
const
CopyableOrtValue
&
other
);
CopyableOrtValue
(
CopyableOrtValue
&&
other
);
CopyableOrtValue
(
CopyableOrtValue
&&
other
)
noexcept
;
CopyableOrtValue
&
operator
=
(
CopyableOrtValue
&&
other
);
CopyableOrtValue
&
operator
=
(
CopyableOrtValue
&&
other
)
noexcept
;
};
std
::
vector
<
CopyableOrtValue
>
Convert
(
std
::
vector
<
Ort
::
Value
>
values
);
...
...
sherpa-onnx/csrc/packed-sequence.cc
查看文件 @
a11c859
...
...
@@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/packed-sequence.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <numeric>
#include <utility>
...
...
@@ -57,7 +56,7 @@ PackedSequence PackPaddedSequence(OrtAllocator *allocator,
int64_t
max_T
=
p_length
[
indexes
[
0
]];
int32_t
sum_T
=
std
::
accumulate
(
p_length
,
p_length
+
n
,
0
);
auto
sum_T
=
std
::
accumulate
(
p_length
,
p_length
+
n
,
static_cast
<
int64_t
>
(
0
)
);
std
::
array
<
int64_t
,
2
>
data_shape
{
sum_T
,
v_shape
[
2
]};
...
...
sherpa-onnx/csrc/pad-sequence.cc
查看文件 @
a11c859
...
...
@@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/pad-sequence.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <vector>
namespace
sherpa_onnx
{
...
...
sherpa-onnx/csrc/parse-options.cc
查看文件 @
a11c859
...
...
@@ -11,9 +11,8 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include <ctype.h>
#include <algorithm>
#include <array>
#include <cctype>
#include <cstring>
#include <fstream>
...
...
@@ -33,7 +32,7 @@ ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po)
}
else
{
other_parser_
=
po
;
}
if
(
po
!=
nullptr
&&
po
->
prefix_
!=
""
)
{
if
(
po
!=
nullptr
&&
!
po
->
prefix_
.
empty
()
)
{
prefix_
=
po
->
prefix_
+
std
::
string
(
"."
)
+
prefix
;
}
else
{
prefix_
=
prefix
;
...
...
@@ -179,10 +178,10 @@ void ParseOptions::DisableOption(const std::string &name) {
string_map_
.
erase
(
name
);
}
int
ParseOptions
::
NumArgs
()
const
{
return
positional_args_
.
size
();
}
int
32_t
ParseOptions
::
NumArgs
()
const
{
return
positional_args_
.
size
();
}
std
::
string
ParseOptions
::
GetArg
(
int
i
)
const
{
if
(
i
<
1
||
i
>
static_cast
<
int
>
(
positional_args_
.
size
()))
{
std
::
string
ParseOptions
::
GetArg
(
int32_t
i
)
const
{
if
(
i
<
1
||
i
>
static_cast
<
int32_t
>
(
positional_args_
.
size
()))
{
SHERPA_ONNX_LOGE
(
"ParseOptions::GetArg, invalid index %d"
,
i
);
exit
(
-
1
);
}
...
...
@@ -191,7 +190,7 @@ std::string ParseOptions::GetArg(int i) const {
}
// We currently do not support any other options.
enum
ShellType
{
kBash
=
0
};
enum
ShellType
:
std
::
uint8_t
{
kBash
=
0
};
// This can be changed in the code if it ever does need to be changed (as it's
// unlikely that one compilation of this tool-set would use both shells).
...
...
@@ -213,7 +212,7 @@ static bool MustBeQuoted(const std::string &str, ShellType st) {
if
(
*
c
==
'\0'
)
{
return
true
;
// Must quote empty string
}
else
{
const
char
*
ok_chars
[
2
]
;
std
::
array
<
const
char
*
,
2
>
ok_chars
{}
;
// These seem not to be interpreted as long as there are no other "bad"
// characters involved (e.g. "," would be interpreted as part of something
...
...
@@ -229,7 +228,7 @@ static bool MustBeQuoted(const std::string &str, ShellType st) {
// are OK. All others are forbidden (this is easier since the shell
// interprets most non-alphanumeric characters).
if
(
!
isalnum
(
*
c
))
{
const
char
*
d
;
const
char
*
d
=
nullptr
;
for
(
d
=
ok_chars
[
st
];
*
d
!=
'\0'
;
++
d
)
{
if
(
*
c
==
*
d
)
break
;
}
...
...
@@ -269,22 +268,22 @@ static std::string QuoteAndEscape(const std::string &str, ShellType /*st*/) {
escape_str
=
"
\\\"
"
;
// should never be accessed.
}
char
buf
[
2
]
;
std
::
array
<
char
,
2
>
buf
{}
;
buf
[
1
]
=
'\0'
;
buf
[
0
]
=
quote_char
;
std
::
string
ans
=
buf
;
std
::
string
ans
=
buf
.
data
()
;
const
char
*
c
=
str
.
c_str
();
for
(;
*
c
!=
'\0'
;
++
c
)
{
if
(
*
c
==
quote_char
)
{
ans
+=
escape_str
;
}
else
{
buf
[
0
]
=
*
c
;
ans
+=
buf
;
ans
+=
buf
.
data
()
;
}
}
buf
[
0
]
=
quote_char
;
ans
+=
buf
;
ans
+=
buf
.
data
()
;
return
ans
;
}
...
...
@@ -293,11 +292,11 @@ std::string ParseOptions::Escape(const std::string &str) {
return
MustBeQuoted
(
str
,
kShellType
)
?
QuoteAndEscape
(
str
,
kShellType
)
:
str
;
}
int
ParseOptions
::
Read
(
int
argc
,
const
char
*
const
argv
[]
)
{
int
32_t
ParseOptions
::
Read
(
int32_t
argc
,
const
char
*
const
*
argv
)
{
argc_
=
argc
;
argv_
=
argv
;
std
::
string
key
,
value
;
int
i
;
int
32_t
i
=
0
;
// first pass: look for config parameter, look for priority
for
(
i
=
1
;
i
<
argc
;
++
i
)
{
...
...
@@ -306,13 +305,13 @@ int ParseOptions::Read(int argc, const char *const argv[]) {
// a lone "--" marks the end of named options
break
;
}
bool
has_equal_sign
;
bool
has_equal_sign
=
false
;
SplitLongArg
(
argv
[
i
],
&
key
,
&
value
,
&
has_equal_sign
);
NormalizeArgName
(
&
key
);
Trim
(
&
value
);
if
(
key
.
compare
(
"config"
)
==
0
)
{
if
(
key
==
"config"
)
{
ReadConfigFile
(
value
);
}
else
if
(
key
.
compare
(
"help"
)
==
0
)
{
}
else
if
(
key
==
"help"
)
{
PrintUsage
();
exit
(
0
);
}
...
...
@@ -330,7 +329,7 @@ int ParseOptions::Read(int argc, const char *const argv[]) {
double_dash_seen
=
true
;
break
;
}
bool
has_equal_sign
;
bool
has_equal_sign
=
false
;
SplitLongArg
(
argv
[
i
],
&
key
,
&
value
,
&
has_equal_sign
);
NormalizeArgName
(
&
key
);
Trim
(
&
value
);
...
...
@@ -349,14 +348,14 @@ int ParseOptions::Read(int argc, const char *const argv[]) {
if
((
std
::
strcmp
(
argv
[
i
],
"--"
)
==
0
)
&&
!
double_dash_seen
)
{
double_dash_seen
=
true
;
}
else
{
positional_args_
.
push_back
(
std
::
string
(
argv
[
i
])
);
positional_args_
.
emplace_back
(
argv
[
i
]
);
}
}
// if the user did not suppress this with --print-args = false....
if
(
print_args_
)
{
std
::
ostringstream
strm
;
for
(
int
j
=
0
;
j
<
argc
;
++
j
)
strm
<<
Escape
(
argv
[
j
])
<<
" "
;
for
(
int
32_t
j
=
0
;
j
<
argc
;
++
j
)
strm
<<
Escape
(
argv
[
j
])
<<
" "
;
strm
<<
'\n'
;
SHERPA_ONNX_LOGE
(
"%s"
,
strm
.
str
().
c_str
());
}
...
...
@@ -368,14 +367,14 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
os
<<
'\n'
<<
usage_
<<
'\n'
;
// first we print application-specific options
bool
app_specific_header_printed
=
false
;
for
(
auto
it
=
doc_map_
.
begin
();
it
!=
doc_map_
.
end
();
++
it
)
{
if
(
it
->
second
.
is_standard_
==
false
)
{
// application-specific option
for
(
const
auto
&
it
:
doc_map_
)
{
if
(
it
.
second
.
is_standard_
==
false
)
{
// application-specific option
if
(
app_specific_header_printed
==
false
)
{
// header was not yet printed
os
<<
"Options:"
<<
'\n'
;
app_specific_header_printed
=
true
;
}
os
<<
" --"
<<
std
::
setw
(
25
)
<<
std
::
left
<<
it
->
second
.
name_
<<
" : "
<<
it
->
second
.
use_msg_
<<
'\n'
;
os
<<
" --"
<<
std
::
setw
(
25
)
<<
std
::
left
<<
it
.
second
.
name_
<<
" : "
<<
it
.
second
.
use_msg_
<<
'\n'
;
}
}
if
(
app_specific_header_printed
==
true
)
{
...
...
@@ -384,17 +383,17 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
// then the standard options
os
<<
"Standard options:"
<<
'\n'
;
for
(
auto
it
=
doc_map_
.
begin
();
it
!=
doc_map_
.
end
();
++
it
)
{
if
(
it
->
second
.
is_standard_
==
true
)
{
// we have standard option
os
<<
" --"
<<
std
::
setw
(
25
)
<<
std
::
left
<<
it
->
second
.
name_
<<
" : "
<<
it
->
second
.
use_msg_
<<
'\n'
;
for
(
const
auto
&
it
:
doc_map_
)
{
if
(
it
.
second
.
is_standard_
==
true
)
{
// we have standard option
os
<<
" --"
<<
std
::
setw
(
25
)
<<
std
::
left
<<
it
.
second
.
name_
<<
" : "
<<
it
.
second
.
use_msg_
<<
'\n'
;
}
}
os
<<
'\n'
;
if
(
print_command_line
)
{
std
::
ostringstream
strm
;
strm
<<
"Command line was: "
;
for
(
int
j
=
0
;
j
<
argc_
;
++
j
)
strm
<<
Escape
(
argv_
[
j
])
<<
" "
;
for
(
int
32_t
j
=
0
;
j
<
argc_
;
++
j
)
strm
<<
Escape
(
argv_
[
j
])
<<
" "
;
strm
<<
'\n'
;
os
<<
strm
.
str
();
}
...
...
@@ -405,9 +404,9 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
void
ParseOptions
::
PrintConfig
(
std
::
ostream
&
os
)
const
{
os
<<
'\n'
<<
"[[ Configuration of UI-Registered options ]]"
<<
'\n'
;
std
::
string
key
;
for
(
auto
it
=
doc_map_
.
begin
();
it
!=
doc_map_
.
end
();
++
it
)
{
key
=
it
->
first
;
os
<<
it
->
second
.
name_
<<
" = "
;
for
(
const
auto
&
it
:
doc_map_
)
{
key
=
it
.
first
;
os
<<
it
.
second
.
name_
<<
" = "
;
if
(
bool_map_
.
end
()
!=
bool_map_
.
find
(
key
))
{
os
<<
(
*
bool_map_
.
at
(
key
)
?
"true"
:
"false"
);
}
else
if
(
int_map_
.
end
()
!=
int_map_
.
find
(
key
))
{
...
...
@@ -442,13 +441,13 @@ void ParseOptions::ReadConfigFile(const std::string &filename) {
while
(
std
::
getline
(
is
,
line
))
{
++
line_number
;
// trim out the comments
size_t
pos
;
if
((
pos
=
line
.
find_first_of
(
'#'
))
!=
std
::
string
::
npos
)
{
size_t
pos
=
line
.
find_first_of
(
'#'
);
if
(
pos
!=
std
::
string
::
npos
)
{
line
.
erase
(
pos
);
}
// skip empty lines
Trim
(
&
line
);
if
(
line
.
length
()
==
0
)
continue
;
if
(
line
.
empty
()
)
continue
;
if
(
line
.
substr
(
0
,
2
)
!=
"--"
)
{
SHERPA_ONNX_LOGE
(
...
...
@@ -461,7 +460,7 @@ void ParseOptions::ReadConfigFile(const std::string &filename) {
}
// parse option
bool
has_equal_sign
;
bool
has_equal_sign
=
false
;
SplitLongArg
(
line
,
&
key
,
&
value
,
&
has_equal_sign
);
NormalizeArgName
(
&
key
);
Trim
(
&
value
);
...
...
@@ -527,7 +526,7 @@ void ParseOptions::Trim(std::string *str) const {
bool
ParseOptions
::
SetOption
(
const
std
::
string
&
key
,
const
std
::
string
&
value
,
bool
has_equal_sign
)
{
if
(
bool_map_
.
end
()
!=
bool_map_
.
find
(
key
))
{
if
(
has_equal_sign
&&
value
==
""
)
{
if
(
has_equal_sign
&&
value
.
empty
()
)
{
SHERPA_ONNX_LOGE
(
"Invalid option --%s="
,
key
.
c_str
());
exit
(
-
1
);
}
...
...
@@ -557,12 +556,10 @@ bool ParseOptions::ToBool(std::string str) const {
std
::
transform
(
str
.
begin
(),
str
.
end
(),
str
.
begin
(),
::
tolower
);
// allow "" as a valid option for "true", so that --x is the same as --x=true
if
((
str
.
compare
(
"true"
)
==
0
)
||
(
str
.
compare
(
"t"
)
==
0
)
||
(
str
.
compare
(
"1"
)
==
0
)
||
(
str
.
compare
(
""
)
==
0
))
{
if
(
str
==
"true"
||
str
==
"t"
||
str
==
"1"
||
str
.
empty
())
{
return
true
;
}
if
((
str
.
compare
(
"false"
)
==
0
)
||
(
str
.
compare
(
"f"
)
==
0
)
||
(
str
.
compare
(
"0"
)
==
0
))
{
if
(
str
==
"false"
||
str
==
"f"
||
str
==
"0"
)
{
return
false
;
}
// if it is neither true nor false:
...
...
@@ -593,7 +590,7 @@ uint32_t ParseOptions::ToUint(const std::string &str) const {
}
float
ParseOptions
::
ToFloat
(
const
std
::
string
&
str
)
const
{
float
ret
;
float
ret
=
0
;
if
(
!
ConvertStringToReal
(
str
,
&
ret
))
{
SHERPA_ONNX_LOGE
(
"Invalid floating-point option
\"
%s
\"
"
,
str
.
c_str
());
exit
(
-
1
);
...
...
@@ -602,7 +599,7 @@ float ParseOptions::ToFloat(const std::string &str) const {
}
double
ParseOptions
::
ToDouble
(
const
std
::
string
&
str
)
const
{
double
ret
;
double
ret
=
0
;
if
(
!
ConvertStringToReal
(
str
,
&
ret
))
{
SHERPA_ONNX_LOGE
(
"Invalid floating-point option
\"
%s
\"
"
,
str
.
c_str
());
exit
(
-
1
);
...
...
sherpa-onnx/csrc/piper-phonemize-lexicon.cc
查看文件 @
a11c859
...
...
@@ -37,7 +37,7 @@ static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) {
std
::
string
sym
;
std
::
u32string
s
;
int32_t
id
;
int32_t
id
=
0
;
while
(
std
::
getline
(
is
,
line
))
{
std
::
istringstream
iss
(
line
);
iss
>>
sym
;
...
...
sherpa-onnx/csrc/resample.cc
查看文件 @
a11c859
...
...
@@ -24,10 +24,9 @@
#include "sherpa-onnx/csrc/resample.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <type_traits>
...
...
@@ -54,8 +53,8 @@ I Gcd(I m, I n) {
}
// could use compile-time assertion
// but involves messing with complex template stuff.
static_assert
(
std
::
is_integral
<
I
>::
value
,
""
);
while
(
1
)
{
static_assert
(
std
::
is_integral_v
<
I
>
);
while
(
true
)
{
m
%=
n
;
if
(
m
==
0
)
return
(
n
>
0
?
n
:
-
n
);
n
%=
m
;
...
...
@@ -139,10 +138,10 @@ void LinearResample::SetIndexesAndWeights() {
in the header as h(t) = f(t)g(t), evaluated at t.
*/
float
LinearResample
::
FilterFunc
(
float
t
)
const
{
float
window
,
// raised-cosine (Hanning) window of width
float
window
=
0
,
// raised-cosine (Hanning) window of width
// num_zeros_/2*filter_cutoff_
filter
;
// sinc filter function
if
(
fabs
(
t
)
<
num_zeros_
/
(
2.0
*
filter_cutoff_
))
filter
=
0
;
// sinc filter function
if
(
std
::
fabs
(
t
)
<
num_zeros_
/
(
2.0
*
filter_cutoff_
))
window
=
0.5
*
(
1
+
cos
(
M_2PI
*
filter_cutoff_
/
num_zeros_
*
t
));
else
window
=
0.0
;
// outside support of window function
...
...
@@ -172,15 +171,15 @@ void LinearResample::Resample(const float *input, int32_t input_dim, bool flush,
// of it we are producing here.
for
(
int64_t
samp_out
=
output_sample_offset_
;
samp_out
<
tot_output_samp
;
samp_out
++
)
{
int64_t
first_samp_in
;
int32_t
samp_out_wrapped
;
int64_t
first_samp_in
=
0
;
int32_t
samp_out_wrapped
=
0
;
GetIndexes
(
samp_out
,
&
first_samp_in
,
&
samp_out_wrapped
);
const
std
::
vector
<
float
>
&
weights
=
weights_
[
samp_out_wrapped
];
// first_input_index is the first index into "input" that we have a weight
// for.
int32_t
first_input_index
=
static_cast
<
int32_t
>
(
first_samp_in
-
input_sample_offset_
);
float
this_output
;
float
this_output
=
0
;
if
(
first_input_index
>=
0
&&
first_input_index
+
static_cast
<
int32_t
>
(
weights
.
size
())
<=
input_dim
)
{
this_output
=
...
...
@@ -239,7 +238,7 @@ int64_t LinearResample::GetNumOutputSamples(int64_t input_num_samp,
// largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one).
// So when we're subtracting the window-width we can ignore the fractional
// part.
int32_t
window_width_ticks
=
floor
(
window_width
*
tick_freq
);
int32_t
window_width_ticks
=
std
::
floor
(
window_width
*
tick_freq
);
// The time-period of the output that we can sample gets reduced
// by the window-width (which is actually the distance from the
// center to the edge of the windowing function) if we're not
...
...
@@ -287,7 +286,7 @@ void LinearResample::SetRemainder(const float *input, int32_t input_dim) {
// that are "in the past" relative to the beginning of the latest
// input... anyway, storing more remainder than needed is not harmful.
int32_t
max_remainder_needed
=
ceil
(
samp_rate_in_
*
num_zeros_
/
filter_cutoff_
);
std
::
ceil
(
samp_rate_in_
*
num_zeros_
/
filter_cutoff_
);
input_remainder_
.
resize
(
max_remainder_needed
);
for
(
int32_t
index
=
-
static_cast
<
int32_t
>
(
input_remainder_
.
size
());
index
<
0
;
index
++
)
{
...
...
sherpa-onnx/csrc/resample.h
查看文件 @
a11c859
...
...
@@ -130,10 +130,10 @@ class LinearResample {
// the following variables keep track of where we are in a particular signal,
// if it is being provided over multiple calls to Resample().
int64_t
input_sample_offset_
;
///< The number of input samples we have
int64_t
input_sample_offset_
=
0
;
///< The number of input samples we have
///< already received for this signal
///< (including anything in remainder_)
int64_t
output_sample_offset_
;
///< The number of samples we have already
int64_t
output_sample_offset_
=
0
;
///< The number of samples we have already
///< output for this signal.
std
::
vector
<
float
>
input_remainder_
;
///< A small trailing part of the
///< previously seen input signal.
...
...
sherpa-onnx/csrc/session.cc
查看文件 @
a11c859
...
...
@@ -21,13 +21,13 @@
namespace
sherpa_onnx
{
static
void
OrtStatusFailure
(
OrtStatus
*
status
,
const
char
*
s
)
{
const
auto
&
api
=
Ort
::
GetApi
();
const
char
*
msg
=
api
.
GetErrorMessage
(
status
);
SHERPA_ONNX_LOGE
(
"Failed to enable TensorRT : %s."
"Available providers: %s. Fallback to cuda"
,
msg
,
s
);
"Available providers: %s. Fallback to cuda"
,
msg
,
s
);
api
.
ReleaseStatus
(
status
);
}
...
...
@@ -65,8 +65,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
}
case
Provider
:
:
kTRT
:
{
struct
TrtPairs
{
const
char
*
op_keys
;
const
char
*
op_values
;
const
char
*
op_keys
;
const
char
*
op_values
;
};
std
::
vector
<
TrtPairs
>
trt_options
=
{
...
...
@@ -79,15 +79,14 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
{
"trt_engine_cache_enable"
,
"1"
},
{
"trt_engine_cache_path"
,
"."
},
{
"trt_timing_cache_enable"
,
"1"
},
{
"trt_timing_cache_path"
,
"."
}
};
{
"trt_timing_cache_path"
,
"."
}};
// ToDo : Trt configs
// "trt_int8_enable"
// "trt_int8_use_native_calibration_table"
// "trt_dump_subgraphs"
std
::
vector
<
const
char
*>
option_keys
,
option_values
;
for
(
const
TrtPairs
&
pair
:
trt_options
)
{
std
::
vector
<
const
char
*>
option_keys
,
option_values
;
for
(
const
TrtPairs
&
pair
:
trt_options
)
{
option_keys
.
emplace_back
(
pair
.
op_keys
);
option_values
.
emplace_back
(
pair
.
op_values
);
}
...
...
@@ -96,18 +95,22 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
Ort
::
GetAvailableProviders
();
if
(
std
::
find
(
available_providers
.
begin
(),
available_providers
.
end
(),
"TensorrtExecutionProvider"
)
!=
available_providers
.
end
())
{
const
auto
&
api
=
Ort
::
GetApi
();
const
auto
&
api
=
Ort
::
GetApi
();
OrtTensorRTProviderOptionsV2
*
tensorrt_options
;
OrtStatus
*
statusC
=
api
.
CreateTensorRTProviderOptions
(
&
tensorrt_options
);
OrtTensorRTProviderOptionsV2
*
tensorrt_options
=
nullptr
;
OrtStatus
*
statusC
=
api
.
CreateTensorRTProviderOptions
(
&
tensorrt_options
);
OrtStatus
*
statusU
=
api
.
UpdateTensorRTProviderOptions
(
tensorrt_options
,
option_keys
.
data
(),
option_values
.
data
(),
option_keys
.
size
());
sess_opts
.
AppendExecutionProvider_TensorRT_V2
(
*
tensorrt_options
);
if
(
statusC
)
{
OrtStatusFailure
(
statusC
,
os
.
str
().
c_str
());
}
if
(
statusU
)
{
OrtStatusFailure
(
statusU
,
os
.
str
().
c_str
());
}
if
(
statusC
)
{
OrtStatusFailure
(
statusC
,
os
.
str
().
c_str
());
}
if
(
statusU
)
{
OrtStatusFailure
(
statusU
,
os
.
str
().
c_str
());
}
api
.
ReleaseTensorRTProviderOptions
(
tensorrt_options
);
}
...
...
sherpa-onnx/csrc/silero-vad-model.cc
查看文件 @
a11c859
...
...
@@ -20,11 +20,11 @@ class SileroVadModel::Impl {
:
config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_ERROR
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
allocator_
{},
sample_rate_
(
config
.
sample_rate
)
{
auto
buf
=
ReadFile
(
config
.
silero_vad
.
model
);
Init
(
buf
.
data
(),
buf
.
size
());
sample_rate_
=
config
.
sample_rate
;
if
(
sample_rate_
!=
16000
)
{
SHERPA_ONNX_LOGE
(
"Expected sample rate 16000. Given: %d"
,
config
.
sample_rate
);
...
...
sherpa-onnx/csrc/slice.cc
查看文件 @
a11c859
...
...
@@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/slice.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <vector>
namespace
sherpa_onnx
{
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
查看文件 @
a11c859
...
...
@@ -12,7 +12,7 @@ namespace sherpa_onnx {
namespace
{
enum
class
ModelType
{
enum
class
ModelType
:
std
::
uint8_t
{
kWeSpeaker
,
k3dSpeaker
,
kNeMo
,
...
...
sherpa-onnx/csrc/speaker-embedding-manager.cc
查看文件 @
a11c859
...
...
@@ -122,7 +122,7 @@ class SpeakerEmbeddingManager::Impl {
Eigen
::
VectorXf
scores
=
embedding_matrix_
*
v
;
Eigen
::
VectorXf
::
Index
max_index
;
Eigen
::
VectorXf
::
Index
max_index
=
0
;
float
max_score
=
scores
.
maxCoeff
(
&
max_index
);
if
(
max_score
<
threshold
)
{
return
{};
...
...
@@ -178,11 +178,12 @@ class SpeakerEmbeddingManager::Impl {
std
::
vector
<
std
::
string
>
GetAllSpeakers
()
const
{
std
::
vector
<
std
::
string
>
all_speakers
;
all_speakers
.
reserve
(
name2row_
.
size
());
for
(
const
auto
&
p
:
name2row_
)
{
all_speakers
.
push_back
(
p
.
first
);
}
std
::
s
table_s
ort
(
all_speakers
.
begin
(),
all_speakers
.
end
());
std
::
sort
(
all_speakers
.
begin
(),
all_speakers
.
end
());
return
all_speakers
;
}
...
...
sherpa-onnx/csrc/spoken-language-identification-impl.cc
查看文件 @
a11c859
...
...
@@ -18,7 +18,7 @@ namespace sherpa_onnx {
namespace
{
enum
class
ModelType
{
enum
class
ModelType
:
std
::
uint8_t
{
kWhisper
,
kUnknown
,
};
...
...
sherpa-onnx/csrc/stack.cc
查看文件 @
a11c859
...
...
@@ -71,8 +71,8 @@ Ort::Value Stack(OrtAllocator *allocator,
T
*
dst
=
ans
.
GetTensorMutableData
<
T
>
();
for
(
int32_t
i
=
0
;
i
!=
leading_size
;
++
i
)
{
for
(
int32_t
n
=
0
;
n
!=
static_cast
<
int32_t
>
(
values
.
size
());
++
n
)
{
const
T
*
src
=
values
[
n
]
->
GetTensorData
<
T
>
();
for
(
auto
value
:
values
)
{
const
T
*
src
=
value
->
GetTensorData
<
T
>
();
src
+=
i
*
trailing_size
;
std
::
copy
(
src
,
src
+
trailing_size
,
dst
);
...
...
sherpa-onnx/csrc/symbol-table.cc
查看文件 @
a11c859
...
...
@@ -36,7 +36,7 @@ SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) {
void
SymbolTable
::
Init
(
std
::
istream
&
is
)
{
std
::
string
sym
;
int32_t
id
;
int32_t
id
=
0
;
while
(
is
>>
sym
>>
id
)
{
#if 0
// we disable the test here since for some multi-lingual BPE models
...
...
sherpa-onnx/csrc/text-utils.cc
查看文件 @
a11c859
...
...
@@ -5,9 +5,8 @@
#include "sherpa-onnx/csrc/text-utils.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdint>
#include <limits>
...
...
sherpa-onnx/csrc/transducer-keyword-decoder.cc
查看文件 @
a11c859
...
...
@@ -151,7 +151,6 @@ void TransducerKeywordDecoder::Decode(
if
(
matched
)
{
float
ys_prob
=
0.0
;
int32_t
length
=
best_hyp
.
ys_probs
.
size
();
for
(
int32_t
i
=
0
;
i
<
matched_state
->
level
;
++
i
)
{
ys_prob
+=
best_hyp
.
ys_probs
[
i
];
}
...
...
sherpa-onnx/csrc/transpose.cc
查看文件 @
a11c859
...
...
@@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/transpose.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <vector>
namespace
sherpa_onnx
{
...
...
sherpa-onnx/csrc/unbind.cc
查看文件 @
a11c859
...
...
@@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/unbind.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <functional>
#include <numeric>
#include <utility>
...
...
sherpa-onnx/csrc/utils.cc
查看文件 @
a11c859
...
...
@@ -30,7 +30,6 @@ static bool EncodeBase(const std::vector<std::string> &lines,
std
::
vector
<
float
>
tmp_thresholds
;
std
::
vector
<
std
::
string
>
tmp_phrases
;
std
::
string
line
;
std
::
string
word
;
bool
has_scores
=
false
;
bool
has_thresholds
=
false
;
...
...
@@ -72,6 +71,7 @@ static bool EncodeBase(const std::vector<std::string> &lines,
}
}
ids
->
push_back
(
std
::
move
(
tmp_ids
));
tmp_ids
=
{};
tmp_scores
.
push_back
(
score
);
tmp_phrases
.
push_back
(
phrase
);
tmp_thresholds
.
push_back
(
threshold
);
...
...
sherpa-onnx/csrc/wave-reader.cc
查看文件 @
a11c859
...
...
@@ -100,13 +100,13 @@ struct WaveHeader {
int32_t
subchunk2_id
;
// a tag of this chunk
int32_t
subchunk2_size
;
// size of subchunk2
};
static_assert
(
sizeof
(
WaveHeader
)
==
44
,
""
);
static_assert
(
sizeof
(
WaveHeader
)
==
44
);
// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std
::
vector
<
float
>
ReadWaveImpl
(
std
::
istream
&
is
,
int32_t
*
sampling_rate
,
bool
*
is_ok
)
{
WaveHeader
header
;
WaveHeader
header
{}
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
header
),
sizeof
(
header
));
if
(
!
is
)
{
*
is_ok
=
false
;
...
...
sherpa-onnx/csrc/wave-writer.cc
查看文件 @
a11c859
...
...
@@ -37,7 +37,7 @@ struct WaveHeader {
bool
WriteWave
(
const
std
::
string
&
filename
,
int32_t
sampling_rate
,
const
float
*
samples
,
int32_t
n
)
{
WaveHeader
header
;
WaveHeader
header
{}
;
header
.
chunk_id
=
0x46464952
;
// FFIR
header
.
format
=
0x45564157
;
// EVAW
header
.
subchunk1_id
=
0x20746d66
;
// "fmt "
...
...
请
注册
或
登录
后发表评论