Brad Murray
Committed by GitHub

Add tdt duration to APIs (#2514)

@@ -59,6 +59,9 @@ target_link_libraries(fire-red-asr-c-api sherpa-onnx-c-api) @@ -59,6 +59,9 @@ target_link_libraries(fire-red-asr-c-api sherpa-onnx-c-api)
59 add_executable(nemo-canary-c-api nemo-canary-c-api.c) 59 add_executable(nemo-canary-c-api nemo-canary-c-api.c)
60 target_link_libraries(nemo-canary-c-api sherpa-onnx-c-api) 60 target_link_libraries(nemo-canary-c-api sherpa-onnx-c-api)
61 61
  62 +add_executable(nemo-parakeet-c-api nemo-parakeet-c-api.c)
  63 +target_link_libraries(nemo-parakeet-c-api sherpa-onnx-c-api)
  64 +
62 add_executable(sense-voice-c-api sense-voice-c-api.c) 65 add_executable(sense-voice-c-api sense-voice-c-api.c)
63 target_link_libraries(sense-voice-c-api sherpa-onnx-c-api) 66 target_link_libraries(sense-voice-c-api sherpa-onnx-c-api)
64 67
  1 +// c-api-examples/nemo-parakeet-c-api.c
  2 +// Example using the C API and sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8 model
  3 +// Prints recognized text, per-token timestamps, and durations
  4 +
  5 +#include <stdio.h>
  6 +#include <stdlib.h>
  7 +#include <string.h>
  8 +
  9 +#include "sherpa-onnx/c-api/c-api.h"
  10 +
  11 +int32_t main() {
  12 + const char *wav_filename =
  13 + "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/test_wavs/en.wav";
  14 + const char *encoder_filename =
  15 + "sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/encoder.int8.onnx";
  16 + const char *decoder_filename =
  17 + "sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/decoder.int8.onnx";
  18 + const char *joiner_filename =
  19 + "sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/joiner.int8.onnx";
  20 + const char *tokens_filename =
  21 + "sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/tokens.txt";
  22 + const char *provider = "cpu";
  23 +
  24 + if (!SherpaOnnxFileExists(wav_filename)) {
  25 + fprintf(stderr, "File not found: %s\n", wav_filename);
  26 + return -1;
  27 + }
  28 + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
  29 + if (wave == NULL) {
  30 + fprintf(stderr, "Failed to read or parse %s (not a valid mono 16-bit WAVE file)\n", wav_filename);
  31 + return -1;
  32 + }
  33 +
  34 + SherpaOnnxOfflineModelConfig offline_model_config;
  35 + memset(&offline_model_config, 0, sizeof(offline_model_config));
  36 + offline_model_config.debug = 0;
  37 + offline_model_config.num_threads = 1;
  38 + offline_model_config.provider = provider;
  39 + offline_model_config.tokens = tokens_filename;
  40 + offline_model_config.transducer.encoder = encoder_filename;
  41 + offline_model_config.transducer.decoder = decoder_filename;
  42 + offline_model_config.transducer.joiner = joiner_filename;
  43 +
  44 + SherpaOnnxOfflineRecognizerConfig recognizer_config;
  45 + memset(&recognizer_config, 0, sizeof(recognizer_config));
  46 + recognizer_config.decoding_method = "greedy_search";
  47 + recognizer_config.model_config = offline_model_config;
  48 +
  49 + const SherpaOnnxOfflineRecognizer *recognizer =
  50 + SherpaOnnxCreateOfflineRecognizer(&recognizer_config);
  51 + if (recognizer == NULL) {
  52 + fprintf(stderr, "Please check your config!\n");
  53 + SherpaOnnxFreeWave(wave);
  54 + return -1;
  55 + }
  56 +
  57 + const SherpaOnnxOfflineStream *stream =
  58 + SherpaOnnxCreateOfflineStream(recognizer);
  59 + if (stream == NULL) {
  60 + fprintf(stderr, "Failed to create offline stream.\n");
  61 + SherpaOnnxDestroyOfflineRecognizer(recognizer);
  62 + SherpaOnnxFreeWave(wave);
  63 + return -1;
  64 + }
  65 +
  66 + SherpaOnnxAcceptWaveformOffline(stream, wave->sample_rate, wave->samples,
  67 + wave->num_samples);
  68 + SherpaOnnxDecodeOfflineStream(recognizer, stream);
  69 + const SherpaOnnxOfflineRecognizerResult *result =
  70 + SherpaOnnxGetOfflineStreamResult(stream);
  71 +
  72 + printf("Recognized text: %s\n", result->text);
  73 +
  74 + if (result->tokens_arr && result->timestamps && result->durations) {
  75 + printf("Token\tTimestamp\tDuration\n");
  76 + for (int32_t i = 0; i < result->count; ++i) {
  77 + printf("%s\t%.2f\t%.2f\n", result->tokens_arr[i], result->timestamps[i], result->durations[i]);
  78 + }
  79 + } else {
  80 + printf("Timestamps or durations not available.\n");
  81 + }
  82 +
  83 + SherpaOnnxDestroyOfflineRecognizerResult(result);
  84 + SherpaOnnxDestroyOfflineStream(stream);
  85 + SherpaOnnxDestroyOfflineRecognizer(recognizer);
  86 + SherpaOnnxFreeWave(wave);
  87 +
  88 + return 0;
  89 +}
  1 +# Example using the sherpa-onnx Python API and sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8 model
  2 +# Prints recognized text, per-token timestamps, and durations
  3 +
  4 +import os
  5 +import sys
  6 +import sherpa_onnx
  7 +import soundfile as sf
  8 +
  9 +wav_filename = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/test_wavs/en.wav"
  10 +encoder = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/encoder.int8.onnx"
  11 +decoder = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/decoder.int8.onnx"
  12 +joiner = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/joiner.int8.onnx"
  13 +tokens = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/tokens.txt"
  14 +
  15 +if not os.path.exists(wav_filename):
  16 + print(f"File not found: {wav_filename}")
  17 + sys.exit(1)
  18 +
  19 +
  20 +recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
  21 + encoder,
  22 + decoder,
  23 + joiner,
  24 + tokens,
  25 + num_threads=1,
  26 + provider="cpu",
  27 + debug=False,
  28 + decoding_method="greedy_search",
  29 + model_type="nemo_transducer"
  30 +)
  31 +
  32 +audio, sample_rate = sf.read(wav_filename, dtype="float32", always_2d=True)
  33 +audio = audio[:, 0] # use first channel if multi-channel
  34 +stream = recognizer.create_stream()
  35 +stream.accept_waveform(sample_rate, audio)
  36 +recognizer.decode_stream(stream)
  37 +result = stream.result
  38 +
  39 +print(f"Recognized text: {result.text}")
  40 +
  41 +if hasattr(result, "tokens") and hasattr(result, "timestamps") and hasattr(result, "durations"):
  42 + print("Token\tTimestamp\tDuration")
  43 + for token, ts, dur in zip(result.tokens, result.timestamps, result.durations):
  44 + print(f"{token}\t{ts:.2f}\t{dur:.2f}")
  45 +else:
  46 + print("Timestamps or durations not available.")
