Fangjun Kuang
Committed by GitHub

Add online transducer decoder (#27)

1 function(download_kaldi_native_fbank) 1 function(download_kaldi_native_fbank)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.11.tar.gz")  
5 - set(kaldi_native_fbank_HASH "SHA256=e69ae25ef6f30566ef31ca949dd1b0b8ec3a827caeba93a61d82bb848dac5d69") 4 + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.12.tar.gz")
  5 + set(kaldi_native_fbank_HASH "SHA256=8f4dfc3f6ddb1adcd9ac0ae87743ebc6cbcae147aacf9d46e76fa54134e12b44")
6 6
7 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 7 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
8 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
@@ -11,10 +11,11 @@ function(download_kaldi_native_fbank) @@ -11,10 +11,11 @@ function(download_kaldi_native_fbank)
11 # If you don't have access to the Internet, 11 # If you don't have access to the Internet,
12 # please pre-download kaldi-native-fbank 12 # please pre-download kaldi-native-fbank
13 set(possible_file_locations 13 set(possible_file_locations
14 - $ENV{HOME}/Downloads/kaldi-native-fbank-1.11.tar.gz  
15 - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.11.tar.gz  
16 - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.11.tar.gz  
17 - /tmp/kaldi-native-fbank-1.11.tar.gz 14 + $ENV{HOME}/Downloads/kaldi-native-fbank-1.12.tar.gz
  15 + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.12.tar.gz
  16 + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.12.tar.gz
  17 + /tmp/kaldi-native-fbank-1.12.tar.gz
  18 + /star-fj/fangjun/download/github/kaldi-native-fbank-1.12.tar.gz
18 ) 19 )
19 20
20 foreach(f IN LISTS possible_file_locations) 21 foreach(f IN LISTS possible_file_locations)
@@ -9,6 +9,7 @@ function(download_onnxruntime) @@ -9,6 +9,7 @@ function(download_onnxruntime)
9 ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz 9 ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz
10 ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz 10 ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz
11 /tmp/onnxruntime-linux-x64-1.14.0.tgz 11 /tmp/onnxruntime-linux-x64-1.14.0.tgz
  12 + /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz
12 ) 13 )
13 14
14 set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz") 15 set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz")
1 include_directories(${CMAKE_SOURCE_DIR}) 1 include_directories(${CMAKE_SOURCE_DIR})
2 2
3 add_executable(sherpa-onnx 3 add_executable(sherpa-onnx
4 - decode.cc  
5 features.cc 4 features.cc
6 online-lstm-transducer-model.cc 5 online-lstm-transducer-model.cc
  6 + online-transducer-greedy-search-decoder.cc
7 online-transducer-model-config.cc 7 online-transducer-model-config.cc
8 online-transducer-model.cc 8 online-transducer-model.cc
9 onnx-utils.cc 9 onnx-utils.cc
1 -// sherpa/csrc/decode.h  
2 -//  
3 -// Copyright (c) 2023 Xiaomi Corporation  
4 -  
5 -#ifndef SHERPA_ONNX_CSRC_DECODE_H_  
6 -#define SHERPA_ONNX_CSRC_DECODE_H_  
7 -  
8 -#include <vector>  
9 -  
10 -#include "sherpa-onnx/csrc/online-transducer-model.h"  
11 -  
12 -namespace sherpa_onnx {  
13 -  
14 -/** Greedy search for non-streaming ASR.  
15 - *  
16 - * @TODO(fangjun) Support batch size > 1  
17 - *  
18 - * @param model The RnntModel  
19 - * @param encoder_out Its shape is (1, num_frames, encoder_out_dim).  
20 - */  
21 -void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,  
22 - std::vector<int64_t> *hyp);  
23 -  
24 -} // namespace sherpa_onnx  
25 -  
26 -#endif // SHERPA_ONNX_CSRC_DECODE_H_  
@@ -15,16 +15,16 @@ namespace sherpa_onnx { @@ -15,16 +15,16 @@ namespace sherpa_onnx {
15 15
16 class FeatureExtractor::Impl { 16 class FeatureExtractor::Impl {
17 public: 17 public:
18 - Impl(int32_t sampling_rate, int32_t feature_dim) { 18 + explicit Impl(const FeatureExtractorConfig &config) {
19 opts_.frame_opts.dither = 0; 19 opts_.frame_opts.dither = 0;
20 opts_.frame_opts.snip_edges = false; 20 opts_.frame_opts.snip_edges = false;
21 - opts_.frame_opts.samp_freq = sampling_rate; 21 + opts_.frame_opts.samp_freq = config.sampling_rate;
22 22
23 // cache 100 seconds of feature frames, which is more than enough 23 // cache 100 seconds of feature frames, which is more than enough
24 // for real needs 24 // for real needs
25 opts_.frame_opts.max_feature_vectors = 100 * 100; 25 opts_.frame_opts.max_feature_vectors = 100 * 100;
26 26
27 - opts_.mel_opts.num_bins = feature_dim; 27 + opts_.mel_opts.num_bins = config.feature_dim;
28 28
29 fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 29 fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
30 } 30 }
@@ -80,9 +80,8 @@ class FeatureExtractor::Impl { @@ -80,9 +80,8 @@ class FeatureExtractor::Impl {
80 mutable std::mutex mutex_; 80 mutable std::mutex mutex_;
81 }; 81 };
82 82
83 -FeatureExtractor::FeatureExtractor(int32_t sampling_rate /*=16000*/,  
84 - int32_t feature_dim /*=80*/)  
85 - : impl_(std::make_unique<Impl>(sampling_rate, feature_dim)) {} 83 +FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
  84 + : impl_(std::make_unique<Impl>(config)) {}
86 85
87 FeatureExtractor::~FeatureExtractor() = default; 86 FeatureExtractor::~FeatureExtractor() = default;
88 87
@@ -10,14 +10,18 @@ @@ -10,14 +10,18 @@
10 10
11 namespace sherpa_onnx { 11 namespace sherpa_onnx {
12 12
  13 +struct FeatureExtractorConfig {
  14 + int32_t sampling_rate = 16000;
  15 + int32_t feature_dim = 80;
  16 +};
  17 +
13 class FeatureExtractor { 18 class FeatureExtractor {
14 public: 19 public:
15 /** 20 /**
16 * @param sampling_rate Sampling rate of the data used to train the model. 21 * @param sampling_rate Sampling rate of the data used to train the model.
17 * @param feature_dim Dimension of the features used to train the model. 22 * @param feature_dim Dimension of the features used to train the model.
18 */ 23 */
19 - explicit FeatureExtractor(int32_t sampling_rate = 16000,  
20 - int32_t feature_dim = 80); 24 + explicit FeatureExtractor(const FeatureExtractorConfig &config = {});
21 ~FeatureExtractor(); 25 ~FeatureExtractor();
22 26
23 /** 27 /**
  1 +// sherpa/csrc/online-transducer-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OnlineTransducerDecoderResult {
  15 + /// The decoded token IDs so far
  16 + std::vector<int64_t> tokens;
  17 +};
  18 +
  19 +class OnlineTransducerDecoder {
  20 + public:
  21 + virtual ~OnlineTransducerDecoder() = default;
  22 +
  23 + /* Return an empty result.
  24 + *
  25 + * To simplify the decoding code, we add `context_size` blanks
  26 + * to the beginning of the decoding result, which will be
  27 + * stripped by calling `StripPrecedingBlanks()`.
  28 + */
  29 + virtual OnlineTransducerDecoderResult GetEmptyResult() = 0;
  30 +
  31 + /** Strip blanks added by `GetEmptyResult()`.
  32 + *
  33 + * @param r It is changed in-place.
  34 + */
  35 + virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) {}
  36 +
  37 + /** Run transducer beam search given the output from the encoder model.
  38 + *
  39 + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
  40 + * @param result It is modified in-place.
  41 + *
  42 + * @note There is no need to pass encoder_out_length here since for the
  43 + * online decoding case, each utterance has the same number of frames
  44 + * and there are no paddings.
  45 + */
  46 + virtual void Decode(Ort::Value encoder_out,
  47 + std::vector<OnlineTransducerDecoderResult> *result) = 0;
  48 +};
  49 +
  50 +} // namespace sherpa_onnx
  51 +
  52 +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_
1 -// sherpa/csrc/decode.cc 1 +// sherpa/csrc/online-transducer-greedy-search-decoder.cc
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 4
5 -#include "sherpa-onnx/csrc/decode.h" 5 +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
6 6
7 #include <assert.h> 7 #include <assert.h>
8 8
@@ -10,19 +10,9 @@ @@ -10,19 +10,9 @@
10 #include <utility> 10 #include <utility>
11 #include <vector> 11 #include <vector>
12 12
13 -namespace sherpa_onnx {  
14 -  
15 -static Ort::Value Clone(Ort::Value *v) {  
16 - auto type_and_shape = v->GetTensorTypeAndShapeInfo();  
17 - std::vector<int64_t> shape = type_and_shape.GetShape();  
18 -  
19 - auto memory_info =  
20 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 13 +#include "sherpa-onnx/csrc/onnx-utils.h"
21 14
22 - return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),  
23 - type_and_shape.GetElementCount(),  
24 - shape.data(), shape.size());  
25 -} 15 +namespace sherpa_onnx {
26 16
27 static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { 17 static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
28 std::vector<int64_t> encoder_out_shape = 18 std::vector<int64_t> encoder_out_shape =
@@ -42,26 +32,58 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { @@ -42,26 +32,58 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
42 encoder_out_dim, shape.data(), shape.size()); 32 encoder_out_dim, shape.data(), shape.size());
43 } 33 }
44 34
45 -void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,  
46 - std::vector<int64_t> *hyp) { 35 +OnlineTransducerDecoderResult
  36 +OnlineTransducerGreedySearchDecoder::GetEmptyResult() {
  37 + int32_t context_size = model_->ContextSize();
  38 + int32_t blank_id = 0; // always 0
  39 + OnlineTransducerDecoderResult r;
  40 + r.tokens.resize(context_size, blank_id);
  41 +
  42 + return r;
  43 +}
  44 +
  45 +void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
  46 + OnlineTransducerDecoderResult *r) {
  47 + int32_t context_size = model_->ContextSize();
  48 +
  49 + auto start = r->tokens.begin() + context_size;
  50 + auto end = r->tokens.end();
  51 +
  52 + r->tokens = std::vector<int64_t>(start, end);
  53 +}
  54 +
  55 +void OnlineTransducerGreedySearchDecoder::Decode(
  56 + Ort::Value encoder_out,
  57 + std::vector<OnlineTransducerDecoderResult> *result) {
47 std::vector<int64_t> encoder_out_shape = 58 std::vector<int64_t> encoder_out_shape =
48 encoder_out.GetTensorTypeAndShapeInfo().GetShape(); 59 encoder_out.GetTensorTypeAndShapeInfo().GetShape();
49 60
50 - if (encoder_out_shape[0] > 1) {  
51 - fprintf(stderr, "Only batch_size=1 is implemented. Given: %d\n",  
52 - static_cast<int32_t>(encoder_out_shape[0])); 61 + if (encoder_out_shape[0] != result->size()) {
  62 + fprintf(stderr,
  63 + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
  64 + static_cast<int32_t>(encoder_out_shape[0]),
  65 + static_cast<int32_t>(result->size()));
  66 + exit(-1);
  67 + }
  68 +
  69 + if (result->size() != 1) {
  70 + fprintf(stderr, "only batch size == 1 is implemented. Given: %d",
  71 + static_cast<int32_t>(result->size()));
  72 + exit(-1);
53 } 73 }
54 74
  75 + auto &hyp = (*result)[0].tokens;
  76 +
55 int32_t num_frames = encoder_out_shape[1]; 77 int32_t num_frames = encoder_out_shape[1];
56 - int32_t vocab_size = model->VocabSize(); 78 + int32_t vocab_size = model_->VocabSize();
57 79
58 - Ort::Value decoder_input = model->BuildDecoderInput(*hyp);  
59 - Ort::Value decoder_out = model->RunDecoder(std::move(decoder_input)); 80 + Ort::Value decoder_input = model_->BuildDecoderInput(hyp);
  81 + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
60 82
61 for (int32_t t = 0; t != num_frames; ++t) { 83 for (int32_t t = 0; t != num_frames; ++t) {
62 Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); 84 Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
63 Ort::Value logit = 85 Ort::Value logit =
64 - model->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); 86 + model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
65 const float *p_logit = logit.GetTensorData<float>(); 87 const float *p_logit = logit.GetTensorData<float>();
66 88
67 auto y = static_cast<int32_t>(std::distance( 89 auto y = static_cast<int32_t>(std::distance(
@@ -69,9 +91,9 @@ void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, @@ -69,9 +91,9 @@ void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,
69 std::max_element(static_cast<const float *>(p_logit), 91 std::max_element(static_cast<const float *>(p_logit),
70 static_cast<const float *>(p_logit) + vocab_size))); 92 static_cast<const float *>(p_logit) + vocab_size)));
71 if (y != 0) { 93 if (y != 0) {
72 - hyp->push_back(y);  
73 - decoder_input = model->BuildDecoderInput(*hyp);  
74 - decoder_out = model->RunDecoder(std::move(decoder_input)); 94 + hyp.push_back(y);
  95 + decoder_input = model_->BuildDecoderInput(hyp);
  96 + decoder_out = model_->RunDecoder(std::move(decoder_input));
75 } 97 }
76 } 98 }
77 } 99 }
  1 +// sherpa/csrc/online-transducer-greedy-search-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  11 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
  16 + public:
  17 + explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model)
  18 + : model_(model) {}
  19 +
  20 + OnlineTransducerDecoderResult GetEmptyResult() override;
  21 +
  22 + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) override;
  23 +
  24 + void Decode(Ort::Value encoder_out,
  25 + std::vector<OnlineTransducerDecoderResult> *result) override;
  26 +
  27 + private:
  28 + OnlineTransducerModel *model_; // Not owned
  29 +};
  30 +
  31 +} // namespace sherpa_onnx
  32 +
  33 +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
