Fangjun Kuang
Committed by GitHub

Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
@@ -36,4 +36,5 @@ tokens.txt @@ -36,4 +36,5 @@ tokens.txt
36 *.onnx 36 *.onnx
37 log.txt 37 log.txt
38 tags 38 tags
  39 +run-decode-file-python.sh
39 android/SherpaOnnx/app/src/main/assets/ 40 android/SherpaOnnx/app/src/main/assets/
@@ -19,14 +19,16 @@ const char *kUsage = @@ -19,14 +19,16 @@ const char *kUsage =
19 " /path/to/encoder.onnx \\\n" 19 " /path/to/encoder.onnx \\\n"
20 " /path/to/decoder.onnx \\\n" 20 " /path/to/decoder.onnx \\\n"
21 " /path/to/joiner.onnx \\\n" 21 " /path/to/joiner.onnx \\\n"
22 - " /path/to/foo.wav [num_threads]\n" 22 + " /path/to/foo.wav [num_threads [decoding_method]]\n"
23 "\n\n" 23 "\n\n"
  24 + "Default num_threads is 1.\n"
  25 + "Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
24 "Please refer to \n" 26 "Please refer to \n"
25 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" 27 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
26 "for a list of pre-trained models to download.\n"; 28 "for a list of pre-trained models to download.\n";
27 29
28 int32_t main(int32_t argc, char *argv[]) { 30 int32_t main(int32_t argc, char *argv[]) {
29 - if (argc < 6 || argc > 7) { 31 + if (argc < 6 || argc > 8) {
30 fprintf(stderr, "%s\n", kUsage); 32 fprintf(stderr, "%s\n", kUsage);
31 return -1; 33 return -1;
32 } 34 }
@@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) { @@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) {
36 config.model_config.decoder = argv[3]; 38 config.model_config.decoder = argv[3];
37 config.model_config.joiner = argv[4]; 39 config.model_config.joiner = argv[4];
38 40
39 - int32_t num_threads = 4; 41 + int32_t num_threads = 1;
40 if (argc == 7 && atoi(argv[6]) > 0) { 42 if (argc == 7 && atoi(argv[6]) > 0) {
41 num_threads = atoi(argv[6]); 43 num_threads = atoi(argv[6]);
42 } 44 }
43 config.model_config.num_threads = num_threads; 45 config.model_config.num_threads = num_threads;
44 config.model_config.debug = 0; 46 config.model_config.debug = 0;
45 47
  48 + config.decoding_method = "greedy_search";
  49 + if (argc == 8) {
  50 + config.decoding_method = argv[7];
  51 + }
  52 +
  53 + config.max_active_paths = 4;
  54 +
46 config.feat_config.sample_rate = 16000; 55 config.feat_config.sample_rate = 16000;
47 config.feat_config.feature_dim = 80; 56 config.feat_config.feature_dim = 80;
48 57
@@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) { @@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) {
54 SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); 63 SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
55 SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); 64 SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
56 65
  66 + SherpaOnnxDisplay *display = CreateDisplay(50);
  67 + int32_t segment_id = 0;
  68 +
57 const char *wav_filename = argv[5]; 69 const char *wav_filename = argv[5];
58 FILE *fp = fopen(wav_filename, "rb"); 70 FILE *fp = fopen(wav_filename, "rb");
59 if (!fp) { 71 if (!fp) {
@@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) { @@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) {
84 96
85 SherpaOnnxOnlineRecognizerResult *r = 97 SherpaOnnxOnlineRecognizerResult *r =
86 GetOnlineStreamResult(recognizer, stream); 98 GetOnlineStreamResult(recognizer, stream);
  99 +
87 if (strlen(r->text)) { 100 if (strlen(r->text)) {
88 - fprintf(stderr, "%s\n", r->text); 101 + SherpaOnnxPrint(display, segment_id, r->text);
89 } 102 }
  103 +
  104 + if (IsEndpoint(recognizer, stream)) {
  105 + if (strlen(r->text)) {
  106 + ++segment_id;
  107 + }
  108 + Reset(recognizer, stream);
  109 + }
  110 +
90 DestroyOnlineRecognizerResult(r); 111 DestroyOnlineRecognizerResult(r);
91 } 112 }
92 } 113 }
@@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) { @@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) {
103 124
104 SherpaOnnxOnlineRecognizerResult *r = 125 SherpaOnnxOnlineRecognizerResult *r =
105 GetOnlineStreamResult(recognizer, stream); 126 GetOnlineStreamResult(recognizer, stream);
  127 +
106 if (strlen(r->text)) { 128 if (strlen(r->text)) {
107 - fprintf(stderr, "%s\n", r->text); 129 + SherpaOnnxPrint(display, segment_id, r->text);
108 } 130 }
109 131
110 DestroyOnlineRecognizerResult(r); 132 DestroyOnlineRecognizerResult(r);
111 133
  134 + DestroyDisplay(display);
112 DestoryOnlineStream(stream); 135 DestoryOnlineStream(stream);
113 DestroyOnlineRecognizer(recognizer); 136 DestroyOnlineRecognizer(recognizer);
  137 + fprintf(stderr, "\n");
114 138
115 return 0; 139 return 0;
116 } 140 }
@@ -26,12 +26,17 @@ if [ ! -f ./sherpa-onnx-ffmpeg ]; then @@ -26,12 +26,17 @@ if [ ! -f ./sherpa-onnx-ffmpeg ]; then
26 make 26 make
27 fi 27 fi
28 28
29 -../ffmpeg-examples/sherpa-onnx-ffmpeg \ 29 +for method in greedy_search modified_beam_search; do
  30 + echo "test method: $method"
  31 + ../ffmpeg-examples/sherpa-onnx-ffmpeg \
30 ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ 32 ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
31 ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ 33 ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
32 ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ 34 ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
33 ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ 35 ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
34 - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/4.wav 36 + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \
  37 + 2 \
  38 + $method
  39 +done
35 40
36 echo "Decoding a URL" 41 echo "Decoding a URL"
37 42
@@ -7,7 +7,6 @@ @@ -7,7 +7,6 @@
7 7
8 #include "sherpa-onnx/c-api/c-api.h" 8 #include "sherpa-onnx/c-api/c-api.h"
9 9
10 -  
11 /* 10 /*
12 * Copyright (c) 2010 Nicolas George 11 * Copyright (c) 2010 Nicolas George
13 * Copyright (c) 2011 Stefano Sabatini 12 * Copyright (c) 2011 Stefano Sabatini
@@ -43,14 +42,15 @@ @@ -43,14 +42,15 @@
43 #include <unistd.h> 42 #include <unistd.h>
44 extern "C" { 43 extern "C" {
45 #include <libavcodec/avcodec.h> 44 #include <libavcodec/avcodec.h>
46 -#include <libavformat/avformat.h>  
47 #include <libavfilter/buffersink.h> 45 #include <libavfilter/buffersink.h>
48 #include <libavfilter/buffersrc.h> 46 #include <libavfilter/buffersrc.h>
  47 +#include <libavformat/avformat.h>
49 #include <libavutil/channel_layout.h> 48 #include <libavutil/channel_layout.h>
50 #include <libavutil/opt.h> 49 #include <libavutil/opt.h>
51 } 50 }
52 51
53 -static const char *filter_descr = "aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono"; 52 +static const char *filter_descr =
  53 + "aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono";
54 54
55 static AVFormatContext *fmt_ctx; 55 static AVFormatContext *fmt_ctx;
56 static AVCodecContext *dec_ctx; 56 static AVCodecContext *dec_ctx;
@@ -59,8 +59,7 @@ AVFilterContext *buffersrc_ctx; @@ -59,8 +59,7 @@ AVFilterContext *buffersrc_ctx;
59 AVFilterGraph *filter_graph; 59 AVFilterGraph *filter_graph;
60 static int audio_stream_index = -1; 60 static int audio_stream_index = -1;
61 61
62 -static int open_input_file(const char *filename)  
63 -{ 62 +static int open_input_file(const char *filename) {
64 const AVCodec *dec; 63 const AVCodec *dec;
65 int ret; 64 int ret;
66 65
@@ -77,16 +76,17 @@ static int open_input_file(const char *filename) @@ -77,16 +76,17 @@ static int open_input_file(const char *filename)
77 /* select the audio stream */ 76 /* select the audio stream */
78 ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0); 77 ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0);
79 if (ret < 0) { 78 if (ret < 0) {
80 - av_log(NULL, AV_LOG_ERROR, "Cannot find an audio stream in the input file\n"); 79 + av_log(NULL, AV_LOG_ERROR,
  80 + "Cannot find an audio stream in the input file\n");
81 return ret; 81 return ret;
82 } 82 }
83 audio_stream_index = ret; 83 audio_stream_index = ret;
84 84
85 /* create decoding context */ 85 /* create decoding context */
86 dec_ctx = avcodec_alloc_context3(dec); 86 dec_ctx = avcodec_alloc_context3(dec);
87 - if (!dec_ctx)  
88 - return AVERROR(ENOMEM);  
89 - avcodec_parameters_to_context(dec_ctx, fmt_ctx->streams[audio_stream_index]->codecpar); 87 + if (!dec_ctx) return AVERROR(ENOMEM);
  88 + avcodec_parameters_to_context(dec_ctx,
  89 + fmt_ctx->streams[audio_stream_index]->codecpar);
