Committed by
GitHub
Refactor offline recognizer. (#94)
* Refactor offline recognizer. The purpose is to make it easier to support different types of models.
正在显示
9 个修改的文件
包含
283 行增加
和
134 行删除
| @@ -6,11 +6,12 @@ set(sources | @@ -6,11 +6,12 @@ set(sources | ||
| 6 | features.cc | 6 | features.cc |
| 7 | file-utils.cc | 7 | file-utils.cc |
| 8 | hypothesis.cc | 8 | hypothesis.cc |
| 9 | + offline-recognizer-impl.cc | ||
| 10 | + offline-recognizer.cc | ||
| 9 | offline-stream.cc | 11 | offline-stream.cc |
| 10 | offline-transducer-greedy-search-decoder.cc | 12 | offline-transducer-greedy-search-decoder.cc |
| 11 | offline-transducer-model-config.cc | 13 | offline-transducer-model-config.cc |
| 12 | offline-transducer-model.cc | 14 | offline-transducer-model.cc |
| 13 | - offline-recognizer.cc | ||
| 14 | online-lstm-transducer-model.cc | 15 | online-lstm-transducer-model.cc |
| 15 | online-recognizer.cc | 16 | online-recognizer.cc |
| 16 | online-stream.cc | 17 | online-stream.cc |
| @@ -23,36 +23,55 @@ | @@ -23,36 +23,55 @@ | ||
| 23 | } while (0) | 23 | } while (0) |
| 24 | #endif | 24 | #endif |
| 25 | 25 | ||
| 26 | +// Read an integer | ||
| 26 | #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ | 27 | #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ |
| 27 | do { \ | 28 | do { \ |
| 28 | auto value = \ | 29 | auto value = \ |
| 29 | meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | 30 | meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ |
| 30 | if (!value) { \ | 31 | if (!value) { \ |
| 31 | - fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ | 32 | + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ |
| 32 | exit(-1); \ | 33 | exit(-1); \ |
| 33 | } \ | 34 | } \ |
| 34 | \ | 35 | \ |
| 35 | dst = atoi(value.get()); \ | 36 | dst = atoi(value.get()); \ |
| 36 | if (dst <= 0) { \ | 37 | if (dst <= 0) { \ |
| 37 | - fprintf(stderr, "Invalid value %d for %s\n", dst, src_key); \ | 38 | + SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ |
| 38 | exit(-1); \ | 39 | exit(-1); \ |
| 39 | } \ | 40 | } \ |
| 40 | } while (0) | 41 | } while (0) |
| 41 | 42 | ||
| 42 | -#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ | ||
| 43 | - do { \ | ||
| 44 | - auto value = \ | ||
| 45 | - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 46 | - if (!value) { \ | ||
| 47 | - fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ | ||
| 48 | - exit(-1); \ | ||
| 49 | - } \ | ||
| 50 | - \ | ||
| 51 | - bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ | ||
| 52 | - if (!ret) { \ | ||
| 53 | - fprintf(stderr, "Invalid value %s for %s\n", value.get(), src_key); \ | ||
| 54 | - exit(-1); \ | ||
| 55 | - } \ | 43 | +// read a vector of integers |
| 44 | +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ | ||
| 45 | + do { \ | ||
| 46 | + auto value = \ | ||
| 47 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 48 | + if (!value) { \ | ||
| 49 | + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 50 | + exit(-1); \ | ||
| 51 | + } \ | ||
| 52 | + \ | ||
| 53 | + bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ | ||
| 54 | + if (!ret) { \ | ||
| 55 | + SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \ | ||
| 56 | + exit(-1); \ | ||
| 57 | + } \ | ||
| 58 | + } while (0) | ||
| 59 | + | ||
| 60 | +// Read a string | ||
| 61 | +#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ | ||
| 62 | + do { \ | ||
| 63 | + auto value = \ | ||
| 64 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 65 | + if (!value) { \ | ||
| 66 | + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 67 | + exit(-1); \ | ||
| 68 | + } \ | ||
| 69 | + \ | ||
| 70 | + dst = value.get(); \ | ||
| 71 | + if (dst.empty()) { \ | ||
| 72 | + SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \ | ||
| 73 | + exit(-1); \ | ||
| 74 | + } \ | ||
| 56 | } while (0) | 75 | } while (0) |
| 57 | 76 | ||
| 58 | #endif // SHERPA_ONNX_CSRC_MACROS_H_ | 77 | #endif // SHERPA_ONNX_CSRC_MACROS_H_ |
sherpa-onnx/csrc/offline-recognizer-impl.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-recognizer-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | +#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" | ||
| 12 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 13 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 18 | + const OfflineRecognizerConfig &config) { | ||
| 19 | + Ort::Env env; | ||
| 20 | + | ||
| 21 | + Ort::SessionOptions sess_opts; | ||
| 22 | + auto buf = ReadFile(config.model_config.encoder_filename); | ||
| 23 | + | ||
| 24 | + auto encoder_sess = | ||
| 25 | + std::make_unique<Ort::Session>(env, buf.data(), buf.size(), sess_opts); | ||
| 26 | + | ||
| 27 | + Ort::ModelMetadata meta_data = encoder_sess->GetModelMetadata(); | ||
| 28 | + | ||
| 29 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 30 | + | ||
| 31 | + std::string model_type; | ||
| 32 | + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); | ||
| 33 | + | ||
| 34 | + if (model_type == "conformer") { | ||
| 35 | + return std::make_unique<OfflineRecognizerTransducerImpl>(config); | ||
| 36 | + } | ||
| 37 | + | ||
| 38 | + SHERPA_ONNX_LOGE("Unsupported model_type: %s\n", model_type.c_str()); | ||
| 39 | + | ||
| 40 | + exit(-1); | ||
| 41 | +} | ||
| 42 | + | ||
| 43 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-recognizer-impl.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-recognizer-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 11 | +#include "sherpa-onnx/csrc/offline-stream.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineRecognizerImpl { | ||
| 16 | + public: | ||
| 17 | + static std::unique_ptr<OfflineRecognizerImpl> Create( | ||
| 18 | + const OfflineRecognizerConfig &config); | ||
| 19 | + | ||
| 20 | + virtual ~OfflineRecognizerImpl() = default; | ||
| 21 | + | ||
| 22 | + virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; | ||
| 23 | + | ||
| 24 | + virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; | ||
| 25 | +}; | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | + | ||
| 29 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-recognizer-transducer-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <utility> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 14 | +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" | ||
| 15 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 16 | +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" | ||
| 17 | +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" | ||
| 18 | +#include "sherpa-onnx/csrc/offline-transducer-model.h" | ||
| 19 | +#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 20 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 21 | + | ||
| 22 | +namespace sherpa_onnx { | ||
| 23 | + | ||
| 24 | +static OfflineRecognitionResult Convert( | ||
| 25 | + const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table, | ||
| 26 | + int32_t frame_shift_ms, int32_t subsampling_factor) { | ||
| 27 | + OfflineRecognitionResult r; | ||
| 28 | + r.tokens.reserve(src.tokens.size()); | ||
| 29 | + r.timestamps.reserve(src.timestamps.size()); | ||
| 30 | + | ||
| 31 | + std::string text; | ||
| 32 | + for (auto i : src.tokens) { | ||
| 33 | + auto sym = sym_table[i]; | ||
| 34 | + text.append(sym); | ||
| 35 | + | ||
| 36 | + r.tokens.push_back(std::move(sym)); | ||
| 37 | + } | ||
| 38 | + r.text = std::move(text); | ||
| 39 | + | ||
| 40 | + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | ||
| 41 | + for (auto t : src.timestamps) { | ||
| 42 | + float time = frame_shift_s * t; | ||
| 43 | + r.timestamps.push_back(time); | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + return r; | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 50 | + public: | ||
| 51 | + explicit OfflineRecognizerTransducerImpl( | ||
| 52 | + const OfflineRecognizerConfig &config) | ||
| 53 | + : config_(config), | ||
| 54 | + symbol_table_(config_.model_config.tokens), | ||
| 55 | + model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { | ||
| 56 | + if (config_.decoding_method == "greedy_search") { | ||
| 57 | + decoder_ = | ||
| 58 | + std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); | ||
| 59 | + } else if (config_.decoding_method == "modified_beam_search") { | ||
| 60 | + SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented"); | ||
| 61 | + exit(-1); | ||
| 62 | + } else { | ||
| 63 | + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
| 64 | + config_.decoding_method.c_str()); | ||
| 65 | + exit(-1); | ||
| 66 | + } | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 70 | + return std::make_unique<OfflineStream>(config_.feat_config); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + void DecodeStreams(OfflineStream **ss, int32_t n) const override { | ||
| 74 | + auto memory_info = | ||
| 75 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 76 | + | ||
| 77 | + int32_t feat_dim = ss[0]->FeatureDim(); | ||
| 78 | + | ||
| 79 | + std::vector<Ort::Value> features; | ||
| 80 | + | ||
| 81 | + features.reserve(n); | ||
| 82 | + | ||
| 83 | + std::vector<std::vector<float>> features_vec(n); | ||
| 84 | + std::vector<int64_t> features_length_vec(n); | ||
| 85 | + for (int32_t i = 0; i != n; ++i) { | ||
| 86 | + auto f = ss[i]->GetFrames(); | ||
| 87 | + int32_t num_frames = f.size() / feat_dim; | ||
| 88 | + | ||
| 89 | + features_length_vec[i] = num_frames; | ||
| 90 | + features_vec[i] = std::move(f); | ||
| 91 | + | ||
| 92 | + std::array<int64_t, 2> shape = {num_frames, feat_dim}; | ||
| 93 | + | ||
| 94 | + Ort::Value x = Ort::Value::CreateTensor( | ||
| 95 | + memory_info, features_vec[i].data(), features_vec[i].size(), | ||
| 96 | + shape.data(), shape.size()); | ||
| 97 | + features.push_back(std::move(x)); | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + std::vector<const Ort::Value *> features_pointer(n); | ||
| 101 | + for (int32_t i = 0; i != n; ++i) { | ||
| 102 | + features_pointer[i] = &features[i]; | ||
| 103 | + } | ||
| 104 | + | ||
| 105 | + std::array<int64_t, 1> features_length_shape = {n}; | ||
| 106 | + Ort::Value x_length = Ort::Value::CreateTensor( | ||
| 107 | + memory_info, features_length_vec.data(), n, | ||
| 108 | + features_length_shape.data(), features_length_shape.size()); | ||
| 109 | + | ||
| 110 | + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, | ||
| 111 | + -23.025850929940457f); | ||
| 112 | + | ||
| 113 | + auto t = model_->RunEncoder(std::move(x), std::move(x_length)); | ||
| 114 | + auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); | ||
| 115 | + | ||
| 116 | + int32_t frame_shift_ms = 10; | ||
| 117 | + for (int32_t i = 0; i != n; ++i) { | ||
| 118 | + auto r = Convert(results[i], symbol_table_, frame_shift_ms, | ||
| 119 | + model_->SubsamplingFactor()); | ||
| 120 | + | ||
| 121 | + ss[i]->SetResult(r); | ||
| 122 | + } | ||
| 123 | + } | ||
| 124 | + | ||
| 125 | + private: | ||
| 126 | + OfflineRecognizerConfig config_; | ||
| 127 | + SymbolTable symbol_table_; | ||
| 128 | + std::unique_ptr<OfflineTransducerModel> model_; | ||
| 129 | + std::unique_ptr<OfflineTransducerDecoder> decoder_; | ||
| 130 | +}; | ||
| 131 | + | ||
| 132 | +} // namespace sherpa_onnx | ||
| 133 | + | ||
| 134 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ |
| @@ -5,42 +5,11 @@ | @@ -5,42 +5,11 @@ | ||
| 5 | #include "sherpa-onnx/csrc/offline-recognizer.h" | 5 | #include "sherpa-onnx/csrc/offline-recognizer.h" |
| 6 | 6 | ||
| 7 | #include <memory> | 7 | #include <memory> |
| 8 | -#include <utility> | ||
| 9 | 8 | ||
| 10 | -#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | -#include "sherpa-onnx/csrc/offline-transducer-decoder.h" | ||
| 12 | -#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" | ||
| 13 | -#include "sherpa-onnx/csrc/offline-transducer-model.h" | ||
| 14 | -#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 15 | -#include "sherpa-onnx/csrc/symbol-table.h" | 9 | +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" |
| 16 | 10 | ||
| 17 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| 18 | 12 | ||
| 19 | -static OfflineRecognitionResult Convert( | ||
| 20 | - const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table, | ||
| 21 | - int32_t frame_shift_ms, int32_t subsampling_factor) { | ||
| 22 | - OfflineRecognitionResult r; | ||
| 23 | - r.tokens.reserve(src.tokens.size()); | ||
| 24 | - r.timestamps.reserve(src.timestamps.size()); | ||
| 25 | - | ||
| 26 | - std::string text; | ||
| 27 | - for (auto i : src.tokens) { | ||
| 28 | - auto sym = sym_table[i]; | ||
| 29 | - text.append(sym); | ||
| 30 | - | ||
| 31 | - r.tokens.push_back(std::move(sym)); | ||
| 32 | - } | ||
| 33 | - r.text = std::move(text); | ||
| 34 | - | ||
| 35 | - float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | ||
| 36 | - for (auto t : src.timestamps) { | ||
| 37 | - float time = frame_shift_s * t; | ||
| 38 | - r.timestamps.push_back(time); | ||
| 39 | - } | ||
| 40 | - | ||
| 41 | - return r; | ||
| 42 | -} | ||
| 43 | - | ||
| 44 | void OfflineRecognizerConfig::Register(ParseOptions *po) { | 13 | void OfflineRecognizerConfig::Register(ParseOptions *po) { |
| 45 | feat_config.Register(po); | 14 | feat_config.Register(po); |
| 46 | model_config.Register(po); | 15 | model_config.Register(po); |
| @@ -65,90 +34,8 @@ std::string OfflineRecognizerConfig::ToString() const { | @@ -65,90 +34,8 @@ std::string OfflineRecognizerConfig::ToString() const { | ||
| 65 | return os.str(); | 34 | return os.str(); |
| 66 | } | 35 | } |
| 67 | 36 | ||
| 68 | -class OfflineRecognizer::Impl { | ||
| 69 | - public: | ||
| 70 | - explicit Impl(const OfflineRecognizerConfig &config) | ||
| 71 | - : config_(config), | ||
| 72 | - symbol_table_(config_.model_config.tokens), | ||
| 73 | - model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { | ||
| 74 | - if (config_.decoding_method == "greedy_search") { | ||
| 75 | - decoder_ = | ||
| 76 | - std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); | ||
| 77 | - } else if (config_.decoding_method == "modified_beam_search") { | ||
| 78 | - SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented"); | ||
| 79 | - exit(-1); | ||
| 80 | - } else { | ||
| 81 | - SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
| 82 | - config_.decoding_method.c_str()); | ||
| 83 | - exit(-1); | ||
| 84 | - } | ||
| 85 | - } | ||
| 86 | - | ||
| 87 | - std::unique_ptr<OfflineStream> CreateStream() const { | ||
| 88 | - return std::make_unique<OfflineStream>(config_.feat_config); | ||
| 89 | - } | ||
| 90 | - | ||
| 91 | - void DecodeStreams(OfflineStream **ss, int32_t n) const { | ||
| 92 | - auto memory_info = | ||
| 93 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 94 | - | ||
| 95 | - int32_t feat_dim = ss[0]->FeatureDim(); | ||
| 96 | - | ||
| 97 | - std::vector<Ort::Value> features; | ||
| 98 | - | ||
| 99 | - features.reserve(n); | ||
| 100 | - | ||
| 101 | - std::vector<std::vector<float>> features_vec(n); | ||
| 102 | - std::vector<int64_t> features_length_vec(n); | ||
| 103 | - for (int32_t i = 0; i != n; ++i) { | ||
| 104 | - auto f = ss[i]->GetFrames(); | ||
| 105 | - int32_t num_frames = f.size() / feat_dim; | ||
| 106 | - | ||
| 107 | - features_length_vec[i] = num_frames; | ||
| 108 | - features_vec[i] = std::move(f); | ||
| 109 | - | ||
| 110 | - std::array<int64_t, 2> shape = {num_frames, feat_dim}; | ||
| 111 | - | ||
| 112 | - Ort::Value x = Ort::Value::CreateTensor( | ||
| 113 | - memory_info, features_vec[i].data(), features_vec[i].size(), | ||
| 114 | - shape.data(), shape.size()); | ||
| 115 | - features.push_back(std::move(x)); | ||
| 116 | - } | ||
| 117 | - | ||
| 118 | - std::vector<const Ort::Value *> features_pointer(n); | ||
| 119 | - for (int32_t i = 0; i != n; ++i) { | ||
| 120 | - features_pointer[i] = &features[i]; | ||
| 121 | - } | ||
| 122 | - | ||
| 123 | - std::array<int64_t, 1> features_length_shape = {n}; | ||
| 124 | - Ort::Value x_length = Ort::Value::CreateTensor( | ||
| 125 | - memory_info, features_length_vec.data(), n, | ||
| 126 | - features_length_shape.data(), features_length_shape.size()); | ||
| 127 | - | ||
| 128 | - Ort::Value x = PadSequence(model_->Allocator(), features_pointer, | ||
| 129 | - -23.025850929940457f); | ||
| 130 | - | ||
| 131 | - auto t = model_->RunEncoder(std::move(x), std::move(x_length)); | ||
| 132 | - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); | ||
| 133 | - | ||
| 134 | - int32_t frame_shift_ms = 10; | ||
| 135 | - for (int32_t i = 0; i != n; ++i) { | ||
| 136 | - auto r = Convert(results[i], symbol_table_, frame_shift_ms, | ||
| 137 | - model_->SubsamplingFactor()); | ||
| 138 | - | ||
| 139 | - ss[i]->SetResult(r); | ||
| 140 | - } | ||
| 141 | - } | ||
| 142 | - | ||
| 143 | - private: | ||
| 144 | - OfflineRecognizerConfig config_; | ||
| 145 | - SymbolTable symbol_table_; | ||
| 146 | - std::unique_ptr<OfflineTransducerModel> model_; | ||
| 147 | - std::unique_ptr<OfflineTransducerDecoder> decoder_; | ||
| 148 | -}; | ||
| 149 | - | ||
| 150 | OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) | 37 | OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) |
| 151 | - : impl_(std::make_unique<Impl>(config)) {} | 38 | + : impl_(OfflineRecognizerImpl::Create(config)) {} |
| 152 | 39 | ||
| 153 | OfflineRecognizer::~OfflineRecognizer() = default; | 40 | OfflineRecognizer::~OfflineRecognizer() = default; |
| 154 | 41 |
| @@ -52,6 +52,8 @@ struct OfflineRecognizerConfig { | @@ -52,6 +52,8 @@ struct OfflineRecognizerConfig { | ||
| 52 | std::string ToString() const; | 52 | std::string ToString() const; |
| 53 | }; | 53 | }; |
| 54 | 54 | ||
| 55 | +class OfflineRecognizerImpl; | ||
| 56 | + | ||
| 55 | class OfflineRecognizer { | 57 | class OfflineRecognizer { |
| 56 | public: | 58 | public: |
| 57 | ~OfflineRecognizer(); | 59 | ~OfflineRecognizer(); |
| @@ -78,8 +80,7 @@ class OfflineRecognizer { | @@ -78,8 +80,7 @@ class OfflineRecognizer { | ||
| 78 | void DecodeStreams(OfflineStream **ss, int32_t n) const; | 80 | void DecodeStreams(OfflineStream **ss, int32_t n) const; |
| 79 | 81 | ||
| 80 | private: | 82 | private: |
| 81 | - class Impl; | ||
| 82 | - std::unique_ptr<Impl> impl_; | 83 | + std::unique_ptr<OfflineRecognizerImpl> impl_; |
| 83 | }; | 84 | }; |
| 84 | 85 | ||
| 85 | } // namespace sherpa_onnx | 86 | } // namespace sherpa_onnx |
| @@ -5,6 +5,8 @@ | @@ -5,6 +5,8 @@ | ||
| 5 | 5 | ||
| 6 | #include "sherpa-onnx/csrc/text-utils.h" | 6 | #include "sherpa-onnx/csrc/text-utils.h" |
| 7 | 7 | ||
| 8 | +#include <assert.h> | ||
| 9 | + | ||
| 8 | #include <string> | 10 | #include <string> |
| 9 | #include <vector> | 11 | #include <vector> |
| 10 | 12 | ||
| @@ -27,4 +29,31 @@ void SplitStringToVector(const std::string &full, const char *delim, | @@ -27,4 +29,31 @@ void SplitStringToVector(const std::string &full, const char *delim, | ||
| 27 | } | 29 | } |
| 28 | } | 30 | } |
| 29 | 31 | ||
| 32 | +template <class F> | ||
| 33 | +bool SplitStringToFloats(const std::string &full, const char *delim, | ||
| 34 | + bool omit_empty_strings, // typically false | ||
| 35 | + std::vector<F> *out) { | ||
| 36 | + assert(out != nullptr); | ||
| 37 | + if (*(full.c_str()) == '\0') { | ||
| 38 | + out->clear(); | ||
| 39 | + return true; | ||
| 40 | + } | ||
| 41 | + std::vector<std::string> split; | ||
| 42 | + SplitStringToVector(full, delim, omit_empty_strings, &split); | ||
| 43 | + out->resize(split.size()); | ||
| 44 | + for (size_t i = 0; i < split.size(); ++i) { | ||
| 45 | + // assume atof never fails | ||
| 46 | + (*out)[i] = atof(split[i].c_str()); | ||
| 47 | + } | ||
| 48 | + return true; | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +// Instantiate the template above for float and double. | ||
| 52 | +template bool SplitStringToFloats(const std::string &full, const char *delim, | ||
| 53 | + bool omit_empty_strings, | ||
| 54 | + std::vector<float> *out); | ||
| 55 | +template bool SplitStringToFloats(const std::string &full, const char *delim, | ||
| 56 | + bool omit_empty_strings, | ||
| 57 | + std::vector<double> *out); | ||
| 58 | + | ||
| 30 | } // namespace sherpa_onnx | 59 | } // namespace sherpa_onnx |
| @@ -80,6 +80,12 @@ bool SplitStringToIntegers(const std::string &full, const char *delim, | @@ -80,6 +80,12 @@ bool SplitStringToIntegers(const std::string &full, const char *delim, | ||
| 80 | return true; | 80 | return true; |
| 81 | } | 81 | } |
| 82 | 82 | ||
| 83 | +// This is defined for F = float and double. | ||
| 84 | +template <class F> | ||
| 85 | +bool SplitStringToFloats(const std::string &full, const char *delim, | ||
| 86 | + bool omit_empty_strings, // typically false | ||
| 87 | + std::vector<F> *out); | ||
| 88 | + | ||
| 83 | } // namespace sherpa_onnx | 89 | } // namespace sherpa_onnx |
| 84 | 90 | ||
| 85 | #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ | 91 | #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ |
-
请 注册 或 登录 后发表评论