@@ -46,4 +46,16 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { @@ -46,4 +46,16 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
46 } 46 }
47 } 47 }
48 48
  49 +Ort::Value Clone(Ort::Value *v) {
  50 + auto type_and_shape = v->GetTensorTypeAndShapeInfo();
  51 + std::vector<int64_t> shape = type_and_shape.GetShape();
  52 +
  53 + auto memory_info =
  54 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  55 +
  56 + return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),
  57 + type_and_shape.GetElementCount(),
  58 + shape.data(), shape.size());
  59 +}
  60 +
49 } // namespace sherpa_onnx 61 } // namespace sherpa_onnx
@@ -55,6 +55,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, @@ -55,6 +55,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
55 void PrintModelMetadata(std::ostream &os, 55 void PrintModelMetadata(std::ostream &os,
56 const Ort::ModelMetadata &meta_data); // NOLINT 56 const Ort::ModelMetadata &meta_data); // NOLINT
57 57
  58 +// Return a shallow copy of v
  59 +Ort::Value Clone(Ort::Value *v);
  60 +
58 } // namespace sherpa_onnx 61 } // namespace sherpa_onnx
59 62
60 #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ 63 #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
@@ -9,8 +9,8 @@ @@ -9,8 +9,8 @@
9 #include <vector> 9 #include <vector>
10 10
11 #include "kaldi-native-fbank/csrc/online-feature.h" 11 #include "kaldi-native-fbank/csrc/online-feature.h"
12 -#include "sherpa-onnx/csrc/decode.h"  
13 #include "sherpa-onnx/csrc/features.h" 12 #include "sherpa-onnx/csrc/features.h"
  13 +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