90 90
91 /* init the audio decoder */ 91 /* init the audio decoder */
92 if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) { 92 if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) {
@@ -97,16 +97,16 @@ static int open_input_file(const char *filename) @@ -97,16 +97,16 @@ static int open_input_file(const char *filename)
97 return 0; 97 return 0;
98 } 98 }
99 99
100 -static int init_filters(const char *filters_descr)  
101 -{ 100 +static int init_filters(const char *filters_descr) {
102 char args[512]; 101 char args[512];
103 int ret = 0; 102 int ret = 0;
104 const AVFilter *abuffersrc = avfilter_get_by_name("abuffer"); 103 const AVFilter *abuffersrc = avfilter_get_by_name("abuffer");
105 const AVFilter *abuffersink = avfilter_get_by_name("abuffersink"); 104 const AVFilter *abuffersink = avfilter_get_by_name("abuffersink");
106 AVFilterInOut *outputs = avfilter_inout_alloc(); 105 AVFilterInOut *outputs = avfilter_inout_alloc();
107 AVFilterInOut *inputs = avfilter_inout_alloc(); 106 AVFilterInOut *inputs = avfilter_inout_alloc();
108 - static const enum AVSampleFormat out_sample_fmts[] = { AV_SAMPLE_FMT_S16, AV_SAMPLE_FMT_NONE };  
109 - static const int out_sample_rates[] = { 16000, -1 }; 107 + static const enum AVSampleFormat out_sample_fmts[] = {AV_SAMPLE_FMT_S16,
  108 + AV_SAMPLE_FMT_NONE};
  109 + static const int out_sample_rates[] = {16000, -1};
110 const AVFilterLink *outlink; 110 const AVFilterLink *outlink;
111 AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base; 111 AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base;
112 112
@@ -116,24 +116,27 @@ static int init_filters(const char *filters_descr) @@ -116,24 +116,27 @@ static int init_filters(const char *filters_descr)
116 goto end; 116 goto end;
117 } 117 }
118 118
119 - /* buffer audio source: the decoded frames from the decoder will be inserted here. */ 119 + /* buffer audio source: the decoded frames from the decoder will be inserted
  120 + * here. */
120 if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC) 121 if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC)
121 - av_channel_layout_default(&dec_ctx->ch_layout, dec_ctx->ch_layout.nb_channels); 122 + av_channel_layout_default(&dec_ctx->ch_layout,
  123 + dec_ctx->ch_layout.nb_channels);
122 ret = snprintf(args, sizeof(args), 124 ret = snprintf(args, sizeof(args),
123 "time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=", 125 "time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=",
124 time_base.num, time_base.den, dec_ctx->sample_rate, 126 time_base.num, time_base.den, dec_ctx->sample_rate,
125 av_get_sample_fmt_name(dec_ctx->sample_fmt)); 127 av_get_sample_fmt_name(dec_ctx->sample_fmt));
126 - av_channel_layout_describe(&dec_ctx->ch_layout, args + ret, sizeof(args) - ret);  
127 - ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in",  
128 - args, NULL, filter_graph); 128 + av_channel_layout_describe(&dec_ctx->ch_layout, args + ret,
  129 + sizeof(args) - ret);
  130 + ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in", args,
  131 + NULL, filter_graph);
129 if (ret < 0) { 132 if (ret < 0) {
130 av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n"); 133 av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n");
131 goto end; 134 goto end;
132 } 135 }
133 136
134 /* buffer audio sink: to terminate the filter chain. */ 137 /* buffer audio sink: to terminate the filter chain. */
135 - ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out",  
136 - NULL, NULL, filter_graph); 138 + ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out", NULL,
  139 + NULL, filter_graph);
137 if (ret < 0) { 140 if (ret < 0) {
138 av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n"); 141 av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n");
139 goto end; 142 goto end;
@@ -146,15 +149,15 @@ static int init_filters(const char *filters_descr) @@ -146,15 +149,15 @@ static int init_filters(const char *filters_descr)
146 goto end; 149 goto end;
147 } 150 }
148 151
149 - ret = av_opt_set(buffersink_ctx, "ch_layouts", "mono",  
150 - AV_OPT_SEARCH_CHILDREN); 152 + ret =
  153 + av_opt_set(buffersink_ctx, "ch_layouts", "mono", AV_OPT_SEARCH_CHILDREN);
151 if (ret < 0) { 154 if (ret < 0) {
152 av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n"); 155 av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n");
153 goto end; 156 goto end;
154 } 157 }
155 158
156 - ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates, -1,  
157 - AV_OPT_SEARCH_CHILDREN); 159 + ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates,
  160 + -1, AV_OPT_SEARCH_CHILDREN);
158 if (ret < 0) { 161 if (ret < 0) {
159 av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n"); 162 av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n");
160 goto end; 163 goto end;
@@ -187,12 +190,11 @@ static int init_filters(const char *filters_descr) @@ -187,12 +190,11 @@ static int init_filters(const char *filters_descr)
187 inputs->pad_idx = 0; 190 inputs->pad_idx = 0;
188 inputs->next = NULL; 191 inputs->next = NULL;
189 192
190 - if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr,  
191 - &inputs, &outputs, NULL)) < 0) 193 + if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr, &inputs,
  194 + &outputs, NULL)) < 0)
192 goto end; 195 goto end;
193 196
194 - if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0)  
195 - goto end; 197 + if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0) goto end;
196 198
197 /* Print summary of the sink buffer 199 /* Print summary of the sink buffer
198 * Note: args buffer is reused to store channel layout string */ 200 * Note: args buffer is reused to store channel layout string */
@@ -200,7 +202,8 @@ static int init_filters(const char *filters_descr) @@ -200,7 +202,8 @@ static int init_filters(const char *filters_descr)
200 av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args)); 202 av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args));
201 av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n", 203 av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n",
202 (int)outlink->sample_rate, 204 (int)outlink->sample_rate,
203 - (char *)av_x_if_null(av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"), 205 + (char *)av_x_if_null(
  206 + av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"),
204 args); 207 args);
205 208
206 end: 209 end:
@@ -210,13 +213,15 @@ end: @@ -210,13 +213,15 @@ end:
210 return ret; 213 return ret;
211 } 214 }
212 215
213 -static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer *recognizer,  
214 - SherpaOnnxOnlineStream* stream)  
215 -{ 216 +static void sherpa_decode_frame(const AVFrame *frame,
  217 + SherpaOnnxOnlineRecognizer *recognizer,
  218 + SherpaOnnxOnlineStream *stream,
  219 + SherpaOnnxDisplay *display,
  220 + int32_t *segment_id) {
216 #define N 3200 // 100s. Sample rate is fixed to 16 kHz 221 #define N 3200 // 100s. Sample rate is fixed to 16 kHz
217 static float samples[N]; 222 static float samples[N];
218 static int nb_samples = 0; 223 static int nb_samples = 0;
219 - const int16_t *p = (int16_t*)frame->data[0]; 224 + const int16_t *p = (int16_t *)frame->data[0];
220 225
221 if (frame->nb_samples + nb_samples > N) { 226 if (frame->nb_samples + nb_samples > N) {
222 AcceptWaveform(stream, 16000, samples, nb_samples); 227 AcceptWaveform(stream, 16000, samples, nb_samples);
@@ -224,17 +229,20 @@ static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer @@ -224,17 +229,20 @@ static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer
224 DecodeOnlineStream(recognizer, stream); 229 DecodeOnlineStream(recognizer, stream);
225 } 230 }
226 231
227 -  
228 - if (IsEndpoint(recognizer, stream)) {  
229 SherpaOnnxOnlineRecognizerResult *r = 232 SherpaOnnxOnlineRecognizerResult *r =
230 GetOnlineStreamResult(recognizer, stream); 233 GetOnlineStreamResult(recognizer, stream);
231 if (strlen(r->text)) { 234 if (strlen(r->text)) {
232 - fprintf(stderr, "%s\n", r->text); 235 + SherpaOnnxPrint(display, *segment_id, r->text);
233 } 236 }
234 - DestroyOnlineRecognizerResult(r);  
235 237
  238 + if (IsEndpoint(recognizer, stream)) {
  239 + if (strlen(r->text)) {
  240 + ++*segment_id;
  241 + }
236 Reset(recognizer, stream); 242 Reset(recognizer, stream);
237 } 243 }
  244 +
  245 + DestroyOnlineRecognizerResult(r);
238 nb_samples = 0; 246 nb_samples = 0;
239 } 247 }
240 248
@@ -243,17 +251,15 @@ static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer @@ -243,17 +251,15 @@ static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer
243 } 251 }
244 } 252 }
245 253
246 -static inline char *__av_err2str(int errnum)  
247 -{ 254 +static inline char *__av_err2str(int errnum) {
248 static char str[AV_ERROR_MAX_STRING_SIZE]; 255 static char str[AV_ERROR_MAX_STRING_SIZE];
249 memset(str, 0, sizeof(str)); 256 memset(str, 0, sizeof(str));
250 return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum); 257 return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum);
251 } 258 }
252 259
253 -int main(int argc, char **argv)  
254 -{ 260 +int main(int argc, char **argv) {
255 int ret; 261 int ret;
256 - int num_threads = 4; 262 + int num_threads = 1;
257 AVPacket *packet = av_packet_alloc(); 263 AVPacket *packet = av_packet_alloc();
258 AVFrame *frame = av_frame_alloc(); 264 AVFrame *frame = av_frame_alloc();
259 AVFrame *filt_frame = av_frame_alloc(); 265 AVFrame *filt_frame = av_frame_alloc();
@@ -265,19 +271,20 @@ int main(int argc, char **argv) @@ -265,19 +271,20 @@ int main(int argc, char **argv)
265 " /path/to/encoder.onnx\\\n" 271 " /path/to/encoder.onnx\\\n"
266 " /path/to/decoder.onnx\\\n" 272 " /path/to/decoder.onnx\\\n"
267 " /path/to/joiner.onnx\\\n" 273 " /path/to/joiner.onnx\\\n"
268 - " /path/to/foo.wav [num_threads]" 274 + " /path/to/foo.wav [num_threads [decoding_method]]"
269 "\n\n" 275 "\n\n"
  276 + "Default num_threads is 1.\n"
  277 + "Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
270 "Please refer to \n" 278 "Please refer to \n"
271 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" 279 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
272 "for a list of pre-trained models to download.\n"; 280 "for a list of pre-trained models to download.\n";
273 281
274 -  
275 if (!packet || !frame || !filt_frame) { 282 if (!packet || !frame || !filt_frame) {
276 fprintf(stderr, "Could not allocate frame or packet\n"); 283 fprintf(stderr, "Could not allocate frame or packet\n");
277 exit(1); 284 exit(1);
278 } 285 }
279 286
280 - if (argc < 6 || argc > 7) { 287 + if (argc < 6 || argc > 8) {
281 fprintf(stderr, "%s\n", kUsage); 288 fprintf(stderr, "%s\n", kUsage);
282 return -1; 289 return -1;
283 } 290 }
@@ -291,12 +298,20 @@ int main(int argc, char **argv) @@ -291,12 +298,20 @@ int main(int argc, char **argv)
291 if (argc == 7 && atoi(argv[6]) > 0) { 298 if (argc == 7 && atoi(argv[6]) > 0) {
292 num_threads = atoi(argv[6]); 299 num_threads = atoi(argv[6]);
293 } 300 }
  301 +
294 config.model_config.num_threads = num_threads; 302 config.model_config.num_threads = num_threads;
295 config.model_config.debug = 0; 303 config.model_config.debug = 0;
296 304
297 config.feat_config.sample_rate = 16000; 305 config.feat_config.sample_rate = 16000;
298 config.feat_config.feature_dim = 80; 306 config.feat_config.feature_dim = 80;
299 307
  308 + config.decoding_method = "greedy_search";
  309 + if (argc == 8) {
  310 + config.decoding_method = argv[7];
  311 + }
  312 +
  313 + config.max_active_paths = 4;
  314 +
300 config.enable_endpoint = 1; 315 config.enable_endpoint = 1;
301 config.rule1_min_trailing_silence = 2.4; 316 config.rule1_min_trailing_silence = 2.4;
302 config.rule2_min_trailing_silence = 1.2; 317 config.rule2_min_trailing_silence = 1.2;
@@ -304,22 +319,22 @@ int main(int argc, char **argv) @@ -304,22 +319,22 @@ int main(int argc, char **argv)
304 319
305 SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); 320 SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
306 SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); 321 SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
  322 + SherpaOnnxDisplay *display = CreateDisplay(50);
  323 + int32_t segment_id = 0;
