Fangjun Kuang
Committed by GitHub

Refactor offline recognizer. (#94)

* Refactor offline recognizer.

The purpose is to make it easier to support different types of models.
@@ -6,11 +6,12 @@ set(sources @@ -6,11 +6,12 @@ set(sources
6 features.cc 6 features.cc
7 file-utils.cc 7 file-utils.cc
8 hypothesis.cc 8 hypothesis.cc
  9 + offline-recognizer-impl.cc
  10 + offline-recognizer.cc
9 offline-stream.cc 11 offline-stream.cc
10 offline-transducer-greedy-search-decoder.cc 12 offline-transducer-greedy-search-decoder.cc
11 offline-transducer-model-config.cc 13 offline-transducer-model-config.cc
12 offline-transducer-model.cc 14 offline-transducer-model.cc
13 - offline-recognizer.cc  
14 online-lstm-transducer-model.cc 15 online-lstm-transducer-model.cc
15 online-recognizer.cc 16 online-recognizer.cc
16 online-stream.cc 17 online-stream.cc
@@ -23,36 +23,55 @@ @@ -23,36 +23,55 @@
23 } while (0) 23 } while (0)
24 #endif 24 #endif
25 25
  26 +// Read an integer
26 #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ 27 #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
27 do { \ 28 do { \
28 auto value = \ 29 auto value = \
29 meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ 30 meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
30 if (!value) { \ 31 if (!value) { \
31 - fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ 32 + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
32 exit(-1); \ 33 exit(-1); \
33 } \ 34 } \
34 \ 35 \
35 dst = atoi(value.get()); \ 36 dst = atoi(value.get()); \
36 if (dst <= 0) { \ 37 if (dst <= 0) { \
37 - fprintf(stderr, "Invalid value %d for %s\n", dst, src_key); \ 38 + SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
38 exit(-1); \ 39 exit(-1); \
39 } \ 40 } \
40 } while (0) 41 } while (0)
41 42
42 -#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \  
43 - do { \  
44 - auto value = \  
45 - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \  
46 - if (!value) { \  
47 - fprintf(stderr, "%s does not exist in the metadata\n", src_key); \  
48 - exit(-1); \  
49 - } \  
50 - \  
51 - bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \  
52 - if (!ret) { \  
53 - fprintf(stderr, "Invalid value %s for %s\n", value.get(), src_key); \  
54 - exit(-1); \  
55 - } \ 43 +// read a vector of integers
  44 +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \
  45 + do { \
  46 + auto value = \
  47 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  48 + if (!value) { \
  49 + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
  50 + exit(-1); \
  51 + } \
  52 + \
  53 + bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \
  54 + if (!ret) { \
  55 + SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \
  56 + exit(-1); \
  57 + } \
  58 + } while (0)
  59 +
  60 +// Read a string
  61 +#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
  62 + do { \
  63 + auto value = \
  64 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  65 + if (!value) { \
  66 + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
  67 + exit(-1); \
  68 + } \
  69 + \
  70 + dst = value.get(); \
  71 + if (dst.empty()) { \
  72 + SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \
  73 + exit(-1); \
  74 + } \
56 } while (0) 75 } while (0)
57 76
58 #endif // SHERPA_ONNX_CSRC_MACROS_H_ 77 #endif // SHERPA_ONNX_CSRC_MACROS_H_
  1 +// sherpa-onnx/csrc/offline-recognizer-impl.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "onnxruntime_cxx_api.h" // NOLINT
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
  13 +#include "sherpa-onnx/csrc/text-utils.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
  18 + const OfflineRecognizerConfig &config) {
  19 + Ort::Env env;
  20 +
  21 + Ort::SessionOptions sess_opts;
  22 + auto buf = ReadFile(config.model_config.encoder_filename);
  23 +
  24 + auto encoder_sess =
  25 + std::make_unique<Ort::Session>(env, buf.data(), buf.size(), sess_opts);
  26 +
  27 + Ort::ModelMetadata meta_data = encoder_sess->GetModelMetadata();
  28 +
  29 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  30 +
  31 + std::string model_type;
  32 + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
  33 +
  34 + if (model_type == "conformer") {
  35 + return std::make_unique<OfflineRecognizerTransducerImpl>(config);
  36 + }
  37 +
  38 + SHERPA_ONNX_LOGE("Unsupported model_type: %s\n", model_type.c_str());
  39 +
  40 + exit(-1);
  41 +}
  42 +
  43 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-recognizer-impl.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
  7 +
  8 +#include <memory>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  11 +#include "sherpa-onnx/csrc/offline-stream.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineRecognizerImpl {
  16 + public:
  17 + static std::unique_ptr<OfflineRecognizerImpl> Create(
  18 + const OfflineRecognizerConfig &config);
  19 +
  20 + virtual ~OfflineRecognizerImpl() = default;
  21 +
  22 + virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
  23 +
  24 + virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
  2 +//
  3 +// Copyright (c) 2022 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#include "sherpa-onnx/csrc/macros.h"
  14 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  15 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  16 +#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
  17 +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
  18 +#include "sherpa-onnx/csrc/offline-transducer-model.h"
  19 +#include "sherpa-onnx/csrc/pad-sequence.h"
  20 +#include "sherpa-onnx/csrc/symbol-table.h"
  21 +
  22 +namespace sherpa_onnx {
  23 +
  24 +static OfflineRecognitionResult Convert(
  25 + const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table,
  26 + int32_t frame_shift_ms, int32_t subsampling_factor) {
  27 + OfflineRecognitionResult r;
  28 + r.tokens.reserve(src.tokens.size());
  29 + r.timestamps.reserve(src.timestamps.size());
  30 +
  31 + std::string text;
  32 + for (auto i : src.tokens) {
  33 + auto sym = sym_table[i];
  34 + text.append(sym);
  35 +
  36 + r.tokens.push_back(std::move(sym));
  37 + }
  38 + r.text = std::move(text);
  39 +
  40 + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
  41 + for (auto t : src.timestamps) {
  42 + float time = frame_shift_s * t;
  43 + r.timestamps.push_back(time);
  44 + }
  45 +
  46 + return r;
  47 +}
  48 +
  49 +class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
  50 + public:
  51 + explicit OfflineRecognizerTransducerImpl(
  52 + const OfflineRecognizerConfig &config)
  53 + : config_(config),
  54 + symbol_table_(config_.model_config.tokens),
  55 + model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
  56 + if (config_.decoding_method == "greedy_search") {
  57 + decoder_ =
  58 + std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
  59 + } else if (config_.decoding_method == "modified_beam_search") {
  60 + SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented");
  61 + exit(-1);
  62 + } else {
  63 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
  64 + config_.decoding_method.c_str());
  65 + exit(-1);
  66 + }
  67 + }
  68 +
  69 + std::unique_ptr<OfflineStream> CreateStream() const override {
  70 + return std::make_unique<OfflineStream>(config_.feat_config);
  71 + }
  72 +
  73 + void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  74 + auto memory_info =
  75 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  76 +
  77 + int32_t feat_dim = ss[0]->FeatureDim();
  78 +
  79 + std::vector<Ort::Value> features;
  80 +
  81 + features.reserve(n);
  82 +
  83 + std::vector<std::vector<float>> features_vec(n);
  84 + std::vector<int64_t> features_length_vec(n);
  85 + for (int32_t i = 0; i != n; ++i) {
  86 + auto f = ss[i]->GetFrames();
  87 + int32_t num_frames = f.size() / feat_dim;
  88 +
  89 + features_length_vec[i] = num_frames;
  90 + features_vec[i] = std::move(f);
  91 +
  92 + std::array<int64_t, 2> shape = {num_frames, feat_dim};
  93 +
  94 + Ort::Value x = Ort::Value::CreateTensor(
  95 + memory_info, features_vec[i].data(), features_vec[i].size(),
  96 + shape.data(), shape.size());
  97 + features.push_back(std::move(x));
  98 + }
  99 +
  100 + std::vector<const Ort::Value *> features_pointer(n);
  101 + for (int32_t i = 0; i != n; ++i) {
  102 + features_pointer[i] = &features[i];
  103 + }
  104 +
  105 + std::array<int64_t, 1> features_length_shape = {n};
  106 + Ort::Value x_length = Ort::Value::CreateTensor(
  107 + memory_info, features_length_vec.data(), n,
  108 + features_length_shape.data(), features_length_shape.size());
  109 +
  110 + Ort::Value x = PadSequence(model_->Allocator(), features_pointer,
  111 + -23.025850929940457f);
  112 +
  113 + auto t = model_->RunEncoder(std::move(x), std::move(x_length));
  114 + auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
  115 +
  116 + int32_t frame_shift_ms = 10;
  117 + for (int32_t i = 0; i != n; ++i) {
  118 + auto r = Convert(results[i], symbol_table_, frame_shift_ms,
  119 + model_->SubsamplingFactor());
  120 +
  121 + ss[i]->SetResult(r);
  122 + }
  123 + }
  124 +
  125 + private:
  126 + OfflineRecognizerConfig config_;
  127 + SymbolTable symbol_table_;
  128 + std::unique_ptr<OfflineTransducerModel> model_;
  129 + std::unique_ptr<OfflineTransducerDecoder> decoder_;
  130 +};
  131 +
  132 +} // namespace sherpa_onnx
  133 +
  134 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
