Fangjun Kuang
Committed by GitHub

Add timestamps for offline paraformer (#310)

... ... @@ -123,3 +123,30 @@ time $EXE \
$repo/test_wavs/8k.wav
rm -rf $repo
log "------------------------------------------------------------"
log "Run Paraformer (Chinese) with timestamps"
log "------------------------------------------------------------"
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-09-14
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd
time $EXE \
--tokens=$repo/tokens.txt \
--paraformer=$repo/model.int8.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/8k.wav
rm -rf $repo
... ...
... ... @@ -353,11 +353,22 @@ SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
const_cast<char *>(r->text)[text.size()] = 0;
if (!result.timestamps.empty()) {
r->timestamps = new float[result.timestamps.size()];
std::copy(result.timestamps.begin(), result.timestamps.end(),
r->timestamps);
r->count = result.timestamps.size();
} else {
r->timestamps = nullptr;
r->count = 0;
}
return r;
}
void DestroyOfflineRecognizerResult(
const SherpaOnnxOfflineRecognizerResult *r) {
delete[] r->text;
delete[] r->timestamps;
delete r;
}
... ...
... ... @@ -408,6 +408,14 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
const char *text;
// Pointer to continuous memory which holds timestamps
//
// It is NULL if the model does not support timestamps
float *timestamps;
// number of entries in timestamps
int32_t count;
// TODO(fangjun): Add more fields
} SherpaOnnxOfflineRecognizerResult;
... ...
... ... @@ -14,6 +14,11 @@ namespace sherpa_onnx {
struct OfflineParaformerDecoderResult {
/// The decoded token IDs
std::vector<int64_t> tokens;
// it contains the start time of each token in seconds
//
// len(timestamps) == len(tokens)
std::vector<float> timestamps;
};
class OfflineParaformerDecoder {
... ... @@ -28,7 +33,8 @@ class OfflineParaformerDecoder {
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineParaformerDecoderResult> Decode(
Ort::Value log_probs, Ort::Value token_num) = 0;
Ort::Value log_probs, Ort::Value token_num,
Ort::Value us_cif_peak = Ort::Value(nullptr)) = 0;
};
} // namespace sherpa_onnx
... ...
... ... @@ -5,13 +5,18 @@
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
std::vector<OfflineParaformerDecoderResult>
OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
Ort::Value /*token_num*/) {
OfflineParaformerGreedySearchDecoder::Decode(
Ort::Value log_probs, Ort::Value /*token_num*/,
Ort::Value us_cif_peak /*=Ort::Value(nullptr)*/
) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = shape[0];
int32_t num_tokens = shape[1];
... ... @@ -25,12 +30,43 @@ OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
for (int32_t k = 0; k != num_tokens; ++k) {
auto max_idx = static_cast<int64_t>(
std::distance(p, std::max_element(p, p + vocab_size)));
if (max_idx == eos_id_) break;
if (max_idx == eos_id_) {
break;
}
results[i].tokens.push_back(max_idx);
p += vocab_size;
}
if (us_cif_peak) {
int32_t dim = us_cif_peak.GetTensorTypeAndShapeInfo().GetShape()[1];
const auto *peak = us_cif_peak.GetTensorData<float>() + i * dim;
std::vector<float> timestamps;
timestamps.reserve(results[i].tokens.size());
// 10.0: frameshift is 10 milliseconds
// 6: LfrWindowSize
// 3: us_cif_peak is upsampled by a factor of 3
// 1000: milliseconds to seconds
float scale = 10.0 * 6 / 3 / 1000;
for (int32_t k = 0; k != dim; ++k) {
if (peak[k] > 1 - 1e-4) {
timestamps.push_back(k * scale);
}
}
timestamps.pop_back();
if (timestamps.size() == results[i].tokens.size()) {
results[i].timestamps = std::move(timestamps);
} else {
SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,
static_cast<int32_t>(results[i].tokens.size()),
static_cast<int32_t>(timestamps.size()));
}
}
}
return results;
... ...
... ... @@ -17,7 +17,8 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder {
: eos_id_(eos_id) {}
std::vector<OfflineParaformerDecoderResult> Decode(
Ort::Value log_probs, Ort::Value /*token_num*/) override;
Ort::Value log_probs, Ort::Value token_num,
Ort::Value us_cif_peak = Ort::Value(nullptr)) override;
private:
int32_t eos_id_;
... ...
... ... @@ -6,6 +6,7 @@
#include <algorithm>
#include <string>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
... ... @@ -36,16 +37,13 @@ class OfflineParaformerModel::Impl {
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> inputs = {std::move(features),
std::move(features_length)};
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return {std::move(out[0]), std::move(out[1])};
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
}
int32_t VocabSize() const { return vocab_size_; }
... ... @@ -119,7 +117,7 @@ OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr,
OfflineParaformerModel::~OfflineParaformerModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward(
std::vector<Ort::Value> OfflineParaformerModel::Forward(
Ort::Value features, Ort::Value features_length) {
return impl_->Forward(std::move(features), std::move(features_length));
}
... ...
... ... @@ -5,7 +5,6 @@
#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
... ... @@ -35,13 +34,17 @@ class OfflineParaformerModel {
* valid frames in `features` before padding.
* Its dtype is int32_t.
*
* @return Return a pair containing:
* @return Return a vector containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size)
* - token_num: A 1-D tensor of shape (N, T') containing number
* of valid tokens in each utterance. Its dtype is int64_t.
* If it is a model supporting timestamps, then there are additional two
* outputs:
* - us_alphas
* - us_cif_peak
*/
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length);
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length);
/** Return the vocabulary size of the model
*/
... ...
... ... @@ -31,6 +31,7 @@ static OfflineRecognitionResult Convert(
const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps = src.timestamps;
std::string text;
... ... @@ -184,7 +185,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
// i.e., -23.025850929940457f
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
std::pair<Ort::Value, Ort::Value> t{nullptr, nullptr};
std::vector<Ort::Value> t;
try {
t = model_->Forward(std::move(x), std::move(x_length));
} catch (const Ort::Exception &ex) {
... ... @@ -193,7 +194,13 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
return;
}
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
std::vector<OfflineParaformerDecoderResult> results;
if (t.size() == 2) {
results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
} else {
results =
decoder_->Decode(std::move(t[0]), std::move(t[1]), std::move(t[3]));
}
for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_);
... ...
... ... @@ -349,6 +349,23 @@ class SherpaOnnxOfflineRecongitionResult {
return String(cString: result.pointee.text)
}
var count: Int32 {
return result.pointee.count
}
var timestamps: [Float] {
if let p = result.pointee.timestamps {
var timestamps: [Float] = []
for index in 0..<count {
timestamps.append(p[Int(index)])
}
return timestamps
} else {
let timestamps: [Float] = []
return timestamps
}
}
init(result: UnsafePointer<SherpaOnnxOfflineRecognizerResult>!) {
self.result = result
}
... ...
... ... @@ -13,21 +13,45 @@ extension AVAudioPCMBuffer {
}
func run() {
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
let whisperConfig = sherpaOnnxOfflineWhisperModelConfig(
encoder: encoder,
decoder: decoder
)
var recognizer: SherpaOnnxOfflineRecognizer
var modelConfig: SherpaOnnxOfflineModelConfig
var modelType = "whisper"
// modelType = "paraformer"
let modelConfig = sherpaOnnxOfflineModelConfig(
tokens: tokens,
whisper: whisperConfig,
debug: 0,
modelType: "whisper"
)
if modelType == "whisper" {
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
let whisperConfig = sherpaOnnxOfflineWhisperModelConfig(
encoder: encoder,
decoder: decoder
)
modelConfig = sherpaOnnxOfflineModelConfig(
tokens: tokens,
whisper: whisperConfig,
debug: 0,
modelType: "whisper"
)
} else if modelType == "paraformer" {
let model = "./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx"
let tokens = "./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt"
let paraformerConfig = sherpaOnnxOfflineParaformerModelConfig(
model: model
)
modelConfig = sherpaOnnxOfflineModelConfig(
tokens: tokens,
paraformer: paraformerConfig,
debug: 0,
modelType: "paraformer"
)
} else {
print("Please specify a supported modelType \(modelType)")
return
}
let featConfig = sherpaOnnxFeatureConfig(
sampleRate: 16000,
... ... @@ -38,7 +62,7 @@ func run() {
modelConfig: modelConfig
)
let recognizer = SherpaOnnxOfflineRecognizer(config: &config)
recognizer = SherpaOnnxOfflineRecognizer(config: &config)
let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
... ... @@ -55,6 +79,10 @@ func run() {
let array: [Float]! = audioFileBuffer?.array()
let result = recognizer.decode(samples: array, sampleRate: Int(audioFormat.sampleRate))
print("\nresult is:\n\(result.text)")
if result.timestamps.count != 0 {
print("\ntimestamps is:\n\(result.timestamps)")
}
}
@main
... ...