307 324
308 - if ((ret = open_input_file(argv[5])) < 0)  
309 - exit(1); 325 + if ((ret = open_input_file(argv[5])) < 0) exit(1);
310 326
311 - if ((ret = init_filters(filter_descr)) < 0)  
312 - exit(1); 327 + if ((ret = init_filters(filter_descr)) < 0) exit(1);
313 328
314 /* read all packets */ 329 /* read all packets */
315 while (1) { 330 while (1) {
316 - if ((ret = av_read_frame(fmt_ctx, packet)) < 0)  
317 - break; 331 + if ((ret = av_read_frame(fmt_ctx, packet)) < 0) break;
318 332
319 if (packet->stream_index == audio_stream_index) { 333 if (packet->stream_index == audio_stream_index) {
320 ret = avcodec_send_packet(dec_ctx, packet); 334 ret = avcodec_send_packet(dec_ctx, packet);
321 if (ret < 0) { 335 if (ret < 0) {
322 - av_log(NULL, AV_LOG_ERROR, "Error while sending a packet to the decoder\n"); 336 + av_log(NULL, AV_LOG_ERROR,
  337 + "Error while sending a packet to the decoder\n");
323 break; 338 break;
324 } 339 }
325 340
@@ -328,25 +343,27 @@ int main(int argc, char **argv) @@ -328,25 +343,27 @@ int main(int argc, char **argv)
328 if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { 343 if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
329 break; 344 break;
330 } else if (ret < 0) { 345 } else if (ret < 0) {
331 - av_log(NULL, AV_LOG_ERROR, "Error while receiving a frame from the decoder\n"); 346 + av_log(NULL, AV_LOG_ERROR,
  347 + "Error while receiving a frame from the decoder\n");
332 exit(1); 348 exit(1);
333 } 349 }
334 350
335 if (ret >= 0) { 351 if (ret >= 0) {
336 /* push the audio data from decoded frame into the filtergraph */ 352 /* push the audio data from decoded frame into the filtergraph */
337 - if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame, AV_BUFFERSRC_FLAG_KEEP_REF) < 0) {  
338 - av_log(NULL, AV_LOG_ERROR, "Error while feeding the audio filtergraph\n"); 353 + if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame,
  354 + AV_BUFFERSRC_FLAG_KEEP_REF) < 0) {
  355 + av_log(NULL, AV_LOG_ERROR,
  356 + "Error while feeding the audio filtergraph\n");
339 break; 357 break;
340 } 358 }
341 359
342 /* pull filtered audio from the filtergraph */ 360 /* pull filtered audio from the filtergraph */
343 while (1) { 361 while (1) {
344 ret = av_buffersink_get_frame(buffersink_ctx, filt_frame); 362 ret = av_buffersink_get_frame(buffersink_ctx, filt_frame);
345 - if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF)  
346 - break;  
347 - if (ret < 0)  
348 - exit(1);  
349 - sherpa_decode_frame(filt_frame, recognizer, stream); 363 + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) break;
  364 + if (ret < 0) exit(1);
  365 + sherpa_decode_frame(filt_frame, recognizer, stream, display,
  366 + &segment_id);
350 av_frame_unref(filt_frame); 367 av_frame_unref(filt_frame);
351 } 368 }
352 av_frame_unref(frame); 369 av_frame_unref(frame);
@@ -368,11 +385,12 @@ int main(int argc, char **argv) @@ -368,11 +385,12 @@ int main(int argc, char **argv)
368 SherpaOnnxOnlineRecognizerResult *r = 385 SherpaOnnxOnlineRecognizerResult *r =
369 GetOnlineStreamResult(recognizer, stream); 386 GetOnlineStreamResult(recognizer, stream);
370 if (strlen(r->text)) { 387 if (strlen(r->text)) {
371 - fprintf(stderr, "%s\n", r->text); 388 + SherpaOnnxPrint(display, segment_id, r->text);
372 } 389 }
373 390
374 DestroyOnlineRecognizerResult(r); 391 DestroyOnlineRecognizerResult(r);
375 392
  393 + DestroyDisplay(display);
376 DestoryOnlineStream(stream); 394 DestoryOnlineStream(stream);
377 DestroyOnlineRecognizer(recognizer); 395 DestroyOnlineRecognizer(recognizer);
378 396
@@ -387,6 +405,7 @@ int main(int argc, char **argv) @@ -387,6 +405,7 @@ int main(int argc, char **argv)
387 fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret)); 405 fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret));
388 exit(1); 406 exit(1);
389 } 407 }
  408 + fprintf(stderr, "\n");
390 409
391 return 0; 410 return 0;
392 } 411 }
@@ -54,6 +54,20 @@ def get_args(): @@ -54,6 +54,20 @@ def get_args():
54 ) 54 )
55 55
56 parser.add_argument( 56 parser.add_argument(
  57 + "--num-threads",
  58 + type=int,
  59 + default=1,
  60 + help="Number of threads for neural network computation",
  61 + )
  62 +
  63 + parser.add_argument(
  64 + "--decoding-method",
  65 + type=str,
  66 + default="greedy_search",
  67 + help="Valid values are greedy_search and modified_beam_search",
  68 + )
  69 +
  70 + parser.add_argument(
57 "--wave-filename", 71 "--wave-filename",
58 type=str, 72 type=str,
59 help="""Path to the wave filename. Must be 16 kHz, 73 help="""Path to the wave filename. Must be 16 kHz,
@@ -65,7 +79,6 @@ def get_args(): @@ -65,7 +79,6 @@ def get_args():
65 79
66 def main(): 80 def main():
67 sample_rate = 16000 81 sample_rate = 16000
68 - num_threads = 2  
69 82
70 args = get_args() 83 args = get_args()
71 assert_file_exists(args.encoder) 84 assert_file_exists(args.encoder)
@@ -81,9 +94,10 @@ def main(): @@ -81,9 +94,10 @@ def main():
81 encoder=args.encoder, 94 encoder=args.encoder,
82 decoder=args.decoder, 95 decoder=args.decoder,
83 joiner=args.joiner, 96 joiner=args.joiner,
84 - num_threads=num_threads, 97 + num_threads=args.num_threads,
85 sample_rate=sample_rate, 98 sample_rate=sample_rate,
86 feature_dim=80, 99 feature_dim=80,
  100 + decoding_method=args.decoding_method,
87 ) 101 )
88 with wave.open(args.wave_filename) as f: 102 with wave.open(args.wave_filename) as f:
89 assert f.getframerate() == sample_rate, f.getframerate() 103 assert f.getframerate() == sample_rate, f.getframerate()
@@ -119,7 +133,8 @@ def main(): @@ -119,7 +133,8 @@ def main():
119 end_time = time.time() 133 end_time = time.time()
120 elapsed_seconds = end_time - start_time 134 elapsed_seconds = end_time - start_time
121 rtf = elapsed_seconds / duration 135 rtf = elapsed_seconds / duration
122 - print(f"num_threads: {num_threads}") 136 + print(f"num_threads: {args.num_threads}")
  137 + print(f"decoding_method: {args.decoding_method}")
