Brad Murray
Committed by GitHub

Add tdt duration to APIs (#2514)

... ... @@ -59,6 +59,9 @@ target_link_libraries(fire-red-asr-c-api sherpa-onnx-c-api)
add_executable(nemo-canary-c-api nemo-canary-c-api.c)
target_link_libraries(nemo-canary-c-api sherpa-onnx-c-api)
add_executable(nemo-parakeet-c-api nemo-parakeet-c-api.c)
target_link_libraries(nemo-parakeet-c-api sherpa-onnx-c-api)
add_executable(sense-voice-c-api sense-voice-c-api.c)
target_link_libraries(sense-voice-c-api sherpa-onnx-c-api)
... ...
// c-api-examples/nemo-parakeet-c-api.c
// Example using the C API and sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8 model
// Prints recognized text, per-token timestamps, and durations
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "sherpa-onnx/c-api/c-api.h"
int32_t main() {
const char *wav_filename =
"./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/test_wavs/en.wav";
const char *encoder_filename =
"sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/encoder.int8.onnx";
const char *decoder_filename =
"sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/decoder.int8.onnx";
const char *joiner_filename =
"sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/joiner.int8.onnx";
const char *tokens_filename =
"sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/tokens.txt";
const char *provider = "cpu";
if (!SherpaOnnxFileExists(wav_filename)) {
fprintf(stderr, "File not found: %s\n", wav_filename);
return -1;
}
const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
if (wave == NULL) {
fprintf(stderr, "Failed to read or parse %s (not a valid mono 16-bit WAVE file)\n", wav_filename);
return -1;
}
SherpaOnnxOfflineModelConfig offline_model_config;
memset(&offline_model_config, 0, sizeof(offline_model_config));
offline_model_config.debug = 0;
offline_model_config.num_threads = 1;
offline_model_config.provider = provider;
offline_model_config.tokens = tokens_filename;
offline_model_config.transducer.encoder = encoder_filename;
offline_model_config.transducer.decoder = decoder_filename;
offline_model_config.transducer.joiner = joiner_filename;
SherpaOnnxOfflineRecognizerConfig recognizer_config;
memset(&recognizer_config, 0, sizeof(recognizer_config));
recognizer_config.decoding_method = "greedy_search";
recognizer_config.model_config = offline_model_config;
const SherpaOnnxOfflineRecognizer *recognizer =
SherpaOnnxCreateOfflineRecognizer(&recognizer_config);
if (recognizer == NULL) {
fprintf(stderr, "Please check your config!\n");
SherpaOnnxFreeWave(wave);
return -1;
}
const SherpaOnnxOfflineStream *stream =
SherpaOnnxCreateOfflineStream(recognizer);
if (stream == NULL) {
fprintf(stderr, "Failed to create offline stream.\n");
SherpaOnnxDestroyOfflineRecognizer(recognizer);
SherpaOnnxFreeWave(wave);
return -1;
}
SherpaOnnxAcceptWaveformOffline(stream, wave->sample_rate, wave->samples,
wave->num_samples);
SherpaOnnxDecodeOfflineStream(recognizer, stream);
const SherpaOnnxOfflineRecognizerResult *result =
SherpaOnnxGetOfflineStreamResult(stream);
printf("Recognized text: %s\n", result->text);
if (result->tokens_arr && result->timestamps && result->durations) {
printf("Token\tTimestamp\tDuration\n");
for (int32_t i = 0; i < result->count; ++i) {
printf("%s\t%.2f\t%.2f\n", result->tokens_arr[i], result->timestamps[i], result->durations[i]);
}
} else {
printf("Timestamps or durations not available.\n");
}
SherpaOnnxDestroyOfflineRecognizerResult(result);
SherpaOnnxDestroyOfflineStream(stream);
SherpaOnnxDestroyOfflineRecognizer(recognizer);
SherpaOnnxFreeWave(wave);
return 0;
}
... ...
# Example using the sherpa-onnx Python API and sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8 model
# Prints recognized text, per-token timestamps, and durations
import os
import sys
import sherpa_onnx
import soundfile as sf
wav_filename = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/test_wavs/en.wav"
encoder = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/encoder.int8.onnx"
decoder = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/decoder.int8.onnx"
joiner = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/joiner.int8.onnx"
tokens = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/tokens.txt"
if not os.path.exists(wav_filename):
print(f"File not found: {wav_filename}")
sys.exit(1)
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
encoder,
decoder,
joiner,
tokens,
num_threads=1,
provider="cpu",
debug=False,
decoding_method="greedy_search",
model_type="nemo_transducer"
)
audio, sample_rate = sf.read(wav_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # use first channel if multi-channel
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
result = stream.result
print(f"Recognized text: {result.text}")
if hasattr(result, "tokens") and hasattr(result, "timestamps") and hasattr(result, "durations"):
print("Token\tTimestamp\tDuration")
for token, ts, dur in zip(result.tokens, result.timestamps, result.durations):
print(f"{token}\t{ts:.2f}\t{dur:.2f}")
else:
print("Timestamps or durations not available.")
... ...
... ... @@ -523,6 +523,7 @@ type OfflineRecognizerResult struct {
Text string
Tokens []string
Timestamps []float32
Durations []float32
Lang string
Emotion string
Event string
... ... @@ -872,14 +873,20 @@ func (s *OfflineStream) GetResult() *OfflineRecognizerResult {
for i := 0; i < n; i++ {
result.Tokens[i] = C.GoString(tokens[i])
}
if p.timestamps == nil {
return result
}
if p.timestamps != nil {
result.Timestamps = make([]float32, n)
timestamps := unsafe.Slice(p.timestamps, n)
for i := 0; i < n; i++ {
result.Timestamps[i] = float32(timestamps[i])
}
}
if p.durations != nil {
result.Durations = make([]float32, n)
durations := unsafe.Slice(p.durations, n)
for i := 0; i < n; i++ {
result.Durations[i] = float32(durations[i])
}
}
return result
}
... ...
... ... @@ -689,6 +689,14 @@ const SherpaOnnxOfflineRecognizerResult *SherpaOnnxGetOfflineStreamResult(
r->timestamps = nullptr;
}
if (!result.durations.empty() && result.durations.size() == r->count) {
r->durations = new float[r->count];
std::copy(result.durations.begin(), result.durations.end(),
r->durations);
} else {
r->durations = nullptr;
}
r->tokens = tokens;
} else {
r->count = 0;
... ... @@ -705,6 +713,7 @@ void SherpaOnnxDestroyOfflineRecognizerResult(
if (r) {
delete[] r->text;
delete[] r->timestamps;
delete[] r->durations;
delete[] r->tokens;
delete[] r->tokens_arr;
delete[] r->json;
... ...
... ... @@ -614,6 +614,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
// It is NULL if the model does not support timestamps
float *timestamps;
// Pointer to continuous memory which holds durations (in seconds) for each token
// It is NULL if the model does not support durations
float *durations;
// number of entries in timestamps
int32_t count;
... ... @@ -631,6 +635,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
* "text": "The recognition result",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "durations": [x, x, x],
* "segment": x,
* "start_time": x,
* "is_final": true|false
... ...
... ... @@ -36,6 +36,7 @@ static OfflineRecognitionResult Convert(
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.timestamps.size());
r.durations.reserve(src.durations.size());
std::string text;
for (auto i : src.tokens) {
... ... @@ -66,6 +67,11 @@ static OfflineRecognitionResult Convert(
r.timestamps.push_back(time);
}
// Copy durations (if present)
for (auto d : src.durations) {
r.durations.push_back(d * frame_shift_s);
}
return r;
}
... ...
... ... @@ -397,6 +397,18 @@ std::string OfflineRecognitionResult::AsJsonString() const {
os << "], ";
os << "\""
<< "durations"
<< "\""
<< ": ";
os << "[";
sep = "";
for (auto d : durations) {
os << sep << std::fixed << std::setprecision(2) << d;
sep = ", ";
}
os << "], ";
os << "\""
<< "tokens"
<< "\""
<< ":";
... ...
... ... @@ -38,6 +38,9 @@ struct OfflineRecognitionResult {
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;
/// durations[i] contains the duration (in seconds) for tokens[i] (TDT models only)
std::vector<float> durations;
std::vector<int32_t> words;
std::string AsJsonString() const;
... ...
... ... @@ -19,6 +19,11 @@ struct OfflineTransducerDecoderResult {
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
/// Note: The index is after subsampling
std::vector<int32_t> timestamps;
/// durations[i] contains the duration for tokens[i] in output frames
/// (post-subsampling). It is converted to seconds by higher layers
/// (e.g., Convert() in offline-recognizer-transducer-impl.h).
std::vector<float> durations;
};
class OfflineTransducerDecoder {
... ...
... ... @@ -130,23 +130,31 @@ static OfflineTransducerDecoderResult DecodeOneTDT(
p_logit[blank_id] -= blank_penalty;
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
int32_t output_size = shape.back();
int32_t num_durations = output_size - vocab_size;
// Split logits into token and duration logits
const float* token_logits = p_logit;
const float* duration_logits = p_logit + vocab_size;
skip = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit) + vocab_size,
std::max_element(static_cast<const float *>(p_logit) + vocab_size,
static_cast<const float *>(p_logit) + shape.back())));
auto y = static_cast<int32_t>(std::distance(
token_logits,
std::max_element(token_logits, token_logits + vocab_size)));
if (skip == 0) {
skip = 1;
int32_t duration = 1;
if (num_durations > 0) {
duration = static_cast<int32_t>(std::distance(
duration_logits,
std::max_element(duration_logits, duration_logits + num_durations)));
skip = duration;
if (skip == 0) skip = 1;
}
if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);
ans.durations.push_back(duration);
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
... ... @@ -155,7 +163,7 @@ static OfflineTransducerDecoderResult DecodeOneTDT(
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
}
} // for (int32_t t = 0; t < num_rows; ++t) {
} // for (int32_t t = 0; t < num_rows; t += skip)
return ans;
}
... ...
... ... @@ -42,8 +42,10 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
[](const PyClass &self) { return self.tokens; })
.def_property_readonly("words",
[](const PyClass &self) { return self.words; })
.def_property_readonly(
"timestamps", [](const PyClass &self) { return self.timestamps; });
.def_property_readonly("timestamps",
[](const PyClass &self) { return self.timestamps; })
.def_property_readonly("durations",
[](const PyClass &self) { return self.durations; });
}
void PybindOfflineStream(py::module *m) {
... ...