Fangjun Kuang
Committed by GitHub

Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
@@ -36,4 +36,5 @@ tokens.txt @@ -36,4 +36,5 @@ tokens.txt
36 *.onnx 36 *.onnx
37 log.txt 37 log.txt
38 tags 38 tags
  39 +run-decode-file-python.sh
39 android/SherpaOnnx/app/src/main/assets/ 40 android/SherpaOnnx/app/src/main/assets/
@@ -19,14 +19,16 @@ const char *kUsage = @@ -19,14 +19,16 @@ const char *kUsage =
19 " /path/to/encoder.onnx \\\n" 19 " /path/to/encoder.onnx \\\n"
20 " /path/to/decoder.onnx \\\n" 20 " /path/to/decoder.onnx \\\n"
21 " /path/to/joiner.onnx \\\n" 21 " /path/to/joiner.onnx \\\n"
22 - " /path/to/foo.wav [num_threads]\n" 22 + " /path/to/foo.wav [num_threads [decoding_method]]\n"
23 "\n\n" 23 "\n\n"
  24 + "Default num_threads is 1.\n"
  25 + "Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
24 "Please refer to \n" 26 "Please refer to \n"
25 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" 27 "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
26 "for a list of pre-trained models to download.\n"; 28 "for a list of pre-trained models to download.\n";
27 29
28 int32_t main(int32_t argc, char *argv[]) { 30 int32_t main(int32_t argc, char *argv[]) {
29 - if (argc < 6 || argc > 7) { 31 + if (argc < 6 || argc > 8) {
30 fprintf(stderr, "%s\n", kUsage); 32 fprintf(stderr, "%s\n", kUsage);
31 return -1; 33 return -1;
32 } 34 }
@@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) { @@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) {
36 config.model_config.decoder = argv[3]; 38 config.model_config.decoder = argv[3];
37 config.model_config.joiner = argv[4]; 39 config.model_config.joiner = argv[4];
38 40
39 - int32_t num_threads = 4; 41 + int32_t num_threads = 1;
40 if (argc == 7 && atoi(argv[6]) > 0) { 42 if (argc == 7 && atoi(argv[6]) > 0) {
41 num_threads = atoi(argv[6]); 43 num_threads = atoi(argv[6]);
42 } 44 }
43 config.model_config.num_threads = num_threads; 45 config.model_config.num_threads = num_threads;
44 config.model_config.debug = 0; 46 config.model_config.debug = 0;
45 47
  48 + config.decoding_method = "greedy_search";
  49 + if (argc == 8) {
  50 + config.decoding_method = argv[7];
  51 + }
  52 +
  53 + config.max_active_paths = 4;
  54 +
46 config.feat_config.sample_rate = 16000; 55 config.feat_config.sample_rate = 16000;
47 config.feat_config.feature_dim = 80; 56 config.feat_config.feature_dim = 80;
48 57
@@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) { @@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) {
54 SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); 63 SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
55 SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); 64 SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
56 65
  66 + SherpaOnnxDisplay *display = CreateDisplay(50);
  67 + int32_t segment_id = 0;
  68 +
57 const char *wav_filename = argv[5]; 69 const char *wav_filename = argv[5];
58 FILE *fp = fopen(wav_filename, "rb"); 70 FILE *fp = fopen(wav_filename, "rb");
59 if (!fp) { 71 if (!fp) {
@@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) { @@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) {
84 96
85 SherpaOnnxOnlineRecognizerResult *r = 97 SherpaOnnxOnlineRecognizerResult *r =
86 GetOnlineStreamResult(recognizer, stream); 98 GetOnlineStreamResult(recognizer, stream);
  99 +
87 if (strlen(r->text)) { 100 if (strlen(r->text)) {
88 - fprintf(stderr, "%s\n", r->text); 101 + SherpaOnnxPrint(display, segment_id, r->text);
89 } 102 }
  103 +
  104 + if (IsEndpoint(recognizer, stream)) {
  105 + if (strlen(r->text)) {
  106 + ++segment_id;
  107 + }
  108 + Reset(recognizer, stream);
  109 + }
  110 +
90 DestroyOnlineRecognizerResult(r); 111 DestroyOnlineRecognizerResult(r);
91 } 112 }
92 } 113 }
@@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) { @@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) {
103 124
104 SherpaOnnxOnlineRecognizerResult *r = 125 SherpaOnnxOnlineRecognizerResult *r =
105 GetOnlineStreamResult(recognizer, stream); 126 GetOnlineStreamResult(recognizer, stream);
  127 +
106 if (strlen(r->text)) { 128 if (strlen(r->text)) {
107 - fprintf(stderr, "%s\n", r->text); 129 + SherpaOnnxPrint(display, segment_id, r->text);
108 } 130 }
109 131
110 DestroyOnlineRecognizerResult(r); 132 DestroyOnlineRecognizerResult(r);
111 133
  134 + DestroyDisplay(display);
112 DestoryOnlineStream(stream); 135 DestoryOnlineStream(stream);
113 DestroyOnlineRecognizer(recognizer); 136 DestroyOnlineRecognizer(recognizer);
  137 + fprintf(stderr, "\n");
114 138
115 return 0; 139 return 0;
116 } 140 }
@@ -26,12 +26,17 @@ if [ ! -f ./sherpa-onnx-ffmpeg ]; then @@ -26,12 +26,17 @@ if [ ! -f ./sherpa-onnx-ffmpeg ]; then
26 make 26 make
27 fi 27 fi
28 28
29 -../ffmpeg-examples/sherpa-onnx-ffmpeg \  
30 - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \  
31 - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \  
32 - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \  
33 - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \  
34 - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/4.wav 29 +for method in greedy_search modified_beam_search; do
  30 + echo "test method: $method"
  31 + ../ffmpeg-examples/sherpa-onnx-ffmpeg \
  32 + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
  33 + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
  34 + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
  35 + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
  36 + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \
  37 + 2 \
  38 + $method
  39 +done
35 40
36 echo "Decoding a URL" 41 echo "Decoding a URL"
37 42
@@ -7,7 +7,6 @@ @@ -7,7 +7,6 @@
7 7
8 #include "sherpa-onnx/c-api/c-api.h" 8 #include "sherpa-onnx/c-api/c-api.h"
9 9
10 -  
11 /* 10 /*
12 * Copyright (c) 2010 Nicolas George 11 * Copyright (c) 2010 Nicolas George
13 * Copyright (c) 2011 Stefano Sabatini 12 * Copyright (c) 2011 Stefano Sabatini
@@ -43,14 +42,15 @@ @@ -43,14 +42,15 @@
43 #include <unistd.h> 42 #include <unistd.h>
44 extern "C" { 43 extern "C" {
45 #include <libavcodec/avcodec.h> 44 #include <libavcodec/avcodec.h>
46 -#include <libavformat/avformat.h>  
47 #include <libavfilter/buffersink.h> 45 #include <libavfilter/buffersink.h>
48 #include <libavfilter/buffersrc.h> 46 #include <libavfilter/buffersrc.h>
  47 +#include <libavformat/avformat.h>
49 #include <libavutil/channel_layout.h> 48 #include <libavutil/channel_layout.h>
50 #include <libavutil/opt.h> 49 #include <libavutil/opt.h>
51 } 50 }
52 51
53 -static const char *filter_descr = "aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono"; 52 +static const char *filter_descr =
  53 + "aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono";
54 54
55 static AVFormatContext *fmt_ctx; 55 static AVFormatContext *fmt_ctx;
56 static AVCodecContext *dec_ctx; 56 static AVCodecContext *dec_ctx;
@@ -59,308 +59,172 @@ AVFilterContext *buffersrc_ctx; @@ -59,308 +59,172 @@ AVFilterContext *buffersrc_ctx;
59 AVFilterGraph *filter_graph; 59 AVFilterGraph *filter_graph;
60 static int audio_stream_index = -1; 60 static int audio_stream_index = -1;
61 61
62 -static int open_input_file(const char *filename)  
63 -{  
64 - const AVCodec *dec;  
65 - int ret; 62 +static int open_input_file(const char *filename) {
  63 + const AVCodec *dec;
  64 + int ret;
66 65
67 - if ((ret = avformat_open_input(&fmt_ctx, filename, NULL, NULL)) < 0) {  
68 - av_log(NULL, AV_LOG_ERROR, "Cannot open input file %s\n", filename);  
69 - return ret;  
70 - } 66 + if ((ret = avformat_open_input(&fmt_ctx, filename, NULL, NULL)) < 0) {
  67 + av_log(NULL, AV_LOG_ERROR, "Cannot open input file %s\n", filename);
  68 + return ret;
  69 + }
71 70
72 - if ((ret = avformat_find_stream_info(fmt_ctx, NULL)) < 0) {  
73 - av_log(NULL, AV_LOG_ERROR, "Cannot find stream information\n");  
74 - return ret;  
75 - } 71 + if ((ret = avformat_find_stream_info(fmt_ctx, NULL)) < 0) {
  72 + av_log(NULL, AV_LOG_ERROR, "Cannot find stream information\n");
  73 + return ret;
  74 + }
76 75
77 - /* select the audio stream */  
78 - ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0);  
79 - if (ret < 0) {  
80 - av_log(NULL, AV_LOG_ERROR, "Cannot find an audio stream in the input file\n");  
81 - return ret;  
82 - }  
83 - audio_stream_index = ret;  
84 -  
85 - /* create decoding context */  
86 - dec_ctx = avcodec_alloc_context3(dec);  
87 - if (!dec_ctx)  
88 - return AVERROR(ENOMEM);  
89 - avcodec_parameters_to_context(dec_ctx, fmt_ctx->streams[audio_stream_index]->codecpar);  
90 -  
91 - /* init the audio decoder */  
92 - if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) {  
93 - av_log(NULL, AV_LOG_ERROR, "Cannot open audio decoder\n");  
94 - return ret;  
95 - } 76 + /* select the audio stream */
  77 + ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0);
  78 + if (ret < 0) {
  79 + av_log(NULL, AV_LOG_ERROR,
  80 + "Cannot find an audio stream in the input file\n");
  81 + return ret;
  82 + }
  83 + audio_stream_index = ret;
  84 +
  85 + /* create decoding context */
  86 + dec_ctx = avcodec_alloc_context3(dec);
  87 + if (!dec_ctx) return AVERROR(ENOMEM);
  88 + avcodec_parameters_to_context(dec_ctx,
  89 + fmt_ctx->streams[audio_stream_index]->codecpar);
  90 +
  91 + /* init the audio decoder */
  92 + if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) {
  93 + av_log(NULL, AV_LOG_ERROR, "Cannot open audio decoder\n");
  94 + return ret;
  95 + }