@@ -5,42 +5,11 @@ @@ -5,42 +5,11 @@
5 #include "sherpa-onnx/csrc/offline-recognizer.h" 5 #include "sherpa-onnx/csrc/offline-recognizer.h"
6 6
7 #include <memory> 7 #include <memory>
8 -#include <utility>  
9 8
10 -#include "sherpa-onnx/csrc/macros.h"  
11 -#include "sherpa-onnx/csrc/offline-transducer-decoder.h"  
12 -#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"  
13 -#include "sherpa-onnx/csrc/offline-transducer-model.h"  
14 -#include "sherpa-onnx/csrc/pad-sequence.h"  
15 -#include "sherpa-onnx/csrc/symbol-table.h" 9 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
16 10
17 namespace sherpa_onnx { 11 namespace sherpa_onnx {
18 12
19 -static OfflineRecognitionResult Convert(  
20 - const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table,  
21 - int32_t frame_shift_ms, int32_t subsampling_factor) {  
22 - OfflineRecognitionResult r;  
23 - r.tokens.reserve(src.tokens.size());  
24 - r.timestamps.reserve(src.timestamps.size());  
25 -  
26 - std::string text;  
27 - for (auto i : src.tokens) {  
28 - auto sym = sym_table[i];  
29 - text.append(sym);  
30 -  
31 - r.tokens.push_back(std::move(sym));  
32 - }  
33 - r.text = std::move(text);  
34 -  
35 - float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;  
36 - for (auto t : src.timestamps) {  
37 - float time = frame_shift_s * t;  
38 - r.timestamps.push_back(time);  
39 - }  
40 -  
41 - return r;  
42 -}  
43 -  
44 void OfflineRecognizerConfig::Register(ParseOptions *po) { 13 void OfflineRecognizerConfig::Register(ParseOptions *po) {
45 feat_config.Register(po); 14 feat_config.Register(po);
46 model_config.Register(po); 15 model_config.Register(po);
@@ -65,90 +34,8 @@ std::string OfflineRecognizerConfig::ToString() const { @@ -65,90 +34,8 @@ std::string OfflineRecognizerConfig::ToString() const {
65 return os.str(); 34 return os.str();
66 } 35 }
67 36
68 -class OfflineRecognizer::Impl {  
69 - public:  
70 - explicit Impl(const OfflineRecognizerConfig &config)  
71 - : config_(config),  
72 - symbol_table_(config_.model_config.tokens),  
73 - model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {  
74 - if (config_.decoding_method == "greedy_search") {  
75 - decoder_ =  
76 - std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());  
77 - } else if (config_.decoding_method == "modified_beam_search") {  
78 - SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented");  
79 - exit(-1);  
80 - } else {  
81 - SHERPA_ONNX_LOGE("Unsupported decoding method: %s",  
82 - config_.decoding_method.c_str());  
83 - exit(-1);  
84 - }  
85 - }  
86 -  
87 - std::unique_ptr<OfflineStream> CreateStream() const {  
88 - return std::make_unique<OfflineStream>(config_.feat_config);  
89 - }  
90 -  
91 - void DecodeStreams(OfflineStream **ss, int32_t n) const {  
92 - auto memory_info =  
93 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
94 -  
95 - int32_t feat_dim = ss[0]->FeatureDim();  
96 -  
97 - std::vector<Ort::Value> features;  
98 -  
99 - features.reserve(n);  
100 -  
101 - std::vector<std::vector<float>> features_vec(n);  
102 - std::vector<int64_t> features_length_vec(n);  
103 - for (int32_t i = 0; i != n; ++i) {  
104 - auto f = ss[i]->GetFrames();  
105 - int32_t num_frames = f.size() / feat_dim;  
106 -  
107 - features_length_vec[i] = num_frames;  
108 - features_vec[i] = std::move(f);  
109 -  
110 - std::array<int64_t, 2> shape = {num_frames, feat_dim};  
111 -  
112 - Ort::Value x = Ort::Value::CreateTensor(  
113 - memory_info, features_vec[i].data(), features_vec[i].size(),  
114 - shape.data(), shape.size());  
115 - features.push_back(std::move(x));  
116 - }  
117 -  
118 - std::vector<const Ort::Value *> features_pointer(n);  
119 - for (int32_t i = 0; i != n; ++i) {  
120 - features_pointer[i] = &features[i];  
121 - }  
122 -  
123 - std::array<int64_t, 1> features_length_shape = {n};  
124 - Ort::Value x_length = Ort::Value::CreateTensor(  
125 - memory_info, features_length_vec.data(), n,  
126 - features_length_shape.data(), features_length_shape.size());  
127 -  
128 - Ort::Value x = PadSequence(model_->Allocator(), features_pointer,  
129 - -23.025850929940457f);  
130 -  
131 - auto t = model_->RunEncoder(std::move(x), std::move(x_length));  
132 - auto results = decoder_->Decode(std::move(t.first), std::move(t.second));  
133 -  
134 - int32_t frame_shift_ms = 10;  
135 - for (int32_t i = 0; i != n; ++i) {  
136 - auto r = Convert(results[i], symbol_table_, frame_shift_ms,  
137 - model_->SubsamplingFactor());  
138 -  
139 - ss[i]->SetResult(r);  
140 - }  
141 - }  
142 -  
143 - private:  
144 - OfflineRecognizerConfig config_;  
145 - SymbolTable symbol_table_;  
146 - std::unique_ptr<OfflineTransducerModel> model_;  
147 - std::unique_ptr<OfflineTransducerDecoder> decoder_;  
148 -};  
149 -  
150 OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) 37 OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
151 - : impl_(std::make_unique<Impl>(config)) {} 38 + : impl_(OfflineRecognizerImpl::Create(config)) {}
152 39
153 OfflineRecognizer::~OfflineRecognizer() = default; 40 OfflineRecognizer::~OfflineRecognizer() = default;
154 41
@@ -52,6 +52,8 @@ struct OfflineRecognizerConfig { @@ -52,6 +52,8 @@ struct OfflineRecognizerConfig {
52 std::string ToString() const; 52 std::string ToString() const;
53 }; 53 };
54 54
  55 +class OfflineRecognizerImpl;
  56 +
55 class OfflineRecognizer { 57 class OfflineRecognizer {
56 public: 58 public:
57 ~OfflineRecognizer(); 59 ~OfflineRecognizer();
@@ -78,8 +80,7 @@ class OfflineRecognizer { @@ -78,8 +80,7 @@ class OfflineRecognizer {
78 void DecodeStreams(OfflineStream **ss, int32_t n) const; 80 void DecodeStreams(OfflineStream **ss, int32_t n) const;
79 81
80 private: 82 private:
81 - class Impl;  
82 - std::unique_ptr<Impl> impl_; 83 + std::unique_ptr<OfflineRecognizerImpl> impl_;
83 }; 84 };
84 85
85 } // namespace sherpa_onnx 86 } // namespace sherpa_onnx
@@ -5,6 +5,8 @@ @@ -5,6 +5,8 @@
5 5
6 #include "sherpa-onnx/csrc/text-utils.h" 6 #include "sherpa-onnx/csrc/text-utils.h"
7 7
  8 +#include <assert.h>
  9 +
8 #include <string> 10 #include <string>
9 #include <vector> 11 #include <vector>
10 12
@@ -27,4 +29,31 @@ void SplitStringToVector(const std::string &full, const char *delim, @@ -27,4 +29,31 @@ void SplitStringToVector(const std::string &full, const char *delim,
27 } 29 }
28 } 30 }
29 31
  32 +template <class F>
  33 +bool SplitStringToFloats(const std::string &full, const char *delim,
  34 + bool omit_empty_strings, // typically false
  35 + std::vector<F> *out) {
  36 + assert(out != nullptr);
  37 + if (*(full.c_str()) == '\0') {
  38 + out->clear();
  39 + return true;
  40 + }
  41 + std::vector<std::string> split;
  42 + SplitStringToVector(full, delim, omit_empty_strings, &split);
  43 + out->resize(split.size());
  44 + for (size_t i = 0; i < split.size(); ++i) {
  45 + // assume atof never fails
  46 + (*out)[i] = atof(split[i].c_str());
  47 + }
  48 + return true;
  49 +}
  50 +
  51 +// Instantiate the template above for float and double.
  52 +template bool SplitStringToFloats(const std::string &full, const char *delim,
  53 + bool omit_empty_strings,
  54 + std::vector<float> *out);
  55 +template bool SplitStringToFloats(const std::string &full, const char *delim,
  56 + bool omit_empty_strings,
  57 + std::vector<double> *out);
  58 +
30 } // namespace sherpa_onnx 59 } // namespace sherpa_onnx
@@ -80,6 +80,12 @@ bool SplitStringToIntegers(const std::string &full, const char *delim, @@ -80,6 +80,12 @@ bool SplitStringToIntegers(const std::string &full, const char *delim,
80 return true; 80 return true;
81 } 81 }
82 82
  83 +// This is defined for F = float and double.
  84 +template <class F>
  85 +bool SplitStringToFloats(const std::string &full, const char *delim,
  86 + bool omit_empty_strings, // typically false
  87 + std::vector<F> *out);
  88 +
83 } // namespace sherpa_onnx 89 } // namespace sherpa_onnx
84 90
85 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ 91 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_