123 print(f"Wave duration: {duration:.3f} s") 138 print(f"Wave duration: {duration:.3f} s")
124 print(f"Elapsed time: {elapsed_seconds:.3f} s") 139 print(f"Elapsed time: {elapsed_seconds:.3f} s")
125 print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") 140 print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
@@ -60,10 +60,10 @@ def get_args(): @@ -60,10 +60,10 @@ def get_args():
60 ) 60 )
61 61
62 parser.add_argument( 62 parser.add_argument(
63 - "--wave-filename", 63 + "--decoding-method",
64 type=str, 64 type=str,
65 - help="""Path to the wave filename. Must be 16 kHz,  
66 - mono with 16-bit samples""", 65 + default="greedy_search",
  66 + help="Valid values are greedy_search and modified_beam_search",
67 ) 67 )
68 68
69 return parser.parse_args() 69 return parser.parse_args()
@@ -83,17 +83,23 @@ def create_recognizer(): @@ -83,17 +83,23 @@ def create_recognizer():
83 encoder=args.encoder, 83 encoder=args.encoder,
84 decoder=args.decoder, 84 decoder=args.decoder,
85 joiner=args.joiner, 85 joiner=args.joiner,
  86 + num_threads=1,
  87 + sample_rate=16000,
  88 + feature_dim=80,
86 enable_endpoint_detection=True, 89 enable_endpoint_detection=True,
87 rule1_min_trailing_silence=2.4, 90 rule1_min_trailing_silence=2.4,
88 rule2_min_trailing_silence=1.2, 91 rule2_min_trailing_silence=1.2,
89 rule3_min_utterance_length=300, # it essentially disables this rule 92 rule3_min_utterance_length=300, # it essentially disables this rule
  93 + decoding_method=args.decoding_method,
  94 + max_feature_vectors=100, # 1 second
90 ) 95 )
91 return recognizer 96 return recognizer
92 97
93 98
94 def main(): 99 def main():
95 - print("Started! Please speak")  
96 recognizer = create_recognizer() 100 recognizer = create_recognizer()
  101 + print("Started! Please speak")
  102 +
97 sample_rate = 16000 103 sample_rate = 16000
98 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms 104 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
99 last_result = "" 105 last_result = ""
@@ -101,6 +107,7 @@ def main(): @@ -101,6 +107,7 @@ def main():
101 107
102 last_result = "" 108 last_result = ""
103 segment_id = 0 109 segment_id = 0
  110 + display = sherpa_onnx.Display(max_word_per_line=30)
104 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: 111 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
105 while True: 112 while True:
106 samples, _ = s.read(samples_per_read) # a blocking read 113 samples, _ = s.read(samples_per_read) # a blocking read
@@ -115,7 +122,7 @@ def main(): @@ -115,7 +122,7 @@ def main():
115 122
116 if result and (last_result != result): 123 if result and (last_result != result):
117 last_result = result 124 last_result = result
118 - print(f"{segment_id}: {result}") 125 + display.print(segment_id, result)
119 126
120 if is_endpoint: 127 if is_endpoint:
121 if result: 128 if result:
@@ -59,10 +59,10 @@ def get_args(): @@ -59,10 +59,10 @@ def get_args():
59 ) 59 )
60 60
61 parser.add_argument( 61 parser.add_argument(
62 - "--wave-filename", 62 + "--decoding-method",
63 type=str, 63 type=str,
64 - help="""Path to the wave filename. Must be 16 kHz,  
65 - mono with 16-bit samples""", 64 + default="greedy_search",
  65 + help="Valid values are greedy_search and modified_beam_search",
66 ) 66 )
67 67
68 return parser.parse_args() 68 return parser.parse_args()
@@ -82,9 +82,11 @@ def create_recognizer(): @@ -82,9 +82,11 @@ def create_recognizer():
82 encoder=args.encoder, 82 encoder=args.encoder,
83 decoder=args.decoder, 83 decoder=args.decoder,
84 joiner=args.joiner, 84 joiner=args.joiner,
85 - num_threads=4, 85 + num_threads=1,
86 sample_rate=16000, 86 sample_rate=16000,
87 feature_dim=80, 87 feature_dim=80,
  88 + decoding_method=args.decoding_method,
  89 + max_feature_vectors=100, # 1 second
88 ) 90 )
89 return recognizer 91 return recognizer
90 92
@@ -96,6 +98,7 @@ def main(): @@ -96,6 +98,7 @@ def main():
96 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms 98 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
97 last_result = "" 99 last_result = ""
98 stream = recognizer.create_stream() 100 stream = recognizer.create_stream()
  101 + display = sherpa_onnx.Display(max_word_per_line=40)
99 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: 102 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
100 while True: 103 while True:
101 samples, _ = s.read(samples_per_read) # a blocking read 104 samples, _ = s.read(samples_per_read) # a blocking read
@@ -106,7 +109,7 @@ def main(): @@ -106,7 +109,7 @@ def main():
106 result = recognizer.get_result(stream) 109 result = recognizer.get_result(stream)
107 if last_result != result: 110 if last_result != result:
108 last_result = result 111 last_result = result
109 - print(result) 112 + display.print(-1, result)
110 113
111 114
112 if __name__ == "__main__": 115 if __name__ == "__main__":
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <utility> 9 #include <utility>
10 #include <vector> 10 #include <vector>
11 11
  12 +#include "sherpa-onnx/csrc/display.h"
12 #include "sherpa-onnx/csrc/online-recognizer.h" 13 #include "sherpa-onnx/csrc/online-recognizer.h"
13 14
14 struct SherpaOnnxOnlineRecognizer { 15 struct SherpaOnnxOnlineRecognizer {
@@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream { @@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream {
21 : impl(std::move(p)) {} 22 : impl(std::move(p)) {}
22 }; 23 };
23 24
  25 +struct SherpaOnnxDisplay {
  26 + std::unique_ptr<sherpa_onnx::Display> impl;
  27 +};
  28 +
24 SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( 29 SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
25 const SherpaOnnxOnlineRecognizerConfig *config) { 30 const SherpaOnnxOnlineRecognizerConfig *config) {
26 sherpa_onnx::OnlineRecognizerConfig recognizer_config; 31 sherpa_onnx::OnlineRecognizerConfig recognizer_config;
@@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( @@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
37 recognizer_config.model_config.num_threads = config->model_config.num_threads; 42 recognizer_config.model_config.num_threads = config->model_config.num_threads;
38 recognizer_config.model_config.debug = config->model_config.debug; 43 recognizer_config.model_config.debug = config->model_config.debug;
39 44
  45 + recognizer_config.decoding_method = config->decoding_method;
  46 + recognizer_config.max_active_paths = config->max_active_paths;
  47 +
40 recognizer_config.enable_endpoint = config->enable_endpoint; 48 recognizer_config.enable_endpoint = config->enable_endpoint;
41 49
42 recognizer_config.endpoint_config.rule1.min_trailing_silence = 50 recognizer_config.endpoint_config.rule1.min_trailing_silence =
@@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, @@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
124 SherpaOnnxOnlineStream *stream) { 132 SherpaOnnxOnlineStream *stream) {
125 return recognizer->impl->IsEndpoint(stream->impl.get()); 133 return recognizer->impl->IsEndpoint(stream->impl.get());
126 } 134 }
  135 +
  136 +SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) {
  137 + SherpaOnnxDisplay *ans = new SherpaOnnxDisplay;
  138 + ans->impl = std::make_unique<sherpa_onnx::Display>(max_word_per_line);
  139 + return ans;
  140 +}
  141 +
  142 +void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; }
  143 +
  144 +void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) {
  145 + display->impl->Print(idx, s);
  146 +}
@@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig { @@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig {
48 SherpaOnnxFeatureConfig feat_config; 48 SherpaOnnxFeatureConfig feat_config;
49 SherpaOnnxOnlineTransducerModelConfig model_config; 49 SherpaOnnxOnlineTransducerModelConfig model_config;
50 50
  51 + /// Possible values are: greedy_search, modified_beam_search
  52 + const char *decoding_method;
  53 +
  54 + /// Used only when decoding_method is modified_beam_search
  55 + /// Example value: 4
  56 + int32_t max_active_paths;
  57 +
51 /// 0 to disable endpoint detection. 58 /// 0 to disable endpoint detection.
52 /// A non-zero value to enable endpoint detection. 59 /// A non-zero value to enable endpoint detection.
53 int32_t enable_endpoint; 60 int32_t enable_endpoint;
@@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream); @@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream);
187 int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, 194 int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
188 SherpaOnnxOnlineStream *stream); 195 SherpaOnnxOnlineStream *stream);
189 196
  197 +// for displaying results on Linux/macOS.
  198 +typedef struct SherpaOnnxDisplay SherpaOnnxDisplay;
  199 +
  200 +/// Create a display object. Must be freed using DestroyDisplay to avoid
  201 +/// memory leak.
  202 +SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line);
  203 +
  204 +void DestroyDisplay(SherpaOnnxDisplay *display);
  205 +
  206 +/// Print the result.
  207 +void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s);
  208 +