96 96
97 - return 0; 97 + return 0;
98 } 98 }
99 99
100 -static int init_filters(const char *filters_descr)  
101 -{  
102 - char args[512];  
103 - int ret = 0;  
104 - const AVFilter *abuffersrc = avfilter_get_by_name("abuffer");  
105 - const AVFilter *abuffersink = avfilter_get_by_name("abuffersink");  
106 - AVFilterInOut *outputs = avfilter_inout_alloc();  
107 - AVFilterInOut *inputs = avfilter_inout_alloc();  
108 - static const enum AVSampleFormat out_sample_fmts[] = { AV_SAMPLE_FMT_S16, AV_SAMPLE_FMT_NONE };  
109 - static const int out_sample_rates[] = { 16000, -1 };  
110 - const AVFilterLink *outlink;  
111 - AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base;  
112 -  
113 - filter_graph = avfilter_graph_alloc();  
114 - if (!outputs || !inputs || !filter_graph) {  
115 - ret = AVERROR(ENOMEM);  
116 - goto end;  
117 - }  
118 -  
119 - /* buffer audio source: the decoded frames from the decoder will be inserted here. */  
120 - if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC)  
121 - av_channel_layout_default(&dec_ctx->ch_layout, dec_ctx->ch_layout.nb_channels);  
122 - ret = snprintf(args, sizeof(args),  
123 - "time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=",  
124 - time_base.num, time_base.den, dec_ctx->sample_rate,  
125 - av_get_sample_fmt_name(dec_ctx->sample_fmt));  
126 - av_channel_layout_describe(&dec_ctx->ch_layout, args + ret, sizeof(args) - ret);  
127 - ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in",  
128 - args, NULL, filter_graph);  
129 - if (ret < 0) {  
130 - av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n");  
131 - goto end;  
132 - }  
133 -  
134 - /* buffer audio sink: to terminate the filter chain. */  
135 - ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out",  
136 - NULL, NULL, filter_graph);  
137 - if (ret < 0) {  
138 - av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n");  
139 - goto end;  
140 - }  
141 -  
142 - ret = av_opt_set_int_list(buffersink_ctx, "sample_fmts", out_sample_fmts, -1,  
143 - AV_OPT_SEARCH_CHILDREN);  
144 - if (ret < 0) {  
145 - av_log(NULL, AV_LOG_ERROR, "Cannot set output sample format\n");  
146 - goto end;  
147 - }  
148 -  
149 - ret = av_opt_set(buffersink_ctx, "ch_layouts", "mono",  
150 - AV_OPT_SEARCH_CHILDREN);  
151 - if (ret < 0) {  
152 - av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n");  
153 - goto end;  
154 - }  
155 -  
156 - ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates, -1,  
157 - AV_OPT_SEARCH_CHILDREN);  
158 - if (ret < 0) {  
159 - av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n");  
160 - goto end;  
161 - }  
162 -  
163 - /*  
164 - * Set the endpoints for the filter graph. The filter_graph will  
165 - * be linked to the graph described by filters_descr.  
166 - */  
167 -  
168 - /*  
169 - * The buffer source output must be connected to the input pad of  
170 - * the first filter described by filters_descr; since the first  
171 - * filter input label is not specified, it is set to "in" by  
172 - * default.  
173 - */  
174 - outputs->name = av_strdup("in");  
175 - outputs->filter_ctx = buffersrc_ctx;  
176 - outputs->pad_idx = 0;  
177 - outputs->next = NULL;  
178 -  
179 - /*  
180 - * The buffer sink input must be connected to the output pad of  
181 - * the last filter described by filters_descr; since the last  
182 - * filter output label is not specified, it is set to "out" by  
183 - * default.  
184 - */  
185 - inputs->name = av_strdup("out");  
186 - inputs->filter_ctx = buffersink_ctx;  
187 - inputs->pad_idx = 0;  
188 - inputs->next = NULL;  
189 -  
190 - if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr,  
191 - &inputs, &outputs, NULL)) < 0)  
192 - goto end;  
193 -  
194 - if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0)  
195 - goto end;  
196 -  
197 - /* Print summary of the sink buffer  
198 - * Note: args buffer is reused to store channel layout string */  
199 - outlink = buffersink_ctx->inputs[0];  
200 - av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args));  
201 - av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n",  
202 - (int)outlink->sample_rate,  
203 - (char *)av_x_if_null(av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"),  
204 - args); 100 +static int init_filters(const char *filters_descr) {
  101 + char args[512];
  102 + int ret = 0;
  103 + const AVFilter *abuffersrc = avfilter_get_by_name("abuffer");
  104 + const AVFilter *abuffersink = avfilter_get_by_name("abuffersink");
  105 + AVFilterInOut *outputs = avfilter_inout_alloc();
  106 + AVFilterInOut *inputs = avfilter_inout_alloc();
  107 + static const enum AVSampleFormat out_sample_fmts[] = {AV_SAMPLE_FMT_S16,
  108 + AV_SAMPLE_FMT_NONE};
  109 + static const int out_sample_rates[] = {16000, -1};
  110 + const AVFilterLink *outlink;
  111 + AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base;
  112 +
  113 + filter_graph = avfilter_graph_alloc();
  114 + if (!outputs || !inputs || !filter_graph) {
  115 + ret = AVERROR(ENOMEM);
  116 + goto end;
  117 + }
  118 +
  119 + /* buffer audio source: the decoded frames from the decoder will be inserted
  120 + * here. */
  121 + if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC)
  122 + av_channel_layout_default(&dec_ctx->ch_layout,
  123 + dec_ctx->ch_layout.nb_channels);
  124 + ret = snprintf(args, sizeof(args),
  125 + "time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=",
  126 + time_base.num, time_base.den, dec_ctx->sample_rate,
  127 + av_get_sample_fmt_name(dec_ctx->sample_fmt));
  128 + av_channel_layout_describe(&dec_ctx->ch_layout, args + ret,
  129 + sizeof(args) - ret);
  130 + ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in", args,
  131 + NULL, filter_graph);
  132 + if (ret < 0) {
  133 + av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n");
  134 + goto end;
  135 + }
  136 +
  137 + /* buffer audio sink: to terminate the filter chain. */
  138 + ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out", NULL,
  139 + NULL, filter_graph);
  140 + if (ret < 0) {
  141 + av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n");
  142 + goto end;
  143 + }
  144 +
  145 + ret = av_opt_set_int_list(buffersink_ctx, "sample_fmts", out_sample_fmts, -1,
  146 + AV_OPT_SEARCH_CHILDREN);
  147 + if (ret < 0) {
  148 + av_log(NULL, AV_LOG_ERROR, "Cannot set output sample format\n");
  149 + goto end;
  150 + }
  151 +
  152 + ret =
  153 + av_opt_set(buffersink_ctx, "ch_layouts", "mono", AV_OPT_SEARCH_CHILDREN);
  154 + if (ret < 0) {
  155 + av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n");
  156 + goto end;
  157 + }
  158 +
  159 + ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates,
  160 + -1, AV_OPT_SEARCH_CHILDREN);
  161 + if (ret < 0) {
  162 + av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n");
  163 + goto end;
  164 + }
  165 +
  166 + /*
  167 + * Set the endpoints for the filter graph. The filter_graph will
  168 + * be linked to the graph described by filters_descr.
  169 + */
  170 +
  171 + /*
  172 + * The buffer source output must be connected to the input pad of
  173 + * the first filter described by filters_descr; since the first
  174 + * filter input label is not specified, it is set to "in" by
  175 + * default.
  176 + */
  177 + outputs->name = av_strdup("in");
  178 + outputs->filter_ctx = buffersrc_ctx;
  179 + outputs->pad_idx = 0;
  180 + outputs->next = NULL;
  181 +
  182 + /*
  183 + * The buffer sink input must be connected to the output pad of
  184 + * the last filter described by filters_descr; since the last
  185 + * filter output label is not specified, it is set to "out" by
  186 + * default.
  187 + */
  188 + inputs->name = av_strdup("out");
  189 + inputs->filter_ctx = buffersink_ctx;
  190 + inputs->pad_idx = 0;
  191 + inputs->next = NULL;
  192 +
  193 + if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr, &inputs,
  194 + &outputs, NULL)) < 0)
  195 + goto end;
  196 +
  197 + if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0) goto end;
  198 +
  199 + /* Print summary of the sink buffer
  200 + * Note: args buffer is reused to store channel layout string */
  201 + outlink = buffersink_ctx->inputs[0];
  202 + av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args));
  203 + av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n",
  204 + (int)outlink->sample_rate,
  205 + (char *)av_x_if_null(
  206 + av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"),
  207 + args);
205 208
206 end: 209 end:
207 - avfilter_inout_free(&inputs);  
208 - avfilter_inout_free(&outputs); 210 + avfilter_inout_free(&inputs);
  211 + avfilter_inout_free(&outputs);
