Fangjun Kuang
Committed by GitHub

Support clang-tidy (#1034)

正在显示 63 个修改的文件 包含 382 行增加238 行删除
---
# NOTE there must be no spaces before the '-', so put the comma last.
# The check bugprone-unchecked-optional-access is also turned off atm
# because it causes clang-tidy to hang randomly. The tracking issue
# can be found at https://github.com/llvm/llvm-project/issues/69369.
#
# Modified from
# https://github.com/pytorch/pytorch/blob/main/.clang-tidy
InheritParentConfig: true
Checks: '
bugprone-*,
-bugprone-easily-swappable-parameters,
-bugprone-forward-declaration-namespace,
-bugprone-implicit-widening-of-multiplication-result,
-bugprone-macro-parentheses,
-bugprone-lambda-function-name,
-bugprone-narrowing-conversions,
-bugprone-reserved-identifier,
-bugprone-swapped-arguments,
-bugprone-unchecked-optional-access,
clang-diagnostic-missing-prototypes,
cppcoreguidelines-*,
-cppcoreguidelines-avoid-const-or-ref-data-members,
-cppcoreguidelines-avoid-do-while,
-cppcoreguidelines-avoid-magic-numbers,
-cppcoreguidelines-avoid-non-const-global-variables,
-cppcoreguidelines-interfaces-global-init,
-cppcoreguidelines-macro-usage,
-cppcoreguidelines-narrowing-conversions,
-cppcoreguidelines-owning-memory,
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
-cppcoreguidelines-pro-bounds-constant-array-index,
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
-cppcoreguidelines-pro-type-const-cast,
-cppcoreguidelines-pro-type-cstyle-cast,
-cppcoreguidelines-pro-type-reinterpret-cast,
-cppcoreguidelines-pro-type-static-cast-downcast,
-cppcoreguidelines-pro-type-union-access,
-cppcoreguidelines-pro-type-vararg,
-cppcoreguidelines-special-member-functions,
-cppcoreguidelines-non-private-member-variables-in-classes,
-facebook-hte-RelativeInclude,
hicpp-exception-baseclass,
hicpp-avoid-goto,
misc-*,
-misc-const-correctness,
-misc-include-cleaner,
-misc-use-anonymous-namespace,
-misc-unused-parameters,
-misc-no-recursion,
-misc-non-private-member-variables-in-classes,
-misc-confusable-identifiers,
modernize-*,
-modernize-macro-to-enum,
-modernize-pass-by-value,
-modernize-return-braced-init-list,
-modernize-use-auto,
-modernize-use-default-member-init,
-modernize-use-using,
-modernize-use-trailing-return-type,
-modernize-use-nodiscard,
performance-*,
readability-container-size-empty,
readability-delete-null-pointer,
readability-duplicate-include
readability-misplaced-array-index,
readability-redundant-function-ptr-dereference,
readability-redundant-smartptr-get,
readability-simplify-subscript-expr,
readability-string-compare,
'
WarningsAsErrors: '*'
...
... ...
name: clang-tidy
on:
push:
branches:
- master
- clang-tidy
paths:
- 'sherpa-onnx/csrc/**'
pull_request:
branches:
- master
paths:
- 'sherpa-onnx/csrc/**'
workflow_dispatch:
concurrency:
group: clang-tidy-${{ github.ref }}
cancel-in-progress: true
jobs:
clang-tidy:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install clang-tidy
shell: bash
run: |
pip install clang-tidy
- name: Configure
shell: bash
run: |
mkdir build
cd build
cmake -DSHERPA_ONNX_ENABLE_PYTHON=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON ..
- name: Check with clang-tidy
shell: bash
run: |
cd build
make check
... ...
... ... @@ -184,6 +184,7 @@ jobs:
path: ./*.tar.bz2
- name: Publish to huggingface
if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && github.event_name == 'push' && contains(github.ref, 'refs/tags/') && matrix.build_type == 'Release'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
... ...
... ... @@ -133,6 +133,7 @@ jobs:
shell: bash
run: |
d=$PWD
SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
pushd sherpa-onnx/flutter
dart pub get
... ... @@ -159,6 +160,7 @@ jobs:
path: ./*.tar.bz2
- name: Publish to huggingface
if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && github.event_name == 'push' && contains(github.ref, 'refs/tags/') && matrix.build_type == 'Release'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
... ...
... ... @@ -167,7 +167,7 @@ if(SHERPA_ONNX_ENABLE_WASM_KWS)
endif()
if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ version to be used.")
endif()
set(CMAKE_CXX_EXTENSIONS OFF)
message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}")
... ...
... ... @@ -3,18 +3,18 @@
function(download_openfst)
include(FetchContent)
set(openfst_URL "https://github.com/csukuangfj/openfst/archive/refs/tags/sherpa-onnx-2024-06-13.tar.gz")
set(openfst_URL2 "https://hub.nuaa.cf/csukuangfj/openfst/archive/refs/tags/sherpa-onnx-2024-06-13.tar.gz")
set(openfst_HASH "SHA256=f10a71c6b64d89eabdc316d372b956c30c825c7c298e2f20c780320e8181ffb6")
set(openfst_URL "https://github.com/csukuangfj/openfst/archive/refs/tags/sherpa-onnx-2024-06-19.tar.gz")
set(openfst_URL2 "https://hub.nuaa.cf/csukuangfj/openfst/archive/refs/tags/sherpa-onnx-2024-06-19.tar.gz")
set(openfst_HASH "SHA256=5c98e82cc509c5618502dde4860b8ea04d843850ed57e6d6b590b644b268853d")
# If you don't have access to the Internet,
# please pre-download it
set(possible_file_locations
$ENV{HOME}/Downloads/openfst-sherpa-onnx-2024-06-13.tar.gz
${CMAKE_SOURCE_DIR}/openfst-sherpa-onnx-2024-06-13.tar.gz
${CMAKE_BINARY_DIR}/openfst-sherpa-onnx-2024-06-13.tar.gz
/tmp/openfst-sherpa-onnx-2024-06-13.tar.gz
/star-fj/fangjun/download/github/openfst-sherpa-onnx-2024-06-13.tar.gz
$ENV{HOME}/Downloads/openfst-sherpa-onnx-2024-06-19.tar.gz
${CMAKE_SOURCE_DIR}/openfst-sherpa-onnx-2024-06-19.tar.gz
${CMAKE_BINARY_DIR}/openfst-sherpa-onnx-2024-06-19.tar.gz
/tmp/openfst-sherpa-onnx-2024-06-19.tar.gz
/star-fj/fangjun/download/github/openfst-sherpa-onnx-2024-06-19.tar.gz
)
foreach(f IN LISTS possible_file_locations)
... ...
... ... @@ -534,3 +534,17 @@ if(SHERPA_ONNX_ENABLE_TESTS)
sherpa_onnx_add_test(${source})
endforeach()
endif()
set(srcs_to_check)
foreach(s IN LISTS sources)
list(APPEND srcs_to_check ${CMAKE_CURRENT_LIST_DIR}/${s})
endforeach()
# For clang-tidy
add_custom_target(
clang-tidy-check
clang-tidy -p ${CMAKE_BINARY_DIR}/compile_commands.json --config-file ${CMAKE_SOURCE_DIR}/.clang-tidy ${srcs_to_check}
DEPENDS ${sources})
add_custom_target(check DEPENDS clang-tidy-check)
... ...
... ... @@ -60,7 +60,7 @@ void AudioTaggingLabels::Init(std::istream &is) {
std::size_t pos{};
int32_t i = std::stoi(index, &pos);
if (index.size() == 0 || pos != index.size()) {
if (index.empty() || pos != index.size()) {
SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str());
exit(-1);
}
... ...
... ... @@ -34,7 +34,7 @@ std::string Base64Decode(const std::string &s) {
exit(-1);
}
int32_t n = s.size() / 4 * 3;
int32_t n = static_cast<int32_t>(s.size()) / 4 * 3;
std::string ans;
ans.reserve(n);
... ... @@ -46,16 +46,16 @@ std::string Base64Decode(const std::string &s) {
}
int32_t first = (Ord(s[i]) << 2) + ((Ord(s[i + 1]) & 0x30) >> 4);
ans.push_back(first);
ans.push_back(static_cast<char>(first));
if (i + 2 < static_cast<int32_t>(s.size()) && s[i + 2] != '=') {
int32_t second =
((Ord(s[i + 1]) & 0x0f) << 4) + ((Ord(s[i + 2]) & 0x3c) >> 2);
ans.push_back(second);
ans.push_back(static_cast<char>(second));
if (i + 3 < static_cast<int32_t>(s.size()) && s[i + 3] != '=') {
int32_t third = ((Ord(s[i + 2]) & 0x03) << 6) + Ord(s[i + 3]);
ans.push_back(third);
ans.push_back(static_cast<char>(third));
}
}
i += 4;
... ...
... ... @@ -82,9 +82,9 @@ Ort::Value Cat(OrtAllocator *allocator,
T *dst = ans.GetTensorMutableData<T>();
for (int32_t i = 0; i != leading_size; ++i) {
for (int32_t n = 0; n != static_cast<int32_t>(values.size()); ++n) {
auto this_dim = values[n]->GetTensorTypeAndShapeInfo().GetShape()[dim];
const T *src = values[n]->GetTensorData<T>();
for (auto value : values) {
auto this_dim = value->GetTensorTypeAndShapeInfo().GetShape()[dim];
const T *src = value->GetTensorData<T>();
src += i * this_dim * trailing_size;
std::copy(src, src + this_dim * trailing_size, dst);
... ...
... ... @@ -20,7 +20,7 @@ CircularBuffer::CircularBuffer(int32_t capacity) {
}
void CircularBuffer::Resize(int32_t new_capacity) {
int32_t capacity = buffer_.size();
int32_t capacity = static_cast<int32_t>(buffer_.size());
if (new_capacity <= capacity) {
SHERPA_ONNX_LOGE("new_capacity (%d) <= original capacity (%d). Skip it.",
new_capacity, capacity);
... ... @@ -86,7 +86,7 @@ void CircularBuffer::Resize(int32_t new_capacity) {
}
void CircularBuffer::Push(const float *p, int32_t n) {
int32_t capacity = buffer_.size();
int32_t capacity = static_cast<int32_t>(buffer_.size());
int32_t size = Size();
if (n + size > capacity) {
int32_t new_capacity = std::max(capacity * 2, n + size);
... ... @@ -126,7 +126,7 @@ std::vector<float> CircularBuffer::Get(int32_t start_index, int32_t n) const {
return {};
}
int32_t capacity = buffer_.size();
int32_t capacity = static_cast<int32_t>(buffer_.size());
if (start_index - head_ + n > size) {
SHERPA_ONNX_LOGE("Invalid start_index: %d and n: %d. head_: %d, size: %d",
... ...
... ... @@ -67,8 +67,8 @@ void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
std::tuple<float, const ContextState *, const ContextState *>
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
bool strict_mode /*= true*/) const {
const ContextState *node;
float score;
const ContextState *node = nullptr;
float score = 0;
if (1 == state->next.count(token)) {
node = state->next.at(token).get();
score = node->token_score;
... ... @@ -84,7 +84,10 @@ ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
score = node->node_score - state->node_score;
}
SHERPA_ONNX_CHECK(nullptr != node);
if (!node) {
SHERPA_ONNX_LOGE("Some bad things happened.");
exit(-1);
}
const ContextState *matched_node =
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
... ...
... ... @@ -73,10 +73,15 @@ std::string EndpointConfig::ToString() const {
return os.str();
}
bool Endpoint::IsEndpoint(int num_frames_decoded, int trailing_silence_frames,
bool Endpoint::IsEndpoint(int32_t num_frames_decoded,
int32_t trailing_silence_frames,
float frame_shift_in_seconds) const {
float utterance_length = num_frames_decoded * frame_shift_in_seconds;
float trailing_silence = trailing_silence_frames * frame_shift_in_seconds;
float utterance_length =
static_cast<float>(num_frames_decoded) * frame_shift_in_seconds;
float trailing_silence =
static_cast<float>(trailing_silence_frames) * frame_shift_in_seconds;
if (RuleActivated(config_.rule1, "rule1", trailing_silence,
utterance_length) ||
RuleActivated(config_.rule2, "rule2", trailing_silence,
... ...
... ... @@ -64,7 +64,7 @@ class Endpoint {
/// This function returns true if this set of endpointing rules thinks we
/// should terminate decoding.
bool IsEndpoint(int num_frames_decoded, int trailing_silence_frames,
bool IsEndpoint(int32_t num_frames_decoded, int32_t trailing_silence_frames,
float frame_shift_in_seconds) const;
private:
... ...
... ... @@ -103,6 +103,7 @@ class JiebaLexicon::Impl {
if (w == "。" || w == "!" || w == "?" || w == ",") {
ans.push_back(std::move(this_sentence));
this_sentence = {};
}
} // for (const auto &w : words)
... ...
... ... @@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <fstream>
#include <iomanip>
#include <memory>
... ...
... ... @@ -82,7 +82,7 @@ std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
std::string line;
std::string sym;
int32_t id;
int32_t id = -1;
while (std::getline(is, line)) {
std::istringstream iss(line);
iss >> sym;
... ... @@ -254,6 +254,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
this_sentence.push_back(eos);
}
ans.push_back(std::move(this_sentence));
this_sentence = {};
if (sil != -1) {
this_sentence.push_back(sil);
... ... @@ -324,6 +325,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
if (w != ",") {
this_sentence.push_back(blank);
ans.push_back(std::move(this_sentence));
this_sentence = {};
}
continue;
... ...
... ... @@ -62,8 +62,8 @@ class Lexicon : public OfflineTtsFrontend {
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
std::unordered_set<std::string> punctuations_;
std::unordered_map<std::string, int32_t> token2id_;
Language language_;
bool debug_;
Language language_ = Language::kUnknown;
bool debug_ = false;
};
} // namespace sherpa_onnx
... ...
... ... @@ -67,7 +67,7 @@ class OfflineCtTransformerModel::Impl {
std::vector<std::string> tokens;
SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(tokens, "tokens", "|");
int32_t vocab_size;
int32_t vocab_size = 0;
SHERPA_ONNX_READ_META_DATA(vocab_size, "vocab_size");
if (static_cast<int32_t>(tokens.size()) != vocab_size) {
SHERPA_ONNX_LOGE("tokens.size() %d != vocab_size %d",
... ...
... ... @@ -19,7 +19,7 @@
namespace {
enum class ModelType {
enum class ModelType : std::uint8_t {
kEncDecCTCModelBPE,
kEncDecHybridRNNTCTCBPEModel,
kTdnn,
... ...
... ... @@ -4,11 +4,11 @@
#include "sherpa-onnx/csrc/offline-stream.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <iomanip>
#include <utility>
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "sherpa-onnx/csrc/macros.h"
... ... @@ -56,7 +56,7 @@ class OfflineStream::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: config_(config), context_graph_(context_graph) {
: config_(config), context_graph_(std::move(context_graph)) {
if (config.is_mfcc) {
mfcc_opts_.frame_opts.dither = config_.dither;
mfcc_opts_.frame_opts.snip_edges = config_.snip_edges;
... ... @@ -266,7 +266,7 @@ class OfflineStream::Impl {
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
: impl_(std::make_unique<Impl>(config, std::move(context_graph))) {}
OfflineStream::OfflineStream(WhisperTag tag)
: impl_(std::make_unique<Impl>(tag)) {}
... ...
... ... @@ -42,7 +42,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);
for (int32_t i = 0; i < batch_size; ++i) {
const ContextState *context_state;
const ContextState *context_state = nullptr;
if (ss != nullptr) {
context_graphs[i] =
ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph();
... ...
... ... @@ -30,7 +30,7 @@ static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) {
std::string sym;
std::u32string s;
int32_t id;
int32_t id = 0;
while (std::getline(is, line)) {
std::istringstream iss(line);
iss >> sym;
... ... @@ -138,6 +138,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
}
ans.push_back(std::move(this_sentence));
this_sentence = {};
// re-initialize this_sentence
if (use_eos_bos) {
... ... @@ -172,6 +173,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
}
ans.push_back(std::move(this_sentence));
this_sentence = {};
// re-initialize this_sentence
if (use_eos_bos) {
... ...
... ... @@ -5,6 +5,7 @@
#include "sherpa-onnx/csrc/offline-tts.h"
#include <string>
#include <utility>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
... ... @@ -87,7 +88,7 @@ OfflineTts::~OfflineTts() = default;
GeneratedAudio OfflineTts::Generate(
const std::string &text, int64_t sid /*=0*/, float speed /*= 1.0*/,
GeneratedAudioCallback callback /*= nullptr*/) const {
return impl_->Generate(text, sid, speed, callback);
return impl_->Generate(text, sid, speed, std::move(callback));
}
int32_t OfflineTts::SampleRate() const { return impl_->SampleRate(); }
... ...
... ... @@ -22,9 +22,9 @@ class OfflineWhisperModel::Impl {
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
debug_(config.debug),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
... ... @@ -39,9 +39,9 @@ class OfflineWhisperModel::Impl {
explicit Impl(const SpokenLanguageIdentificationConfig &config)
: lid_config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
debug_(config_.debug),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
... ... @@ -148,7 +148,6 @@ class OfflineWhisperModel::Impl {
cross_v = std::move(std::get<4>(decoder_out));
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
int32_t vocab_size = VocabSize();
const auto &all_language_ids = GetAllLanguageIDs();
int32_t lang_id = all_language_ids[0];
... ... @@ -317,18 +316,18 @@ class OfflineWhisperModel::Impl {
std::unordered_map<int32_t, std::string> id2lang_;
// model meta data
int32_t n_text_layer_;
int32_t n_text_ctx_;
int32_t n_text_state_;
int32_t n_vocab_;
int32_t sot_;
int32_t eot_;
int32_t blank_;
int32_t translate_;
int32_t transcribe_;
int32_t no_timestamps_;
int32_t no_speech_;
int32_t is_multilingual_;
int32_t n_text_layer_ = 0;
int32_t n_text_ctx_ = 0;
int32_t n_text_state_ = 0;
int32_t n_vocab_ = 0;
int32_t sot_ = 0;
int32_t eot_ = 0;
int32_t blank_ = 0;
int32_t translate_ = 0;
int32_t transcribe_ = 0;
int32_t no_timestamps_ = 0;
int32_t no_speech_ = 0;
int32_t is_multilingual_ = 0;
std::vector<int64_t> sot_sequence_;
};
... ...
... ... @@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <memory>
#include <sstream>
#include <string>
... ...
... ... @@ -52,8 +52,9 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
if (ok) {
std::vector<int32_t> isymbols_out;
std::vector<int32_t> osymbols_out;
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, &osymbols_out,
nullptr);
/*ok =*/fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
&osymbols_out, nullptr);
// TODO(fangjun): handle ok is false
std::vector<int64_t> tokens;
tokens.reserve(isymbols_out.size());
... ...
... ... @@ -3,9 +3,8 @@
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <memory>
#include <sstream>
#include <string>
... ...
... ... @@ -265,16 +265,16 @@ class OnlineNeMoCtcModel::Impl {
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t window_size_;
int32_t chunk_shift_;
int32_t subsampling_factor_;
int32_t vocab_size_;
int32_t cache_last_channel_dim1_;
int32_t cache_last_channel_dim2_;
int32_t cache_last_channel_dim3_;
int32_t cache_last_time_dim1_;
int32_t cache_last_time_dim2_;
int32_t cache_last_time_dim3_;
int32_t window_size_ = 0;
int32_t chunk_shift_ = 0;
int32_t subsampling_factor_ = 0;
int32_t vocab_size_ = 0;
int32_t cache_last_channel_dim1_ = 0;
int32_t cache_last_channel_dim2_ = 0;
int32_t cache_last_channel_dim3_ = 0;
int32_t cache_last_time_dim1_ = 0;
int32_t cache_last_time_dim2_ = 0;
int32_t cache_last_time_dim3_ = 0;
Ort::Value cache_last_channel_{nullptr};
Ort::Value cache_last_time_{nullptr};
... ...
... ... @@ -5,9 +5,8 @@
#include "sherpa-onnx/csrc/online-recognizer.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <iomanip>
#include <memory>
#include <sstream>
... ...
... ... @@ -8,6 +8,7 @@
#include <vector>
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
namespace sherpa_onnx {
... ... @@ -15,7 +16,7 @@ class OnlineStream::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: feat_extractor_(config), context_graph_(context_graph) {}
: feat_extractor_(config), context_graph_(std::move(context_graph)) {}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
... ... @@ -146,7 +147,7 @@ class OnlineStream::Impl {
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
ContextGraphPtr context_graph /*= nullptr */)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
: impl_(std::make_unique<Impl>(config, std::move(context_graph))) {}
OnlineStream::~OnlineStream() = default;
... ...
... ... @@ -15,7 +15,6 @@
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
namespace sherpa_onnx {
... ...
... ... @@ -45,13 +45,13 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
}
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
OnlineTransducerDecoderResult &&other)
OnlineTransducerDecoderResult &&other) noexcept
: OnlineTransducerDecoderResult() {
*this = std::move(other);
}
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
OnlineTransducerDecoderResult &&other) {
OnlineTransducerDecoderResult &&other) noexcept {
if (this == &other) {
return *this;
}
... ...
... ... @@ -44,10 +44,10 @@ struct OnlineTransducerDecoderResult {
OnlineTransducerDecoderResult &operator=(
const OnlineTransducerDecoderResult &other);
OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other);
OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other) noexcept;
OnlineTransducerDecoderResult &operator=(
OnlineTransducerDecoderResult &&other);
OnlineTransducerDecoderResult &&other) noexcept;
};
class OnlineStream;
... ...
... ... @@ -23,7 +23,7 @@
namespace {
enum class ModelType {
enum class ModelType : std::uint8_t {
kConformer,
kLstm,
kZipformer,
... ...
... ... @@ -5,10 +5,9 @@
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
#include <assert.h>
#include <math.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <memory>
#include <numeric>
#include <sstream>
... ... @@ -429,8 +428,8 @@ class OnlineTransducerNeMoModel::Impl {
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
int32_t window_size_;
int32_t chunk_shift_;
int32_t window_size_ = 0;
int32_t chunk_shift_ = 0;
int32_t vocab_size_ = 0;
int32_t subsampling_factor_ = 8;
std::string normalize_type_;
... ... @@ -438,12 +437,12 @@ class OnlineTransducerNeMoModel::Impl {
int32_t pred_hidden_ = -1;
// encoder states
int32_t cache_last_channel_dim1_;
int32_t cache_last_channel_dim2_;
int32_t cache_last_channel_dim3_;
int32_t cache_last_time_dim1_;
int32_t cache_last_time_dim2_;
int32_t cache_last_time_dim3_;
int32_t cache_last_channel_dim1_ = 0;
int32_t cache_last_channel_dim2_ = 0;
int32_t cache_last_channel_dim3_ = 0;
int32_t cache_last_time_dim1_ = 0;
int32_t cache_last_time_dim2_ = 0;
int32_t cache_last_time_dim3_ = 0;
// init encoder states
Ort::Value cache_last_channel_{nullptr};
... ...
... ... @@ -192,15 +192,15 @@ class OnlineWenetCtcModel::Impl {
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t head_;
int32_t num_blocks_;
int32_t output_size_;
int32_t cnn_module_kernel_;
int32_t right_context_;
int32_t subsampling_factor_;
int32_t vocab_size_;
int32_t required_cache_size_;
int32_t head_ = 0;
int32_t num_blocks_ = 0;
int32_t output_size_ = 0;
int32_t cnn_module_kernel_ = 0;
int32_t right_context_ = 0;
int32_t subsampling_factor_ = 0;
int32_t vocab_size_ = 0;
int32_t required_cache_size_ = 0;
Ort::Value attn_cache_{nullptr};
Ort::Value conv_cache_{nullptr};
... ...
... ... @@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <memory>
#include <sstream>
#include <string>
... ...
... ... @@ -4,10 +4,8 @@
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
#include <assert.h>
#include <math.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <string>
... ... @@ -90,7 +88,6 @@ class OnlineZipformer2CtcModel::Impl {
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const {
int32_t batch_size = static_cast<int32_t>(states.size());
int32_t num_encoders = static_cast<int32_t>(num_encoder_layers_.size());
std::vector<const Ort::Value *> buf(batch_size);
... ... @@ -168,7 +165,6 @@ class OnlineZipformer2CtcModel::Impl {
assert(states.size() == m * 6 + 2);
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
int32_t num_encoders = num_encoder_layers_.size();
std::vector<std::vector<Ort::Value>> ans;
ans.resize(batch_size);
... ...
... ... @@ -4,10 +4,9 @@
#include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h"
#include <assert.h>
#include <math.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <memory>
#include <numeric>
#include <sstream>
... ...
... ... @@ -281,11 +281,12 @@ CopyableOrtValue &CopyableOrtValue::operator=(const CopyableOrtValue &other) {
return *this;
}
CopyableOrtValue::CopyableOrtValue(CopyableOrtValue &&other) {
CopyableOrtValue::CopyableOrtValue(CopyableOrtValue &&other) noexcept {
*this = std::move(other);
}
CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) {
CopyableOrtValue &CopyableOrtValue::operator=(
CopyableOrtValue &&other) noexcept {
if (this == &other) {
return *this;
}
... ...
... ... @@ -110,9 +110,9 @@ struct CopyableOrtValue {
CopyableOrtValue &operator=(const CopyableOrtValue &other);
CopyableOrtValue(CopyableOrtValue &&other);
CopyableOrtValue(CopyableOrtValue &&other) noexcept;
CopyableOrtValue &operator=(CopyableOrtValue &&other);
CopyableOrtValue &operator=(CopyableOrtValue &&other) noexcept;
};
std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values);
... ...
... ... @@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/packed-sequence.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <numeric>
#include <utility>
... ... @@ -57,7 +56,7 @@ PackedSequence PackPaddedSequence(OrtAllocator *allocator,
int64_t max_T = p_length[indexes[0]];
int32_t sum_T = std::accumulate(p_length, p_length + n, 0);
auto sum_T = std::accumulate(p_length, p_length + n, static_cast<int64_t>(0));
std::array<int64_t, 2> data_shape{sum_T, v_shape[2]};
... ...
... ... @@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/pad-sequence.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <vector>
namespace sherpa_onnx {
... ...
... ... @@ -11,9 +11,8 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include <ctype.h>
#include <algorithm>
#include <array>
#include <cctype>
#include <cstring>
#include <fstream>
... ... @@ -33,7 +32,7 @@ ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po)
} else {
other_parser_ = po;
}
if (po != nullptr && po->prefix_ != "") {
if (po != nullptr && !po->prefix_.empty()) {
prefix_ = po->prefix_ + std::string(".") + prefix;
} else {
prefix_ = prefix;
... ... @@ -179,10 +178,10 @@ void ParseOptions::DisableOption(const std::string &name) {
string_map_.erase(name);
}
int ParseOptions::NumArgs() const { return positional_args_.size(); }
int32_t ParseOptions::NumArgs() const { return positional_args_.size(); }
std::string ParseOptions::GetArg(int i) const {
if (i < 1 || i > static_cast<int>(positional_args_.size())) {
std::string ParseOptions::GetArg(int32_t i) const {
if (i < 1 || i > static_cast<int32_t>(positional_args_.size())) {
SHERPA_ONNX_LOGE("ParseOptions::GetArg, invalid index %d", i);
exit(-1);
}
... ... @@ -191,7 +190,7 @@ std::string ParseOptions::GetArg(int i) const {
}
// We currently do not support any other options.
enum ShellType { kBash = 0 };
enum ShellType : std::uint8_t { kBash = 0 };
// This can be changed in the code if it ever does need to be changed (as it's
// unlikely that one compilation of this tool-set would use both shells).
... ... @@ -213,7 +212,7 @@ static bool MustBeQuoted(const std::string &str, ShellType st) {
if (*c == '\0') {
return true; // Must quote empty string
} else {
const char *ok_chars[2];
std::array<const char *, 2> ok_chars{};
// These seem not to be interpreted as long as there are no other "bad"
// characters involved (e.g. "," would be interpreted as part of something
... ... @@ -229,7 +228,7 @@ static bool MustBeQuoted(const std::string &str, ShellType st) {
// are OK. All others are forbidden (this is easier since the shell
// interprets most non-alphanumeric characters).
if (!isalnum(*c)) {
const char *d;
const char *d = nullptr;
for (d = ok_chars[st]; *d != '\0'; ++d) {
if (*c == *d) break;
}
... ... @@ -269,22 +268,22 @@ static std::string QuoteAndEscape(const std::string &str, ShellType /*st*/) {
escape_str = "\\\""; // should never be accessed.
}
char buf[2];
std::array<char, 2> buf{};
buf[1] = '\0';
buf[0] = quote_char;
std::string ans = buf;
std::string ans = buf.data();
const char *c = str.c_str();
for (; *c != '\0'; ++c) {
if (*c == quote_char) {
ans += escape_str;
} else {
buf[0] = *c;
ans += buf;
ans += buf.data();
}
}
buf[0] = quote_char;
ans += buf;
ans += buf.data();
return ans;
}
... ... @@ -293,11 +292,11 @@ std::string ParseOptions::Escape(const std::string &str) {
return MustBeQuoted(str, kShellType) ? QuoteAndEscape(str, kShellType) : str;
}
int ParseOptions::Read(int argc, const char *const argv[]) {
int32_t ParseOptions::Read(int32_t argc, const char *const *argv) {
argc_ = argc;
argv_ = argv;
std::string key, value;
int i;
int32_t i = 0;
// first pass: look for config parameter, look for priority
for (i = 1; i < argc; ++i) {
... ... @@ -306,13 +305,13 @@ int ParseOptions::Read(int argc, const char *const argv[]) {
// a lone "--" marks the end of named options
break;
}
bool has_equal_sign;
bool has_equal_sign = false;
SplitLongArg(argv[i], &key, &value, &has_equal_sign);
NormalizeArgName(&key);
Trim(&value);
if (key.compare("config") == 0) {
if (key == "config") {
ReadConfigFile(value);
} else if (key.compare("help") == 0) {
} else if (key == "help") {
PrintUsage();
exit(0);
}
... ... @@ -330,7 +329,7 @@ int ParseOptions::Read(int argc, const char *const argv[]) {
double_dash_seen = true;
break;
}
bool has_equal_sign;
bool has_equal_sign = false;
SplitLongArg(argv[i], &key, &value, &has_equal_sign);
NormalizeArgName(&key);
Trim(&value);
... ... @@ -349,14 +348,14 @@ int ParseOptions::Read(int argc, const char *const argv[]) {
if ((std::strcmp(argv[i], "--") == 0) && !double_dash_seen) {
double_dash_seen = true;
} else {
positional_args_.push_back(std::string(argv[i]));
positional_args_.emplace_back(argv[i]);
}
}
// if the user did not suppress this with --print-args = false....
if (print_args_) {
std::ostringstream strm;
for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " ";
for (int32_t j = 0; j < argc; ++j) strm << Escape(argv[j]) << " ";
strm << '\n';
SHERPA_ONNX_LOGE("%s", strm.str().c_str());
}
... ... @@ -368,14 +367,14 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
os << '\n' << usage_ << '\n';
// first we print application-specific options
bool app_specific_header_printed = false;
for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
if (it->second.is_standard_ == false) { // application-specific option
for (const auto &it : doc_map_) {
if (it.second.is_standard_ == false) { // application-specific option
if (app_specific_header_printed == false) { // header was not yet printed
os << "Options:" << '\n';
app_specific_header_printed = true;
}
os << " --" << std::setw(25) << std::left << it->second.name_ << " : "
<< it->second.use_msg_ << '\n';
os << " --" << std::setw(25) << std::left << it.second.name_ << " : "
<< it.second.use_msg_ << '\n';
}
}
if (app_specific_header_printed == true) {
... ... @@ -384,17 +383,17 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
// then the standard options
os << "Standard options:" << '\n';
for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
if (it->second.is_standard_ == true) { // we have standard option
os << " --" << std::setw(25) << std::left << it->second.name_ << " : "
<< it->second.use_msg_ << '\n';
for (const auto &it : doc_map_) {
if (it.second.is_standard_ == true) { // we have standard option
os << " --" << std::setw(25) << std::left << it.second.name_ << " : "
<< it.second.use_msg_ << '\n';
}
}
os << '\n';
if (print_command_line) {
std::ostringstream strm;
strm << "Command line was: ";
for (int j = 0; j < argc_; ++j) strm << Escape(argv_[j]) << " ";
for (int32_t j = 0; j < argc_; ++j) strm << Escape(argv_[j]) << " ";
strm << '\n';
os << strm.str();
}
... ... @@ -405,9 +404,9 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
void ParseOptions::PrintConfig(std::ostream &os) const {
os << '\n' << "[[ Configuration of UI-Registered options ]]" << '\n';
std::string key;
for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
key = it->first;
os << it->second.name_ << " = ";
for (const auto &it : doc_map_) {
key = it.first;
os << it.second.name_ << " = ";
if (bool_map_.end() != bool_map_.find(key)) {
os << (*bool_map_.at(key) ? "true" : "false");
} else if (int_map_.end() != int_map_.find(key)) {
... ... @@ -442,13 +441,13 @@ void ParseOptions::ReadConfigFile(const std::string &filename) {
while (std::getline(is, line)) {
++line_number;
// trim out the comments
size_t pos;
if ((pos = line.find_first_of('#')) != std::string::npos) {
size_t pos = line.find_first_of('#');
if (pos != std::string::npos) {
line.erase(pos);
}
// skip empty lines
Trim(&line);
if (line.length() == 0) continue;
if (line.empty()) continue;
if (line.substr(0, 2) != "--") {
SHERPA_ONNX_LOGE(
... ... @@ -461,7 +460,7 @@ void ParseOptions::ReadConfigFile(const std::string &filename) {
}
// parse option
bool has_equal_sign;
bool has_equal_sign = false;
SplitLongArg(line, &key, &value, &has_equal_sign);
NormalizeArgName(&key);
Trim(&value);
... ... @@ -527,7 +526,7 @@ void ParseOptions::Trim(std::string *str) const {
bool ParseOptions::SetOption(const std::string &key, const std::string &value,
bool has_equal_sign) {
if (bool_map_.end() != bool_map_.find(key)) {
if (has_equal_sign && value == "") {
if (has_equal_sign && value.empty()) {
SHERPA_ONNX_LOGE("Invalid option --%s=", key.c_str());
exit(-1);
}
... ... @@ -557,12 +556,10 @@ bool ParseOptions::ToBool(std::string str) const {
std::transform(str.begin(), str.end(), str.begin(), ::tolower);
// allow "" as a valid option for "true", so that --x is the same as --x=true
if ((str.compare("true") == 0) || (str.compare("t") == 0) ||
(str.compare("1") == 0) || (str.compare("") == 0)) {
if (str == "true" || str == "t" || str == "1" || str.empty()) {
return true;
}
if ((str.compare("false") == 0) || (str.compare("f") == 0) ||
(str.compare("0") == 0)) {
if (str == "false" || str == "f" || str == "0") {
return false;
}
// if it is neither true nor false:
... ... @@ -593,7 +590,7 @@ uint32_t ParseOptions::ToUint(const std::string &str) const {
}
float ParseOptions::ToFloat(const std::string &str) const {
float ret;
float ret = 0;
if (!ConvertStringToReal(str, &ret)) {
SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str());
exit(-1);
... ... @@ -602,7 +599,7 @@ float ParseOptions::ToFloat(const std::string &str) const {
}
double ParseOptions::ToDouble(const std::string &str) const {
double ret;
double ret = 0;
if (!ConvertStringToReal(str, &ret)) {
SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str());
exit(-1);
... ...
... ... @@ -37,7 +37,7 @@ static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) {
std::string sym;
std::u32string s;
int32_t id;
int32_t id = 0;
while (std::getline(is, line)) {
std::istringstream iss(line);
iss >> sym;
... ...
... ... @@ -24,10 +24,9 @@
#include "sherpa-onnx/csrc/resample.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <type_traits>
... ... @@ -54,8 +53,8 @@ I Gcd(I m, I n) {
}
// could use compile-time assertion
// but involves messing with complex template stuff.
static_assert(std::is_integral<I>::value, "");
while (1) {
static_assert(std::is_integral_v<I>);
while (true) {
m %= n;
if (m == 0) return (n > 0 ? n : -n);
n %= m;
... ... @@ -139,10 +138,10 @@ void LinearResample::SetIndexesAndWeights() {
in the header as h(t) = f(t)g(t), evaluated at t.
*/
float LinearResample::FilterFunc(float t) const {
float window, // raised-cosine (Hanning) window of width
// num_zeros_/2*filter_cutoff_
filter; // sinc filter function
if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_))
float window = 0, // raised-cosine (Hanning) window of width
// num_zeros_/2*filter_cutoff_
filter = 0; // sinc filter function
if (std::fabs(t) < num_zeros_ / (2.0 * filter_cutoff_))
window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t));
else
window = 0.0; // outside support of window function
... ... @@ -172,15 +171,15 @@ void LinearResample::Resample(const float *input, int32_t input_dim, bool flush,
// of it we are producing here.
for (int64_t samp_out = output_sample_offset_; samp_out < tot_output_samp;
samp_out++) {
int64_t first_samp_in;
int32_t samp_out_wrapped;
int64_t first_samp_in = 0;
int32_t samp_out_wrapped = 0;
GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped);
const std::vector<float> &weights = weights_[samp_out_wrapped];
// first_input_index is the first index into "input" that we have a weight
// for.
int32_t first_input_index =
static_cast<int32_t>(first_samp_in - input_sample_offset_);
float this_output;
float this_output = 0;
if (first_input_index >= 0 &&
first_input_index + static_cast<int32_t>(weights.size()) <= input_dim) {
this_output =
... ... @@ -239,7 +238,7 @@ int64_t LinearResample::GetNumOutputSamples(int64_t input_num_samp,
// largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one).
// So when we're subtracting the window-width we can ignore the fractional
// part.
int32_t window_width_ticks = floor(window_width * tick_freq);
int32_t window_width_ticks = std::floor(window_width * tick_freq);
// The time-period of the output that we can sample gets reduced
// by the window-width (which is actually the distance from the
// center to the edge of the windowing function) if we're not
... ... @@ -287,7 +286,7 @@ void LinearResample::SetRemainder(const float *input, int32_t input_dim) {
// that are "in the past" relative to the beginning of the latest
// input... anyway, storing more remainder than needed is not harmful.
int32_t max_remainder_needed =
ceil(samp_rate_in_ * num_zeros_ / filter_cutoff_);
std::ceil(samp_rate_in_ * num_zeros_ / filter_cutoff_);
input_remainder_.resize(max_remainder_needed);
for (int32_t index = -static_cast<int32_t>(input_remainder_.size());
index < 0; index++) {
... ...
... ... @@ -130,11 +130,11 @@ class LinearResample {
// the following variables keep track of where we are in a particular signal,
// if it is being provided over multiple calls to Resample().
int64_t input_sample_offset_; ///< The number of input samples we have
///< already received for this signal
///< (including anything in remainder_)
int64_t output_sample_offset_; ///< The number of samples we have already
///< output for this signal.
int64_t input_sample_offset_ = 0; ///< The number of input samples we have
///< already received for this signal
///< (including anything in remainder_)
int64_t output_sample_offset_ = 0; ///< The number of samples we have already
///< output for this signal.
std::vector<float> input_remainder_; ///< A small trailing part of the
///< previously seen input signal.
};
... ...
... ... @@ -21,14 +21,14 @@
namespace sherpa_onnx {
static void OrtStatusFailure(OrtStatus *status, const char *s) {
const auto &api = Ort::GetApi();
const char *msg = api.GetErrorMessage(status);
SHERPA_ONNX_LOGE(
const auto &api = Ort::GetApi();
const char *msg = api.GetErrorMessage(status);
SHERPA_ONNX_LOGE(
"Failed to enable TensorRT : %s."
"Available providers: %s. Fallback to cuda", msg, s);
api.ReleaseStatus(status);
"Available providers: %s. Fallback to cuda",
msg, s);
api.ReleaseStatus(status);
}
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
... ... @@ -65,29 +65,28 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
}
case Provider::kTRT: {
struct TrtPairs {
const char* op_keys;
const char* op_values;
const char *op_keys;
const char *op_values;
};
std::vector<TrtPairs> trt_options = {
{"device_id", "0"},
{"trt_max_workspace_size", "2147483648"},
{"trt_max_partition_iterations", "10"},
{"trt_min_subgraph_size", "5"},
{"trt_fp16_enable", "0"},
{"trt_detailed_build_log", "0"},
{"trt_engine_cache_enable", "1"},
{"trt_engine_cache_path", "."},
{"trt_timing_cache_enable", "1"},
{"trt_timing_cache_path", "."}
};
{"device_id", "0"},
{"trt_max_workspace_size", "2147483648"},
{"trt_max_partition_iterations", "10"},
{"trt_min_subgraph_size", "5"},
{"trt_fp16_enable", "0"},
{"trt_detailed_build_log", "0"},
{"trt_engine_cache_enable", "1"},
{"trt_engine_cache_path", "."},
{"trt_timing_cache_enable", "1"},
{"trt_timing_cache_path", "."}};
// ToDo : Trt configs
// "trt_int8_enable"
// "trt_int8_use_native_calibration_table"
// "trt_dump_subgraphs"
std::vector<const char*> option_keys, option_values;
for (const TrtPairs& pair : trt_options) {
std::vector<const char *> option_keys, option_values;
for (const TrtPairs &pair : trt_options) {
option_keys.emplace_back(pair.op_keys);
option_values.emplace_back(pair.op_values);
}
... ... @@ -95,19 +94,23 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
std::vector<std::string> available_providers =
Ort::GetAvailableProviders();
if (std::find(available_providers.begin(), available_providers.end(),
"TensorrtExecutionProvider") != available_providers.end()) {
const auto& api = Ort::GetApi();
"TensorrtExecutionProvider") != available_providers.end()) {
const auto &api = Ort::GetApi();
OrtTensorRTProviderOptionsV2* tensorrt_options;
OrtStatus *statusC = api.CreateTensorRTProviderOptions(
&tensorrt_options);
OrtTensorRTProviderOptionsV2 *tensorrt_options = nullptr;
OrtStatus *statusC =
api.CreateTensorRTProviderOptions(&tensorrt_options);
OrtStatus *statusU = api.UpdateTensorRTProviderOptions(
tensorrt_options, option_keys.data(), option_values.data(),
option_keys.size());
tensorrt_options, option_keys.data(), option_values.data(),
option_keys.size());
sess_opts.AppendExecutionProvider_TensorRT_V2(*tensorrt_options);
if (statusC) { OrtStatusFailure(statusC, os.str().c_str()); }
if (statusU) { OrtStatusFailure(statusU, os.str().c_str()); }
if (statusC) {
OrtStatusFailure(statusC, os.str().c_str());
}
if (statusU) {
OrtStatusFailure(statusU, os.str().c_str());
}
api.ReleaseTensorRTProviderOptions(tensorrt_options);
}
... ...
... ... @@ -20,11 +20,11 @@ class SileroVadModel::Impl {
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
allocator_{},
sample_rate_(config.sample_rate) {
auto buf = ReadFile(config.silero_vad.model);
Init(buf.data(), buf.size());
sample_rate_ = config.sample_rate;
if (sample_rate_ != 16000) {
SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
config.sample_rate);
... ...
... ... @@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/slice.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <vector>
namespace sherpa_onnx {
... ...
... ... @@ -12,7 +12,7 @@ namespace sherpa_onnx {
namespace {
enum class ModelType {
enum class ModelType : std::uint8_t {
kWeSpeaker,
k3dSpeaker,
kNeMo,
... ...
... ... @@ -122,7 +122,7 @@ class SpeakerEmbeddingManager::Impl {
Eigen::VectorXf scores = embedding_matrix_ * v;
Eigen::VectorXf::Index max_index;
Eigen::VectorXf::Index max_index = 0;
float max_score = scores.maxCoeff(&max_index);
if (max_score < threshold) {
return {};
... ... @@ -178,11 +178,12 @@ class SpeakerEmbeddingManager::Impl {
std::vector<std::string> GetAllSpeakers() const {
std::vector<std::string> all_speakers;
all_speakers.reserve(name2row_.size());
for (const auto &p : name2row_) {
all_speakers.push_back(p.first);
}
std::stable_sort(all_speakers.begin(), all_speakers.end());
std::sort(all_speakers.begin(), all_speakers.end());
return all_speakers;
}
... ...
... ... @@ -18,7 +18,7 @@ namespace sherpa_onnx {
namespace {
enum class ModelType {
enum class ModelType : std::uint8_t {
kWhisper,
kUnknown,
};
... ...
... ... @@ -71,8 +71,8 @@ Ort::Value Stack(OrtAllocator *allocator,
T *dst = ans.GetTensorMutableData<T>();
for (int32_t i = 0; i != leading_size; ++i) {
for (int32_t n = 0; n != static_cast<int32_t>(values.size()); ++n) {
const T *src = values[n]->GetTensorData<T>();
for (auto value : values) {
const T *src = value->GetTensorData<T>();
src += i * trailing_size;
std::copy(src, src + trailing_size, dst);
... ...
... ... @@ -36,7 +36,7 @@ SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) {
void SymbolTable::Init(std::istream &is) {
std::string sym;
int32_t id;
int32_t id = 0;
while (is >> sym >> id) {
#if 0
// we disable the test here since for some multi-lingual BPE models
... ...
... ... @@ -5,9 +5,8 @@
#include "sherpa-onnx/csrc/text-utils.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdint>
#include <limits>
... ...
... ... @@ -151,7 +151,6 @@ void TransducerKeywordDecoder::Decode(
if (matched) {
float ys_prob = 0.0;
int32_t length = best_hyp.ys_probs.size();
for (int32_t i = 0; i < matched_state->level; ++i) {
ys_prob += best_hyp.ys_probs[i];
}
... ...
... ... @@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/transpose.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <vector>
namespace sherpa_onnx {
... ...
... ... @@ -4,9 +4,8 @@
#include "sherpa-onnx/csrc/unbind.h"
#include <assert.h>
#include <algorithm>
#include <cassert>
#include <functional>
#include <numeric>
#include <utility>
... ...
... ... @@ -30,7 +30,6 @@ static bool EncodeBase(const std::vector<std::string> &lines,
std::vector<float> tmp_thresholds;
std::vector<std::string> tmp_phrases;
std::string line;
std::string word;
bool has_scores = false;
bool has_thresholds = false;
... ... @@ -72,6 +71,7 @@ static bool EncodeBase(const std::vector<std::string> &lines,
}
}
ids->push_back(std::move(tmp_ids));
tmp_ids = {};
tmp_scores.push_back(score);
tmp_phrases.push_back(phrase);
tmp_thresholds.push_back(threshold);
... ...
... ... @@ -100,13 +100,13 @@ struct WaveHeader {
int32_t subchunk2_id; // a tag of this chunk
int32_t subchunk2_size; // size of subchunk2
};
static_assert(sizeof(WaveHeader) == 44, "");
static_assert(sizeof(WaveHeader) == 44);
// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
bool *is_ok) {
WaveHeader header;
WaveHeader header{};
is.read(reinterpret_cast<char *>(&header), sizeof(header));
if (!is) {
*is_ok = false;
... ...
... ... @@ -37,7 +37,7 @@ struct WaveHeader {
bool WriteWave(const std::string &filename, int32_t sampling_rate,
const float *samples, int32_t n) {
WaveHeader header;
WaveHeader header{};
header.chunk_id = 0x46464952; // FFIR
header.format = 0x45564157; // EVAW
header.subchunk1_id = 0x20746d66; // "fmt "
... ...