@@ -523,6 +523,7 @@ type OfflineRecognizerResult struct { @@ -523,6 +523,7 @@ type OfflineRecognizerResult struct {
523 Text string 523 Text string
524 Tokens []string 524 Tokens []string
525 Timestamps []float32 525 Timestamps []float32
  526 + Durations []float32
526 Lang string 527 Lang string
527 Emotion string 528 Emotion string
528 Event string 529 Event string
@@ -872,13 +873,19 @@ func (s *OfflineStream) GetResult() *OfflineRecognizerResult { @@ -872,13 +873,19 @@ func (s *OfflineStream) GetResult() *OfflineRecognizerResult {
872 for i := 0; i < n; i++ { 873 for i := 0; i < n; i++ {
873 result.Tokens[i] = C.GoString(tokens[i]) 874 result.Tokens[i] = C.GoString(tokens[i])
874 } 875 }
875 - if p.timestamps == nil {  
876 - return result 876 + if p.timestamps != nil {
  877 + result.Timestamps = make([]float32, n)
  878 + timestamps := unsafe.Slice(p.timestamps, n)
  879 + for i := 0; i < n; i++ {
  880 + result.Timestamps[i] = float32(timestamps[i])
  881 + }
877 } 882 }
878 - result.Timestamps = make([]float32, n)  
879 - timestamps := unsafe.Slice(p.timestamps, n)  
880 - for i := 0; i < n; i++ {  
881 - result.Timestamps[i] = float32(timestamps[i]) 883 + if p.durations != nil {
  884 + result.Durations = make([]float32, n)
  885 + durations := unsafe.Slice(p.durations, n)
  886 + for i := 0; i < n; i++ {
  887 + result.Durations[i] = float32(durations[i])
  888 + }
882 } 889 }
883 return result 890 return result
884 } 891 }
@@ -689,6 +689,14 @@ const SherpaOnnxOfflineRecognizerResult *SherpaOnnxGetOfflineStreamResult( @@ -689,6 +689,14 @@ const SherpaOnnxOfflineRecognizerResult *SherpaOnnxGetOfflineStreamResult(
689 r->timestamps = nullptr; 689 r->timestamps = nullptr;
690 } 690 }
691 691
  692 + if (!result.durations.empty() && result.durations.size() == r->count) {
  693 + r->durations = new float[r->count];
  694 + std::copy(result.durations.begin(), result.durations.end(),
  695 + r->durations);
  696 + } else {
  697 + r->durations = nullptr;
  698 + }
  699 +
692 r->tokens = tokens; 700 r->tokens = tokens;
693 } else { 701 } else {
694 r->count = 0; 702 r->count = 0;
@@ -705,6 +713,7 @@ void SherpaOnnxDestroyOfflineRecognizerResult( @@ -705,6 +713,7 @@ void SherpaOnnxDestroyOfflineRecognizerResult(
705 if (r) { 713 if (r) {
706 delete[] r->text; 714 delete[] r->text;
707 delete[] r->timestamps; 715 delete[] r->timestamps;
  716 + delete[] r->durations;
708 delete[] r->tokens; 717 delete[] r->tokens;
709 delete[] r->tokens_arr; 718 delete[] r->tokens_arr;
710 delete[] r->json; 719 delete[] r->json;
@@ -614,6 +614,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { @@ -614,6 +614,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
614 // It is NULL if the model does not support timestamps 614 // It is NULL if the model does not support timestamps
615 float *timestamps; 615 float *timestamps;
616 616
  617 + // Pointer to continuous memory which holds durations (in seconds) for each token
  618 + // It is NULL if the model does not support durations
  619 + float *durations;
  620 +
617 // number of entries in timestamps 621 // number of entries in timestamps
618 int32_t count; 622 int32_t count;
619 623
@@ -631,6 +635,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { @@ -631,6 +635,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
631 * "text": "The recognition result", 635 * "text": "The recognition result",
632 * "tokens": [x, x, x], 636 * "tokens": [x, x, x],
633 * "timestamps": [x, x, x], 637 * "timestamps": [x, x, x],
  638 + * "durations": [x, x, x],
634 * "segment": x, 639 * "segment": x,
635 * "start_time": x, 640 * "start_time": x,
636 * "is_final": true|false 641 * "is_final": true|false
@@ -36,6 +36,7 @@ static OfflineRecognitionResult Convert( @@ -36,6 +36,7 @@ static OfflineRecognitionResult Convert(
36 OfflineRecognitionResult r; 36 OfflineRecognitionResult r;
37 r.tokens.reserve(src.tokens.size()); 37 r.tokens.reserve(src.tokens.size());
38 r.timestamps.reserve(src.timestamps.size()); 38 r.timestamps.reserve(src.timestamps.size());
  39 + r.durations.reserve(src.durations.size());
39 40
40 std::string text; 41 std::string text;
41 for (auto i : src.tokens) { 42 for (auto i : src.tokens) {
@@ -66,6 +67,11 @@ static OfflineRecognitionResult Convert( @@ -66,6 +67,11 @@ static OfflineRecognitionResult Convert(
66 r.timestamps.push_back(time); 67 r.timestamps.push_back(time);
67 } 68 }
68 69
  70 + // Copy durations (if present)
  71 + for (auto d : src.durations) {
  72 + r.durations.push_back(d * frame_shift_s);
  73 + }
  74 +
69 return r; 75 return r;
70 } 76 }
71 77
@@ -397,6 +397,18 @@ std::string OfflineRecognitionResult::AsJsonString() const { @@ -397,6 +397,18 @@ std::string OfflineRecognitionResult::AsJsonString() const {
397 os << "], "; 397 os << "], ";
398 398
399 os << "\"" 399 os << "\""
  400 + << "durations"
  401 + << "\""
  402 + << ": ";
  403 + os << "[";
  404 + sep = "";
  405 + for (auto d : durations) {
  406 + os << sep << std::fixed << std::setprecision(2) << d;
  407 + sep = ", ";
  408 + }
  409 + os << "], ";
  410 +
  411 + os << "\""
400 << "tokens" 412 << "tokens"
401 << "\"" 413 << "\""
402 << ":"; 414 << ":";
@@ -38,6 +38,9 @@ struct OfflineRecognitionResult { @@ -38,6 +38,9 @@ struct OfflineRecognitionResult {
38 /// timestamps[i] records the time in seconds when tokens[i] is decoded. 38 /// timestamps[i] records the time in seconds when tokens[i] is decoded.
39 std::vector<float> timestamps; 39 std::vector<float> timestamps;
40 40
  41 + /// durations[i] contains the duration (in seconds) for tokens[i] (TDT models only)
  42 + std::vector<float> durations;
  43 +
41 std::vector<int32_t> words; 44 std::vector<int32_t> words;
42 45
43 std::string AsJsonString() const; 46 std::string AsJsonString() const;
@@ -19,6 +19,11 @@ struct OfflineTransducerDecoderResult { @@ -19,6 +19,11 @@ struct OfflineTransducerDecoderResult {
19 /// timestamps[i] contains the output frame index where tokens[i] is decoded. 19 /// timestamps[i] contains the output frame index where tokens[i] is decoded.
20 /// Note: The index is after subsampling 20 /// Note: The index is after subsampling
21 std::vector<int32_t> timestamps; 21 std::vector<int32_t> timestamps;
  22 +
  23 + /// durations[i] contains the duration for tokens[i] in output frames
  24 + /// (post-subsampling). It is converted to seconds by higher layers
  25 + /// (e.g., Convert() in offline-recognizer-transducer-impl.h).
  26 + std::vector<float> durations;
22 }; 27 };
23 28
24 class OfflineTransducerDecoder { 29 class OfflineTransducerDecoder {
@@ -130,23 +130,31 @@ static OfflineTransducerDecoderResult DecodeOneTDT( @@ -130,23 +130,31 @@ static OfflineTransducerDecoderResult DecodeOneTDT(
130 p_logit[blank_id] -= blank_penalty; 130 p_logit[blank_id] -= blank_penalty;
131 } 131 }
132 132
133 - auto y = static_cast<int32_t>(std::distance(  
134 - static_cast<const float *>(p_logit),  
135 - std::max_element(static_cast<const float *>(p_logit),  
136 - static_cast<const float *>(p_logit) + vocab_size))); 133 + int32_t output_size = shape.back();
  134 + int32_t num_durations = output_size - vocab_size;
137 135
138 - skip = static_cast<int32_t>(std::distance(  
139 - static_cast<const float *>(p_logit) + vocab_size,  
140 - std::max_element(static_cast<const float *>(p_logit) + vocab_size,  
141 - static_cast<const float *>(p_logit) + shape.back()))); 136 + // Split logits into token and duration logits
  137 + const float* token_logits = p_logit;
  138 + const float* duration_logits = p_logit + vocab_size;
142 139
143 - if (skip == 0) {  
144 - skip = 1; 140 + auto y = static_cast<int32_t>(std::distance(
  141 + token_logits,
  142 + std::max_element(token_logits, token_logits + vocab_size)));
  143 +
  144 + skip = 1;
  145 + int32_t duration = 1;
  146 + if (num_durations > 0) {
  147 + duration = static_cast<int32_t>(std::distance(
  148 + duration_logits,
  149 + std::max_element(duration_logits, duration_logits + num_durations)));
  150 + skip = duration;
  151 + if (skip == 0) skip = 1;
145 } 152 }
146 153
147 if (y != blank_id) { 154 if (y != blank_id) {
148 ans.tokens.push_back(y); 155 ans.tokens.push_back(y);
149 ans.timestamps.push_back(t); 156 ans.timestamps.push_back(t);
  157 + ans.durations.push_back(duration);
150 158
151 decoder_input_pair = BuildDecoderInput(y, model->Allocator()); 159 decoder_input_pair = BuildDecoderInput(y, model->Allocator());
152 160
@@ -155,7 +163,7 @@ static OfflineTransducerDecoderResult DecodeOneTDT( @@ -155,7 +163,7 @@ static OfflineTransducerDecoderResult DecodeOneTDT(
155 std::move(decoder_input_pair.second), 163 std::move(decoder_input_pair.second),
156 std::move(decoder_output_pair.second)); 164 std::move(decoder_output_pair.second));
157 } 165 }
158 - } // for (int32_t t = 0; t < num_rows; ++t) { 166 + } // for (int32_t t = 0; t < num_rows; t += skip)
159 167
160 return ans; 168 return ans;
161 } 169 }
@@ -33,17 +33,19 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT @@ -33,17 +33,19 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
33 self.text.size(), "ignore")); 33 self.text.size(), "ignore"));
34 }) 34 })
35 .def_property_readonly("lang", 35 .def_property_readonly("lang",
36 - [](const PyClass &self) { return self.lang; }) 36 + [](const PyClass &self) { return self.lang; })
37 .def_property_readonly("emotion", 37 .def_property_readonly("emotion",
38 - [](const PyClass &self) { return self.emotion; }) 38 + [](const PyClass &self) { return self.emotion; })
39 .def_property_readonly("event", 39 .def_property_readonly("event",
40 - [](const PyClass &self) { return self.event; }) 40 + [](const PyClass &self) { return self.event; })
41 .def_property_readonly("tokens", 41 .def_property_readonly("tokens",
42 - [](const PyClass &self) { return self.tokens; }) 42 + [](const PyClass &self) { return self.tokens; })
43 .def_property_readonly("words", 43 .def_property_readonly("words",
44 - [](const PyClass &self) { return self.words; })  
45 - .def_property_readonly(  
46 - "timestamps", [](const PyClass &self) { return self.timestamps; }); 44 + [](const PyClass &self) { return self.words; })
  45 + .def_property_readonly("timestamps",
  46 + [](const PyClass &self) { return self.timestamps; })
  47 + .def_property_readonly("durations",
  48 + [](const PyClass &self) { return self.durations; });
47 } 49 }
48 50
49 void PybindOfflineStream(py::module *m) { 51 void PybindOfflineStream(py::module *m) {