209 212
210 - return ret; 213 + return ret;
211 } 214 }
212 215
213 -static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer *recognizer,  
214 - SherpaOnnxOnlineStream* stream)  
215 -{ 216 +static void sherpa_decode_frame(const AVFrame *frame,
  217 + SherpaOnnxOnlineRecognizer *recognizer,
  218 + SherpaOnnxOnlineStream *stream,
  219 + SherpaOnnxDisplay *display,
  220 + int32_t *segment_id) {
216 #define N 3200 // 100s. Sample rate is fixed to 16 kHz 221 #define N 3200 // 100s. Sample rate is fixed to 16 kHz
217 - static float samples[N];  
218 - static int nb_samples = 0;  
219 - const int16_t *p = (int16_t*)frame->data[0];  
220 -  
221 - if (frame->nb_samples + nb_samples > N) {  
222 - AcceptWaveform(stream, 16000, samples, nb_samples);  
223 - while (IsOnlineStreamReady(recognizer, stream)) {  
224 - DecodeOnlineStream(recognizer, stream);  
225 - }  
226 -  
227 -  
228 - if (IsEndpoint(recognizer, stream)) {  
229 - SherpaOnnxOnlineRecognizerResult *r =  
230 - GetOnlineStreamResult(recognizer, stream);  
231 - if (strlen(r->text)) {  
232 - fprintf(stderr, "%s\n", r->text);  
233 - }  
234 - DestroyOnlineRecognizerResult(r);  
235 -  
236 - Reset(recognizer, stream);  
237 - }  
238 - nb_samples = 0;  
239 - }  
240 -  
241 - for (int i = 0; i < frame->nb_samples; i++) {  
242 - samples[nb_samples++] = p[i] / 32768.;  
243 - }  
244 -}  
245 -  
246 -static inline char *__av_err2str(int errnum)  
247 -{  
248 - static char str[AV_ERROR_MAX_STRING_SIZE];  
249 - memset(str, 0, sizeof(str));  
250 - return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum);  
251 -}  
252 -  
253 -int main(int argc, char **argv)  
254 -{  
255 - int ret;  
256 - int num_threads = 4;  
257 - AVPacket *packet = av_packet_alloc();  
258 - AVFrame *frame = av_frame_alloc();  
259 - AVFrame *filt_frame = av_frame_alloc();  
260 - const char *kUsage =  
261 - "\n"  
262 - "Usage:\n"  
263 - " ./sherpa-onnx-ffmpeg \\\n"  
264 - " /path/to/tokens.txt \\\n"  
265 - " /path/to/encoder.onnx\\\n"  
266 - " /path/to/decoder.onnx\\\n"  
267 - " /path/to/joiner.onnx\\\n"  
268 - " /path/to/foo.wav [num_threads]"  
269 - "\n\n"  
270 - "Please refer to \n"  
271 - "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"  
272 - "for a list of pre-trained models to download.\n";  
273 -  
274 -  
275 - if (!packet || !frame || !filt_frame) {  
276 - fprintf(stderr, "Could not allocate frame or packet\n");  
277 - exit(1);  
278 - }  
279 -  
280 - if (argc < 6 || argc > 7) {  
281 - fprintf(stderr, "%s\n", kUsage);  
282 - return -1;  
283 - }  
284 -  
285 - SherpaOnnxOnlineRecognizerConfig config;  
286 - config.model_config.tokens = argv[1];  
287 - config.model_config.encoder = argv[2];  
288 - config.model_config.decoder = argv[3];  
289 - config.model_config.joiner = argv[4];  
290 -  
291 - if (argc == 7 && atoi(argv[6]) > 0) {  
292 - num_threads = atoi(argv[6]);  
293 - }  
294 - config.model_config.num_threads = num_threads;  
295 - config.model_config.debug = 0;  
296 -  
297 - config.feat_config.sample_rate = 16000;  
298 - config.feat_config.feature_dim = 80;  
299 -  
300 - config.enable_endpoint = 1;  
301 - config.rule1_min_trailing_silence = 2.4;  
302 - config.rule2_min_trailing_silence = 1.2;  
303 - config.rule3_min_utterance_length = 300;  
304 -  
305 - SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);  
306 - SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);  
307 -  
308 - if ((ret = open_input_file(argv[5])) < 0)  
309 - exit(1);  
310 -  
311 - if ((ret = init_filters(filter_descr)) < 0)  
312 - exit(1);  
313 -  
314 - /* read all packets */  
315 - while (1) {  
316 - if ((ret = av_read_frame(fmt_ctx, packet)) < 0)  
317 - break;  
318 -  
319 - if (packet->stream_index == audio_stream_index) {  
320 - ret = avcodec_send_packet(dec_ctx, packet);  
321 - if (ret < 0) {  
322 - av_log(NULL, AV_LOG_ERROR, "Error while sending a packet to the decoder\n");  
323 - break;  
324 - }  
325 -  
326 - while (ret >= 0) {  
327 - ret = avcodec_receive_frame(dec_ctx, frame);  
328 - if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {  
329 - break;  
330 - } else if (ret < 0) {  
331 - av_log(NULL, AV_LOG_ERROR, "Error while receiving a frame from the decoder\n");  
332 - exit(1);  
333 - }  
334 -  
335 - if (ret >= 0) {  
336 - /* push the audio data from decoded frame into the filtergraph */  
337 - if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame, AV_BUFFERSRC_FLAG_KEEP_REF) < 0) {  
338 - av_log(NULL, AV_LOG_ERROR, "Error while feeding the audio filtergraph\n");  
339 - break;  
340 - }  
341 -  
342 - /* pull filtered audio from the filtergraph */  
343 - while (1) {  
344 - ret = av_buffersink_get_frame(buffersink_ctx, filt_frame);  
345 - if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF)  
346 - break;  
347 - if (ret < 0)  
348 - exit(1);  
349 - sherpa_decode_frame(filt_frame, recognizer, stream);  
350 - av_frame_unref(filt_frame);  
351 - }  
352 - av_frame_unref(frame);  
353 - }  
354 - }  
355 - }  
356 - av_packet_unref(packet);  
357 - }  
358 -  
359 - // add some tail padding  
360 - float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate  
361 - AcceptWaveform(stream, 16000, tail_paddings, 4800);  
362 - InputFinished(stream); 222 + static float samples[N];
  223 + static int nb_samples = 0;
  224 + const int16_t *p = (int16_t *)frame->data[0];
363 225
  226 + if (frame->nb_samples + nb_samples > N) {
  227 + AcceptWaveform(stream, 16000, samples, nb_samples);
364 while (IsOnlineStreamReady(recognizer, stream)) { 228 while (IsOnlineStreamReady(recognizer, stream)) {
365 DecodeOnlineStream(recognizer, stream); 229 DecodeOnlineStream(recognizer, stream);
366 } 230 }
@@ -368,25 +232,180 @@ int main(int argc, char **argv) @@ -368,25 +232,180 @@ int main(int argc, char **argv)
368 SherpaOnnxOnlineRecognizerResult *r = 232 SherpaOnnxOnlineRecognizerResult *r =
369 GetOnlineStreamResult(recognizer, stream); 233 GetOnlineStreamResult(recognizer, stream);
370 if (strlen(r->text)) { 234 if (strlen(r->text)) {
371 - fprintf(stderr, "%s\n", r->text); 235 + SherpaOnnxPrint(display, *segment_id, r->text);
  236 + }
  237 +
  238 + if (IsEndpoint(recognizer, stream)) {
  239 + if (strlen(r->text)) {
  240 + ++*segment_id;
  241 + }
  242 + Reset(recognizer, stream);
372 } 243 }
373 244
374 DestroyOnlineRecognizerResult(r); 245 DestroyOnlineRecognizerResult(r);
  246 + nb_samples = 0;
  247 + }
375 248
376 - DestoryOnlineStream(stream);  
377 - DestroyOnlineRecognizer(recognizer); 249 + for (int i = 0; i < frame->nb_samples; i++) {
  250 + samples[nb_samples++] = p[i] / 32768.;
  251 + }
  252 +}
378 253
379 - avfilter_graph_free(&filter_graph);  
380 - avcodec_free_context(&dec_ctx);  
381 - avformat_close_input(&fmt_ctx);  
382 - av_packet_free(&packet);  
383 - av_frame_free(&frame);  
384 - av_frame_free(&filt_frame); 254 +static inline char *__av_err2str(int errnum) {
  255 + static char str[AV_ERROR_MAX_STRING_SIZE];
  256 + memset(str, 0, sizeof(str));
  257 + return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum);
  258 +}
385 259
386 - if (ret < 0 && ret != AVERROR_EOF) {  
387 - fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret));  
388 - exit(1);  
389 - } 260 +int main(int argc, char **argv) {
  261 + int ret;
  262 + int num_threads = 1;
  263 + AVPacket *packet = av_packet_alloc();
  264 + AVFrame *frame = av_frame_alloc();
  265 + AVFrame *filt_frame = av_frame_alloc();
  266 + const char *kUsage =
  267 + "\n"
  268 + "Usage:\n"
  269 + " ./sherpa-onnx-ffmpeg \\\n"
  270 + " /path/to/tokens.txt \\\n"
  271 + " /path/to/encoder.onnx\\\n"
  272 + " /path/to/decoder.onnx\\\n"
  273 + " /path/to/joiner.onnx\\\n"
  274 + " /path/to/foo.wav [num_threads [decoding_method]]"
  275 + "\n\n"
  276 + "Default num_threads is 1.\n"
  277 + "Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
  278 + "Please refer to \n"
  279 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
  280 + "for a list of pre-trained models to download.\n";
  281 +
  282 + if (!packet || !frame || !filt_frame) {
  283 + fprintf(stderr, "Could not allocate frame or packet\n");
  284 + exit(1);
  285 + }
  286 +
  287 + if (argc < 6 || argc > 8) {
  288 + fprintf(stderr, "%s\n", kUsage);
  289 + return -1;
  290 + }
  291 +
  292 + SherpaOnnxOnlineRecognizerConfig config;
  293 + config.model_config.tokens = argv[1];
  294 + config.model_config.encoder = argv[2];
  295 + config.model_config.decoder = argv[3];
  296 + config.model_config.joiner = argv[4];
  297 +
  298 + if (argc == 7 && atoi(argv[6]) > 0) {
  299 + num_threads = atoi(argv[6]);
  300 + }
  301 +
  302 + config.model_config.num_threads = num_threads;
  303 + config.model_config.debug = 0;
  304 +
  305 + config.feat_config.sample_rate = 16000;
  306 + config.feat_config.feature_dim = 80;
  307 +
  308 + config.decoding_method = "greedy_search";
  309 + if (argc == 8) {
  310 + config.decoding_method = argv[7];
  311 + }
  312 +
  313 + config.max_active_paths = 4;
  314 +
  315 + config.enable_endpoint = 1;
  316 + config.rule1_min_trailing_silence = 2.4;
  317 + config.rule2_min_trailing_silence = 1.2;
  318 + config.rule3_min_utterance_length = 300;
  319 +
  320 + SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
  321 + SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
  322 + SherpaOnnxDisplay *display = CreateDisplay(50);
  323 + int32_t segment_id = 0;
  324 +
  325 + if ((ret = open_input_file(argv[5])) < 0) exit(1);
  326 +
  327 + if ((ret = init_filters(filter_descr)) < 0) exit(1);
  328 +
  329 + /* read all packets */
  330 + while (1) {
  331 + if ((ret = av_read_frame(fmt_ctx, packet)) < 0) break;
  332 +
  333 + if (packet->stream_index == audio_stream_index) {
  334 + ret = avcodec_send_packet(dec_ctx, packet);
  335 + if (ret < 0) {
  336 + av_log(NULL, AV_LOG_ERROR,
  337 + "Error while sending a packet to the decoder\n");
  338 + break;
  339 + }
  340 +
  341 + while (ret >= 0) {
  342 + ret = avcodec_receive_frame(dec_ctx, frame);
  343 + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
  344 + break;
  345 + } else if (ret < 0) {
  346 + av_log(NULL, AV_LOG_ERROR,
  347 + "Error while receiving a frame from the decoder\n");
  348 + exit(1);
  349 + }
390 350
391 - return 0; 351 + if (ret >= 0) {
  352 + /* push the audio data from decoded frame into the filtergraph */
  353 + if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame,
  354 + AV_BUFFERSRC_FLAG_KEEP_REF) < 0) {
  355 + av_log(NULL, AV_LOG_ERROR,
  356 + "Error while feeding the audio filtergraph\n");
  357 + break;
  358 + }
  359 +
  360 + /* pull filtered audio from the filtergraph */
  361 + while (1) {
  362 + ret = av_buffersink_get_frame(buffersink_ctx, filt_frame);
  363 + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) break;
  364 + if (ret < 0) exit(1);
  365 + sherpa_decode_frame(filt_frame, recognizer, stream, display,
  366 + &segment_id);
  367 + av_frame_unref(filt_frame);
  368 + }
  369 + av_frame_unref(frame);
  370 + }
  371 + }
  372 + }
  373 + av_packet_unref(packet);
  374 + }
  375 +
  376 + // add some tail padding
  377 + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate
  378 + AcceptWaveform(stream, 16000, tail_paddings, 4800);
  379 + InputFinished(stream);
  380 +
  381 + while (IsOnlineStreamReady(recognizer, stream)) {
  382 + DecodeOnlineStream(recognizer, stream);
  383 + }
  384 +
  385 + SherpaOnnxOnlineRecognizerResult *r =
  386 + GetOnlineStreamResult(recognizer, stream);
  387 + if (strlen(r->text)) {
  388 + SherpaOnnxPrint(display, segment_id, r->text);
  389 + }
  390 +
  391 + DestroyOnlineRecognizerResult(r);
  392 +
  393 + DestroyDisplay(display);
  394 + DestoryOnlineStream(stream);
  395 + DestroyOnlineRecognizer(recognizer);
  396 +
  397 + avfilter_graph_free(&filter_graph);
  398 + avcodec_free_context(&dec_ctx);
  399 + avformat_close_input(&fmt_ctx);
  400 + av_packet_free(&packet);
  401 + av_frame_free(&frame);
  402 + av_frame_free(&filt_frame);
  403 +
  404 + if (ret < 0 && ret != AVERROR_EOF) {
  405 + fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret));
  406 + exit(1);
  407 + }
  408 + fprintf(stderr, "\n");
  409 +
  410 + return 0;
