Committed by
GitHub
Code refactoring (#74)
* Don't reset model state and feature extractor on endpointing * support passing decoding_method from commandline * Add modified_beam_search to Python API * fix C API example * Fix style issues
正在显示
34 个修改的文件
包含
504 行增加
和
134 行删除
| @@ -19,14 +19,16 @@ const char *kUsage = | @@ -19,14 +19,16 @@ const char *kUsage = | ||
| 19 | " /path/to/encoder.onnx \\\n" | 19 | " /path/to/encoder.onnx \\\n" |
| 20 | " /path/to/decoder.onnx \\\n" | 20 | " /path/to/decoder.onnx \\\n" |
| 21 | " /path/to/joiner.onnx \\\n" | 21 | " /path/to/joiner.onnx \\\n" |
| 22 | - " /path/to/foo.wav [num_threads]\n" | 22 | + " /path/to/foo.wav [num_threads [decoding_method]]\n" |
| 23 | "\n\n" | 23 | "\n\n" |
| 24 | + "Default num_threads is 1.\n" | ||
| 25 | + "Valid decoding_method: greedy_search (default), modified_beam_search\n\n" | ||
| 24 | "Please refer to \n" | 26 | "Please refer to \n" |
| 25 | "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" | 27 | "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" |
| 26 | "for a list of pre-trained models to download.\n"; | 28 | "for a list of pre-trained models to download.\n"; |
| 27 | 29 | ||
| 28 | int32_t main(int32_t argc, char *argv[]) { | 30 | int32_t main(int32_t argc, char *argv[]) { |
| 29 | - if (argc < 6 || argc > 7) { | 31 | + if (argc < 6 || argc > 8) { |
| 30 | fprintf(stderr, "%s\n", kUsage); | 32 | fprintf(stderr, "%s\n", kUsage); |
| 31 | return -1; | 33 | return -1; |
| 32 | } | 34 | } |
| @@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) { | @@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) { | ||
| 36 | config.model_config.decoder = argv[3]; | 38 | config.model_config.decoder = argv[3]; |
| 37 | config.model_config.joiner = argv[4]; | 39 | config.model_config.joiner = argv[4]; |
| 38 | 40 | ||
| 39 | - int32_t num_threads = 4; | 41 | + int32_t num_threads = 1; |
| 40 | if (argc == 7 && atoi(argv[6]) > 0) { | 42 | if (argc == 7 && atoi(argv[6]) > 0) { |
| 41 | num_threads = atoi(argv[6]); | 43 | num_threads = atoi(argv[6]); |
| 42 | } | 44 | } |
| 43 | config.model_config.num_threads = num_threads; | 45 | config.model_config.num_threads = num_threads; |
| 44 | config.model_config.debug = 0; | 46 | config.model_config.debug = 0; |
| 45 | 47 | ||
| 48 | + config.decoding_method = "greedy_search"; | ||
| 49 | + if (argc == 8) { | ||
| 50 | + config.decoding_method = argv[7]; | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + config.max_active_paths = 4; | ||
| 54 | + | ||
| 46 | config.feat_config.sample_rate = 16000; | 55 | config.feat_config.sample_rate = 16000; |
| 47 | config.feat_config.feature_dim = 80; | 56 | config.feat_config.feature_dim = 80; |
| 48 | 57 | ||
| @@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) { | @@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) { | ||
| 54 | SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); | 63 | SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); |
| 55 | SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); | 64 | SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); |
| 56 | 65 | ||
| 66 | + SherpaOnnxDisplay *display = CreateDisplay(50); | ||
| 67 | + int32_t segment_id = 0; | ||
| 68 | + | ||
| 57 | const char *wav_filename = argv[5]; | 69 | const char *wav_filename = argv[5]; |
| 58 | FILE *fp = fopen(wav_filename, "rb"); | 70 | FILE *fp = fopen(wav_filename, "rb"); |
| 59 | if (!fp) { | 71 | if (!fp) { |
| @@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) { | @@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) { | ||
| 84 | 96 | ||
| 85 | SherpaOnnxOnlineRecognizerResult *r = | 97 | SherpaOnnxOnlineRecognizerResult *r = |
| 86 | GetOnlineStreamResult(recognizer, stream); | 98 | GetOnlineStreamResult(recognizer, stream); |
| 99 | + | ||
| 87 | if (strlen(r->text)) { | 100 | if (strlen(r->text)) { |
| 88 | - fprintf(stderr, "%s\n", r->text); | 101 | + SherpaOnnxPrint(display, segment_id, r->text); |
| 89 | } | 102 | } |
| 103 | + | ||
| 104 | + if (IsEndpoint(recognizer, stream)) { | ||
| 105 | + if (strlen(r->text)) { | ||
| 106 | + ++segment_id; | ||
| 107 | + } | ||
| 108 | + Reset(recognizer, stream); | ||
| 109 | + } | ||
| 110 | + | ||
| 90 | DestroyOnlineRecognizerResult(r); | 111 | DestroyOnlineRecognizerResult(r); |
| 91 | } | 112 | } |
| 92 | } | 113 | } |
| @@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) { | @@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) { | ||
| 103 | 124 | ||
| 104 | SherpaOnnxOnlineRecognizerResult *r = | 125 | SherpaOnnxOnlineRecognizerResult *r = |
| 105 | GetOnlineStreamResult(recognizer, stream); | 126 | GetOnlineStreamResult(recognizer, stream); |
| 127 | + | ||
| 106 | if (strlen(r->text)) { | 128 | if (strlen(r->text)) { |
| 107 | - fprintf(stderr, "%s\n", r->text); | 129 | + SherpaOnnxPrint(display, segment_id, r->text); |
| 108 | } | 130 | } |
| 109 | 131 | ||
| 110 | DestroyOnlineRecognizerResult(r); | 132 | DestroyOnlineRecognizerResult(r); |
| 111 | 133 | ||
| 134 | + DestroyDisplay(display); | ||
| 112 | DestoryOnlineStream(stream); | 135 | DestoryOnlineStream(stream); |
| 113 | DestroyOnlineRecognizer(recognizer); | 136 | DestroyOnlineRecognizer(recognizer); |
| 137 | + fprintf(stderr, "\n"); | ||
| 114 | 138 | ||
| 115 | return 0; | 139 | return 0; |
| 116 | } | 140 | } |
| @@ -26,12 +26,17 @@ if [ ! -f ./sherpa-onnx-ffmpeg ]; then | @@ -26,12 +26,17 @@ if [ ! -f ./sherpa-onnx-ffmpeg ]; then | ||
| 26 | make | 26 | make |
| 27 | fi | 27 | fi |
| 28 | 28 | ||
| 29 | -../ffmpeg-examples/sherpa-onnx-ffmpeg \ | 29 | +for method in greedy_search modified_beam_search; do |
| 30 | + echo "test method: $method" | ||
| 31 | + ../ffmpeg-examples/sherpa-onnx-ffmpeg \ | ||
| 30 | ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ | 32 | ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ |
| 31 | ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ | 33 | ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ |
| 32 | ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ | 34 | ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ |
| 33 | ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ | 35 | ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ |
| 34 | - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/4.wav | 36 | + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \ |
| 37 | + 2 \ | ||
| 38 | + $method | ||
| 39 | +done | ||
| 35 | 40 | ||
| 36 | echo "Decoding a URL" | 41 | echo "Decoding a URL" |
| 37 | 42 |
| @@ -7,7 +7,6 @@ | @@ -7,7 +7,6 @@ | ||
| 7 | 7 | ||
| 8 | #include "sherpa-onnx/c-api/c-api.h" | 8 | #include "sherpa-onnx/c-api/c-api.h" |
| 9 | 9 | ||
| 10 | - | ||
| 11 | /* | 10 | /* |
| 12 | * Copyright (c) 2010 Nicolas George | 11 | * Copyright (c) 2010 Nicolas George |
| 13 | * Copyright (c) 2011 Stefano Sabatini | 12 | * Copyright (c) 2011 Stefano Sabatini |
| @@ -43,14 +42,15 @@ | @@ -43,14 +42,15 @@ | ||
| 43 | #include <unistd.h> | 42 | #include <unistd.h> |
| 44 | extern "C" { | 43 | extern "C" { |
| 45 | #include <libavcodec/avcodec.h> | 44 | #include <libavcodec/avcodec.h> |
| 46 | -#include <libavformat/avformat.h> | ||
| 47 | #include <libavfilter/buffersink.h> | 45 | #include <libavfilter/buffersink.h> |
| 48 | #include <libavfilter/buffersrc.h> | 46 | #include <libavfilter/buffersrc.h> |
| 47 | +#include <libavformat/avformat.h> | ||
| 49 | #include <libavutil/channel_layout.h> | 48 | #include <libavutil/channel_layout.h> |
| 50 | #include <libavutil/opt.h> | 49 | #include <libavutil/opt.h> |
| 51 | } | 50 | } |
| 52 | 51 | ||
| 53 | -static const char *filter_descr = "aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono"; | 52 | +static const char *filter_descr = |
| 53 | + "aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono"; | ||
| 54 | 54 | ||
| 55 | static AVFormatContext *fmt_ctx; | 55 | static AVFormatContext *fmt_ctx; |
| 56 | static AVCodecContext *dec_ctx; | 56 | static AVCodecContext *dec_ctx; |
| @@ -59,8 +59,7 @@ AVFilterContext *buffersrc_ctx; | @@ -59,8 +59,7 @@ AVFilterContext *buffersrc_ctx; | ||
| 59 | AVFilterGraph *filter_graph; | 59 | AVFilterGraph *filter_graph; |
| 60 | static int audio_stream_index = -1; | 60 | static int audio_stream_index = -1; |
| 61 | 61 | ||
| 62 | -static int open_input_file(const char *filename) | ||
| 63 | -{ | 62 | +static int open_input_file(const char *filename) { |
| 64 | const AVCodec *dec; | 63 | const AVCodec *dec; |
| 65 | int ret; | 64 | int ret; |
| 66 | 65 | ||
| @@ -77,16 +76,17 @@ static int open_input_file(const char *filename) | @@ -77,16 +76,17 @@ static int open_input_file(const char *filename) | ||
| 77 | /* select the audio stream */ | 76 | /* select the audio stream */ |
| 78 | ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0); | 77 | ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0); |
| 79 | if (ret < 0) { | 78 | if (ret < 0) { |
| 80 | - av_log(NULL, AV_LOG_ERROR, "Cannot find an audio stream in the input file\n"); | 79 | + av_log(NULL, AV_LOG_ERROR, |
| 80 | + "Cannot find an audio stream in the input file\n"); | ||
| 81 | return ret; | 81 | return ret; |
| 82 | } | 82 | } |
| 83 | audio_stream_index = ret; | 83 | audio_stream_index = ret; |
| 84 | 84 | ||
| 85 | /* create decoding context */ | 85 | /* create decoding context */ |
| 86 | dec_ctx = avcodec_alloc_context3(dec); | 86 | dec_ctx = avcodec_alloc_context3(dec); |
| 87 | - if (!dec_ctx) | ||
| 88 | - return AVERROR(ENOMEM); | ||
| 89 | - avcodec_parameters_to_context(dec_ctx, fmt_ctx->streams[audio_stream_index]->codecpar); | 87 | + if (!dec_ctx) return AVERROR(ENOMEM); |
| 88 | + avcodec_parameters_to_context(dec_ctx, | ||
| 89 | + fmt_ctx->streams[audio_stream_index]->codecpar); | ||
| 90 | 90 | ||
| 91 | /* init the audio decoder */ | 91 | /* init the audio decoder */ |
| 92 | if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) { | 92 | if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) { |
| @@ -97,16 +97,16 @@ static int open_input_file(const char *filename) | @@ -97,16 +97,16 @@ static int open_input_file(const char *filename) | ||
| 97 | return 0; | 97 | return 0; |
| 98 | } | 98 | } |
| 99 | 99 | ||
| 100 | -static int init_filters(const char *filters_descr) | ||
| 101 | -{ | 100 | +static int init_filters(const char *filters_descr) { |
| 102 | char args[512]; | 101 | char args[512]; |
| 103 | int ret = 0; | 102 | int ret = 0; |
| 104 | const AVFilter *abuffersrc = avfilter_get_by_name("abuffer"); | 103 | const AVFilter *abuffersrc = avfilter_get_by_name("abuffer"); |
| 105 | const AVFilter *abuffersink = avfilter_get_by_name("abuffersink"); | 104 | const AVFilter *abuffersink = avfilter_get_by_name("abuffersink"); |
| 106 | AVFilterInOut *outputs = avfilter_inout_alloc(); | 105 | AVFilterInOut *outputs = avfilter_inout_alloc(); |
| 107 | AVFilterInOut *inputs = avfilter_inout_alloc(); | 106 | AVFilterInOut *inputs = avfilter_inout_alloc(); |
| 108 | - static const enum AVSampleFormat out_sample_fmts[] = { AV_SAMPLE_FMT_S16, AV_SAMPLE_FMT_NONE }; | ||
| 109 | - static const int out_sample_rates[] = { 16000, -1 }; | 107 | + static const enum AVSampleFormat out_sample_fmts[] = {AV_SAMPLE_FMT_S16, |
| 108 | + AV_SAMPLE_FMT_NONE}; | ||
| 109 | + static const int out_sample_rates[] = {16000, -1}; | ||
| 110 | const AVFilterLink *outlink; | 110 | const AVFilterLink *outlink; |
| 111 | AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base; | 111 | AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base; |
| 112 | 112 | ||
| @@ -116,24 +116,27 @@ static int init_filters(const char *filters_descr) | @@ -116,24 +116,27 @@ static int init_filters(const char *filters_descr) | ||
| 116 | goto end; | 116 | goto end; |
| 117 | } | 117 | } |
| 118 | 118 | ||
| 119 | - /* buffer audio source: the decoded frames from the decoder will be inserted here. */ | 119 | + /* buffer audio source: the decoded frames from the decoder will be inserted |
| 120 | + * here. */ | ||
| 120 | if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC) | 121 | if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC) |
| 121 | - av_channel_layout_default(&dec_ctx->ch_layout, dec_ctx->ch_layout.nb_channels); | 122 | + av_channel_layout_default(&dec_ctx->ch_layout, |
| 123 | + dec_ctx->ch_layout.nb_channels); | ||
| 122 | ret = snprintf(args, sizeof(args), | 124 | ret = snprintf(args, sizeof(args), |
| 123 | "time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=", | 125 | "time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=", |
| 124 | time_base.num, time_base.den, dec_ctx->sample_rate, | 126 | time_base.num, time_base.den, dec_ctx->sample_rate, |
| 125 | av_get_sample_fmt_name(dec_ctx->sample_fmt)); | 127 | av_get_sample_fmt_name(dec_ctx->sample_fmt)); |
| 126 | - av_channel_layout_describe(&dec_ctx->ch_layout, args + ret, sizeof(args) - ret); | ||
| 127 | - ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in", | ||
| 128 | - args, NULL, filter_graph); | 128 | + av_channel_layout_describe(&dec_ctx->ch_layout, args + ret, |
| 129 | + sizeof(args) - ret); | ||
| 130 | + ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in", args, | ||
| 131 | + NULL, filter_graph); | ||
| 129 | if (ret < 0) { | 132 | if (ret < 0) { |
| 130 | av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n"); | 133 | av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n"); |
| 131 | goto end; | 134 | goto end; |
| 132 | } | 135 | } |
| 133 | 136 | ||
| 134 | /* buffer audio sink: to terminate the filter chain. */ | 137 | /* buffer audio sink: to terminate the filter chain. */ |
| 135 | - ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out", | ||
| 136 | - NULL, NULL, filter_graph); | 138 | + ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out", NULL, |
| 139 | + NULL, filter_graph); | ||
| 137 | if (ret < 0) { | 140 | if (ret < 0) { |
| 138 | av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n"); | 141 | av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n"); |
| 139 | goto end; | 142 | goto end; |
| @@ -146,15 +149,15 @@ static int init_filters(const char *filters_descr) | @@ -146,15 +149,15 @@ static int init_filters(const char *filters_descr) | ||
| 146 | goto end; | 149 | goto end; |
| 147 | } | 150 | } |
| 148 | 151 | ||
| 149 | - ret = av_opt_set(buffersink_ctx, "ch_layouts", "mono", | ||
| 150 | - AV_OPT_SEARCH_CHILDREN); | 152 | + ret = |
| 153 | + av_opt_set(buffersink_ctx, "ch_layouts", "mono", AV_OPT_SEARCH_CHILDREN); | ||
| 151 | if (ret < 0) { | 154 | if (ret < 0) { |
| 152 | av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n"); | 155 | av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n"); |
| 153 | goto end; | 156 | goto end; |
| 154 | } | 157 | } |
| 155 | 158 | ||
| 156 | - ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates, -1, | ||
| 157 | - AV_OPT_SEARCH_CHILDREN); | 159 | + ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates, |
| 160 | + -1, AV_OPT_SEARCH_CHILDREN); | ||
| 158 | if (ret < 0) { | 161 | if (ret < 0) { |
| 159 | av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n"); | 162 | av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n"); |
| 160 | goto end; | 163 | goto end; |
| @@ -187,12 +190,11 @@ static int init_filters(const char *filters_descr) | @@ -187,12 +190,11 @@ static int init_filters(const char *filters_descr) | ||
| 187 | inputs->pad_idx = 0; | 190 | inputs->pad_idx = 0; |
| 188 | inputs->next = NULL; | 191 | inputs->next = NULL; |
| 189 | 192 | ||
| 190 | - if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr, | ||
| 191 | - &inputs, &outputs, NULL)) < 0) | 193 | + if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr, &inputs, |
| 194 | + &outputs, NULL)) < 0) | ||
| 192 | goto end; | 195 | goto end; |
| 193 | 196 | ||
| 194 | - if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0) | ||
| 195 | - goto end; | 197 | + if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0) goto end; |
| 196 | 198 | ||
| 197 | /* Print summary of the sink buffer | 199 | /* Print summary of the sink buffer |
| 198 | * Note: args buffer is reused to store channel layout string */ | 200 | * Note: args buffer is reused to store channel layout string */ |
| @@ -200,7 +202,8 @@ static int init_filters(const char *filters_descr) | @@ -200,7 +202,8 @@ static int init_filters(const char *filters_descr) | ||
| 200 | av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args)); | 202 | av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args)); |
| 201 | av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n", | 203 | av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n", |
| 202 | (int)outlink->sample_rate, | 204 | (int)outlink->sample_rate, |
| 203 | - (char *)av_x_if_null(av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"), | 205 | + (char *)av_x_if_null( |
| 206 | + av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"), | ||
| 204 | args); | 207 | args); |
| 205 | 208 | ||
| 206 | end: | 209 | end: |
| @@ -210,13 +213,15 @@ end: | @@ -210,13 +213,15 @@ end: | ||
| 210 | return ret; | 213 | return ret; |
| 211 | } | 214 | } |
| 212 | 215 | ||
| 213 | -static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer *recognizer, | ||
| 214 | - SherpaOnnxOnlineStream* stream) | ||
| 215 | -{ | 216 | +static void sherpa_decode_frame(const AVFrame *frame, |
| 217 | + SherpaOnnxOnlineRecognizer *recognizer, | ||
| 218 | + SherpaOnnxOnlineStream *stream, | ||
| 219 | + SherpaOnnxDisplay *display, | ||
| 220 | + int32_t *segment_id) { | ||
| 216 | #define N 3200 // 100s. Sample rate is fixed to 16 kHz | 221 | #define N 3200 // 100s. Sample rate is fixed to 16 kHz |
| 217 | static float samples[N]; | 222 | static float samples[N]; |
| 218 | static int nb_samples = 0; | 223 | static int nb_samples = 0; |
| 219 | - const int16_t *p = (int16_t*)frame->data[0]; | 224 | + const int16_t *p = (int16_t *)frame->data[0]; |
| 220 | 225 | ||
| 221 | if (frame->nb_samples + nb_samples > N) { | 226 | if (frame->nb_samples + nb_samples > N) { |
| 222 | AcceptWaveform(stream, 16000, samples, nb_samples); | 227 | AcceptWaveform(stream, 16000, samples, nb_samples); |
| @@ -224,17 +229,20 @@ static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer | @@ -224,17 +229,20 @@ static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer | ||
| 224 | DecodeOnlineStream(recognizer, stream); | 229 | DecodeOnlineStream(recognizer, stream); |
| 225 | } | 230 | } |
| 226 | 231 | ||
| 227 | - | ||
| 228 | - if (IsEndpoint(recognizer, stream)) { | ||
| 229 | SherpaOnnxOnlineRecognizerResult *r = | 232 | SherpaOnnxOnlineRecognizerResult *r = |
| 230 | GetOnlineStreamResult(recognizer, stream); | 233 | GetOnlineStreamResult(recognizer, stream); |
| 231 | if (strlen(r->text)) { | 234 | if (strlen(r->text)) { |
| 232 | - fprintf(stderr, "%s\n", r->text); | 235 | + SherpaOnnxPrint(display, *segment_id, r->text); |
| 233 | } | 236 | } |
| 234 | - DestroyOnlineRecognizerResult(r); | ||
| 235 | 237 | ||
| 238 | + if (IsEndpoint(recognizer, stream)) { | ||
| 239 | + if (strlen(r->text)) { | ||
| 240 | + ++*segment_id; | ||
| 241 | + } | ||
| 236 | Reset(recognizer, stream); | 242 | Reset(recognizer, stream); |
| 237 | } | 243 | } |
| 244 | + | ||
| 245 | + DestroyOnlineRecognizerResult(r); | ||
| 238 | nb_samples = 0; | 246 | nb_samples = 0; |
| 239 | } | 247 | } |
| 240 | 248 | ||
| @@ -243,17 +251,15 @@ static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer | @@ -243,17 +251,15 @@ static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer | ||
| 243 | } | 251 | } |
| 244 | } | 252 | } |
| 245 | 253 | ||
| 246 | -static inline char *__av_err2str(int errnum) | ||
| 247 | -{ | 254 | +static inline char *__av_err2str(int errnum) { |
| 248 | static char str[AV_ERROR_MAX_STRING_SIZE]; | 255 | static char str[AV_ERROR_MAX_STRING_SIZE]; |
| 249 | memset(str, 0, sizeof(str)); | 256 | memset(str, 0, sizeof(str)); |
| 250 | return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum); | 257 | return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum); |
| 251 | } | 258 | } |
| 252 | 259 | ||
| 253 | -int main(int argc, char **argv) | ||
| 254 | -{ | 260 | +int main(int argc, char **argv) { |
| 255 | int ret; | 261 | int ret; |
| 256 | - int num_threads = 4; | 262 | + int num_threads = 1; |
| 257 | AVPacket *packet = av_packet_alloc(); | 263 | AVPacket *packet = av_packet_alloc(); |
| 258 | AVFrame *frame = av_frame_alloc(); | 264 | AVFrame *frame = av_frame_alloc(); |
| 259 | AVFrame *filt_frame = av_frame_alloc(); | 265 | AVFrame *filt_frame = av_frame_alloc(); |
| @@ -265,19 +271,20 @@ int main(int argc, char **argv) | @@ -265,19 +271,20 @@ int main(int argc, char **argv) | ||
| 265 | " /path/to/encoder.onnx\\\n" | 271 | " /path/to/encoder.onnx\\\n" |
| 266 | " /path/to/decoder.onnx\\\n" | 272 | " /path/to/decoder.onnx\\\n" |
| 267 | " /path/to/joiner.onnx\\\n" | 273 | " /path/to/joiner.onnx\\\n" |
| 268 | - " /path/to/foo.wav [num_threads]" | 274 | + " /path/to/foo.wav [num_threads [decoding_method]]" |
| 269 | "\n\n" | 275 | "\n\n" |
| 276 | + "Default num_threads is 1.\n" | ||
| 277 | + "Valid decoding_method: greedy_search (default), modified_beam_search\n\n" | ||
| 270 | "Please refer to \n" | 278 | "Please refer to \n" |
| 271 | "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" | 279 | "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" |
| 272 | "for a list of pre-trained models to download.\n"; | 280 | "for a list of pre-trained models to download.\n"; |
| 273 | 281 | ||
| 274 | - | ||
| 275 | if (!packet || !frame || !filt_frame) { | 282 | if (!packet || !frame || !filt_frame) { |
| 276 | fprintf(stderr, "Could not allocate frame or packet\n"); | 283 | fprintf(stderr, "Could not allocate frame or packet\n"); |
| 277 | exit(1); | 284 | exit(1); |
| 278 | } | 285 | } |
| 279 | 286 | ||
| 280 | - if (argc < 6 || argc > 7) { | 287 | + if (argc < 6 || argc > 8) { |
| 281 | fprintf(stderr, "%s\n", kUsage); | 288 | fprintf(stderr, "%s\n", kUsage); |
| 282 | return -1; | 289 | return -1; |
| 283 | } | 290 | } |
| @@ -291,12 +298,20 @@ int main(int argc, char **argv) | @@ -291,12 +298,20 @@ int main(int argc, char **argv) | ||
| 291 | if (argc == 7 && atoi(argv[6]) > 0) { | 298 | if (argc == 7 && atoi(argv[6]) > 0) { |
| 292 | num_threads = atoi(argv[6]); | 299 | num_threads = atoi(argv[6]); |
| 293 | } | 300 | } |
| 301 | + | ||
| 294 | config.model_config.num_threads = num_threads; | 302 | config.model_config.num_threads = num_threads; |
| 295 | config.model_config.debug = 0; | 303 | config.model_config.debug = 0; |
| 296 | 304 | ||
| 297 | config.feat_config.sample_rate = 16000; | 305 | config.feat_config.sample_rate = 16000; |
| 298 | config.feat_config.feature_dim = 80; | 306 | config.feat_config.feature_dim = 80; |
| 299 | 307 | ||
| 308 | + config.decoding_method = "greedy_search"; | ||
| 309 | + if (argc == 8) { | ||
| 310 | + config.decoding_method = argv[7]; | ||
| 311 | + } | ||
| 312 | + | ||
| 313 | + config.max_active_paths = 4; | ||
| 314 | + | ||
| 300 | config.enable_endpoint = 1; | 315 | config.enable_endpoint = 1; |
| 301 | config.rule1_min_trailing_silence = 2.4; | 316 | config.rule1_min_trailing_silence = 2.4; |
| 302 | config.rule2_min_trailing_silence = 1.2; | 317 | config.rule2_min_trailing_silence = 1.2; |
| @@ -304,22 +319,22 @@ int main(int argc, char **argv) | @@ -304,22 +319,22 @@ int main(int argc, char **argv) | ||
| 304 | 319 | ||
| 305 | SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); | 320 | SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); |
| 306 | SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); | 321 | SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); |
| 322 | + SherpaOnnxDisplay *display = CreateDisplay(50); | ||
| 323 | + int32_t segment_id = 0; | ||
| 307 | 324 | ||
| 308 | - if ((ret = open_input_file(argv[5])) < 0) | ||
| 309 | - exit(1); | 325 | + if ((ret = open_input_file(argv[5])) < 0) exit(1); |
| 310 | 326 | ||
| 311 | - if ((ret = init_filters(filter_descr)) < 0) | ||
| 312 | - exit(1); | 327 | + if ((ret = init_filters(filter_descr)) < 0) exit(1); |
| 313 | 328 | ||
| 314 | /* read all packets */ | 329 | /* read all packets */ |
| 315 | while (1) { | 330 | while (1) { |
| 316 | - if ((ret = av_read_frame(fmt_ctx, packet)) < 0) | ||
| 317 | - break; | 331 | + if ((ret = av_read_frame(fmt_ctx, packet)) < 0) break; |
| 318 | 332 | ||
| 319 | if (packet->stream_index == audio_stream_index) { | 333 | if (packet->stream_index == audio_stream_index) { |
| 320 | ret = avcodec_send_packet(dec_ctx, packet); | 334 | ret = avcodec_send_packet(dec_ctx, packet); |
| 321 | if (ret < 0) { | 335 | if (ret < 0) { |
| 322 | - av_log(NULL, AV_LOG_ERROR, "Error while sending a packet to the decoder\n"); | 336 | + av_log(NULL, AV_LOG_ERROR, |
| 337 | + "Error while sending a packet to the decoder\n"); | ||
| 323 | break; | 338 | break; |
| 324 | } | 339 | } |
| 325 | 340 | ||
| @@ -328,25 +343,27 @@ int main(int argc, char **argv) | @@ -328,25 +343,27 @@ int main(int argc, char **argv) | ||
| 328 | if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { | 343 | if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { |
| 329 | break; | 344 | break; |
| 330 | } else if (ret < 0) { | 345 | } else if (ret < 0) { |
| 331 | - av_log(NULL, AV_LOG_ERROR, "Error while receiving a frame from the decoder\n"); | 346 | + av_log(NULL, AV_LOG_ERROR, |
| 347 | + "Error while receiving a frame from the decoder\n"); | ||
| 332 | exit(1); | 348 | exit(1); |
| 333 | } | 349 | } |
| 334 | 350 | ||
| 335 | if (ret >= 0) { | 351 | if (ret >= 0) { |
| 336 | /* push the audio data from decoded frame into the filtergraph */ | 352 | /* push the audio data from decoded frame into the filtergraph */ |
| 337 | - if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame, AV_BUFFERSRC_FLAG_KEEP_REF) < 0) { | ||
| 338 | - av_log(NULL, AV_LOG_ERROR, "Error while feeding the audio filtergraph\n"); | 353 | + if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame, |
| 354 | + AV_BUFFERSRC_FLAG_KEEP_REF) < 0) { | ||
| 355 | + av_log(NULL, AV_LOG_ERROR, | ||
| 356 | + "Error while feeding the audio filtergraph\n"); | ||
| 339 | break; | 357 | break; |
| 340 | } | 358 | } |
| 341 | 359 | ||
| 342 | /* pull filtered audio from the filtergraph */ | 360 | /* pull filtered audio from the filtergraph */ |
| 343 | while (1) { | 361 | while (1) { |
| 344 | ret = av_buffersink_get_frame(buffersink_ctx, filt_frame); | 362 | ret = av_buffersink_get_frame(buffersink_ctx, filt_frame); |
| 345 | - if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) | ||
| 346 | - break; | ||
| 347 | - if (ret < 0) | ||
| 348 | - exit(1); | ||
| 349 | - sherpa_decode_frame(filt_frame, recognizer, stream); | 363 | + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) break; |
| 364 | + if (ret < 0) exit(1); | ||
| 365 | + sherpa_decode_frame(filt_frame, recognizer, stream, display, | ||
| 366 | + &segment_id); | ||
| 350 | av_frame_unref(filt_frame); | 367 | av_frame_unref(filt_frame); |
| 351 | } | 368 | } |
| 352 | av_frame_unref(frame); | 369 | av_frame_unref(frame); |
| @@ -368,11 +385,12 @@ int main(int argc, char **argv) | @@ -368,11 +385,12 @@ int main(int argc, char **argv) | ||
| 368 | SherpaOnnxOnlineRecognizerResult *r = | 385 | SherpaOnnxOnlineRecognizerResult *r = |
| 369 | GetOnlineStreamResult(recognizer, stream); | 386 | GetOnlineStreamResult(recognizer, stream); |
| 370 | if (strlen(r->text)) { | 387 | if (strlen(r->text)) { |
| 371 | - fprintf(stderr, "%s\n", r->text); | 388 | + SherpaOnnxPrint(display, segment_id, r->text); |
| 372 | } | 389 | } |
| 373 | 390 | ||
| 374 | DestroyOnlineRecognizerResult(r); | 391 | DestroyOnlineRecognizerResult(r); |
| 375 | 392 | ||
| 393 | + DestroyDisplay(display); | ||
| 376 | DestoryOnlineStream(stream); | 394 | DestoryOnlineStream(stream); |
| 377 | DestroyOnlineRecognizer(recognizer); | 395 | DestroyOnlineRecognizer(recognizer); |
| 378 | 396 | ||
| @@ -387,6 +405,7 @@ int main(int argc, char **argv) | @@ -387,6 +405,7 @@ int main(int argc, char **argv) | ||
| 387 | fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret)); | 405 | fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret)); |
| 388 | exit(1); | 406 | exit(1); |
| 389 | } | 407 | } |
| 408 | + fprintf(stderr, "\n"); | ||
| 390 | 409 | ||
| 391 | return 0; | 410 | return 0; |
| 392 | } | 411 | } |
| @@ -54,6 +54,20 @@ def get_args(): | @@ -54,6 +54,20 @@ def get_args(): | ||
| 54 | ) | 54 | ) |
| 55 | 55 | ||
| 56 | parser.add_argument( | 56 | parser.add_argument( |
| 57 | + "--num-threads", | ||
| 58 | + type=int, | ||
| 59 | + default=1, | ||
| 60 | + help="Number of threads for neural network computation", | ||
| 61 | + ) | ||
| 62 | + | ||
| 63 | + parser.add_argument( | ||
| 64 | + "--decoding-method", | ||
| 65 | + type=str, | ||
| 66 | + default="greedy_search", | ||
| 67 | + help="Valid values are greedy_search and modified_beam_search", | ||
| 68 | + ) | ||
| 69 | + | ||
| 70 | + parser.add_argument( | ||
| 57 | "--wave-filename", | 71 | "--wave-filename", |
| 58 | type=str, | 72 | type=str, |
| 59 | help="""Path to the wave filename. Must be 16 kHz, | 73 | help="""Path to the wave filename. Must be 16 kHz, |
| @@ -65,7 +79,6 @@ def get_args(): | @@ -65,7 +79,6 @@ def get_args(): | ||
| 65 | 79 | ||
| 66 | def main(): | 80 | def main(): |
| 67 | sample_rate = 16000 | 81 | sample_rate = 16000 |
| 68 | - num_threads = 2 | ||
| 69 | 82 | ||
| 70 | args = get_args() | 83 | args = get_args() |
| 71 | assert_file_exists(args.encoder) | 84 | assert_file_exists(args.encoder) |
| @@ -81,9 +94,10 @@ def main(): | @@ -81,9 +94,10 @@ def main(): | ||
| 81 | encoder=args.encoder, | 94 | encoder=args.encoder, |
| 82 | decoder=args.decoder, | 95 | decoder=args.decoder, |
| 83 | joiner=args.joiner, | 96 | joiner=args.joiner, |
| 84 | - num_threads=num_threads, | 97 | + num_threads=args.num_threads, |
| 85 | sample_rate=sample_rate, | 98 | sample_rate=sample_rate, |
| 86 | feature_dim=80, | 99 | feature_dim=80, |
| 100 | + decoding_method=args.decoding_method, | ||
| 87 | ) | 101 | ) |
| 88 | with wave.open(args.wave_filename) as f: | 102 | with wave.open(args.wave_filename) as f: |
| 89 | assert f.getframerate() == sample_rate, f.getframerate() | 103 | assert f.getframerate() == sample_rate, f.getframerate() |
| @@ -119,7 +133,8 @@ def main(): | @@ -119,7 +133,8 @@ def main(): | ||
| 119 | end_time = time.time() | 133 | end_time = time.time() |
| 120 | elapsed_seconds = end_time - start_time | 134 | elapsed_seconds = end_time - start_time |
| 121 | rtf = elapsed_seconds / duration | 135 | rtf = elapsed_seconds / duration |
| 122 | - print(f"num_threads: {num_threads}") | 136 | + print(f"num_threads: {args.num_threads}") |
| 137 | + print(f"decoding_method: {args.decoding_method}") | ||
| 123 | print(f"Wave duration: {duration:.3f} s") | 138 | print(f"Wave duration: {duration:.3f} s") |
| 124 | print(f"Elapsed time: {elapsed_seconds:.3f} s") | 139 | print(f"Elapsed time: {elapsed_seconds:.3f} s") |
| 125 | print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") | 140 | print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") |
| @@ -60,10 +60,10 @@ def get_args(): | @@ -60,10 +60,10 @@ def get_args(): | ||
| 60 | ) | 60 | ) |
| 61 | 61 | ||
| 62 | parser.add_argument( | 62 | parser.add_argument( |
| 63 | - "--wave-filename", | 63 | + "--decoding-method", |
| 64 | type=str, | 64 | type=str, |
| 65 | - help="""Path to the wave filename. Must be 16 kHz, | ||
| 66 | - mono with 16-bit samples""", | 65 | + default="greedy_search", |
| 66 | + help="Valid values are greedy_search and modified_beam_search", | ||
| 67 | ) | 67 | ) |
| 68 | 68 | ||
| 69 | return parser.parse_args() | 69 | return parser.parse_args() |
| @@ -83,17 +83,23 @@ def create_recognizer(): | @@ -83,17 +83,23 @@ def create_recognizer(): | ||
| 83 | encoder=args.encoder, | 83 | encoder=args.encoder, |
| 84 | decoder=args.decoder, | 84 | decoder=args.decoder, |
| 85 | joiner=args.joiner, | 85 | joiner=args.joiner, |
| 86 | + num_threads=1, | ||
| 87 | + sample_rate=16000, | ||
| 88 | + feature_dim=80, | ||
| 86 | enable_endpoint_detection=True, | 89 | enable_endpoint_detection=True, |
| 87 | rule1_min_trailing_silence=2.4, | 90 | rule1_min_trailing_silence=2.4, |
| 88 | rule2_min_trailing_silence=1.2, | 91 | rule2_min_trailing_silence=1.2, |
| 89 | rule3_min_utterance_length=300, # it essentially disables this rule | 92 | rule3_min_utterance_length=300, # it essentially disables this rule |
| 93 | + decoding_method=args.decoding_method, | ||
| 94 | + max_feature_vectors=100, # 1 second | ||
| 90 | ) | 95 | ) |
| 91 | return recognizer | 96 | return recognizer |
| 92 | 97 | ||
| 93 | 98 | ||
| 94 | def main(): | 99 | def main(): |
| 95 | - print("Started! Please speak") | ||
| 96 | recognizer = create_recognizer() | 100 | recognizer = create_recognizer() |
| 101 | + print("Started! Please speak") | ||
| 102 | + | ||
| 97 | sample_rate = 16000 | 103 | sample_rate = 16000 |
| 98 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms | 104 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms |
| 99 | last_result = "" | 105 | last_result = "" |
| @@ -101,6 +107,7 @@ def main(): | @@ -101,6 +107,7 @@ def main(): | ||
| 101 | 107 | ||
| 102 | last_result = "" | 108 | last_result = "" |
| 103 | segment_id = 0 | 109 | segment_id = 0 |
| 110 | + display = sherpa_onnx.Display(max_word_per_line=30) | ||
| 104 | with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: | 111 | with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: |
| 105 | while True: | 112 | while True: |
| 106 | samples, _ = s.read(samples_per_read) # a blocking read | 113 | samples, _ = s.read(samples_per_read) # a blocking read |
| @@ -115,7 +122,7 @@ def main(): | @@ -115,7 +122,7 @@ def main(): | ||
| 115 | 122 | ||
| 116 | if result and (last_result != result): | 123 | if result and (last_result != result): |
| 117 | last_result = result | 124 | last_result = result |
| 118 | - print(f"{segment_id}: {result}") | 125 | + display.print(segment_id, result) |
| 119 | 126 | ||
| 120 | if is_endpoint: | 127 | if is_endpoint: |
| 121 | if result: | 128 | if result: |
| @@ -59,10 +59,10 @@ def get_args(): | @@ -59,10 +59,10 @@ def get_args(): | ||
| 59 | ) | 59 | ) |
| 60 | 60 | ||
| 61 | parser.add_argument( | 61 | parser.add_argument( |
| 62 | - "--wave-filename", | 62 | + "--decoding-method", |
| 63 | type=str, | 63 | type=str, |
| 64 | - help="""Path to the wave filename. Must be 16 kHz, | ||
| 65 | - mono with 16-bit samples""", | 64 | + default="greedy_search", |
| 65 | + help="Valid values are greedy_search and modified_beam_search", | ||
| 66 | ) | 66 | ) |
| 67 | 67 | ||
| 68 | return parser.parse_args() | 68 | return parser.parse_args() |
| @@ -82,9 +82,11 @@ def create_recognizer(): | @@ -82,9 +82,11 @@ def create_recognizer(): | ||
| 82 | encoder=args.encoder, | 82 | encoder=args.encoder, |
| 83 | decoder=args.decoder, | 83 | decoder=args.decoder, |
| 84 | joiner=args.joiner, | 84 | joiner=args.joiner, |
| 85 | - num_threads=4, | 85 | + num_threads=1, |
| 86 | sample_rate=16000, | 86 | sample_rate=16000, |
| 87 | feature_dim=80, | 87 | feature_dim=80, |
| 88 | + decoding_method=args.decoding_method, | ||
| 89 | + max_feature_vectors=100, # 1 second | ||
| 88 | ) | 90 | ) |
| 89 | return recognizer | 91 | return recognizer |
| 90 | 92 | ||
| @@ -96,6 +98,7 @@ def main(): | @@ -96,6 +98,7 @@ def main(): | ||
| 96 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms | 98 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms |
| 97 | last_result = "" | 99 | last_result = "" |
| 98 | stream = recognizer.create_stream() | 100 | stream = recognizer.create_stream() |
| 101 | + display = sherpa_onnx.Display(max_word_per_line=40) | ||
| 99 | with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: | 102 | with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: |
| 100 | while True: | 103 | while True: |
| 101 | samples, _ = s.read(samples_per_read) # a blocking read | 104 | samples, _ = s.read(samples_per_read) # a blocking read |
| @@ -106,7 +109,7 @@ def main(): | @@ -106,7 +109,7 @@ def main(): | ||
| 106 | result = recognizer.get_result(stream) | 109 | result = recognizer.get_result(stream) |
| 107 | if last_result != result: | 110 | if last_result != result: |
| 108 | last_result = result | 111 | last_result = result |
| 109 | - print(result) | 112 | + display.print(-1, result) |
| 110 | 113 | ||
| 111 | 114 | ||
| 112 | if __name__ == "__main__": | 115 | if __name__ == "__main__": |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include <utility> | 9 | #include <utility> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#include "sherpa-onnx/csrc/display.h" | ||
| 12 | #include "sherpa-onnx/csrc/online-recognizer.h" | 13 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| 13 | 14 | ||
| 14 | struct SherpaOnnxOnlineRecognizer { | 15 | struct SherpaOnnxOnlineRecognizer { |
| @@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream { | @@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream { | ||
| 21 | : impl(std::move(p)) {} | 22 | : impl(std::move(p)) {} |
| 22 | }; | 23 | }; |
| 23 | 24 | ||
| 25 | +struct SherpaOnnxDisplay { | ||
| 26 | + std::unique_ptr<sherpa_onnx::Display> impl; | ||
| 27 | +}; | ||
| 28 | + | ||
| 24 | SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | 29 | SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( |
| 25 | const SherpaOnnxOnlineRecognizerConfig *config) { | 30 | const SherpaOnnxOnlineRecognizerConfig *config) { |
| 26 | sherpa_onnx::OnlineRecognizerConfig recognizer_config; | 31 | sherpa_onnx::OnlineRecognizerConfig recognizer_config; |
| @@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | @@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | ||
| 37 | recognizer_config.model_config.num_threads = config->model_config.num_threads; | 42 | recognizer_config.model_config.num_threads = config->model_config.num_threads; |
| 38 | recognizer_config.model_config.debug = config->model_config.debug; | 43 | recognizer_config.model_config.debug = config->model_config.debug; |
| 39 | 44 | ||
| 45 | + recognizer_config.decoding_method = config->decoding_method; | ||
| 46 | + recognizer_config.max_active_paths = config->max_active_paths; | ||
| 47 | + | ||
| 40 | recognizer_config.enable_endpoint = config->enable_endpoint; | 48 | recognizer_config.enable_endpoint = config->enable_endpoint; |
| 41 | 49 | ||
| 42 | recognizer_config.endpoint_config.rule1.min_trailing_silence = | 50 | recognizer_config.endpoint_config.rule1.min_trailing_silence = |
| @@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, | @@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, | ||
| 124 | SherpaOnnxOnlineStream *stream) { | 132 | SherpaOnnxOnlineStream *stream) { |
| 125 | return recognizer->impl->IsEndpoint(stream->impl.get()); | 133 | return recognizer->impl->IsEndpoint(stream->impl.get()); |
| 126 | } | 134 | } |
| 135 | + | ||
| 136 | +SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) { | ||
| 137 | + SherpaOnnxDisplay *ans = new SherpaOnnxDisplay; | ||
| 138 | + ans->impl = std::make_unique<sherpa_onnx::Display>(max_word_per_line); | ||
| 139 | + return ans; | ||
| 140 | +} | ||
| 141 | + | ||
| 142 | +void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; } | ||
| 143 | + | ||
| 144 | +void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) { | ||
| 145 | + display->impl->Print(idx, s); | ||
| 146 | +} |
| @@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig { | @@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig { | ||
| 48 | SherpaOnnxFeatureConfig feat_config; | 48 | SherpaOnnxFeatureConfig feat_config; |
| 49 | SherpaOnnxOnlineTransducerModelConfig model_config; | 49 | SherpaOnnxOnlineTransducerModelConfig model_config; |
| 50 | 50 | ||
| 51 | + /// Possible values are: greedy_search, modified_beam_search | ||
| 52 | + const char *decoding_method; | ||
| 53 | + | ||
| 54 | + /// Used only when decoding_method is modified_beam_search | ||
| 55 | + /// Example value: 4 | ||
| 56 | + int32_t max_active_paths; | ||
| 57 | + | ||
| 51 | /// 0 to disable endpoint detection. | 58 | /// 0 to disable endpoint detection. |
| 52 | /// A non-zero value to enable endpoint detection. | 59 | /// A non-zero value to enable endpoint detection. |
| 53 | int32_t enable_endpoint; | 60 | int32_t enable_endpoint; |
| @@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream); | @@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream); | ||
| 187 | int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, | 194 | int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, |
| 188 | SherpaOnnxOnlineStream *stream); | 195 | SherpaOnnxOnlineStream *stream); |
| 189 | 196 | ||
| 197 | +// for displaying results on Linux/macOS. | ||
| 198 | +typedef struct SherpaOnnxDisplay SherpaOnnxDisplay; | ||
| 199 | + | ||
| 200 | +/// Create a display object. Must be freed using DestroyDisplay to avoid | ||
| 201 | +/// memory leak. | ||
| 202 | +SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line); | ||
| 203 | + | ||
| 204 | +void DestroyDisplay(SherpaOnnxDisplay *display); | ||
| 205 | + | ||
| 206 | +/// Print the result. | ||
| 207 | +void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s); | ||
| 208 | + | ||
| 190 | #ifdef __cplusplus | 209 | #ifdef __cplusplus |
| 191 | } /* extern "C" */ | 210 | } /* extern "C" */ |
| 192 | #endif | 211 | #endif |
| @@ -9,10 +9,11 @@ set(sources | @@ -9,10 +9,11 @@ set(sources | ||
| 9 | online-lstm-transducer-model.cc | 9 | online-lstm-transducer-model.cc |
| 10 | online-recognizer.cc | 10 | online-recognizer.cc |
| 11 | online-stream.cc | 11 | online-stream.cc |
| 12 | + online-transducer-decoder.cc | ||
| 12 | online-transducer-greedy-search-decoder.cc | 13 | online-transducer-greedy-search-decoder.cc |
| 13 | online-transducer-model-config.cc | 14 | online-transducer-model-config.cc |
| 14 | - online-transducer-modified-beam-search-decoder.cc | ||
| 15 | online-transducer-model.cc | 15 | online-transducer-model.cc |
| 16 | + online-transducer-modified-beam-search-decoder.cc | ||
| 16 | online-zipformer-transducer-model.cc | 17 | online-zipformer-transducer-model.cc |
| 17 | onnx-utils.cc | 18 | onnx-utils.cc |
| 18 | parse-options.cc | 19 | parse-options.cc |
| @@ -12,9 +12,16 @@ namespace sherpa_onnx { | @@ -12,9 +12,16 @@ namespace sherpa_onnx { | ||
| 12 | 12 | ||
| 13 | class Display { | 13 | class Display { |
| 14 | public: | 14 | public: |
| 15 | + explicit Display(int32_t max_word_per_line = 60) | ||
| 16 | + : max_word_per_line_(max_word_per_line) {} | ||
| 17 | + | ||
| 15 | void Print(int32_t segment_id, const std::string &s) { | 18 | void Print(int32_t segment_id, const std::string &s) { |
| 16 | #ifdef _MSC_VER | 19 | #ifdef _MSC_VER |
| 20 | + if (segment_id != -1) { | ||
| 17 | fprintf(stderr, "%d:%s\n", segment_id, s.c_str()); | 21 | fprintf(stderr, "%d:%s\n", segment_id, s.c_str()); |
| 22 | + } else { | ||
| 23 | + fprintf(stderr, "%s\n", s.c_str()); | ||
| 24 | + } | ||
| 18 | return; | 25 | return; |
| 19 | #endif | 26 | #endif |
| 20 | if (last_segment_ == segment_id) { | 27 | if (last_segment_ == segment_id) { |
| @@ -27,7 +34,9 @@ class Display { | @@ -27,7 +34,9 @@ class Display { | ||
| 27 | num_previous_lines_ = 0; | 34 | num_previous_lines_ = 0; |
| 28 | } | 35 | } |
| 29 | 36 | ||
| 37 | + if (segment_id != -1) { | ||
| 30 | fprintf(stderr, "\r%d:", segment_id); | 38 | fprintf(stderr, "\r%d:", segment_id); |
| 39 | + } | ||
| 31 | 40 | ||
| 32 | int32_t i = 0; | 41 | int32_t i = 0; |
| 33 | for (size_t n = 0; n < s.size();) { | 42 | for (size_t n = 0; n < s.size();) { |
| @@ -69,7 +78,7 @@ class Display { | @@ -69,7 +78,7 @@ class Display { | ||
| 69 | void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); } | 78 | void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); } |
| 70 | 79 | ||
| 71 | private: | 80 | private: |
| 72 | - int32_t max_word_per_line_ = 60; | 81 | + int32_t max_word_per_line_; |
| 73 | int32_t num_previous_lines_ = 0; | 82 | int32_t num_previous_lines_ = 0; |
| 74 | int32_t last_segment_ = -1; | 83 | int32_t last_segment_ = -1; |
| 75 | }; | 84 | }; |
| @@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const { | @@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const { | ||
| 28 | 28 | ||
| 29 | os << "FeatureExtractorConfig("; | 29 | os << "FeatureExtractorConfig("; |
| 30 | os << "sampling_rate=" << sampling_rate << ", "; | 30 | os << "sampling_rate=" << sampling_rate << ", "; |
| 31 | - os << "feature_dim=" << feature_dim << ")"; | 31 | + os << "feature_dim=" << feature_dim << ", "; |
| 32 | + os << "max_feature_vectors=" << max_feature_vectors << ")"; | ||
| 32 | 33 | ||
| 33 | return os.str(); | 34 | return os.str(); |
| 34 | } | 35 | } |
| @@ -40,9 +41,7 @@ class FeatureExtractor::Impl { | @@ -40,9 +41,7 @@ class FeatureExtractor::Impl { | ||
| 40 | opts_.frame_opts.snip_edges = false; | 41 | opts_.frame_opts.snip_edges = false; |
| 41 | opts_.frame_opts.samp_freq = config.sampling_rate; | 42 | opts_.frame_opts.samp_freq = config.sampling_rate; |
| 42 | 43 | ||
| 43 | - // cache 100 seconds of feature frames, which is more than enough | ||
| 44 | - // for real needs | ||
| 45 | - opts_.frame_opts.max_feature_vectors = 100 * 100; | 44 | + opts_.frame_opts.max_feature_vectors = config.max_feature_vectors; |
| 46 | 45 | ||
| 47 | opts_.mel_opts.num_bins = config.feature_dim; | 46 | opts_.mel_opts.num_bins = config.feature_dim; |
| 48 | 47 |
| @@ -16,6 +16,7 @@ namespace sherpa_onnx { | @@ -16,6 +16,7 @@ namespace sherpa_onnx { | ||
| 16 | struct FeatureExtractorConfig { | 16 | struct FeatureExtractorConfig { |
| 17 | float sampling_rate = 16000; | 17 | float sampling_rate = 16000; |
| 18 | int32_t feature_dim = 80; | 18 | int32_t feature_dim = 80; |
| 19 | + int32_t max_feature_vectors = -1; | ||
| 19 | 20 | ||
| 20 | std::string ToString() const; | 21 | std::string ToString() const; |
| 21 | 22 |
| @@ -18,7 +18,7 @@ namespace sherpa_onnx { | @@ -18,7 +18,7 @@ namespace sherpa_onnx { | ||
| 18 | 18 | ||
| 19 | struct Hypothesis { | 19 | struct Hypothesis { |
| 20 | // The predicted tokens so far. Newly predicated tokens are appended. | 20 | // The predicted tokens so far. Newly predicated tokens are appended. |
| 21 | - std::vector<int32_t> ys; | 21 | + std::vector<int64_t> ys; |
| 22 | 22 | ||
| 23 | // timestamps[i] contains the frame number after subsampling | 23 | // timestamps[i] contains the frame number after subsampling |
| 24 | // on which ys[i] is decoded. | 24 | // on which ys[i] is decoded. |
| @@ -30,7 +30,7 @@ struct Hypothesis { | @@ -30,7 +30,7 @@ struct Hypothesis { | ||
| 30 | int32_t num_trailing_blanks = 0; | 30 | int32_t num_trailing_blanks = 0; |
| 31 | 31 | ||
| 32 | Hypothesis() = default; | 32 | Hypothesis() = default; |
| 33 | - Hypothesis(const std::vector<int32_t> &ys, double log_prob) | 33 | + Hypothesis(const std::vector<int64_t> &ys, double log_prob) |
| 34 | : ys(ys), log_prob(log_prob) {} | 34 | : ys(ys), log_prob(log_prob) {} |
| 35 | 35 | ||
| 36 | // If two Hypotheses have the same `Key`, then they contain | 36 | // If two Hypotheses have the same `Key`, then they contain |
| @@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 43 | "True to enable endpoint detection. False to disable it."); | 43 | "True to enable endpoint detection. False to disable it."); |
| 44 | po->Register("max-active-paths", &max_active_paths, | 44 | po->Register("max-active-paths", &max_active_paths, |
| 45 | "beam size used in modified beam search."); | 45 | "beam size used in modified beam search."); |
| 46 | - po->Register("decoding-mothod", &decoding_method, | 46 | + po->Register("decoding-method", &decoding_method, |
| 47 | "decoding method," | 47 | "decoding method," |
| 48 | "now support greedy_search and modified_beam_search."); | 48 | "now support greedy_search and modified_beam_search."); |
| 49 | } | 49 | } |
| @@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 59 | os << "feat_config=" << feat_config.ToString() << ", "; | 59 | os << "feat_config=" << feat_config.ToString() << ", "; |
| 60 | os << "model_config=" << model_config.ToString() << ", "; | 60 | os << "model_config=" << model_config.ToString() << ", "; |
| 61 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; | 61 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; |
| 62 | - os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ","; | ||
| 63 | - os << "max_active_paths=" << max_active_paths << ","; | 62 | + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; |
| 63 | + os << "max_active_paths=" << max_active_paths << ", "; | ||
| 64 | os << "decoding_method=\"" << decoding_method << "\")"; | 64 | os << "decoding_method=\"" << decoding_method << "\")"; |
| 65 | 65 | ||
| 66 | return os.str(); | 66 | return os.str(); |
| @@ -187,16 +187,14 @@ class OnlineRecognizer::Impl { | @@ -187,16 +187,14 @@ class OnlineRecognizer::Impl { | ||
| 187 | } | 187 | } |
| 188 | 188 | ||
| 189 | void Reset(OnlineStream *s) const { | 189 | void Reset(OnlineStream *s) const { |
| 190 | - // reset result, neural network model state, and | ||
| 191 | - // the feature extractor state | ||
| 192 | - | ||
| 193 | - // reset result | 190 | + // we keep the decoder_out |
| 191 | + decoder_->UpdateDecoderOut(&s->GetResult()); | ||
| 192 | + Ort::Value decoder_out = std::move(s->GetResult().decoder_out); | ||
| 194 | s->SetResult(decoder_->GetEmptyResult()); | 193 | s->SetResult(decoder_->GetEmptyResult()); |
| 194 | + s->GetResult().decoder_out = std::move(decoder_out); | ||
| 195 | 195 | ||
| 196 | - // reset neural network model state | ||
| 197 | - s->SetStates(model_->GetEncoderInitStates()); | ||
| 198 | - | ||
| 199 | - // reset feature extractor | 196 | + // Note: We only update counters. The underlying audio samples |
| 197 | + // are not discarded. | ||
| 200 | s->Reset(); | 198 | s->Reset(); |
| 201 | } | 199 | } |
| 202 | 200 |
| @@ -33,21 +33,26 @@ struct OnlineRecognizerConfig { | @@ -33,21 +33,26 @@ struct OnlineRecognizerConfig { | ||
| 33 | OnlineTransducerModelConfig model_config; | 33 | OnlineTransducerModelConfig model_config; |
| 34 | EndpointConfig endpoint_config; | 34 | EndpointConfig endpoint_config; |
| 35 | bool enable_endpoint = true; | 35 | bool enable_endpoint = true; |
| 36 | - int32_t max_active_paths = 4; | ||
| 37 | 36 | ||
| 38 | - std::string decoding_method = "modified_beam_search"; | 37 | + std::string decoding_method = "greedy_search"; |
| 39 | // now support modified_beam_search and greedy_search | 38 | // now support modified_beam_search and greedy_search |
| 40 | 39 | ||
| 40 | + int32_t max_active_paths = 4; // used only for modified_beam_search | ||
| 41 | + | ||
| 41 | OnlineRecognizerConfig() = default; | 42 | OnlineRecognizerConfig() = default; |
| 42 | 43 | ||
| 43 | OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, | 44 | OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, |
| 44 | const OnlineTransducerModelConfig &model_config, | 45 | const OnlineTransducerModelConfig &model_config, |
| 45 | const EndpointConfig &endpoint_config, | 46 | const EndpointConfig &endpoint_config, |
| 46 | - bool enable_endpoint) | 47 | + bool enable_endpoint, |
| 48 | + const std::string &decoding_method, | ||
| 49 | + int32_t max_active_paths) | ||
| 47 | : feat_config(feat_config), | 50 | : feat_config(feat_config), |
| 48 | model_config(model_config), | 51 | model_config(model_config), |
| 49 | endpoint_config(endpoint_config), | 52 | endpoint_config(endpoint_config), |
| 50 | - enable_endpoint(enable_endpoint) {} | 53 | + enable_endpoint(enable_endpoint), |
| 54 | + decoding_method(decoding_method), | ||
| 55 | + max_active_paths(max_active_paths) {} | ||
| 51 | 56 | ||
| 52 | void Register(ParseOptions *po); | 57 | void Register(ParseOptions *po); |
| 53 | bool Validate() const; | 58 | bool Validate() const; |
| @@ -22,18 +22,21 @@ class OnlineStream::Impl { | @@ -22,18 +22,21 @@ class OnlineStream::Impl { | ||
| 22 | 22 | ||
| 23 | void InputFinished() { feat_extractor_.InputFinished(); } | 23 | void InputFinished() { feat_extractor_.InputFinished(); } |
| 24 | 24 | ||
| 25 | - int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); } | 25 | + int32_t NumFramesReady() const { |
| 26 | + return feat_extractor_.NumFramesReady() - start_frame_index_; | ||
| 27 | + } | ||
| 26 | 28 | ||
| 27 | bool IsLastFrame(int32_t frame) const { | 29 | bool IsLastFrame(int32_t frame) const { |
| 28 | return feat_extractor_.IsLastFrame(frame); | 30 | return feat_extractor_.IsLastFrame(frame); |
| 29 | } | 31 | } |
| 30 | 32 | ||
| 31 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { | 33 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { |
| 32 | - return feat_extractor_.GetFrames(frame_index, n); | 34 | + return feat_extractor_.GetFrames(frame_index + start_frame_index_, n); |
| 33 | } | 35 | } |
| 34 | 36 | ||
| 35 | void Reset() { | 37 | void Reset() { |
| 36 | - feat_extractor_.Reset(); | 38 | + // we don't reset the feature extractor |
| 39 | + start_frame_index_ += num_processed_frames_; | ||
| 37 | num_processed_frames_ = 0; | 40 | num_processed_frames_ = 0; |
| 38 | } | 41 | } |
| 39 | 42 | ||
| @@ -41,7 +44,7 @@ class OnlineStream::Impl { | @@ -41,7 +44,7 @@ class OnlineStream::Impl { | ||
| 41 | 44 | ||
| 42 | void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } | 45 | void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } |
| 43 | 46 | ||
| 44 | - const OnlineTransducerDecoderResult &GetResult() const { return result_; } | 47 | + OnlineTransducerDecoderResult &GetResult() { return result_; } |
| 45 | 48 | ||
| 46 | int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } | 49 | int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } |
| 47 | 50 | ||
| @@ -54,6 +57,7 @@ class OnlineStream::Impl { | @@ -54,6 +57,7 @@ class OnlineStream::Impl { | ||
| 54 | private: | 57 | private: |
| 55 | FeatureExtractor feat_extractor_; | 58 | FeatureExtractor feat_extractor_; |
| 56 | int32_t num_processed_frames_ = 0; // before subsampling | 59 | int32_t num_processed_frames_ = 0; // before subsampling |
| 60 | + int32_t start_frame_index_ = 0; // never reset | ||
| 57 | OnlineTransducerDecoderResult result_; | 61 | OnlineTransducerDecoderResult result_; |
| 58 | std::vector<Ort::Value> states_; | 62 | std::vector<Ort::Value> states_; |
| 59 | }; | 63 | }; |
| @@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { | @@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { | ||
| 93 | impl_->SetResult(r); | 97 | impl_->SetResult(r); |
| 94 | } | 98 | } |
| 95 | 99 | ||
| 96 | -const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { | 100 | +OnlineTransducerDecoderResult &OnlineStream::GetResult() { |
| 97 | return impl_->GetResult(); | 101 | return impl_->GetResult(); |
| 98 | } | 102 | } |
| 99 | 103 |
| @@ -63,7 +63,7 @@ class OnlineStream { | @@ -63,7 +63,7 @@ class OnlineStream { | ||
| 63 | int32_t &GetNumProcessedFrames(); | 63 | int32_t &GetNumProcessedFrames(); |
| 64 | 64 | ||
| 65 | void SetResult(const OnlineTransducerDecoderResult &r); | 65 | void SetResult(const OnlineTransducerDecoderResult &r); |
| 66 | - const OnlineTransducerDecoderResult &GetResult() const; | 66 | + OnlineTransducerDecoderResult &GetResult(); |
| 67 | 67 | ||
| 68 | void SetStates(std::vector<Ort::Value> states); | 68 | void SetStates(std::vector<Ort::Value> states); |
| 69 | std::vector<Ort::Value> &GetStates(); | 69 | std::vector<Ort::Value> &GetStates(); |
| 1 | +// sherpa-onnx/csrc/online-transducer-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <utility> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +OnlineTransducerDecoderResult::OnlineTransducerDecoderResult( | ||
| 16 | + const OnlineTransducerDecoderResult &other) | ||
| 17 | + : OnlineTransducerDecoderResult() { | ||
| 18 | + *this = other; | ||
| 19 | +} | ||
| 20 | + | ||
| 21 | +OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( | ||
| 22 | + const OnlineTransducerDecoderResult &other) { | ||
| 23 | + if (this == &other) { | ||
| 24 | + return *this; | ||
| 25 | + } | ||
| 26 | + | ||
| 27 | + tokens = other.tokens; | ||
| 28 | + num_trailing_blanks = other.num_trailing_blanks; | ||
| 29 | + | ||
| 30 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 31 | + if (other.decoder_out) { | ||
| 32 | + decoder_out = Clone(allocator, &other.decoder_out); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + hyps = other.hyps; | ||
| 36 | + | ||
| 37 | + return *this; | ||
| 38 | +} | ||
| 39 | + | ||
| 40 | +OnlineTransducerDecoderResult::OnlineTransducerDecoderResult( | ||
| 41 | + OnlineTransducerDecoderResult &&other) | ||
| 42 | + : OnlineTransducerDecoderResult() { | ||
| 43 | + *this = std::move(other); | ||
| 44 | +} | ||
| 45 | + | ||
| 46 | +OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( | ||
| 47 | + OnlineTransducerDecoderResult &&other) { | ||
| 48 | + if (this == &other) { | ||
| 49 | + return *this; | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + tokens = std::move(other.tokens); | ||
| 53 | + num_trailing_blanks = other.num_trailing_blanks; | ||
| 54 | + decoder_out = std::move(other.decoder_out); | ||
| 55 | + hyps = std::move(other.hyps); | ||
| 56 | + | ||
| 57 | + return *this; | ||
| 58 | +} | ||
| 59 | + | ||
| 60 | +} // namespace sherpa_onnx |
| @@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult { | @@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult { | ||
| 19 | /// number of trailing blank frames decoded so far | 19 | /// number of trailing blank frames decoded so far |
| 20 | int32_t num_trailing_blanks = 0; | 20 | int32_t num_trailing_blanks = 0; |
| 21 | 21 | ||
| 22 | + // Cache decoder_out for endpointing | ||
| 23 | + Ort::Value decoder_out; | ||
| 24 | + | ||
| 22 | // used only in modified beam_search | 25 | // used only in modified beam_search |
| 23 | Hypotheses hyps; | 26 | Hypotheses hyps; |
| 27 | + | ||
| 28 | + OnlineTransducerDecoderResult() | ||
| 29 | + : tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {} | ||
| 30 | + | ||
| 31 | + OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other); | ||
| 32 | + | ||
| 33 | + OnlineTransducerDecoderResult &operator=( | ||
| 34 | + const OnlineTransducerDecoderResult &other); | ||
| 35 | + | ||
| 36 | + OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other); | ||
| 37 | + | ||
| 38 | + OnlineTransducerDecoderResult &operator=( | ||
| 39 | + OnlineTransducerDecoderResult &&other); | ||
| 24 | }; | 40 | }; |
| 25 | 41 | ||
| 26 | class OnlineTransducerDecoder { | 42 | class OnlineTransducerDecoder { |
| @@ -53,6 +69,9 @@ class OnlineTransducerDecoder { | @@ -53,6 +69,9 @@ class OnlineTransducerDecoder { | ||
| 53 | */ | 69 | */ |
| 54 | virtual void Decode(Ort::Value encoder_out, | 70 | virtual void Decode(Ort::Value encoder_out, |
| 55 | std::vector<OnlineTransducerDecoderResult> *result) = 0; | 71 | std::vector<OnlineTransducerDecoderResult> *result) = 0; |
| 72 | + | ||
| 73 | + // used for endpointing. We need to keep decoder_out after reset | ||
| 74 | + virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} | ||
| 56 | }; | 75 | }; |
| 57 | 76 | ||
| 58 | } // namespace sherpa_onnx | 77 | } // namespace sherpa_onnx |
| @@ -13,6 +13,43 @@ | @@ -13,6 +13,43 @@ | ||
| 13 | 13 | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| 16 | +static void UseCachedDecoderOut( | ||
| 17 | + const std::vector<OnlineTransducerDecoderResult> &results, | ||
| 18 | + Ort::Value *decoder_out) { | ||
| 19 | + std::vector<int64_t> shape = | ||
| 20 | + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 21 | + float *dst = decoder_out->GetTensorMutableData<float>(); | ||
| 22 | + for (const auto &r : results) { | ||
| 23 | + if (r.decoder_out) { | ||
| 24 | + const float *src = r.decoder_out.GetTensorData<float>(); | ||
| 25 | + std::copy(src, src + shape[1], dst); | ||
| 26 | + } | ||
| 27 | + dst += shape[1]; | ||
| 28 | + } | ||
| 29 | +} | ||
| 30 | + | ||
| 31 | +static void UpdateCachedDecoderOut( | ||
| 32 | + OrtAllocator *allocator, const Ort::Value *decoder_out, | ||
| 33 | + std::vector<OnlineTransducerDecoderResult> *results) { | ||
| 34 | + std::vector<int64_t> shape = | ||
| 35 | + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 36 | + auto memory_info = | ||
| 37 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 38 | + std::array<int64_t, 2> v_shape{1, shape[1]}; | ||
| 39 | + | ||
| 40 | + const float *src = decoder_out->GetTensorData<float>(); | ||
| 41 | + for (auto &r : *results) { | ||
| 42 | + if (!r.decoder_out) { | ||
| 43 | + r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(), | ||
| 44 | + v_shape.size()); | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + float *dst = r.decoder_out.GetTensorMutableData<float>(); | ||
| 48 | + std::copy(src, src + shape[1], dst); | ||
| 49 | + src += shape[1]; | ||
| 50 | + } | ||
| 51 | +} | ||
| 52 | + | ||
| 16 | OnlineTransducerDecoderResult | 53 | OnlineTransducerDecoderResult |
| 17 | OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { | 54 | OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { |
| 18 | int32_t context_size = model_->ContextSize(); | 55 | int32_t context_size = model_->ContextSize(); |
| @@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 53 | 90 | ||
| 54 | Ort::Value decoder_input = model_->BuildDecoderInput(*result); | 91 | Ort::Value decoder_input = model_->BuildDecoderInput(*result); |
| 55 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | 92 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); |
| 93 | + UseCachedDecoderOut(*result, &decoder_out); | ||
| 56 | 94 | ||
| 57 | for (int32_t t = 0; t != num_frames; ++t) { | 95 | for (int32_t t = 0; t != num_frames; ++t) { |
| 58 | Ort::Value cur_encoder_out = | 96 | Ort::Value cur_encoder_out = |
| @@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 77 | } | 115 | } |
| 78 | } | 116 | } |
| 79 | if (emitted) { | 117 | if (emitted) { |
| 80 | - decoder_input = model_->BuildDecoderInput(*result); | 118 | + Ort::Value decoder_input = model_->BuildDecoderInput(*result); |
| 81 | decoder_out = model_->RunDecoder(std::move(decoder_input)); | 119 | decoder_out = model_->RunDecoder(std::move(decoder_input)); |
| 82 | } | 120 | } |
| 83 | } | 121 | } |
| 122 | + | ||
| 123 | + UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result); | ||
| 84 | } | 124 | } |
| 85 | 125 | ||
| 86 | } // namespace sherpa_onnx | 126 | } // namespace sherpa_onnx |
| @@ -13,6 +13,29 @@ | @@ -13,6 +13,29 @@ | ||
| 13 | 13 | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| 16 | +static void UseCachedDecoderOut( | ||
| 17 | + const std::vector<int32_t> &hyps_num_split, | ||
| 18 | + const std::vector<OnlineTransducerDecoderResult> &results, | ||
| 19 | + int32_t context_size, Ort::Value *decoder_out) { | ||
| 20 | + std::vector<int64_t> shape = | ||
| 21 | + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 22 | + | ||
| 23 | + float *dst = decoder_out->GetTensorMutableData<float>(); | ||
| 24 | + | ||
| 25 | + int32_t batch_size = static_cast<int32_t>(results.size()); | ||
| 26 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 27 | + int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i]; | ||
| 28 | + if (num_hyps > 1 || !results[i].decoder_out) { | ||
| 29 | + dst += num_hyps * shape[1]; | ||
| 30 | + continue; | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + const float *src = results[i].decoder_out.GetTensorData<float>(); | ||
| 34 | + std::copy(src, src + shape[1], dst); | ||
| 35 | + dst += shape[1]; | ||
| 36 | + } | ||
| 37 | +} | ||
| 38 | + | ||
| 16 | static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, | 39 | static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, |
| 17 | const std::vector<int32_t> &hyps_num_split) { | 40 | const std::vector<int32_t> &hyps_num_split) { |
| 18 | std::vector<int64_t> cur_encoder_out_shape = | 41 | std::vector<int64_t> cur_encoder_out_shape = |
| @@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { | @@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { | ||
| 50 | int32_t context_size = model_->ContextSize(); | 73 | int32_t context_size = model_->ContextSize(); |
| 51 | int32_t blank_id = 0; // always 0 | 74 | int32_t blank_id = 0; // always 0 |
| 52 | OnlineTransducerDecoderResult r; | 75 | OnlineTransducerDecoderResult r; |
| 53 | - std::vector<int32_t> blanks(context_size, blank_id); | 76 | + std::vector<int64_t> blanks(context_size, blank_id); |
| 54 | Hypotheses blank_hyp({{blanks, 0}}); | 77 | Hypotheses blank_hyp({{blanks, 0}}); |
| 55 | r.hyps = std::move(blank_hyp); | 78 | r.hyps = std::move(blank_hyp); |
| 56 | return r; | 79 | return r; |
| @@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 110 | 133 | ||
| 111 | Ort::Value decoder_input = model_->BuildDecoderInput(prev); | 134 | Ort::Value decoder_input = model_->BuildDecoderInput(prev); |
| 112 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | 135 | Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); |
| 136 | + if (t == 0) { | ||
| 137 | + UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(), | ||
| 138 | + &decoder_out); | ||
| 139 | + } | ||
| 113 | 140 | ||
| 114 | Ort::Value cur_encoder_out = | 141 | Ort::Value cur_encoder_out = |
| 115 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); | 142 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); |
| @@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 147 | } | 174 | } |
| 148 | 175 | ||
| 149 | for (int32_t b = 0; b != batch_size; ++b) { | 176 | for (int32_t b = 0; b != batch_size; ++b) { |
| 150 | - (*result)[b].hyps = std::move(cur[b]); | 177 | + auto &hyps = cur[b]; |
| 178 | + auto best_hyp = hyps.GetMostProbable(true); | ||
| 179 | + | ||
| 180 | + (*result)[b].hyps = std::move(hyps); | ||
| 181 | + (*result)[b].tokens = std::move(best_hyp.ys); | ||
| 182 | + (*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks; | ||
| 183 | + } | ||
| 184 | +} | ||
| 185 | + | ||
| 186 | +void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut( | ||
| 187 | + OnlineTransducerDecoderResult *result) { | ||
| 188 | + if (result->tokens.size() == model_->ContextSize()) { | ||
| 189 | + result->decoder_out = Ort::Value{nullptr}; | ||
| 190 | + return; | ||
| 151 | } | 191 | } |
| 192 | + Ort::Value decoder_input = model_->BuildDecoderInput({*result}); | ||
| 193 | + result->decoder_out = model_->RunDecoder(std::move(decoder_input)); | ||
| 152 | } | 194 | } |
| 153 | 195 | ||
| 154 | } // namespace sherpa_onnx | 196 | } // namespace sherpa_onnx |
| @@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder | @@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder | ||
| 27 | void Decode(Ort::Value encoder_out, | 27 | void Decode(Ort::Value encoder_out, |
| 28 | std::vector<OnlineTransducerDecoderResult> *result) override; | 28 | std::vector<OnlineTransducerDecoderResult> *result) override; |
| 29 | 29 | ||
| 30 | + void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override; | ||
| 31 | + | ||
| 30 | private: | 32 | private: |
| 31 | OnlineTransducerModel *model_; // Not owned | 33 | OnlineTransducerModel *model_; // Not owned |
| 32 | int32_t max_active_paths_; | 34 | int32_t max_active_paths_; |
| @@ -21,7 +21,7 @@ static void Handler(int sig) { | @@ -21,7 +21,7 @@ static void Handler(int sig) { | ||
| 21 | } | 21 | } |
| 22 | 22 | ||
| 23 | int main(int32_t argc, char *argv[]) { | 23 | int main(int32_t argc, char *argv[]) { |
| 24 | - if (argc < 6 || argc > 7) { | 24 | + if (argc < 6 || argc > 8) { |
| 25 | const char *usage = R"usage( | 25 | const char *usage = R"usage( |
| 26 | Usage: | 26 | Usage: |
| 27 | ./bin/sherpa-onnx-alsa \ | 27 | ./bin/sherpa-onnx-alsa \ |
| @@ -30,7 +30,10 @@ Usage: | @@ -30,7 +30,10 @@ Usage: | ||
| 30 | /path/to/decoder.onnx \ | 30 | /path/to/decoder.onnx \ |
| 31 | /path/to/joiner.onnx \ | 31 | /path/to/joiner.onnx \ |
| 32 | device_name \ | 32 | device_name \ |
| 33 | - [num_threads] | 33 | + [num_threads [decoding_method]] |
| 34 | + | ||
| 35 | +Default value for num_threads is 2. | ||
| 36 | +Valid values for decoding_method: greedy_search (default), modified_beam_search. | ||
| 34 | 37 | ||
| 35 | Please refer to | 38 | Please refer to |
| 36 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | 39 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| @@ -79,6 +82,11 @@ as the device_name. | @@ -79,6 +82,11 @@ as the device_name. | ||
| 79 | config.model_config.num_threads = atoi(argv[6]); | 82 | config.model_config.num_threads = atoi(argv[6]); |
| 80 | } | 83 | } |
| 81 | 84 | ||
| 85 | + if (argc == 8) { | ||
| 86 | + config.decoding_method = argv[7]; | ||
| 87 | + } | ||
| 88 | + config.max_active_paths = 4; | ||
| 89 | + | ||
| 82 | config.enable_endpoint = true; | 90 | config.enable_endpoint = true; |
| 83 | 91 | ||
| 84 | config.endpoint_config.rule1.min_trailing_silence = 2.4; | 92 | config.endpoint_config.rule1.min_trailing_silence = 2.4; |
| @@ -36,7 +36,7 @@ static void Handler(int32_t sig) { | @@ -36,7 +36,7 @@ static void Handler(int32_t sig) { | ||
| 36 | } | 36 | } |
| 37 | 37 | ||
| 38 | int32_t main(int32_t argc, char *argv[]) { | 38 | int32_t main(int32_t argc, char *argv[]) { |
| 39 | - if (argc < 5 || argc > 6) { | 39 | + if (argc < 5 || argc > 7) { |
| 40 | const char *usage = R"usage( | 40 | const char *usage = R"usage( |
| 41 | Usage: | 41 | Usage: |
| 42 | ./bin/sherpa-onnx-microphone \ | 42 | ./bin/sherpa-onnx-microphone \ |
| @@ -44,7 +44,10 @@ Usage: | @@ -44,7 +44,10 @@ Usage: | ||
| 44 | /path/to/encoder.onnx\ | 44 | /path/to/encoder.onnx\ |
| 45 | /path/to/decoder.onnx\ | 45 | /path/to/decoder.onnx\ |
| 46 | /path/to/joiner.onnx\ | 46 | /path/to/joiner.onnx\ |
| 47 | - [num_threads] | 47 | + [num_threads [decoding_method]] |
| 48 | + | ||
| 49 | +Default value for num_threads is 2. | ||
| 50 | +Valid values for decoding_method: greedy_search (default), modified_beam_search. | ||
| 48 | 51 | ||
| 49 | Please refer to | 52 | Please refer to |
| 50 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | 53 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| @@ -70,6 +73,11 @@ for a list of pre-trained models to download. | @@ -70,6 +73,11 @@ for a list of pre-trained models to download. | ||
| 70 | config.model_config.num_threads = atoi(argv[5]); | 73 | config.model_config.num_threads = atoi(argv[5]); |
| 71 | } | 74 | } |
| 72 | 75 | ||
| 76 | + if (argc == 7) { | ||
| 77 | + config.decoding_method = argv[6]; | ||
| 78 | + } | ||
| 79 | + config.max_active_paths = 4; | ||
| 80 | + | ||
| 73 | config.enable_endpoint = true; | 81 | config.enable_endpoint = true; |
| 74 | 82 | ||
| 75 | config.endpoint_config.rule1.min_trailing_silence = 2.4; | 83 | config.endpoint_config.rule1.min_trailing_silence = 2.4; |
| @@ -14,7 +14,7 @@ | @@ -14,7 +14,7 @@ | ||
| 14 | #include "sherpa-onnx/csrc/wave-reader.h" | 14 | #include "sherpa-onnx/csrc/wave-reader.h" |
| 15 | 15 | ||
| 16 | int main(int32_t argc, char *argv[]) { | 16 | int main(int32_t argc, char *argv[]) { |
| 17 | - if (argc < 6 || argc > 7) { | 17 | + if (argc < 6 || argc > 8) { |
| 18 | const char *usage = R"usage( | 18 | const char *usage = R"usage( |
| 19 | Usage: | 19 | Usage: |
| 20 | ./bin/sherpa-onnx \ | 20 | ./bin/sherpa-onnx \ |
| @@ -22,7 +22,10 @@ Usage: | @@ -22,7 +22,10 @@ Usage: | ||
| 22 | /path/to/encoder.onnx \ | 22 | /path/to/encoder.onnx \ |
| 23 | /path/to/decoder.onnx \ | 23 | /path/to/decoder.onnx \ |
| 24 | /path/to/joiner.onnx \ | 24 | /path/to/joiner.onnx \ |
| 25 | - /path/to/foo.wav [num_threads] | 25 | + /path/to/foo.wav [num_threads [decoding_method]] |
| 26 | + | ||
| 27 | +Default value for num_threads is 2. | ||
| 28 | +Valid values for decoding_method: greedy_search (default), modified_beam_search. | ||
| 26 | 29 | ||
| 27 | Please refer to | 30 | Please refer to |
| 28 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | 31 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| @@ -45,9 +48,15 @@ for a list of pre-trained models to download. | @@ -45,9 +48,15 @@ for a list of pre-trained models to download. | ||
| 45 | std::string wav_filename = argv[5]; | 48 | std::string wav_filename = argv[5]; |
| 46 | 49 | ||
| 47 | config.model_config.num_threads = 2; | 50 | config.model_config.num_threads = 2; |
| 48 | - if (argc == 7) { | 51 | + if (argc == 7 && atoi(argv[6]) > 0) { |
| 49 | config.model_config.num_threads = atoi(argv[6]); | 52 | config.model_config.num_threads = atoi(argv[6]); |
| 50 | } | 53 | } |
| 54 | + | ||
| 55 | + if (argc == 8) { | ||
| 56 | + config.decoding_method = argv[7]; | ||
| 57 | + } | ||
| 58 | + config.max_active_paths = 4; | ||
| 59 | + | ||
| 51 | fprintf(stderr, "%s\n", config.ToString().c_str()); | 60 | fprintf(stderr, "%s\n", config.ToString().c_str()); |
| 52 | 61 | ||
| 53 | sherpa_onnx::OnlineRecognizer recognizer(config); | 62 | sherpa_onnx::OnlineRecognizer recognizer(config); |
| @@ -98,6 +107,7 @@ for a list of pre-trained models to download. | @@ -98,6 +107,7 @@ for a list of pre-trained models to download. | ||
| 98 | 1000.; | 107 | 1000.; |
| 99 | 108 | ||
| 100 | fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); | 109 | fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); |
| 110 | + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); | ||
| 101 | 111 | ||
| 102 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | 112 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); |
| 103 | float rtf = elapsed_seconds / duration; | 113 | float rtf = elapsed_seconds / duration; |
| 1 | include_directories(${CMAKE_SOURCE_DIR}) | 1 | include_directories(${CMAKE_SOURCE_DIR}) |
| 2 | 2 | ||
| 3 | pybind11_add_module(_sherpa_onnx | 3 | pybind11_add_module(_sherpa_onnx |
| 4 | + display.cc | ||
| 5 | + endpoint.cc | ||
| 4 | features.cc | 6 | features.cc |
| 7 | + online-recognizer.cc | ||
| 8 | + online-stream.cc | ||
| 5 | online-transducer-model-config.cc | 9 | online-transducer-model-config.cc |
| 6 | sherpa-onnx.cc | 10 | sherpa-onnx.cc |
| 7 | - endpoint.cc | ||
| 8 | - online-stream.cc | ||
| 9 | - online-recognizer.cc | ||
| 10 | ) | 11 | ) |
| 11 | 12 | ||
| 12 | if(APPLE) | 13 | if(APPLE) |
sherpa-onnx/python/csrc/display.cc
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/display.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/display.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/display.h" | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +void PybindDisplay(py::module *m) { | ||
| 12 | + using PyClass = Display; | ||
| 13 | + py::class_<PyClass>(*m, "Display") | ||
| 14 | + .def(py::init<int32_t>(), py::arg("max_word_per_line") = 60) | ||
| 15 | + .def("print", &PyClass::Print, py::arg("idx"), py::arg("s")); | ||
| 16 | +} | ||
| 17 | + | ||
| 18 | +} // namespace sherpa_onnx |
sherpa-onnx/python/csrc/display.h
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/display.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindDisplay(py::module *m); | ||
| 13 | + | ||
| 14 | +} // namespace sherpa_onnx | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_ |
| @@ -11,10 +11,12 @@ namespace sherpa_onnx { | @@ -11,10 +11,12 @@ namespace sherpa_onnx { | ||
| 11 | static void PybindFeatureExtractorConfig(py::module *m) { | 11 | static void PybindFeatureExtractorConfig(py::module *m) { |
| 12 | using PyClass = FeatureExtractorConfig; | 12 | using PyClass = FeatureExtractorConfig; |
| 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") | 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") |
| 14 | - .def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000, | ||
| 15 | - py::arg("feature_dim") = 80) | 14 | + .def(py::init<float, int32_t, int32_t>(), |
| 15 | + py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80, | ||
| 16 | + py::arg("max_feature_vectors") = -1) | ||
| 16 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) | 17 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) |
| 17 | .def_readwrite("feature_dim", &PyClass::feature_dim) | 18 | .def_readwrite("feature_dim", &PyClass::feature_dim) |
| 19 | + .def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors) | ||
| 18 | .def("__str__", &PyClass::ToString); | 20 | .def("__str__", &PyClass::ToString); |
| 19 | } | 21 | } |
| 20 | 22 |
| @@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 22 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") | 22 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") |
| 23 | .def(py::init<const FeatureExtractorConfig &, | 23 | .def(py::init<const FeatureExtractorConfig &, |
| 24 | const OnlineTransducerModelConfig &, const EndpointConfig &, | 24 | const OnlineTransducerModelConfig &, const EndpointConfig &, |
| 25 | - bool>(), | 25 | + bool, const std::string &, int32_t>(), |
| 26 | py::arg("feat_config"), py::arg("model_config"), | 26 | py::arg("feat_config"), py::arg("model_config"), |
| 27 | - py::arg("endpoint_config"), py::arg("enable_endpoint")) | 27 | + py::arg("endpoint_config"), py::arg("enable_endpoint"), |
| 28 | + py::arg("decoding_method"), py::arg("max_active_paths")) | ||
| 28 | .def_readwrite("feat_config", &PyClass::feat_config) | 29 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 29 | .def_readwrite("model_config", &PyClass::model_config) | 30 | .def_readwrite("model_config", &PyClass::model_config) |
| 30 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) | 31 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) |
| 31 | .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) | 32 | .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) |
| 33 | + .def_readwrite("decoding_method", &PyClass::decoding_method) | ||
| 34 | + .def_readwrite("max_active_paths", &PyClass::max_active_paths) | ||
| 32 | .def("__str__", &PyClass::ToString); | 35 | .def("__str__", &PyClass::ToString); |
| 33 | } | 36 | } |
| 34 | 37 |
| @@ -4,6 +4,7 @@ | @@ -4,6 +4,7 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/python/csrc/sherpa-onnx.h" | 5 | #include "sherpa-onnx/python/csrc/sherpa-onnx.h" |
| 6 | 6 | ||
| 7 | +#include "sherpa-onnx/python/csrc/display.h" | ||
| 7 | #include "sherpa-onnx/python/csrc/endpoint.h" | 8 | #include "sherpa-onnx/python/csrc/endpoint.h" |
| 8 | #include "sherpa-onnx/python/csrc/features.h" | 9 | #include "sherpa-onnx/python/csrc/features.h" |
| 9 | #include "sherpa-onnx/python/csrc/online-recognizer.h" | 10 | #include "sherpa-onnx/python/csrc/online-recognizer.h" |
| @@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 19 | PybindOnlineStream(&m); | 20 | PybindOnlineStream(&m); |
| 20 | PybindEndpoint(&m); | 21 | PybindEndpoint(&m); |
| 21 | PybindOnlineRecognizer(&m); | 22 | PybindOnlineRecognizer(&m); |
| 23 | + | ||
| 24 | + PybindDisplay(&m); | ||
| 22 | } | 25 | } |
| 23 | 26 | ||
| 24 | } // namespace sherpa_onnx | 27 | } // namespace sherpa_onnx |
| 1 | -from _sherpa_onnx import ( | ||
| 2 | - EndpointConfig, | ||
| 3 | - FeatureExtractorConfig, | ||
| 4 | - OnlineRecognizerConfig, | ||
| 5 | - OnlineStream, | ||
| 6 | - OnlineTransducerModelConfig, | ||
| 7 | -) | 1 | +from _sherpa_onnx import Display |
| 8 | 2 | ||
| 9 | from .online_recognizer import OnlineRecognizer | 3 | from .online_recognizer import OnlineRecognizer |
| @@ -32,6 +32,9 @@ class OnlineRecognizer(object): | @@ -32,6 +32,9 @@ class OnlineRecognizer(object): | ||
| 32 | rule1_min_trailing_silence: int = 2.4, | 32 | rule1_min_trailing_silence: int = 2.4, |
| 33 | rule2_min_trailing_silence: int = 1.2, | 33 | rule2_min_trailing_silence: int = 1.2, |
| 34 | rule3_min_utterance_length: int = 20, | 34 | rule3_min_utterance_length: int = 20, |
| 35 | + decoding_method: str = "greedy_search", | ||
| 36 | + max_active_paths: int = 4, | ||
| 37 | + max_feature_vectors: int = -1, | ||
| 35 | ): | 38 | ): |
| 36 | """ | 39 | """ |
| 37 | Please refer to | 40 | Please refer to |
| @@ -74,6 +77,14 @@ class OnlineRecognizer(object): | @@ -74,6 +77,14 @@ class OnlineRecognizer(object): | ||
| 74 | Used only when enable_endpoint_detection is True. If the utterance | 77 | Used only when enable_endpoint_detection is True. If the utterance |
| 75 | length in seconds is larger than this value, we assume an endpoint | 78 | length in seconds is larger than this value, we assume an endpoint |
| 76 | is detected. | 79 | is detected. |
| 80 | + decoding_method: | ||
| 81 | + Valid values are greedy_search, modified_beam_search. | ||
| 82 | + max_active_paths: | ||
| 83 | + Use only when decoding_method is modified_beam_search. It specifies | ||
| 84 | + the maximum number of active paths during beam search. | ||
| 85 | + max_feature_vectors: | ||
| 86 | + Number of feature vectors to cache. -1 means to cache all feature | ||
| 87 | + frames that have been processed. | ||
| 77 | """ | 88 | """ |
| 78 | _assert_file_exists(tokens) | 89 | _assert_file_exists(tokens) |
| 79 | _assert_file_exists(encoder) | 90 | _assert_file_exists(encoder) |
| @@ -93,6 +104,7 @@ class OnlineRecognizer(object): | @@ -93,6 +104,7 @@ class OnlineRecognizer(object): | ||
| 93 | feat_config = FeatureExtractorConfig( | 104 | feat_config = FeatureExtractorConfig( |
| 94 | sampling_rate=sample_rate, | 105 | sampling_rate=sample_rate, |
| 95 | feature_dim=feature_dim, | 106 | feature_dim=feature_dim, |
| 107 | + max_feature_vectors=max_feature_vectors, | ||
| 96 | ) | 108 | ) |
| 97 | 109 | ||
| 98 | endpoint_config = EndpointConfig( | 110 | endpoint_config = EndpointConfig( |
| @@ -106,6 +118,8 @@ class OnlineRecognizer(object): | @@ -106,6 +118,8 @@ class OnlineRecognizer(object): | ||
| 106 | model_config=model_config, | 118 | model_config=model_config, |
| 107 | endpoint_config=endpoint_config, | 119 | endpoint_config=endpoint_config, |
| 108 | enable_endpoint=enable_endpoint_detection, | 120 | enable_endpoint=enable_endpoint_detection, |
| 121 | + decoding_method=decoding_method, | ||
| 122 | + max_active_paths=max_active_paths, | ||
| 109 | ) | 123 | ) |
| 110 | 124 | ||
| 111 | self.recognizer = _Recognizer(recognizer_config) | 125 | self.recognizer = _Recognizer(recognizer_config) |
-
请 注册 或 登录 后发表评论