14 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 14 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
15 #include "sherpa-onnx/csrc/online-transducer-model.h" 15 #include "sherpa-onnx/csrc/online-transducer-model.h"
16 #include "sherpa-onnx/csrc/symbol-table.h" 16 #include "sherpa-onnx/csrc/symbol-table.h"
@@ -64,8 +64,6 @@ for a list of pre-trained models to download. @@ -64,8 +64,6 @@ for a list of pre-trained models to download.
64 64
65 std::vector<Ort::Value> states = model->GetEncoderInitStates(); 65 std::vector<Ort::Value> states = model->GetEncoderInitStates();
66 66
67 - std::vector<int64_t> hyp(model->ContextSize(), 0);  
68 -  
69 int32_t expected_sampling_rate = 16000; 67 int32_t expected_sampling_rate = 16000;
70 68
71 bool is_ok = false; 69 bool is_ok = false;
@@ -100,6 +98,10 @@ for a list of pre-trained models to download. @@ -100,6 +98,10 @@ for a list of pre-trained models to download.
100 98
101 std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; 99 std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
102 100
  101 + sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get());
  102 + std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = {
  103 + decoder.GetEmptyResult()};
  104 +
103 for (int32_t start = 0; start + chunk_size < num_frames; 105 for (int32_t start = 0; start + chunk_size < num_frames;
104 start += chunk_shift) { 106 start += chunk_shift) {
105 std::vector<float> features = feat_extractor.GetFrames(start, chunk_size); 107 std::vector<float> features = feat_extractor.GetFrames(start, chunk_size);
@@ -109,8 +111,10 @@ for a list of pre-trained models to download. @@ -109,8 +111,10 @@ for a list of pre-trained models to download.
109 x_shape.data(), x_shape.size()); 111 x_shape.data(), x_shape.size());
110 auto pair = model->RunEncoder(std::move(x), states); 112 auto pair = model->RunEncoder(std::move(x), states);
111 states = std::move(pair.second); 113 states = std::move(pair.second);
112 - sherpa_onnx::GreedySearch(model.get(), std::move(pair.first), &hyp); 114 + decoder.Decode(std::move(pair.first), &result);
113 } 115 }
  116 + decoder.StripLeadingBlanks(&result[0]);
  117 + const auto &hyp = result[0].tokens;
114 std::string text; 118 std::string text;
115 for (size_t i = model->ContextSize(); i != hyp.size(); ++i) { 119 for (size_t i = model->ContextSize(); i != hyp.size(); ++i) {
116 text += sym[hyp[i]]; 120 text += sym[hyp[i]];