392 } 411 }
@@ -54,6 +54,20 @@ def get_args(): @@ -54,6 +54,20 @@ def get_args():
54 ) 54 )
55 55
56 parser.add_argument( 56 parser.add_argument(
  57 + "--num-threads",
  58 + type=int,
  59 + default=1,
  60 + help="Number of threads for neural network computation",
  61 + )
  62 +
  63 + parser.add_argument(
  64 + "--decoding-method",
  65 + type=str,
  66 + default="greedy_search",
  67 + help="Valid values are greedy_search and modified_beam_search",
  68 + )
  69 +
  70 + parser.add_argument(
57 "--wave-filename", 71 "--wave-filename",
58 type=str, 72 type=str,
59 help="""Path to the wave filename. Must be 16 kHz, 73 help="""Path to the wave filename. Must be 16 kHz,
@@ -65,7 +79,6 @@ def get_args(): @@ -65,7 +79,6 @@ def get_args():
65 79
66 def main(): 80 def main():
67 sample_rate = 16000 81 sample_rate = 16000
68 - num_threads = 2  
69 82
70 args = get_args() 83 args = get_args()
71 assert_file_exists(args.encoder) 84 assert_file_exists(args.encoder)
@@ -81,9 +94,10 @@ def main(): @@ -81,9 +94,10 @@ def main():
81 encoder=args.encoder, 94 encoder=args.encoder,
82 decoder=args.decoder, 95 decoder=args.decoder,
83 joiner=args.joiner, 96 joiner=args.joiner,
84 - num_threads=num_threads, 97 + num_threads=args.num_threads,
85 sample_rate=sample_rate, 98 sample_rate=sample_rate,
86 feature_dim=80, 99 feature_dim=80,
  100 + decoding_method=args.decoding_method,
87 ) 101 )
88 with wave.open(args.wave_filename) as f: 102 with wave.open(args.wave_filename) as f:
89 assert f.getframerate() == sample_rate, f.getframerate() 103 assert f.getframerate() == sample_rate, f.getframerate()
@@ -119,7 +133,8 @@ def main(): @@ -119,7 +133,8 @@ def main():
119 end_time = time.time() 133 end_time = time.time()
120 elapsed_seconds = end_time - start_time 134 elapsed_seconds = end_time - start_time
121 rtf = elapsed_seconds / duration 135 rtf = elapsed_seconds / duration
122 - print(f"num_threads: {num_threads}") 136 + print(f"num_threads: {args.num_threads}")
  137 + print(f"decoding_method: {args.decoding_method}")
123 print(f"Wave duration: {duration:.3f} s") 138 print(f"Wave duration: {duration:.3f} s")
124 print(f"Elapsed time: {elapsed_seconds:.3f} s") 139 print(f"Elapsed time: {elapsed_seconds:.3f} s")
125 print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") 140 print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
@@ -60,10 +60,10 @@ def get_args(): @@ -60,10 +60,10 @@ def get_args():
60 ) 60 )
61 61
62 parser.add_argument( 62 parser.add_argument(
63 - "--wave-filename", 63 + "--decoding-method",
64 type=str, 64 type=str,
65 - help="""Path to the wave filename. Must be 16 kHz,  
66 - mono with 16-bit samples""", 65 + default="greedy_search",
  66 + help="Valid values are greedy_search and modified_beam_search",
67 ) 67 )
68 68
69 return parser.parse_args() 69 return parser.parse_args()
@@ -83,17 +83,23 @@ def create_recognizer(): @@ -83,17 +83,23 @@ def create_recognizer():
83 encoder=args.encoder, 83 encoder=args.encoder,
84 decoder=args.decoder, 84 decoder=args.decoder,
85 joiner=args.joiner, 85 joiner=args.joiner,
  86 + num_threads=1,
  87 + sample_rate=16000,
  88 + feature_dim=80,
86 enable_endpoint_detection=True, 89 enable_endpoint_detection=True,
87 rule1_min_trailing_silence=2.4, 90 rule1_min_trailing_silence=2.4,
88 rule2_min_trailing_silence=1.2, 91 rule2_min_trailing_silence=1.2,
89 rule3_min_utterance_length=300, # it essentially disables this rule 92 rule3_min_utterance_length=300, # it essentially disables this rule
  93 + decoding_method=args.decoding_method,
  94 + max_feature_vectors=100, # 1 second
90 ) 95 )
91 return recognizer 96 return recognizer
92 97
93 98
94 def main(): 99 def main():
95 - print("Started! Please speak")  
96 recognizer = create_recognizer() 100 recognizer = create_recognizer()
  101 + print("Started! Please speak")
  102 +
97 sample_rate = 16000 103 sample_rate = 16000
98 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms 104 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
99 last_result = "" 105 last_result = ""
@@ -101,6 +107,7 @@ def main(): @@ -101,6 +107,7 @@ def main():
101 107
102 last_result = "" 108 last_result = ""
103 segment_id = 0 109 segment_id = 0
  110 + display = sherpa_onnx.Display(max_word_per_line=30)
104 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: 111 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
105 while True: 112 while True:
106 samples, _ = s.read(samples_per_read) # a blocking read 113 samples, _ = s.read(samples_per_read) # a blocking read
@@ -115,7 +122,7 @@ def main(): @@ -115,7 +122,7 @@ def main():
115 122
116 if result and (last_result != result): 123 if result and (last_result != result):
117 last_result = result 124 last_result = result
118 - print(f"{segment_id}: {result}") 125 + display.print(segment_id, result)
119 126
120 if is_endpoint: 127 if is_endpoint:
121 if result: 128 if result:
@@ -59,10 +59,10 @@ def get_args(): @@ -59,10 +59,10 @@ def get_args():
59 ) 59 )
60 60
61 parser.add_argument( 61 parser.add_argument(
62 - "--wave-filename", 62 + "--decoding-method",
63 type=str, 63 type=str,
64 - help="""Path to the wave filename. Must be 16 kHz,  
65 - mono with 16-bit samples""", 64 + default="greedy_search",
  65 + help="Valid values are greedy_search and modified_beam_search",
66 ) 66 )
67 67
68 return parser.parse_args() 68 return parser.parse_args()
@@ -82,9 +82,11 @@ def create_recognizer(): @@ -82,9 +82,11 @@ def create_recognizer():
82 encoder=args.encoder, 82 encoder=args.encoder,
83 decoder=args.decoder, 83 decoder=args.decoder,
84 joiner=args.joiner, 84 joiner=args.joiner,
85 - num_threads=4, 85 + num_threads=1,
86 sample_rate=16000, 86 sample_rate=16000,
87 feature_dim=80, 87 feature_dim=80,
  88 + decoding_method=args.decoding_method,
  89 + max_feature_vectors=100, # 1 second
88 ) 90 )
89 return recognizer 91 return recognizer
90 92
@@ -96,6 +98,7 @@ def main(): @@ -96,6 +98,7 @@ def main():
96 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms 98 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
97 last_result = "" 99 last_result = ""
98 stream = recognizer.create_stream() 100 stream = recognizer.create_stream()
  101 + display = sherpa_onnx.Display(max_word_per_line=40)
99 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: 102 with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
100 while True: 103 while True:
101 samples, _ = s.read(samples_per_read) # a blocking read 104 samples, _ = s.read(samples_per_read) # a blocking read
@@ -106,7 +109,7 @@ def main(): @@ -106,7 +109,7 @@ def main():
106 result = recognizer.get_result(stream) 109 result = recognizer.get_result(stream)
107 if last_result != result: 110 if last_result != result:
108 last_result = result 111 last_result = result
109 - print(result) 112 + display.print(-1, result)
110 113
111 114
112 if __name__ == "__main__": 115 if __name__ == "__main__":
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <utility> 9 #include <utility>
10 #include <vector> 10 #include <vector>
11 11
  12 +#include "sherpa-onnx/csrc/display.h"