190 #ifdef __cplusplus 209 #ifdef __cplusplus
191 } /* extern "C" */ 210 } /* extern "C" */
192 #endif 211 #endif
@@ -9,10 +9,11 @@ set(sources @@ -9,10 +9,11 @@ set(sources
9 online-lstm-transducer-model.cc 9 online-lstm-transducer-model.cc
10 online-recognizer.cc 10 online-recognizer.cc
11 online-stream.cc 11 online-stream.cc
  12 + online-transducer-decoder.cc
12 online-transducer-greedy-search-decoder.cc 13 online-transducer-greedy-search-decoder.cc
13 online-transducer-model-config.cc 14 online-transducer-model-config.cc
14 - online-transducer-modified-beam-search-decoder.cc  
15 online-transducer-model.cc 15 online-transducer-model.cc
  16 + online-transducer-modified-beam-search-decoder.cc
16 online-zipformer-transducer-model.cc 17 online-zipformer-transducer-model.cc
17 onnx-utils.cc 18 onnx-utils.cc
18 parse-options.cc 19 parse-options.cc
@@ -12,9 +12,16 @@ namespace sherpa_onnx { @@ -12,9 +12,16 @@ namespace sherpa_onnx {
12 12
13 class Display { 13 class Display {
14 public: 14 public:
  15 + explicit Display(int32_t max_word_per_line = 60)
  16 + : max_word_per_line_(max_word_per_line) {}
  17 +
15 void Print(int32_t segment_id, const std::string &s) { 18 void Print(int32_t segment_id, const std::string &s) {
16 #ifdef _MSC_VER 19 #ifdef _MSC_VER
  20 + if (segment_id != -1) {
17 fprintf(stderr, "%d:%s\n", segment_id, s.c_str()); 21 fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
  22 + } else {
  23 + fprintf(stderr, "%s\n", s.c_str());
  24 + }
18 return; 25 return;
19 #endif 26 #endif
20 if (last_segment_ == segment_id) { 27 if (last_segment_ == segment_id) {
@@ -27,7 +34,9 @@ class Display { @@ -27,7 +34,9 @@ class Display {
27 num_previous_lines_ = 0; 34 num_previous_lines_ = 0;
28 } 35 }
29 36
  37 + if (segment_id != -1) {
30 fprintf(stderr, "\r%d:", segment_id); 38 fprintf(stderr, "\r%d:", segment_id);
  39 + }
31 40
32 int32_t i = 0; 41 int32_t i = 0;
33 for (size_t n = 0; n < s.size();) { 42 for (size_t n = 0; n < s.size();) {
@@ -69,7 +78,7 @@ class Display { @@ -69,7 +78,7 @@ class Display {
69 void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); } 78 void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); }
70 79
71 private: 80 private:
72 - int32_t max_word_per_line_ = 60; 81 + int32_t max_word_per_line_;
73 int32_t num_previous_lines_ = 0; 82 int32_t num_previous_lines_ = 0;
74 int32_t last_segment_ = -1; 83 int32_t last_segment_ = -1;
75 }; 84 };
@@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const { @@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const {
28 28
29 os << "FeatureExtractorConfig("; 29 os << "FeatureExtractorConfig(";
30 os << "sampling_rate=" << sampling_rate << ", "; 30 os << "sampling_rate=" << sampling_rate << ", ";
31 - os << "feature_dim=" << feature_dim << ")"; 31 + os << "feature_dim=" << feature_dim << ", ";
  32 + os << "max_feature_vectors=" << max_feature_vectors << ")";
32 33
33 return os.str(); 34 return os.str();
34 } 35 }
@@ -40,9 +41,7 @@ class FeatureExtractor::Impl { @@ -40,9 +41,7 @@ class FeatureExtractor::Impl {
40 opts_.frame_opts.snip_edges = false; 41 opts_.frame_opts.snip_edges = false;
41 opts_.frame_opts.samp_freq = config.sampling_rate; 42 opts_.frame_opts.samp_freq = config.sampling_rate;
42 43
43 - // cache 100 seconds of feature frames, which is more than enough  
44 - // for real needs  
45 - opts_.frame_opts.max_feature_vectors = 100 * 100; 44 + opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
46 45
47 opts_.mel_opts.num_bins = config.feature_dim; 46 opts_.mel_opts.num_bins = config.feature_dim;
48 47
@@ -16,6 +16,7 @@ namespace sherpa_onnx { @@ -16,6 +16,7 @@ namespace sherpa_onnx {
16 struct FeatureExtractorConfig { 16 struct FeatureExtractorConfig {
17 float sampling_rate = 16000; 17 float sampling_rate = 16000;
18 int32_t feature_dim = 80; 18 int32_t feature_dim = 80;
  19 + int32_t max_feature_vectors = -1;
19 20
20 std::string ToString() const; 21 std::string ToString() const;
21 22
@@ -18,7 +18,7 @@ namespace sherpa_onnx { @@ -18,7 +18,7 @@ namespace sherpa_onnx {
18 18
19 struct Hypothesis { 19 struct Hypothesis {
20 // The predicted tokens so far. Newly predicated tokens are appended. 20 // The predicted tokens so far. Newly predicated tokens are appended.
21 - std::vector<int32_t> ys; 21 + std::vector<int64_t> ys;
22 22
23 // timestamps[i] contains the frame number after subsampling 23 // timestamps[i] contains the frame number after subsampling
24 // on which ys[i] is decoded. 24 // on which ys[i] is decoded.
@@ -30,7 +30,7 @@ struct Hypothesis { @@ -30,7 +30,7 @@ struct Hypothesis {
30 int32_t num_trailing_blanks = 0; 30 int32_t num_trailing_blanks = 0;
31 31
32 Hypothesis() = default; 32 Hypothesis() = default;
33 - Hypothesis(const std::vector<int32_t> &ys, double log_prob) 33 + Hypothesis(const std::vector<int64_t> &ys, double log_prob)
34 : ys(ys), log_prob(log_prob) {} 34 : ys(ys), log_prob(log_prob) {}
35 35
36 // If two Hypotheses have the same `Key`, then they contain 36 // If two Hypotheses have the same `Key`, then they contain
@@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
43 "True to enable endpoint detection. False to disable it."); 43 "True to enable endpoint detection. False to disable it.");
44 po->Register("max-active-paths", &max_active_paths, 44 po->Register("max-active-paths", &max_active_paths,
45 "beam size used in modified beam search."); 45 "beam size used in modified beam search.");
46 - po->Register("decoding-mothod", &decoding_method, 46 + po->Register("decoding-method", &decoding_method,
47 "decoding method," 47 "decoding method,"
48 "now support greedy_search and modified_beam_search."); 48 "now support greedy_search and modified_beam_search.");
49 } 49 }
@@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const {
59 os << "feat_config=" << feat_config.ToString() << ", "; 59 os << "feat_config=" << feat_config.ToString() << ", ";
60 os << "model_config=" << model_config.ToString() << ", "; 60 os << "model_config=" << model_config.ToString() << ", ";
61 os << "endpoint_config=" << endpoint_config.ToString() << ", "; 61 os << "endpoint_config=" << endpoint_config.ToString() << ", ";
62 - os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ",";  
63 - os << "max_active_paths=" << max_active_paths << ","; 62 + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
  63 + os << "max_active_paths=" << max_active_paths << ", ";
64 os << "decoding_method=\"" << decoding_method << "\")"; 64 os << "decoding_method=\"" << decoding_method << "\")";
65 65
66 return os.str(); 66 return os.str();
@@ -187,16 +187,14 @@ class OnlineRecognizer::Impl { @@ -187,16 +187,14 @@ class OnlineRecognizer::Impl {
187 } 187 }
188 188
189 void Reset(OnlineStream *s) const { 189 void Reset(OnlineStream *s) const {
190 - // reset result, neural network model state, and  
191 - // the feature extractor state  
192 -  
193 - // reset result 190 + // we keep the decoder_out
  191 + decoder_->UpdateDecoderOut(&s->GetResult());
  192 + Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
194 s->SetResult(decoder_->GetEmptyResult()); 193 s->SetResult(decoder_->GetEmptyResult());
  194 + s->GetResult().decoder_out = std::move(decoder_out);
195 195
196 - // reset neural network model state  
197 - s->SetStates(model_->GetEncoderInitStates());  
198 -  
199 - // reset feature extractor 196 + // Note: We only update counters. The underlying audio samples
  197 + // are not discarded.
200 s->Reset(); 198 s->Reset();
201 } 199 }
202 200
@@ -33,21 +33,26 @@ struct OnlineRecognizerConfig { @@ -33,21 +33,26 @@ struct OnlineRecognizerConfig {
33 OnlineTransducerModelConfig model_config; 33 OnlineTransducerModelConfig model_config;
34 EndpointConfig endpoint_config; 34 EndpointConfig endpoint_config;
35 bool enable_endpoint = true; 35 bool enable_endpoint = true;
36 - int32_t max_active_paths = 4;  
37 36
38 - std::string decoding_method = "modified_beam_search"; 37 + std::string decoding_method = "greedy_search";
39 // now support modified_beam_search and greedy_search 38 // now support modified_beam_search and greedy_search
40 39
  40 + int32_t max_active_paths = 4; // used only for modified_beam_search
  41 +
41 OnlineRecognizerConfig() = default; 42 OnlineRecognizerConfig() = default;
42 43
43 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, 44 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
44 const OnlineTransducerModelConfig &model_config, 45 const OnlineTransducerModelConfig &model_config,
45 const EndpointConfig &endpoint_config, 46 const EndpointConfig &endpoint_config,
46 - bool enable_endpoint) 47 + bool enable_endpoint,
  48 + const std::string &decoding_method,
  49 + int32_t max_active_paths)
47 : feat_config(feat_config), 50 : feat_config(feat_config),
48 model_config(model_config), 51 model_config(model_config),
49 endpoint_config(endpoint_config), 52 endpoint_config(endpoint_config),
50 - enable_endpoint(enable_endpoint) {} 53 + enable_endpoint(enable_endpoint),
  54 + decoding_method(decoding_method),
  55 + max_active_paths(max_active_paths) {}
51 56
52 void Register(ParseOptions *po); 57 void Register(ParseOptions *po);
53 bool Validate() const; 58 bool Validate() const;
@@ -22,18 +22,21 @@ class OnlineStream::Impl { @@ -22,18 +22,21 @@ class OnlineStream::Impl {
22 22
23 void InputFinished() { feat_extractor_.InputFinished(); } 23 void InputFinished() { feat_extractor_.InputFinished(); }
24 24
25 - int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); } 25 + int32_t NumFramesReady() const {
  26 + return feat_extractor_.NumFramesReady() - start_frame_index_;
  27 + }
26 28
27 bool IsLastFrame(int32_t frame) const { 29 bool IsLastFrame(int32_t frame) const {
28 return feat_extractor_.IsLastFrame(frame); 30 return feat_extractor_.IsLastFrame(frame);
29 } 31 }
30 32
31 std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { 33 std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
32 - return feat_extractor_.GetFrames(frame_index, n); 34 + return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
33 } 35 }
34 36
35 void Reset() { 37 void Reset() {
36 - feat_extractor_.Reset(); 38 + // we don't reset the feature extractor
  39 + start_frame_index_ += num_processed_frames_;
37 num_processed_frames_ = 0; 40 num_processed_frames_ = 0;
38 } 41 }
39 42
@@ -41,7 +44,7 @@ class OnlineStream::Impl { @@ -41,7 +44,7 @@ class OnlineStream::Impl {
41 44
42 void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } 45 void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
43 46
44 - const OnlineTransducerDecoderResult &GetResult() const { return result_; } 47 + OnlineTransducerDecoderResult &GetResult() { return result_; }
45 48
46 int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } 49 int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
47 50
@@ -54,6 +57,7 @@ class OnlineStream::Impl { @@ -54,6 +57,7 @@ class OnlineStream::Impl {
54 private: 57 private:
55 FeatureExtractor feat_extractor_; 58 FeatureExtractor feat_extractor_;
56 int32_t num_processed_frames_ = 0; // before subsampling 59 int32_t num_processed_frames_ = 0; // before subsampling
  60 + int32_t start_frame_index_ = 0; // never reset
57 OnlineTransducerDecoderResult result_; 61 OnlineTransducerDecoderResult result_;
58 std::vector<Ort::Value> states_; 62 std::vector<Ort::Value> states_;
59 }; 63 };
@@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { @@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
93 impl_->SetResult(r); 97 impl_->SetResult(r);
94 } 98 }
95 99
96 -const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { 100 +OnlineTransducerDecoderResult &OnlineStream::GetResult() {
97 return impl_->GetResult(); 101 return impl_->GetResult();
98 } 102 }
99 103
@@ -63,7 +63,7 @@ class OnlineStream { @@ -63,7 +63,7 @@ class OnlineStream {
63 int32_t &GetNumProcessedFrames(); 63 int32_t &GetNumProcessedFrames();
64 64
65 void SetResult(const OnlineTransducerDecoderResult &r); 65 void SetResult(const OnlineTransducerDecoderResult &r);
66 - const OnlineTransducerDecoderResult &GetResult() const; 66 + OnlineTransducerDecoderResult &GetResult();
67 67
68 void SetStates(std::vector<Ort::Value> states); 68 void SetStates(std::vector<Ort::Value> states);
69 std::vector<Ort::Value> &GetStates(); 69 std::vector<Ort::Value> &GetStates();
  1 +// sherpa-onnx/csrc/online-transducer-decoder.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  6 +
  7 +#include <utility>
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +#include "sherpa-onnx/csrc/onnx-utils.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
  16 + const OnlineTransducerDecoderResult &other)
  17 + : OnlineTransducerDecoderResult() {
  18 + *this = other;
  19 +}
  20 +
  21 +OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
  22 + const OnlineTransducerDecoderResult &other) {
  23 + if (this == &other) {
  24 + return *this;
  25 + }
  26 +
  27 + tokens = other.tokens;
  28 + num_trailing_blanks = other.num_trailing_blanks;
  29 +
  30 + Ort::AllocatorWithDefaultOptions allocator;
  31 + if (other.decoder_out) {
  32 + decoder_out = Clone(allocator, &other.decoder_out);
  33 + }
  34 +
  35 + hyps = other.hyps;
  36 +
  37 + return *this;
  38 +}
  39 +
  40 +OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
  41 + OnlineTransducerDecoderResult &&other)
  42 + : OnlineTransducerDecoderResult() {
  43 + *this = std::move(other);
  44 +}
  45 +
  46 +OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
  47 + OnlineTransducerDecoderResult &&other) {
  48 + if (this == &other) {
  49 + return *this;
  50 + }
  51 +
  52 + tokens = std::move(other.tokens);
  53 + num_trailing_blanks = other.num_trailing_blanks;
  54 + decoder_out = std::move(other.decoder_out);
  55 + hyps = std::move(other.hyps);
  56 +
  57 + return *this;
  58 +}
  59 +
  60 +} // namespace sherpa_onnx
@@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult { @@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult {
19 /// number of trailing blank frames decoded so far 19 /// number of trailing blank frames decoded so far
20 int32_t num_trailing_blanks = 0; 20 int32_t num_trailing_blanks = 0;
21 21
  22 + // Cache decoder_out for endpointing
  23 + Ort::Value decoder_out;
  24 +
22 // used only in modified beam_search 25 // used only in modified beam_search
23 Hypotheses hyps; 26 Hypotheses hyps;
  27 +
  28 + OnlineTransducerDecoderResult()
  29 + : tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {}
  30 +
  31 + OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other);
  32 +
  33 + OnlineTransducerDecoderResult &operator=(
  34 + const OnlineTransducerDecoderResult &other);
  35 +
  36 + OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other);
  37 +
  38 + OnlineTransducerDecoderResult &operator=(
  39 + OnlineTransducerDecoderResult &&other);
24 }; 40 };
25 41
26 class OnlineTransducerDecoder { 42 class OnlineTransducerDecoder {
@@ -53,6 +69,9 @@ class OnlineTransducerDecoder { @@ -53,6 +69,9 @@ class OnlineTransducerDecoder {
53 */ 69 */
54 virtual void Decode(Ort::Value encoder_out, 70 virtual void Decode(Ort::Value encoder_out,
55 std::vector<OnlineTransducerDecoderResult> *result) = 0; 71 std::vector<OnlineTransducerDecoderResult> *result) = 0;
  72 +
  73 + // used for endpointing. We need to keep decoder_out after reset
  74 + virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
56 }; 75 };
57 76
58 } // namespace sherpa_onnx 77 } // namespace sherpa_onnx
@@ -13,6 +13,43 @@ @@ -13,6 +13,43 @@
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
  16 +static void UseCachedDecoderOut(
  17 + const std::vector<OnlineTransducerDecoderResult> &results,
  18 + Ort::Value *decoder_out) {
  19 + std::vector<int64_t> shape =
  20 + decoder_out->GetTensorTypeAndShapeInfo().GetShape();
  21 + float *dst = decoder_out->GetTensorMutableData<float>();
  22 + for (const auto &r : results) {
  23 + if (r.decoder_out) {
  24 + const float *src = r.decoder_out.GetTensorData<float>();
  25 + std::copy(src, src + shape[1], dst);
  26 + }
  27 + dst += shape[1];
  28 + }
  29 +}
  30 +
  31 +static void UpdateCachedDecoderOut(
  32 + OrtAllocator *allocator, const Ort::Value *decoder_out,
  33 + std::vector<OnlineTransducerDecoderResult> *results) {
  34 + std::vector<int64_t> shape =
  35 + decoder_out->GetTensorTypeAndShapeInfo().GetShape();
  36 + auto memory_info =
  37 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  38 + std::array<int64_t, 2> v_shape{1, shape[1]};
  39 +
  40 + const float *src = decoder_out->GetTensorData<float>();
  41 + for (auto &r : *results) {
  42 + if (!r.decoder_out) {
  43 + r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
  44 + v_shape.size());
  45 + }
  46 +
  47 + float *dst = r.decoder_out.GetTensorMutableData<float>();
  48 + std::copy(src, src + shape[1], dst);
  49 + src += shape[1];
  50 + }
  51 +}
  52 +
16 OnlineTransducerDecoderResult 53 OnlineTransducerDecoderResult
17 OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { 54 OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
18 int32_t context_size = model_->ContextSize(); 55 int32_t context_size = model_->ContextSize();
@@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
53 90
54 Ort::Value decoder_input = model_->BuildDecoderInput(*result); 91 Ort::Value decoder_input = model_->BuildDecoderInput(*result);
55 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); 92 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
  93 + UseCachedDecoderOut(*result, &decoder_out);
56 94
57 for (int32_t t = 0; t != num_frames; ++t) { 95 for (int32_t t = 0; t != num_frames; ++t) {
58 Ort::Value cur_encoder_out = 96 Ort::Value cur_encoder_out =
@@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode(
77 } 115 }
78 } 116 }
79 if (emitted) { 117 if (emitted) {
80 - decoder_input = model_->BuildDecoderInput(*result); 118 + Ort::Value decoder_input = model_->BuildDecoderInput(*result);
81 decoder_out = model_->RunDecoder(std::move(decoder_input)); 119 decoder_out = model_->RunDecoder(std::move(decoder_input));
82 } 120 }
83 } 121 }
  122 +
  123 + UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
84 } 124 }
85 125
86 } // namespace sherpa_onnx 126 } // namespace sherpa_onnx
@@ -13,6 +13,29 @@ @@ -13,6 +13,29 @@
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
  16 +static void UseCachedDecoderOut(
  17 + const std::vector<int32_t> &hyps_num_split,
  18 + const std::vector<OnlineTransducerDecoderResult> &results,
  19 + int32_t context_size, Ort::Value *decoder_out) {
  20 + std::vector<int64_t> shape =
  21 + decoder_out->GetTensorTypeAndShapeInfo().GetShape();
  22 +
  23 + float *dst = decoder_out->GetTensorMutableData<float>();
  24 +
  25 + int32_t batch_size = static_cast<int32_t>(results.size());
  26 + for (int32_t i = 0; i != batch_size; ++i) {
  27 + int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i];
  28 + if (num_hyps > 1 || !results[i].decoder_out) {
  29 + dst += num_hyps * shape[1];
  30 + continue;
  31 + }
  32 +
  33 + const float *src = results[i].decoder_out.GetTensorData<float>();
  34 + std::copy(src, src + shape[1], dst);
  35 + dst += shape[1];
  36 + }
  37 +}
  38 +
16 static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, 39 static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
17 const std::vector<int32_t> &hyps_num_split) { 40 const std::vector<int32_t> &hyps_num_split) {
18 std::vector<int64_t> cur_encoder_out_shape = 41 std::vector<int64_t> cur_encoder_out_shape =
@@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { @@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
50 int32_t context_size = model_->ContextSize(); 73 int32_t context_size = model_->ContextSize();
51 int32_t blank_id = 0; // always 0 74 int32_t blank_id = 0; // always 0
52 OnlineTransducerDecoderResult r; 75 OnlineTransducerDecoderResult r;
53 - std::vector<int32_t> blanks(context_size, blank_id); 76 + std::vector<int64_t> blanks(context_size, blank_id);
54 Hypotheses blank_hyp({{blanks, 0}}); 77 Hypotheses blank_hyp({{blanks, 0}});
55 r.hyps = std::move(blank_hyp); 78 r.hyps = std::move(blank_hyp);
56 return r; 79 return r;
@@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
110 133
111 Ort::Value decoder_input = model_->BuildDecoderInput(prev); 134 Ort::Value decoder_input = model_->BuildDecoderInput(prev);
112 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); 135 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
  136 + if (t == 0) {
  137 + UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(),
  138 + &decoder_out);
  139 + }
113 140
114 Ort::Value cur_encoder_out = 141 Ort::Value cur_encoder_out =
115 GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); 142 GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
@@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
147 } 174 }
148 175
149 for (int32_t b = 0; b != batch_size; ++b) { 176 for (int32_t b = 0; b != batch_size; ++b) {
150 - (*result)[b].hyps = std::move(cur[b]); 177 + auto &hyps = cur[b];
  178 + auto best_hyp = hyps.GetMostProbable(true);
  179 +
  180 + (*result)[b].hyps = std::move(hyps);
  181 + (*result)[b].tokens = std::move(best_hyp.ys);
  182 + (*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
  183 + }
  184 +}
  185 +
  186 +void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
  187 + OnlineTransducerDecoderResult *result) {
  188 + if (result->tokens.size() == model_->ContextSize()) {
  189 + result->decoder_out = Ort::Value{nullptr};
  190 + return;
151 } 191 }
  192 + Ort::Value decoder_input = model_->BuildDecoderInput({*result});
  193 + result->decoder_out = model_->RunDecoder(std::move(decoder_input));
