Fangjun Kuang
Committed by GitHub

Add timestamps for offline paraformer (#310)

@@ -123,3 +123,30 @@ time $EXE \ @@ -123,3 +123,30 @@ time $EXE \
123 $repo/test_wavs/8k.wav 123 $repo/test_wavs/8k.wav
124 124
125 rm -rf $repo 125 rm -rf $repo
  126 +
  127 +log "------------------------------------------------------------"
  128 +log "Run Paraformer (Chinese) with timestamps"
  129 +log "------------------------------------------------------------"
  130 +
  131 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-09-14
  132 +log "Start testing ${repo_url}"
  133 +repo=$(basename $repo_url)
  134 +log "Download pretrained model and test-data from $repo_url"
  135 +
  136 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  137 +pushd $repo
  138 +git lfs pull --include "*.onnx"
  139 +ls -lh *.onnx
  140 +popd
  141 +
  142 +time $EXE \
  143 + --tokens=$repo/tokens.txt \
  144 + --paraformer=$repo/model.int8.onnx \
  145 + --num-threads=2 \
  146 + --decoding-method=greedy_search \
  147 + $repo/test_wavs/0.wav \
  148 + $repo/test_wavs/1.wav \
  149 + $repo/test_wavs/2.wav \
  150 + $repo/test_wavs/8k.wav
  151 +
  152 +rm -rf $repo
@@ -353,11 +353,22 @@ SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( @@ -353,11 +353,22 @@ SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
353 std::copy(text.begin(), text.end(), const_cast<char *>(r->text)); 353 std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
354 const_cast<char *>(r->text)[text.size()] = 0; 354 const_cast<char *>(r->text)[text.size()] = 0;
355 355
  356 + if (!result.timestamps.empty()) {
  357 + r->timestamps = new float[result.timestamps.size()];
  358 + std::copy(result.timestamps.begin(), result.timestamps.end(),
  359 + r->timestamps);
  360 + r->count = result.timestamps.size();
  361 + } else {
  362 + r->timestamps = nullptr;
  363 + r->count = 0;
  364 + }
  365 +
356 return r; 366 return r;
357 } 367 }
358 368
359 void DestroyOfflineRecognizerResult( 369 void DestroyOfflineRecognizerResult(
360 const SherpaOnnxOfflineRecognizerResult *r) { 370 const SherpaOnnxOfflineRecognizerResult *r) {
361 delete[] r->text; 371 delete[] r->text;
  372 + delete[] r->timestamps;
362 delete r; 373 delete r;
363 } 374 }
@@ -408,6 +408,14 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( @@ -408,6 +408,14 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
408 408
409 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { 409 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
410 const char *text; 410 const char *text;
  411 +
  412 + // Pointer to continuous memory which holds timestamps
  413 + //
  414 + // It is NULL if the model does not support timestamps
  415 + float *timestamps;
  416 +
  417 + // number of entries in timestamps
  418 + int32_t count;
411 // TODO(fangjun): Add more fields 419 // TODO(fangjun): Add more fields
412 } SherpaOnnxOfflineRecognizerResult; 420 } SherpaOnnxOfflineRecognizerResult;
413 421
@@ -14,6 +14,11 @@ namespace sherpa_onnx { @@ -14,6 +14,11 @@ namespace sherpa_onnx {
14 struct OfflineParaformerDecoderResult { 14 struct OfflineParaformerDecoderResult {
15 /// The decoded token IDs 15 /// The decoded token IDs
16 std::vector<int64_t> tokens; 16 std::vector<int64_t> tokens;
  17 +
  18 + // it contains the start time of each token in seconds
  19 + //
  20 + // len(timestamps) == len(tokens)
  21 + std::vector<float> timestamps;
17 }; 22 };
18 23
19 class OfflineParaformerDecoder { 24 class OfflineParaformerDecoder {
@@ -28,7 +33,8 @@ class OfflineParaformerDecoder { @@ -28,7 +33,8 @@ class OfflineParaformerDecoder {
28 * @return Return a vector of size `N` containing the decoded results. 33 * @return Return a vector of size `N` containing the decoded results.
29 */ 34 */
30 virtual std::vector<OfflineParaformerDecoderResult> Decode( 35 virtual std::vector<OfflineParaformerDecoderResult> Decode(
31 - Ort::Value log_probs, Ort::Value token_num) = 0; 36 + Ort::Value log_probs, Ort::Value token_num,
  37 + Ort::Value us_cif_peak = Ort::Value(nullptr)) = 0;
32 }; 38 };
33 39
34 } // namespace sherpa_onnx 40 } // namespace sherpa_onnx
@@ -5,13 +5,18 @@ @@ -5,13 +5,18 @@
5 #include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" 5 #include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
6 6
7 #include <algorithm> 7 #include <algorithm>
  8 +#include <utility>
8 #include <vector> 9 #include <vector>
9 10
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +
10 namespace sherpa_onnx { 13 namespace sherpa_onnx {
11 14
12 std::vector<OfflineParaformerDecoderResult> 15 std::vector<OfflineParaformerDecoderResult>
13 -OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,  
14 - Ort::Value /*token_num*/) { 16 +OfflineParaformerGreedySearchDecoder::Decode(
  17 + Ort::Value log_probs, Ort::Value /*token_num*/,
  18 + Ort::Value us_cif_peak /*=Ort::Value(nullptr)*/
  19 +) {
15 std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); 20 std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
16 int32_t batch_size = shape[0]; 21 int32_t batch_size = shape[0];
17 int32_t num_tokens = shape[1]; 22 int32_t num_tokens = shape[1];
@@ -25,12 +30,43 @@ OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs, @@ -25,12 +30,43 @@ OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
25 for (int32_t k = 0; k != num_tokens; ++k) { 30 for (int32_t k = 0; k != num_tokens; ++k) {
26 auto max_idx = static_cast<int64_t>( 31 auto max_idx = static_cast<int64_t>(
27 std::distance(p, std::max_element(p, p + vocab_size))); 32 std::distance(p, std::max_element(p, p + vocab_size)));
28 - if (max_idx == eos_id_) break; 33 + if (max_idx == eos_id_) {
  34 + break;
  35 + }
29 36
30 results[i].tokens.push_back(max_idx); 37 results[i].tokens.push_back(max_idx);
31 38
32 p += vocab_size; 39 p += vocab_size;
33 } 40 }
  41 +
  42 + if (us_cif_peak) {
  43 + int32_t dim = us_cif_peak.GetTensorTypeAndShapeInfo().GetShape()[1];
  44 +
  45 + const auto *peak = us_cif_peak.GetTensorData<float>() + i * dim;
  46 + std::vector<float> timestamps;
  47 + timestamps.reserve(results[i].tokens.size());
  48 +
  49 + // 10.0: frameshift is 10 milliseconds
  50 + // 6: LfrWindowSize
  51 + // 3: us_cif_peak is upsampled by a factor of 3
  52 + // 1000: milliseconds to seconds
  53 + float scale = 10.0 * 6 / 3 / 1000;
  54 +
  55 + for (int32_t k = 0; k != dim; ++k) {
  56 + if (peak[k] > 1 - 1e-4) {
  57 + timestamps.push_back(k * scale);
  58 + }
  59 + }
  60 + timestamps.pop_back();
  61 +
  62 + if (timestamps.size() == results[i].tokens.size()) {
  63 + results[i].timestamps = std::move(timestamps);
  64 + } else {
  65 + SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,
  66 + static_cast<int32_t>(results[i].tokens.size()),
  67 + static_cast<int32_t>(timestamps.size()));
  68 + }
  69 + }
34 } 70 }
35 71
36 return results; 72 return results;
@@ -17,7 +17,8 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder { @@ -17,7 +17,8 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder {
17 : eos_id_(eos_id) {} 17 : eos_id_(eos_id) {}
18 18
19 std::vector<OfflineParaformerDecoderResult> Decode( 19 std::vector<OfflineParaformerDecoderResult> Decode(
20 - Ort::Value log_probs, Ort::Value /*token_num*/) override; 20 + Ort::Value log_probs, Ort::Value token_num,
  21 + Ort::Value us_cif_peak = Ort::Value(nullptr)) override;
21 22
22 private: 23 private:
23 int32_t eos_id_; 24 int32_t eos_id_;
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 6
7 #include <algorithm> 7 #include <algorithm>
8 #include <string> 8 #include <string>
  9 +#include <utility>
9 10
10 #include "sherpa-onnx/csrc/macros.h" 11 #include "sherpa-onnx/csrc/macros.h"
11 #include "sherpa-onnx/csrc/onnx-utils.h" 12 #include "sherpa-onnx/csrc/onnx-utils.h"
@@ -36,16 +37,13 @@ class OfflineParaformerModel::Impl { @@ -36,16 +37,13 @@ class OfflineParaformerModel::Impl {
36 } 37 }
37 #endif 38 #endif
38 39
39 - std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features, 40 + std::vector<Ort::Value> Forward(Ort::Value features,
40 Ort::Value features_length) { 41 Ort::Value features_length) {
41 std::array<Ort::Value, 2> inputs = {std::move(features), 42 std::array<Ort::Value, 2> inputs = {std::move(features),
42 std::move(features_length)}; 43 std::move(features_length)};
43 44
44 - auto out =  
45 - sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), 45 + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
46 output_names_ptr_.data(), output_names_ptr_.size()); 46 output_names_ptr_.data(), output_names_ptr_.size());
47 -  
48 - return {std::move(out[0]), std::move(out[1])};  
49 } 47 }
50 48
51 int32_t VocabSize() const { return vocab_size_; } 49 int32_t VocabSize() const { return vocab_size_; }
@@ -119,7 +117,7 @@ OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr, @@ -119,7 +117,7 @@ OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr,
119 117
120 OfflineParaformerModel::~OfflineParaformerModel() = default; 118 OfflineParaformerModel::~OfflineParaformerModel() = default;
121 119
122 -std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward( 120 +std::vector<Ort::Value> OfflineParaformerModel::Forward(
123 Ort::Value features, Ort::Value features_length) { 121 Ort::Value features, Ort::Value features_length) {
124 return impl_->Forward(std::move(features), std::move(features_length)); 122 return impl_->Forward(std::move(features), std::move(features_length));
125 } 123 }
@@ -5,7 +5,6 @@ @@ -5,7 +5,6 @@
5 #define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ 5 #define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
6 6
7 #include <memory> 7 #include <memory>
8 -#include <utility>  
9 #include <vector> 8 #include <vector>
10 9
11 #if __ANDROID_API__ >= 9 10 #if __ANDROID_API__ >= 9
@@ -35,12 +34,16 @@ class OfflineParaformerModel { @@ -35,12 +34,16 @@ class OfflineParaformerModel {
35 * valid frames in `features` before padding. 34 * valid frames in `features` before padding.
36 * Its dtype is int32_t. 35 * Its dtype is int32_t.
37 * 36 *
38 - * @return Return a pair containing: 37 + * @return Return a vector containing:
39 * - log_probs: A 3-D tensor of shape (N, T', vocab_size) 38 * - log_probs: A 3-D tensor of shape (N, T', vocab_size)
40 * - token_num: A 1-D tensor of shape (N, T') containing number 39 * - token_num: A 1-D tensor of shape (N, T') containing number
41 * of valid tokens in each utterance. Its dtype is int64_t. 40 * of valid tokens in each utterance. Its dtype is int64_t.
  41 + * If it is a model supporting timestamps, then there are additional two
  42 + * outputs:
  43 + * - us_alphas
  44 + * - us_cif_peak
42 */ 45 */
43 - std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features, 46 + std::vector<Ort::Value> Forward(Ort::Value features,
44 Ort::Value features_length); 47 Ort::Value features_length);
45 48
46 /** Return the vocabulary size of the model 49 /** Return the vocabulary size of the model
@@ -31,6 +31,7 @@ static OfflineRecognitionResult Convert( @@ -31,6 +31,7 @@ static OfflineRecognitionResult Convert(
31 const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) { 31 const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) {
32 OfflineRecognitionResult r; 32 OfflineRecognitionResult r;
33 r.tokens.reserve(src.tokens.size()); 33 r.tokens.reserve(src.tokens.size());
  34 + r.timestamps = src.timestamps;
34 35
35 std::string text; 36 std::string text;
36 37
@@ -184,7 +185,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { @@ -184,7 +185,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
184 // i.e., -23.025850929940457f 185 // i.e., -23.025850929940457f
185 Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0); 186 Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
186 187
187 - std::pair<Ort::Value, Ort::Value> t{nullptr, nullptr}; 188 + std::vector<Ort::Value> t;
188 try { 189 try {
189 t = model_->Forward(std::move(x), std::move(x_length)); 190 t = model_->Forward(std::move(x), std::move(x_length));
190 } catch (const Ort::Exception &ex) { 191 } catch (const Ort::Exception &ex) {
@@ -193,7 +194,13 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { @@ -193,7 +194,13 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
193 return; 194 return;
194 } 195 }
195 196
196 - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); 197 + std::vector<OfflineParaformerDecoderResult> results;
  198 + if (t.size() == 2) {
  199 + results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
  200 + } else {
  201 + results =
  202 + decoder_->Decode(std::move(t[0]), std::move(t[1]), std::move(t[3]));
  203 + }
197 204
198 for (int32_t i = 0; i != n; ++i) { 205 for (int32_t i = 0; i != n; ++i) {
199 auto r = Convert(results[i], symbol_table_); 206 auto r = Convert(results[i], symbol_table_);
@@ -349,6 +349,23 @@ class SherpaOnnxOfflineRecongitionResult { @@ -349,6 +349,23 @@ class SherpaOnnxOfflineRecongitionResult {
349 return String(cString: result.pointee.text) 349 return String(cString: result.pointee.text)
350 } 350 }
351 351
  352 + var count: Int32 {
  353 + return result.pointee.count
  354 + }
  355 +
  356 + var timestamps: [Float] {
  357 + if let p = result.pointee.timestamps {
  358 + var timestamps: [Float] = []
  359 + for index in 0..<count {
  360 + timestamps.append(p[Int(index)])
  361 + }
  362 + return timestamps
  363 + } else {
  364 + let timestamps: [Float] = []
  365 + return timestamps
  366 + }
  367 + }
  368 +
352 init(result: UnsafePointer<SherpaOnnxOfflineRecognizerResult>!) { 369 init(result: UnsafePointer<SherpaOnnxOfflineRecognizerResult>!) {
353 self.result = result 370 self.result = result
354 } 371 }
@@ -13,6 +13,13 @@ extension AVAudioPCMBuffer { @@ -13,6 +13,13 @@ extension AVAudioPCMBuffer {
13 } 13 }
14 14
15 func run() { 15 func run() {
  16 +
  17 + var recognizer: SherpaOnnxOfflineRecognizer
  18 + var modelConfig: SherpaOnnxOfflineModelConfig
  19 + var modelType = "whisper"
  20 + // modelType = "paraformer"
  21 +
  22 + if modelType == "whisper" {
16 let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" 23 let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
17 let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx" 24 let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
18 let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt" 25 let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
@@ -22,12 +29,29 @@ func run() { @@ -22,12 +29,29 @@ func run() {
22 decoder: decoder 29 decoder: decoder
23 ) 30 )
24 31
25 - let modelConfig = sherpaOnnxOfflineModelConfig( 32 + modelConfig = sherpaOnnxOfflineModelConfig(
26 tokens: tokens, 33 tokens: tokens,
27 whisper: whisperConfig, 34 whisper: whisperConfig,
28 debug: 0, 35 debug: 0,
29 modelType: "whisper" 36 modelType: "whisper"
30 ) 37 )
  38 + } else if modelType == "paraformer" {
  39 + let model = "./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx"
  40 + let tokens = "./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt"
  41 + let paraformerConfig = sherpaOnnxOfflineParaformerModelConfig(
  42 + model: model
  43 + )
  44 +
  45 + modelConfig = sherpaOnnxOfflineModelConfig(
  46 + tokens: tokens,
  47 + paraformer: paraformerConfig,
  48 + debug: 0,
  49 + modelType: "paraformer"
  50 + )
  51 + } else {
  52 + print("Please specify a supported modelType \(modelType)")
  53 + return
  54 + }
31 55
32 let featConfig = sherpaOnnxFeatureConfig( 56 let featConfig = sherpaOnnxFeatureConfig(
33 sampleRate: 16000, 57 sampleRate: 16000,
@@ -38,7 +62,7 @@ func run() { @@ -38,7 +62,7 @@ func run() {
38 modelConfig: modelConfig 62 modelConfig: modelConfig
39 ) 63 )
40 64
41 - let recognizer = SherpaOnnxOfflineRecognizer(config: &config) 65 + recognizer = SherpaOnnxOfflineRecognizer(config: &config)
42 66
43 let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" 67 let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
44 let fileURL: NSURL = NSURL(fileURLWithPath: filePath) 68 let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
@@ -55,6 +79,10 @@ func run() { @@ -55,6 +79,10 @@ func run() {
55 let array: [Float]! = audioFileBuffer?.array() 79 let array: [Float]! = audioFileBuffer?.array()
56 let result = recognizer.decode(samples: array, sampleRate: Int(audioFormat.sampleRate)) 80 let result = recognizer.decode(samples: array, sampleRate: Int(audioFormat.sampleRate))
57 print("\nresult is:\n\(result.text)") 81 print("\nresult is:\n\(result.text)")
  82 + if result.timestamps.count != 0 {
  83 + print("\ntimestamps is:\n\(result.timestamps)")
  84 + }
  85 +
58 } 86 }
59 87
60 @main 88 @main