12 #include "sherpa-onnx/csrc/online-recognizer.h" 13 #include "sherpa-onnx/csrc/online-recognizer.h"
13 14
14 struct SherpaOnnxOnlineRecognizer { 15 struct SherpaOnnxOnlineRecognizer {
@@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream { @@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream {
21 : impl(std::move(p)) {} 22 : impl(std::move(p)) {}
22 }; 23 };
23 24
  25 +struct SherpaOnnxDisplay {
  26 + std::unique_ptr<sherpa_onnx::Display> impl;
  27 +};
  28 +
24 SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( 29 SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
25 const SherpaOnnxOnlineRecognizerConfig *config) { 30 const SherpaOnnxOnlineRecognizerConfig *config) {
26 sherpa_onnx::OnlineRecognizerConfig recognizer_config; 31 sherpa_onnx::OnlineRecognizerConfig recognizer_config;
@@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( @@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
37 recognizer_config.model_config.num_threads = config->model_config.num_threads; 42 recognizer_config.model_config.num_threads = config->model_config.num_threads;
38 recognizer_config.model_config.debug = config->model_config.debug; 43 recognizer_config.model_config.debug = config->model_config.debug;
39 44
  45 + recognizer_config.decoding_method = config->decoding_method;
  46 + recognizer_config.max_active_paths = config->max_active_paths;
  47 +
40 recognizer_config.enable_endpoint = config->enable_endpoint; 48 recognizer_config.enable_endpoint = config->enable_endpoint;
41 49
42 recognizer_config.endpoint_config.rule1.min_trailing_silence = 50 recognizer_config.endpoint_config.rule1.min_trailing_silence =
@@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, @@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
124 SherpaOnnxOnlineStream *stream) { 132 SherpaOnnxOnlineStream *stream) {
125 return recognizer->impl->IsEndpoint(stream->impl.get()); 133 return recognizer->impl->IsEndpoint(stream->impl.get());
126 } 134 }
  135 +
  136 +SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) {
  137 + SherpaOnnxDisplay *ans = new SherpaOnnxDisplay;
  138 + ans->impl = std::make_unique<sherpa_onnx::Display>(max_word_per_line);
  139 + return ans;
  140 +}
  141 +
  142 +void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; }
  143 +
  144 +void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) {
  145 + display->impl->Print(idx, s);
  146 +}
@@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig { @@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig {
48 SherpaOnnxFeatureConfig feat_config; 48 SherpaOnnxFeatureConfig feat_config;
49 SherpaOnnxOnlineTransducerModelConfig model_config; 49 SherpaOnnxOnlineTransducerModelConfig model_config;
50 50
  51 + /// Possible values are: greedy_search, modified_beam_search
  52 + const char *decoding_method;
  53 +
  54 + /// Used only when decoding_method is modified_beam_search
  55 + /// Example value: 4
  56 + int32_t max_active_paths;
  57 +
51 /// 0 to disable endpoint detection. 58 /// 0 to disable endpoint detection.
52 /// A non-zero value to enable endpoint detection. 59 /// A non-zero value to enable endpoint detection.
53 int32_t enable_endpoint; 60 int32_t enable_endpoint;
@@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream); @@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream);
187 int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, 194 int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
188 SherpaOnnxOnlineStream *stream); 195 SherpaOnnxOnlineStream *stream);
189 196
  197 +// for displaying results on Linux/macOS.
  198 +typedef struct SherpaOnnxDisplay SherpaOnnxDisplay;
  199 +
  200 +/// Create a display object. Must be freed using DestroyDisplay to avoid
  201 +/// memory leak.
  202 +SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line);
  203 +
  204 +void DestroyDisplay(SherpaOnnxDisplay *display);
  205 +
  206 +/// Print the result.
  207 +void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s);
  208 +