152 } 194 }
153 195
154 } // namespace sherpa_onnx 196 } // namespace sherpa_onnx
@@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder
27 void Decode(Ort::Value encoder_out, 27 void Decode(Ort::Value encoder_out,
28 std::vector<OnlineTransducerDecoderResult> *result) override; 28 std::vector<OnlineTransducerDecoderResult> *result) override;
29 29
  30 + void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
  31 +
30 private: 32 private:
31 OnlineTransducerModel *model_; // Not owned 33 OnlineTransducerModel *model_; // Not owned
32 int32_t max_active_paths_; 34 int32_t max_active_paths_;
@@ -21,7 +21,7 @@ static void Handler(int sig) { @@ -21,7 +21,7 @@ static void Handler(int sig) {
21 } 21 }
22 22
23 int main(int32_t argc, char *argv[]) { 23 int main(int32_t argc, char *argv[]) {
24 - if (argc < 6 || argc > 7) { 24 + if (argc < 6 || argc > 8) {
25 const char *usage = R"usage( 25 const char *usage = R"usage(
26 Usage: 26 Usage:
27 ./bin/sherpa-onnx-alsa \ 27 ./bin/sherpa-onnx-alsa \
@@ -30,7 +30,10 @@ Usage: @@ -30,7 +30,10 @@ Usage:
30 /path/to/decoder.onnx \ 30 /path/to/decoder.onnx \
31 /path/to/joiner.onnx \ 31 /path/to/joiner.onnx \
32 device_name \ 32 device_name \
33 - [num_threads] 33 + [num_threads [decoding_method]]
  34 +
  35 +Default value for num_threads is 2.
  36 +Valid values for decoding_method: greedy_search (default), modified_beam_search.
34 37
35 Please refer to 38 Please refer to
36 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 39 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -79,6 +82,11 @@ as the device_name. @@ -79,6 +82,11 @@ as the device_name.
79 config.model_config.num_threads = atoi(argv[6]); 82 config.model_config.num_threads = atoi(argv[6]);
80 } 83 }
81 84
  85 + if (argc == 8) {
  86 + config.decoding_method = argv[7];
  87 + }
  88 + config.max_active_paths = 4;
  89 +
82 config.enable_endpoint = true; 90 config.enable_endpoint = true;
83 91
84 config.endpoint_config.rule1.min_trailing_silence = 2.4; 92 config.endpoint_config.rule1.min_trailing_silence = 2.4;
@@ -36,7 +36,7 @@ static void Handler(int32_t sig) { @@ -36,7 +36,7 @@ static void Handler(int32_t sig) {
36 } 36 }
37 37
38 int32_t main(int32_t argc, char *argv[]) { 38 int32_t main(int32_t argc, char *argv[]) {
39 - if (argc < 5 || argc > 6) { 39 + if (argc < 5 || argc > 7) {
40 const char *usage = R"usage( 40 const char *usage = R"usage(
41 Usage: 41 Usage:
42 ./bin/sherpa-onnx-microphone \ 42 ./bin/sherpa-onnx-microphone \
@@ -44,7 +44,10 @@ Usage: @@ -44,7 +44,10 @@ Usage:
44 /path/to/encoder.onnx\ 44 /path/to/encoder.onnx\
45 /path/to/decoder.onnx\ 45 /path/to/decoder.onnx\
46 /path/to/joiner.onnx\ 46 /path/to/joiner.onnx\
47 - [num_threads] 47 + [num_threads [decoding_method]]
  48 +
  49 +Default value for num_threads is 2.
  50 +Valid values for decoding_method: greedy_search (default), modified_beam_search.
48 51
49 Please refer to 52 Please refer to
50 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 53 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -70,6 +73,11 @@ for a list of pre-trained models to download. @@ -70,6 +73,11 @@ for a list of pre-trained models to download.
70 config.model_config.num_threads = atoi(argv[5]); 73 config.model_config.num_threads = atoi(argv[5]);
71 } 74 }
72 75
  76 + if (argc == 7) {
  77 + config.decoding_method = argv[6];
  78 + }
  79 + config.max_active_paths = 4;
  80 +
73 config.enable_endpoint = true; 81 config.enable_endpoint = true;
74 82
75 config.endpoint_config.rule1.min_trailing_silence = 2.4; 83 config.endpoint_config.rule1.min_trailing_silence = 2.4;
@@ -14,7 +14,7 @@ @@ -14,7 +14,7 @@
14 #include "sherpa-onnx/csrc/wave-reader.h" 14 #include "sherpa-onnx/csrc/wave-reader.h"
15 15
16 int main(int32_t argc, char *argv[]) { 16 int main(int32_t argc, char *argv[]) {
17 - if (argc < 6 || argc > 7) { 17 + if (argc < 6 || argc > 8) {
18 const char *usage = R"usage( 18 const char *usage = R"usage(
19 Usage: 19 Usage:
20 ./bin/sherpa-onnx \ 20 ./bin/sherpa-onnx \
@@ -22,7 +22,10 @@ Usage: @@ -22,7 +22,10 @@ Usage:
22 /path/to/encoder.onnx \ 22 /path/to/encoder.onnx \
23 /path/to/decoder.onnx \ 23 /path/to/decoder.onnx \
24 /path/to/joiner.onnx \ 24 /path/to/joiner.onnx \
25 - /path/to/foo.wav [num_threads] 25 + /path/to/foo.wav [num_threads [decoding_method]]
  26 +
  27 +Default value for num_threads is 2.
  28 +Valid values for decoding_method: greedy_search (default), modified_beam_search.
26 29
27 Please refer to 30 Please refer to
28 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 31 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -45,9 +48,15 @@ for a list of pre-trained models to download. @@ -45,9 +48,15 @@ for a list of pre-trained models to download.
45 std::string wav_filename = argv[5]; 48 std::string wav_filename = argv[5];
46 49
47 config.model_config.num_threads = 2; 50 config.model_config.num_threads = 2;
48 - if (argc == 7) { 51 + if (argc == 7 && atoi(argv[6]) > 0) {
49 config.model_config.num_threads = atoi(argv[6]); 52 config.model_config.num_threads = atoi(argv[6]);
50 } 53 }
  54 +
  55 + if (argc == 8) {
  56 + config.decoding_method = argv[7];
  57 + }
  58 + config.max_active_paths = 4;
  59 +
51 fprintf(stderr, "%s\n", config.ToString().c_str()); 60 fprintf(stderr, "%s\n", config.ToString().c_str());
52 61
53 sherpa_onnx::OnlineRecognizer recognizer(config); 62 sherpa_onnx::OnlineRecognizer recognizer(config);
@@ -98,6 +107,7 @@ for a list of pre-trained models to download. @@ -98,6 +107,7 @@ for a list of pre-trained models to download.
98 1000.; 107 1000.;
99 108
100 fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); 109 fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
  110 + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
101 111
102 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); 112 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
103 float rtf = elapsed_seconds / duration; 113 float rtf = elapsed_seconds / duration;
1 include_directories(${CMAKE_SOURCE_DIR}) 1 include_directories(${CMAKE_SOURCE_DIR})
2 2
3 pybind11_add_module(_sherpa_onnx 3 pybind11_add_module(_sherpa_onnx
  4 + display.cc
  5 + endpoint.cc
4 features.cc 6 features.cc
  7 + online-recognizer.cc
  8 + online-stream.cc
5 online-transducer-model-config.cc 9 online-transducer-model-config.cc
6 sherpa-onnx.cc 10 sherpa-onnx.cc
7 - endpoint.cc  
8 - online-stream.cc  
9 - online-recognizer.cc  
10 ) 11 )
11 12
12 if(APPLE) 13 if(APPLE)
  1 +// sherpa-onnx/python/csrc/display.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/display.h"
  6 +
  7 +#include "sherpa-onnx/csrc/display.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +void PybindDisplay(py::module *m) {
  12 + using PyClass = Display;
  13 + py::class_<PyClass>(*m, "Display")
  14 + .def(py::init<int32_t>(), py::arg("max_word_per_line") = 60)
  15 + .def("print", &PyClass::Print, py::arg("idx"), py::arg("s"));
  16 +}
  17 +
  18 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/display.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindDisplay(py::module *m);
  13 +
  14 +} // namespace sherpa_onnx
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
@@ -11,10 +11,12 @@ namespace sherpa_onnx { @@ -11,10 +11,12 @@ namespace sherpa_onnx {
11 static void PybindFeatureExtractorConfig(py::module *m) { 11 static void PybindFeatureExtractorConfig(py::module *m) {
12 using PyClass = FeatureExtractorConfig; 12 using PyClass = FeatureExtractorConfig;
13 py::class_<PyClass>(*m, "FeatureExtractorConfig") 13 py::class_<PyClass>(*m, "FeatureExtractorConfig")
14 - .def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000,  
15 - py::arg("feature_dim") = 80) 14 + .def(py::init<float, int32_t, int32_t>(),
  15 + py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
  16 + py::arg("max_feature_vectors") = -1)
16 .def_readwrite("sampling_rate", &PyClass::sampling_rate) 17 .def_readwrite("sampling_rate", &PyClass::sampling_rate)
17 .def_readwrite("feature_dim", &PyClass::feature_dim) 18 .def_readwrite("feature_dim", &PyClass::feature_dim)
  19 + .def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
18 .def("__str__", &PyClass::ToString); 20 .def("__str__", &PyClass::ToString);
19 } 21 }
20 22
@@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
22 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 22 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
23 .def(py::init<const FeatureExtractorConfig &, 23 .def(py::init<const FeatureExtractorConfig &,
24 const OnlineTransducerModelConfig &, const EndpointConfig &, 24 const OnlineTransducerModelConfig &, const EndpointConfig &,
25 - bool>(), 25 + bool, const std::string &, int32_t>(),
26 py::arg("feat_config"), py::arg("model_config"), 26 py::arg("feat_config"), py::arg("model_config"),
27 - py::arg("endpoint_config"), py::arg("enable_endpoint")) 27 + py::arg("endpoint_config"), py::arg("enable_endpoint"),
  28 + py::arg("decoding_method"), py::arg("max_active_paths"))
28 .def_readwrite("feat_config", &PyClass::feat_config) 29 .def_readwrite("feat_config", &PyClass::feat_config)
29 .def_readwrite("model_config", &PyClass::model_config) 30 .def_readwrite("model_config", &PyClass::model_config)
30 .def_readwrite("endpoint_config", &PyClass::endpoint_config) 31 .def_readwrite("endpoint_config", &PyClass::endpoint_config)
31 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) 32 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
  33 + .def_readwrite("decoding_method", &PyClass::decoding_method)
  34 + .def_readwrite("max_active_paths", &PyClass::max_active_paths)
32 .def("__str__", &PyClass::ToString); 35 .def("__str__", &PyClass::ToString);
33 } 36 }
34 37
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 4
5 #include "sherpa-onnx/python/csrc/sherpa-onnx.h" 5 #include "sherpa-onnx/python/csrc/sherpa-onnx.h"
6 6
  7 +#include "sherpa-onnx/python/csrc/display.h"
7 #include "sherpa-onnx/python/csrc/endpoint.h" 8 #include "sherpa-onnx/python/csrc/endpoint.h"
8 #include "sherpa-onnx/python/csrc/features.h" 9 #include "sherpa-onnx/python/csrc/features.h"
9 #include "sherpa-onnx/python/csrc/online-recognizer.h" 10 #include "sherpa-onnx/python/csrc/online-recognizer.h"
@@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
19 PybindOnlineStream(&m); 20 PybindOnlineStream(&m);
20 PybindEndpoint(&m); 21 PybindEndpoint(&m);
21 PybindOnlineRecognizer(&m); 22 PybindOnlineRecognizer(&m);
  23 +
  24 + PybindDisplay(&m);
22 } 25 }
23 26
24 } // namespace sherpa_onnx 27 } // namespace sherpa_onnx
1 -from _sherpa_onnx import (  
2 - EndpointConfig,  
3 - FeatureExtractorConfig,  
4 - OnlineRecognizerConfig,  
5 - OnlineStream,  
6 - OnlineTransducerModelConfig,  
7 -) 1 +from _sherpa_onnx import Display
8 2
9 from .online_recognizer import OnlineRecognizer 3 from .online_recognizer import OnlineRecognizer
@@ -32,6 +32,9 @@ class OnlineRecognizer(object): @@ -32,6 +32,9 @@ class OnlineRecognizer(object):
32 rule1_min_trailing_silence: int = 2.4, 32 rule1_min_trailing_silence: int = 2.4,
33 rule2_min_trailing_silence: int = 1.2, 33 rule2_min_trailing_silence: int = 1.2,
34 rule3_min_utterance_length: int = 20, 34 rule3_min_utterance_length: int = 20,
  35 + decoding_method: str = "greedy_search",
  36 + max_active_paths: int = 4,
  37 + max_feature_vectors: int = -1,
35 ): 38 ):
36 """ 39 """
37 Please refer to 40 Please refer to
@@ -74,6 +77,14 @@ class OnlineRecognizer(object): @@ -74,6 +77,14 @@ class OnlineRecognizer(object):
74 Used only when enable_endpoint_detection is True. If the utterance 77 Used only when enable_endpoint_detection is True. If the utterance
75 length in seconds is larger than this value, we assume an endpoint 78 length in seconds is larger than this value, we assume an endpoint
76 is detected. 79 is detected.
  80 + decoding_method:
  81 + Valid values are greedy_search, modified_beam_search.
  82 + max_active_paths:
  83 + Use only when decoding_method is modified_beam_search. It specifies
  84 + the maximum number of active paths during beam search.
  85 + max_feature_vectors:
  86 + Number of feature vectors to cache. -1 means to cache all feature
  87 + frames that have been processed.