190 #ifdef __cplusplus 209 #ifdef __cplusplus
191 } /* extern "C" */ 210 } /* extern "C" */
192 #endif 211 #endif
@@ -9,10 +9,11 @@ set(sources @@ -9,10 +9,11 @@ set(sources
9 online-lstm-transducer-model.cc 9 online-lstm-transducer-model.cc
10 online-recognizer.cc 10 online-recognizer.cc
11 online-stream.cc 11 online-stream.cc
  12 + online-transducer-decoder.cc
12 online-transducer-greedy-search-decoder.cc 13 online-transducer-greedy-search-decoder.cc
13 online-transducer-model-config.cc 14 online-transducer-model-config.cc
14 - online-transducer-modified-beam-search-decoder.cc  
15 online-transducer-model.cc 15 online-transducer-model.cc
  16 + online-transducer-modified-beam-search-decoder.cc
16 online-zipformer-transducer-model.cc 17 online-zipformer-transducer-model.cc
17 onnx-utils.cc 18 onnx-utils.cc
18 parse-options.cc 19 parse-options.cc
@@ -12,9 +12,16 @@ namespace sherpa_onnx { @@ -12,9 +12,16 @@ namespace sherpa_onnx {
12 12
13 class Display { 13 class Display {
14 public: 14 public:
  15 + explicit Display(int32_t max_word_per_line = 60)
  16 + : max_word_per_line_(max_word_per_line) {}
  17 +
15 void Print(int32_t segment_id, const std::string &s) { 18 void Print(int32_t segment_id, const std::string &s) {
16 #ifdef _MSC_VER 19 #ifdef _MSC_VER
17 - fprintf(stderr, "%d:%s\n", segment_id, s.c_str()); 20 + if (segment_id != -1) {
  21 + fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
  22 + } else {
  23 + fprintf(stderr, "%s\n", s.c_str());
  24 + }
18 return; 25 return;
19 #endif 26 #endif
20 if (last_segment_ == segment_id) { 27 if (last_segment_ == segment_id) {
@@ -27,7 +34,9 @@ class Display { @@ -27,7 +34,9 @@ class Display {
27 num_previous_lines_ = 0; 34 num_previous_lines_ = 0;
28 } 35 }
29 36
30 - fprintf(stderr, "\r%d:", segment_id); 37 + if (segment_id != -1) {
  38 + fprintf(stderr, "\r%d:", segment_id);
  39 + }
31 40
32 int32_t i = 0; 41 int32_t i = 0;
33 for (size_t n = 0; n < s.size();) { 42 for (size_t n = 0; n < s.size();) {
@@ -69,7 +78,7 @@ class Display { @@ -69,7 +78,7 @@ class Display {
69 void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); } 78 void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); }
70 79
71 private: 80 private:
72 - int32_t max_word_per_line_ = 60; 81 + int32_t max_word_per_line_;
73 int32_t num_previous_lines_ = 0; 82 int32_t num_previous_lines_ = 0;
74 int32_t last_segment_ = -1; 83 int32_t last_segment_ = -1;
75 }; 84 };
@@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const { @@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const {
28 28
29 os << "FeatureExtractorConfig("; 29 os << "FeatureExtractorConfig(";
30 os << "sampling_rate=" << sampling_rate << ", "; 30 os << "sampling_rate=" << sampling_rate << ", ";
31 - os << "feature_dim=" << feature_dim << ")"; 31 + os << "feature_dim=" << feature_dim << ", ";
  32 + os << "max_feature_vectors=" << max_feature_vectors << ")";
32 33
33 return os.str(); 34 return os.str();
34 } 35 }
@@ -40,9 +41,7 @@ class FeatureExtractor::Impl { @@ -40,9 +41,7 @@ class FeatureExtractor::Impl {
40 opts_.frame_opts.snip_edges = false; 41 opts_.frame_opts.snip_edges = false;
41 opts_.frame_opts.samp_freq = config.sampling_rate; 42 opts_.frame_opts.samp_freq = config.sampling_rate;
42 43
43 - // cache 100 seconds of feature frames, which is more than enough  
44 - // for real needs  
45 - opts_.frame_opts.max_feature_vectors = 100 * 100; 44 + opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
46 45
47 opts_.mel_opts.num_bins = config.feature_dim; 46 opts_.mel_opts.num_bins = config.feature_dim;
48 47
@@ -16,6 +16,7 @@ namespace sherpa_onnx { @@ -16,6 +16,7 @@ namespace sherpa_onnx {
16 struct FeatureExtractorConfig { 16 struct FeatureExtractorConfig {
17 float sampling_rate = 16000; 17 float sampling_rate = 16000;
18 int32_t feature_dim = 80; 18 int32_t feature_dim = 80;
  19 + int32_t max_feature_vectors = -1;
19 20
20 std::string ToString() const; 21 std::string ToString() const;
21 22
@@ -18,7 +18,7 @@ namespace sherpa_onnx { @@ -18,7 +18,7 @@ namespace sherpa_onnx {
18 18
19 struct Hypothesis { 19 struct Hypothesis {
20 // The predicted tokens so far. Newly predicated tokens are appended. 20 // The predicted tokens so far. Newly predicated tokens are appended.
21 - std::vector<int32_t> ys; 21 + std::vector<int64_t> ys;
22 22
23 // timestamps[i] contains the frame number after subsampling 23 // timestamps[i] contains the frame number after subsampling
24 // on which ys[i] is decoded. 24 // on which ys[i] is decoded.
@@ -30,7 +30,7 @@ struct Hypothesis { @@ -30,7 +30,7 @@ struct Hypothesis {
30 int32_t num_trailing_blanks = 0; 30 int32_t num_trailing_blanks = 0;
31 31
32 Hypothesis() = default; 32 Hypothesis() = default;
33 - Hypothesis(const std::vector<int32_t> &ys, double log_prob) 33 + Hypothesis(const std::vector<int64_t> &ys, double log_prob)
34 : ys(ys), log_prob(log_prob) {} 34 : ys(ys), log_prob(log_prob) {}
35 35
36 // If two Hypotheses have the same `Key`, then they contain 36 // If two Hypotheses have the same `Key`, then they contain
@@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
43 "True to enable endpoint detection. False to disable it."); 43 "True to enable endpoint detection. False to disable it.");
44 po->Register("max-active-paths", &max_active_paths, 44 po->Register("max-active-paths", &max_active_paths,
45 "beam size used in modified beam search."); 45 "beam size used in modified beam search.");
46 - po->Register("decoding-mothod", &decoding_method, 46 + po->Register("decoding-method", &decoding_method,
47 "decoding method," 47 "decoding method,"
48 "now support greedy_search and modified_beam_search."); 48 "now support greedy_search and modified_beam_search.");
49 } 49 }
@@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const {
59 os << "feat_config=" << feat_config.ToString() << ", "; 59 os << "feat_config=" << feat_config.ToString() << ", ";
60 os << "model_config=" << model_config.ToString() << ", "; 60 os << "model_config=" << model_config.ToString() << ", ";
61 os << "endpoint_config=" << endpoint_config.ToString() << ", "; 61 os << "endpoint_config=" << endpoint_config.ToString() << ", ";
62 - os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ",";  
63 - os << "max_active_paths=" << max_active_paths << ","; 62 + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
  63 + os << "max_active_paths=" << max_active_paths << ", ";
64 os << "decoding_method=\"" << decoding_method << "\")"; 64 os << "decoding_method=\"" << decoding_method << "\")";
65 65
66 return os.str(); 66 return os.str();
@@ -187,16 +187,14 @@ class OnlineRecognizer::Impl { @@ -187,16 +187,14 @@ class OnlineRecognizer::Impl {
187 } 187 }
188 188
189 void Reset(OnlineStream *s) const { 189 void Reset(OnlineStream *s) const {
190 - // reset result, neural network model state, and  
191 - // the feature extractor state  
192 -  
193 - // reset result 190 + // we keep the decoder_out
  191 + decoder_->UpdateDecoderOut(&s->GetResult());
  192 + Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
194 s->SetResult(decoder_->GetEmptyResult()); 193 s->SetResult(decoder_->GetEmptyResult());
  194 + s->GetResult().decoder_out = std::move(decoder_out);
195 195
196 - // reset neural network model state  
197 - s->SetStates(model_->GetEncoderInitStates());  
198 -  
199 - // reset feature extractor 196 + // Note: We only update counters. The underlying audio samples
  197 + // are not discarded.
200 s->Reset(); 198 s->Reset();
201 } 199 }
202 200
@@ -33,21 +33,26 @@ struct OnlineRecognizerConfig { @@ -33,21 +33,26 @@ struct OnlineRecognizerConfig {
33 OnlineTransducerModelConfig model_config; 33 OnlineTransducerModelConfig model_config;
34 EndpointConfig endpoint_config; 34 EndpointConfig endpoint_config;
35 bool enable_endpoint = true; 35 bool enable_endpoint = true;
36 - int32_t max_active_paths = 4;  
37 36
38 - std::string decoding_method = "modified_beam_search"; 37 + std::string decoding_method = "greedy_search";
39 // now support modified_beam_search and greedy_search 38 // now support modified_beam_search and greedy_search
40 39
  40 + int32_t max_active_paths = 4; // used only for modified_beam_search
  41 +
41 OnlineRecognizerConfig() = default; 42 OnlineRecognizerConfig() = default;
42 43
43 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, 44 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
44 const OnlineTransducerModelConfig &model_config, 45 const OnlineTransducerModelConfig &model_config,
45 const EndpointConfig &endpoint_config, 46 const EndpointConfig &endpoint_config,
46 - bool enable_endpoint) 47 + bool enable_endpoint,
  48 + const std::string &decoding_method,
  49 + int32_t max_active_paths)
47 : feat_config(feat_config), 50 : feat_config(feat_config),
48 model_config(model_config), 51 model_config(model_config),
49 endpoint_config(endpoint_config), 52 endpoint_config(endpoint_config),
50 - enable_endpoint(enable_endpoint) {} 53 + enable_endpoint(enable_endpoint),
  54 + decoding_method(decoding_method),
  55 + max_active_paths(max_active_paths) {}
51 56
52 void Register(ParseOptions *po); 57 void Register(ParseOptions *po);
53 bool Validate() const; 58 bool Validate() const;
@@ -22,18 +22,21 @@ class OnlineStream::Impl { @@ -22,18 +22,21 @@ class OnlineStream::Impl {
22 22
23 void InputFinished() { feat_extractor_.InputFinished(); } 23 void InputFinished() { feat_extractor_.InputFinished(); }
24 24
25 - int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); } 25 + int32_t NumFramesReady() const {
  26 + return feat_extractor_.NumFramesReady() - start_frame_index_;
  27 + }
26 28
27 bool IsLastFrame(int32_t frame) const { 29 bool IsLastFrame(int32_t frame) const {
28 return feat_extractor_.IsLastFrame(frame); 30 return feat_extractor_.IsLastFrame(frame);
29 } 31 }
30 32
31 std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { 33 std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
32 - return feat_extractor_.GetFrames(frame_index, n); 34 + return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
33 } 35 }
34 36
35 void Reset() { 37 void Reset() {
36 - feat_extractor_.Reset(); 38 + // we don't reset the feature extractor
  39 + start_frame_index_ += num_processed_frames_;
37 num_processed_frames_ = 0; 40 num_processed_frames_ = 0;
38 } 41 }
39 42
@@ -41,7 +44,7 @@ class OnlineStream::Impl { @@ -41,7 +44,7 @@ class OnlineStream::Impl {
41 44
42 void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } 45 void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
43 46
44 - const OnlineTransducerDecoderResult &GetResult() const { return result_; } 47 + OnlineTransducerDecoderResult &GetResult() { return result_; }
45 48
46 int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } 49 int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
47 50
@@ -54,6 +57,7 @@ class OnlineStream::Impl { @@ -54,6 +57,7 @@ class OnlineStream::Impl {
54 private: 57 private:
55 FeatureExtractor feat_extractor_; 58 FeatureExtractor feat_extractor_;
56 int32_t num_processed_frames_ = 0; // before subsampling 59 int32_t num_processed_frames_ = 0; // before subsampling
  60 + int32_t start_frame_index_ = 0; // never reset
57 OnlineTransducerDecoderResult result_; 61 OnlineTransducerDecoderResult result_;
58 std::vector<Ort::Value> states_; 62 std::vector<Ort::Value> states_;
59 }; 63 };
@@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { @@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
93 impl_->SetResult(r); 97 impl_->SetResult(r);
94 } 98 }
95 99
96 -const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { 100 +OnlineTransducerDecoderResult &OnlineStream::GetResult() {
97 return impl_->GetResult(); 101 return impl_->GetResult();
98 } 102 }
99 103
@@ -63,7 +63,7 @@ class OnlineStream { @@ -63,7 +63,7 @@ class OnlineStream {
63 int32_t &GetNumProcessedFrames(); 63 int32_t &GetNumProcessedFrames();
64 64
65 void SetResult(const OnlineTransducerDecoderResult &r); 65 void SetResult(const OnlineTransducerDecoderResult &r);
66 - const OnlineTransducerDecoderResult &GetResult() const; 66 + OnlineTransducerDecoderResult &GetResult();
67 67
68 void SetStates(std::vector<Ort::Value> states); 68 void SetStates(std::vector<Ort::Value> states);
69 std::vector<Ort::Value> &GetStates(); 69 std::vector<Ort::Value> &GetStates();
  1 +// sherpa-onnx/csrc/online-transducer-decoder.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  6 +
  7 +#include <utility>
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +#include "sherpa-onnx/csrc/onnx-utils.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
  16 + const OnlineTransducerDecoderResult &other)
  17 + : OnlineTransducerDecoderResult() {
  18 + *this = other;
  19 +}
  20 +
  21 +OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
  22 + const OnlineTransducerDecoderResult &other) {
  23 + if (this == &other) {
  24 + return *this;
  25 + }
  26 +
  27 + tokens = other.tokens;
  28 + num_trailing_blanks = other.num_trailing_blanks;
  29 +
  30 + Ort::AllocatorWithDefaultOptions allocator;
  31 + if (other.decoder_out) {
  32 + decoder_out = Clone(allocator, &other.decoder_out);
  33 + }
  34 +
  35 + hyps = other.hyps;
  36 +
  37 + return *this;
  38 +}
  39 +
  40 +OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
  41 + OnlineTransducerDecoderResult &&other)
  42 + : OnlineTransducerDecoderResult() {
  43 + *this = std::move(other);
  44 +}
  45 +
  46 +OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
  47 + OnlineTransducerDecoderResult &&other) {
  48 + if (this == &other) {
  49 + return *this;
  50 + }
  51 +
  52 + tokens = std::move(other.tokens);
  53 + num_trailing_blanks = other.num_trailing_blanks;
  54 + decoder_out = std::move(other.decoder_out);
  55 + hyps = std::move(other.hyps);
  56 +
  57 + return *this;
  58 +}
  59 +
  60 +} // namespace sherpa_onnx
@@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult { @@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult {
19 /// number of trailing blank frames decoded so far 19 /// number of trailing blank frames decoded so far
20 int32_t num_trailing_blanks = 0; 20 int32_t num_trailing_blanks = 0;
21 21
  22 + // Cache decoder_out for endpointing
  23 + Ort::Value decoder_out;
  24 +
22 // used only in modified beam_search 25 // used only in modified beam_search
23 Hypotheses hyps; 26 Hypotheses hyps;
  27 +
  28 + OnlineTransducerDecoderResult()
  29 + : tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {}
  30 +
  31 + OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other);
  32 +
  33 + OnlineTransducerDecoderResult &operator=(
  34 + const OnlineTransducerDecoderResult &other);
  35 +
  36 + OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other);
  37 +
  38 + OnlineTransducerDecoderResult &operator=(
  39 + OnlineTransducerDecoderResult &&other);
24 }; 40 };
25 41
26 class OnlineTransducerDecoder { 42 class OnlineTransducerDecoder {
@@ -53,6 +69,9 @@ class OnlineTransducerDecoder { @@ -53,6 +69,9 @@ class OnlineTransducerDecoder {
53 */ 69 */
54 virtual void Decode(Ort::Value encoder_out, 70 virtual void Decode(Ort::Value encoder_out,
55 std::vector<OnlineTransducerDecoderResult> *result) = 0; 71 std::vector<OnlineTransducerDecoderResult> *result) = 0;
  72 +
  73 + // used for endpointing. We need to keep decoder_out after reset
  74 + virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
56 }; 75 };
57 76
58 } // namespace sherpa_onnx 77 } // namespace sherpa_onnx
@@ -13,6 +13,43 @@ @@ -13,6 +13,43 @@
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
  16 +static void UseCachedDecoderOut(
  17 + const std::vector<OnlineTransducerDecoderResult> &results,
  18 + Ort::Value *decoder_out) {
  19 + std::vector<int64_t> shape =
  20 + decoder_out->GetTensorTypeAndShapeInfo().GetShape();
  21 + float *dst = decoder_out->GetTensorMutableData<float>();
  22 + for (const auto &r : results) {
  23 + if (r.decoder_out) {
  24 + const float *src = r.decoder_out.GetTensorData<float>();
  25 + std::copy(src, src + shape[1], dst);
  26 + }
  27 + dst += shape[1];
  28 + }
  29 +}
  30 +
  31 +static void UpdateCachedDecoderOut(
  32 + OrtAllocator *allocator, const Ort::Value *decoder_out,
  33 + std::vector<OnlineTransducerDecoderResult> *results) {
  34 + std::vector<int64_t> shape =
  35 + decoder_out->GetTensorTypeAndShapeInfo().GetShape();
  36 + auto memory_info =
  37 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  38 + std::array<int64_t, 2> v_shape{1, shape[1]};
  39 +
  40 + const float *src = decoder_out->GetTensorData<float>();
  41 + for (auto &r : *results) {
  42 + if (!r.decoder_out) {
  43 + r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
  44 + v_shape.size());
  45 + }
  46 +
  47 + float *dst = r.decoder_out.GetTensorMutableData<float>();
  48 + std::copy(src, src + shape[1], dst);
  49 + src += shape[1];
  50 + }
  51 +}
  52 +
16 OnlineTransducerDecoderResult 53 OnlineTransducerDecoderResult
17 OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { 54 OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
18 int32_t context_size = model_->ContextSize(); 55 int32_t context_size = model_->ContextSize();
@@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
53 90
54 Ort::Value decoder_input = model_->BuildDecoderInput(*result); 91 Ort::Value decoder_input = model_->BuildDecoderInput(*result);
55 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); 92 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
  93 + UseCachedDecoderOut(*result, &decoder_out);
56 94
57 for (int32_t t = 0; t != num_frames; ++t) { 95 for (int32_t t = 0; t != num_frames; ++t) {
58 Ort::Value cur_encoder_out = 96 Ort::Value cur_encoder_out =
@@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode(
77 } 115 }
78 } 116 }
79 if (emitted) { 117 if (emitted) {
80 - decoder_input = model_->BuildDecoderInput(*result); 118 + Ort::Value decoder_input = model_->BuildDecoderInput(*result);
81 decoder_out = model_->RunDecoder(std::move(decoder_input)); 119 decoder_out = model_->RunDecoder(std::move(decoder_input));
82 } 120 }
83 } 121 }
  122 +
  123 + UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
84 } 124 }
85 125
86 } // namespace sherpa_onnx 126 } // namespace sherpa_onnx
@@ -13,6 +13,29 @@ @@ -13,6 +13,29 @@
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
  16 +static void UseCachedDecoderOut(
  17 + const std::vector<int32_t> &hyps_num_split,
  18 + const std::vector<OnlineTransducerDecoderResult> &results,
  19 + int32_t context_size, Ort::Value *decoder_out) {
  20 + std::vector<int64_t> shape =
  21 + decoder_out->GetTensorTypeAndShapeInfo().GetShape();
  22 +
  23 + float *dst = decoder_out->GetTensorMutableData<float>();
  24 +
  25 + int32_t batch_size = static_cast<int32_t>(results.size());
  26 + for (int32_t i = 0; i != batch_size; ++i) {
  27 + int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i];
  28 + if (num_hyps > 1 || !results[i].decoder_out) {
  29 + dst += num_hyps * shape[1];
  30 + continue;
  31 + }
  32 +
  33 + const float *src = results[i].decoder_out.GetTensorData<float>();
  34 + std::copy(src, src + shape[1], dst);
  35 + dst += shape[1];
  36 + }
  37 +}
  38 +
16 static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, 39 static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
17 const std::vector<int32_t> &hyps_num_split) { 40 const std::vector<int32_t> &hyps_num_split) {
18 std::vector<int64_t> cur_encoder_out_shape = 41 std::vector<int64_t> cur_encoder_out_shape =
@@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { @@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
50 int32_t context_size = model_->ContextSize(); 73 int32_t context_size = model_->ContextSize();
51 int32_t blank_id = 0; // always 0 74 int32_t blank_id = 0; // always 0
52 OnlineTransducerDecoderResult r; 75 OnlineTransducerDecoderResult r;
53 - std::vector<int32_t> blanks(context_size, blank_id); 76 + std::vector<int64_t> blanks(context_size, blank_id);
54 Hypotheses blank_hyp({{blanks, 0}}); 77 Hypotheses blank_hyp({{blanks, 0}});
55 r.hyps = std::move(blank_hyp); 78 r.hyps = std::move(blank_hyp);
56 return r; 79 return r;
@@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
110 133
111 Ort::Value decoder_input = model_->BuildDecoderInput(prev); 134 Ort::Value decoder_input = model_->BuildDecoderInput(prev);
112 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); 135 Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
  136 + if (t == 0) {
  137 + UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(),
  138 + &decoder_out);
  139 + }
113 140
114 Ort::Value cur_encoder_out = 141 Ort::Value cur_encoder_out =
115 GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); 142 GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
@@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
147 } 174 }
148 175
149 for (int32_t b = 0; b != batch_size; ++b) { 176 for (int32_t b = 0; b != batch_size; ++b) {
150 - (*result)[b].hyps = std::move(cur[b]); 177 + auto &hyps = cur[b];
  178 + auto best_hyp = hyps.GetMostProbable(true);
  179 +
  180 + (*result)[b].hyps = std::move(hyps);
  181 + (*result)[b].tokens = std::move(best_hyp.ys);
  182 + (*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
  183 + }
  184 +}
  185 +
  186 +void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
  187 + OnlineTransducerDecoderResult *result) {
  188 + if (result->tokens.size() == model_->ContextSize()) {
  189 + result->decoder_out = Ort::Value{nullptr};
  190 + return;
151 } 191 }
  192 + Ort::Value decoder_input = model_->BuildDecoderInput({*result});
  193 + result->decoder_out = model_->RunDecoder(std::move(decoder_input));
152 } 194 }
153 195
154 } // namespace sherpa_onnx 196 } // namespace sherpa_onnx
@@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder
27 void Decode(Ort::Value encoder_out, 27 void Decode(Ort::Value encoder_out,
28 std::vector<OnlineTransducerDecoderResult> *result) override; 28 std::vector<OnlineTransducerDecoderResult> *result) override;
29 29
  30 + void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
  31 +
30 private: 32 private:
31 OnlineTransducerModel *model_; // Not owned 33 OnlineTransducerModel *model_; // Not owned
32 int32_t max_active_paths_; 34 int32_t max_active_paths_;
@@ -21,7 +21,7 @@ static void Handler(int sig) { @@ -21,7 +21,7 @@ static void Handler(int sig) {
21 } 21 }
22 22
23 int main(int32_t argc, char *argv[]) { 23 int main(int32_t argc, char *argv[]) {
24 - if (argc < 6 || argc > 7) { 24 + if (argc < 6 || argc > 8) {
25 const char *usage = R"usage( 25 const char *usage = R"usage(
26 Usage: 26 Usage:
27 ./bin/sherpa-onnx-alsa \ 27 ./bin/sherpa-onnx-alsa \
@@ -30,7 +30,10 @@ Usage: @@ -30,7 +30,10 @@ Usage:
30 /path/to/decoder.onnx \ 30 /path/to/decoder.onnx \
31 /path/to/joiner.onnx \ 31 /path/to/joiner.onnx \
32 device_name \ 32 device_name \
33 - [num_threads] 33 + [num_threads [decoding_method]]
  34 +
  35 +Default value for num_threads is 2.
  36 +Valid values for decoding_method: greedy_search (default), modified_beam_search.
34 37
35 Please refer to 38 Please refer to
36 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 39 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -79,6 +82,11 @@ as the device_name. @@ -79,6 +82,11 @@ as the device_name.
79 config.model_config.num_threads = atoi(argv[6]); 82 config.model_config.num_threads = atoi(argv[6]);
80 } 83 }
81 84
  85 + if (argc == 8) {
  86 + config.decoding_method = argv[7];
  87 + }
  88 + config.max_active_paths = 4;
  89 +
82 config.enable_endpoint = true; 90 config.enable_endpoint = true;
83 91
84 config.endpoint_config.rule1.min_trailing_silence = 2.4; 92 config.endpoint_config.rule1.min_trailing_silence = 2.4;
@@ -36,7 +36,7 @@ static void Handler(int32_t sig) { @@ -36,7 +36,7 @@ static void Handler(int32_t sig) {
36 } 36 }
37 37
38 int32_t main(int32_t argc, char *argv[]) { 38 int32_t main(int32_t argc, char *argv[]) {
39 - if (argc < 5 || argc > 6) { 39 + if (argc < 5 || argc > 7) {
40 const char *usage = R"usage( 40 const char *usage = R"usage(
41 Usage: 41 Usage:
42 ./bin/sherpa-onnx-microphone \ 42 ./bin/sherpa-onnx-microphone \
@@ -44,7 +44,10 @@ Usage: @@ -44,7 +44,10 @@ Usage:
44 /path/to/encoder.onnx\ 44 /path/to/encoder.onnx\
45 /path/to/decoder.onnx\ 45 /path/to/decoder.onnx\
46 /path/to/joiner.onnx\ 46 /path/to/joiner.onnx\
47 - [num_threads] 47 + [num_threads [decoding_method]]
  48 +
  49 +Default value for num_threads is 2.
  50 +Valid values for decoding_method: greedy_search (default), modified_beam_search.
48 51
49 Please refer to 52 Please refer to
50 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 53 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -70,6 +73,11 @@ for a list of pre-trained models to download. @@ -70,6 +73,11 @@ for a list of pre-trained models to download.
70 config.model_config.num_threads = atoi(argv[5]); 73 config.model_config.num_threads = atoi(argv[5]);
71 } 74 }
72 75
  76 + if (argc == 7) {
  77 + config.decoding_method = argv[6];
  78 + }
  79 + config.max_active_paths = 4;
  80 +
73 config.enable_endpoint = true; 81 config.enable_endpoint = true;
74 82
75 config.endpoint_config.rule1.min_trailing_silence = 2.4; 83 config.endpoint_config.rule1.min_trailing_silence = 2.4;
@@ -14,7 +14,7 @@ @@ -14,7 +14,7 @@
14 #include "sherpa-onnx/csrc/wave-reader.h" 14 #include "sherpa-onnx/csrc/wave-reader.h"
15 15
16 int main(int32_t argc, char *argv[]) { 16 int main(int32_t argc, char *argv[]) {
17 - if (argc < 6 || argc > 7) { 17 + if (argc < 6 || argc > 8) {
18 const char *usage = R"usage( 18 const char *usage = R"usage(
19 Usage: 19 Usage:
20 ./bin/sherpa-onnx \ 20 ./bin/sherpa-onnx \
@@ -22,7 +22,10 @@ Usage: @@ -22,7 +22,10 @@ Usage:
22 /path/to/encoder.onnx \ 22 /path/to/encoder.onnx \
23 /path/to/decoder.onnx \ 23 /path/to/decoder.onnx \
24 /path/to/joiner.onnx \ 24 /path/to/joiner.onnx \
25 - /path/to/foo.wav [num_threads] 25 + /path/to/foo.wav [num_threads [decoding_method]]
  26 +
  27 +Default value for num_threads is 2.
  28 +Valid values for decoding_method: greedy_search (default), modified_beam_search.
26 29
27 Please refer to 30 Please refer to
28 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 31 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -45,9 +48,15 @@ for a list of pre-trained models to download. @@ -45,9 +48,15 @@ for a list of pre-trained models to download.
45 std::string wav_filename = argv[5]; 48 std::string wav_filename = argv[5];
46 49
47 config.model_config.num_threads = 2; 50 config.model_config.num_threads = 2;
48 - if (argc == 7) { 51 + if (argc == 7 && atoi(argv[6]) > 0) {
49 config.model_config.num_threads = atoi(argv[6]); 52 config.model_config.num_threads = atoi(argv[6]);
50 } 53 }
  54 +
  55 + if (argc == 8) {
  56 + config.decoding_method = argv[7];
  57 + }
  58 + config.max_active_paths = 4;
  59 +
51 fprintf(stderr, "%s\n", config.ToString().c_str()); 60 fprintf(stderr, "%s\n", config.ToString().c_str());
52 61
53 sherpa_onnx::OnlineRecognizer recognizer(config); 62 sherpa_onnx::OnlineRecognizer recognizer(config);
@@ -98,6 +107,7 @@ for a list of pre-trained models to download. @@ -98,6 +107,7 @@ for a list of pre-trained models to download.
98 1000.; 107 1000.;
99 108
100 fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); 109 fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
  110 + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
101 111
102 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); 112 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
103 float rtf = elapsed_seconds / duration; 113 float rtf = elapsed_seconds / duration;
1 include_directories(${CMAKE_SOURCE_DIR}) 1 include_directories(${CMAKE_SOURCE_DIR})
2 2
3 pybind11_add_module(_sherpa_onnx 3 pybind11_add_module(_sherpa_onnx
  4 + display.cc
  5 + endpoint.cc
4 features.cc 6 features.cc
  7 + online-recognizer.cc
  8 + online-stream.cc
5 online-transducer-model-config.cc 9 online-transducer-model-config.cc
6 sherpa-onnx.cc 10 sherpa-onnx.cc
7 - endpoint.cc  
8 - online-stream.cc  
9 - online-recognizer.cc  
10 ) 11 )
11 12
12 if(APPLE) 13 if(APPLE)
  1 +// sherpa-onnx/python/csrc/display.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/display.h"
  6 +
  7 +#include "sherpa-onnx/csrc/display.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +void PybindDisplay(py::module *m) {
  12 + using PyClass = Display;
  13 + py::class_<PyClass>(*m, "Display")
  14 + .def(py::init<int32_t>(), py::arg("max_word_per_line") = 60)
  15 + .def("print", &PyClass::Print, py::arg("idx"), py::arg("s"));
  16 +}
  17 +
  18 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/display.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindDisplay(py::module *m);
  13 +
  14 +} // namespace sherpa_onnx
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
@@ -11,10 +11,12 @@ namespace sherpa_onnx { @@ -11,10 +11,12 @@ namespace sherpa_onnx {
11 static void PybindFeatureExtractorConfig(py::module *m) { 11 static void PybindFeatureExtractorConfig(py::module *m) {
12 using PyClass = FeatureExtractorConfig; 12 using PyClass = FeatureExtractorConfig;
13 py::class_<PyClass>(*m, "FeatureExtractorConfig") 13 py::class_<PyClass>(*m, "FeatureExtractorConfig")
14 - .def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000,  
15 - py::arg("feature_dim") = 80) 14 + .def(py::init<float, int32_t, int32_t>(),
  15 + py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
  16 + py::arg("max_feature_vectors") = -1)
16 .def_readwrite("sampling_rate", &PyClass::sampling_rate) 17 .def_readwrite("sampling_rate", &PyClass::sampling_rate)
17 .def_readwrite("feature_dim", &PyClass::feature_dim) 18 .def_readwrite("feature_dim", &PyClass::feature_dim)
  19 + .def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
18 .def("__str__", &PyClass::ToString); 20 .def("__str__", &PyClass::ToString);
19 } 21 }
20 22
@@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
22 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 22 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
23 .def(py::init<const FeatureExtractorConfig &, 23 .def(py::init<const FeatureExtractorConfig &,
24 const OnlineTransducerModelConfig &, const EndpointConfig &, 24 const OnlineTransducerModelConfig &, const EndpointConfig &,
25 - bool>(), 25 + bool, const std::string &, int32_t>(),
26 py::arg("feat_config"), py::arg("model_config"), 26 py::arg("feat_config"), py::arg("model_config"),
27 - py::arg("endpoint_config"), py::arg("enable_endpoint")) 27 + py::arg("endpoint_config"), py::arg("enable_endpoint"),
  28 + py::arg("decoding_method"), py::arg("max_active_paths"))
28 .def_readwrite("feat_config", &PyClass::feat_config) 29 .def_readwrite("feat_config", &PyClass::feat_config)
29 .def_readwrite("model_config", &PyClass::model_config) 30 .def_readwrite("model_config", &PyClass::model_config)
30 .def_readwrite("endpoint_config", &PyClass::endpoint_config) 31 .def_readwrite("endpoint_config", &PyClass::endpoint_config)
31 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) 32 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
  33 + .def_readwrite("decoding_method", &PyClass::decoding_method)
  34 + .def_readwrite("max_active_paths", &PyClass::max_active_paths)
32 .def("__str__", &PyClass::ToString); 35 .def("__str__", &PyClass::ToString);
33 } 36 }
34 37
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 4
5 #include "sherpa-onnx/python/csrc/sherpa-onnx.h" 5 #include "sherpa-onnx/python/csrc/sherpa-onnx.h"
6 6
  7 +#include "sherpa-onnx/python/csrc/display.h"
7 #include "sherpa-onnx/python/csrc/endpoint.h" 8 #include "sherpa-onnx/python/csrc/endpoint.h"
8 #include "sherpa-onnx/python/csrc/features.h" 9 #include "sherpa-onnx/python/csrc/features.h"
9 #include "sherpa-onnx/python/csrc/online-recognizer.h" 10 #include "sherpa-onnx/python/csrc/online-recognizer.h"
@@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
19 PybindOnlineStream(&m); 20 PybindOnlineStream(&m);
20 PybindEndpoint(&m); 21 PybindEndpoint(&m);
21 PybindOnlineRecognizer(&m); 22 PybindOnlineRecognizer(&m);
  23 +
  24 + PybindDisplay(&m);
22 } 25 }
23 26
24 } // namespace sherpa_onnx 27 } // namespace sherpa_onnx
1 -from _sherpa_onnx import (  
2 - EndpointConfig,  
3 - FeatureExtractorConfig,  
4 - OnlineRecognizerConfig,  
5 - OnlineStream,  
6 - OnlineTransducerModelConfig,  
7 -) 1 +from _sherpa_onnx import Display
8 2
9 from .online_recognizer import OnlineRecognizer 3 from .online_recognizer import OnlineRecognizer
@@ -32,6 +32,9 @@ class OnlineRecognizer(object): @@ -32,6 +32,9 @@ class OnlineRecognizer(object):
32 rule1_min_trailing_silence: int = 2.4, 32 rule1_min_trailing_silence: int = 2.4,
33 rule2_min_trailing_silence: int = 1.2, 33 rule2_min_trailing_silence: int = 1.2,
34 rule3_min_utterance_length: int = 20, 34 rule3_min_utterance_length: int = 20,
  35 + decoding_method: str = "greedy_search",
  36 + max_active_paths: int = 4,
  37 + max_feature_vectors: int = -1,
35 ): 38 ):
36 """ 39 """
37 Please refer to 40 Please refer to
@@ -74,6 +77,14 @@ class OnlineRecognizer(object): @@ -74,6 +77,14 @@ class OnlineRecognizer(object):
74 Used only when enable_endpoint_detection is True. If the utterance 77 Used only when enable_endpoint_detection is True. If the utterance
75 length in seconds is larger than this value, we assume an endpoint 78 length in seconds is larger than this value, we assume an endpoint
76 is detected. 79 is detected.
  80 + decoding_method:
  81 + Valid values are greedy_search, modified_beam_search.
  82 + max_active_paths:
  83 + Use only when decoding_method is modified_beam_search. It specifies
  84 + the maximum number of active paths during beam search.
  85 + max_feature_vectors:
  86 + Number of feature vectors to cache. -1 means to cache all feature
  87 + frames that have been processed.