77 """ 88 """
78 _assert_file_exists(tokens) 89 _assert_file_exists(tokens)
79 _assert_file_exists(encoder) 90 _assert_file_exists(encoder)
@@ -93,6 +104,7 @@ class OnlineRecognizer(object): @@ -93,6 +104,7 @@ class OnlineRecognizer(object):
93 feat_config = FeatureExtractorConfig( 104 feat_config = FeatureExtractorConfig(
94 sampling_rate=sample_rate, 105 sampling_rate=sample_rate,
95 feature_dim=feature_dim, 106 feature_dim=feature_dim,
  107 + max_feature_vectors=max_feature_vectors,
96 ) 108 )
97 109
98 endpoint_config = EndpointConfig( 110 endpoint_config = EndpointConfig(
@@ -106,6 +118,8 @@ class OnlineRecognizer(object): @@ -106,6 +118,8 @@ class OnlineRecognizer(object):
106 model_config=model_config, 118 model_config=model_config,
107 endpoint_config=endpoint_config, 119 endpoint_config=endpoint_config,
108 enable_endpoint=enable_endpoint_detection, 120 enable_endpoint=enable_endpoint_detection,
  121 + decoding_method=decoding_method,
  122 + max_active_paths=max_active_paths,
109 ) 123 )
110 124
111 self.recognizer = _Recognizer(recognizer_config) 125 self.recognizer = _Recognizer(recognizer_config)