77 """ 88 """
78 _assert_file_exists(tokens) 89 _assert_file_exists(tokens)
79 _assert_file_exists(encoder) 90 _assert_file_exists(encoder)
@@ -93,6 +104,7 @@ class OnlineRecognizer(object): @@ -93,6 +104,7 @@ class OnlineRecognizer(object):
93 feat_config = FeatureExtractorConfig( 104 feat_config = FeatureExtractorConfig(
94 sampling_rate=sample_rate, 105 sampling_rate=sample_rate,
95 feature_dim=feature_dim, 106 feature_dim=feature_dim,
  107 + max_feature_vectors=max_feature_vectors,
96 ) 108 )
97 109
98 endpoint_config = EndpointConfig( 110 endpoint_config = EndpointConfig(
@@ -106,6 +118,8 @@ class OnlineRecognizer(object): @@ -106,6 +118,8 @@ class OnlineRecognizer(object):
106 model_config=model_config, 118 model_config=model_config,
107 endpoint_config=endpoint_config, 119 endpoint_config=endpoint_config,
108 enable_endpoint=enable_endpoint_detection, 120 enable_endpoint=enable_endpoint_detection,
  121 + decoding_method=decoding_method,
  122 + max_active_paths=max_active_paths,
109 ) 123 )
110 124
111 self.recognizer = _Recognizer(recognizer_config) 125 self.recognizer = _Recognizer(recognizer_config)