Fangjun Kuang
Committed by GitHub

Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
... ... @@ -36,4 +36,5 @@ tokens.txt
*.onnx
log.txt
tags
run-decode-file-python.sh
android/SherpaOnnx/app/src/main/assets/
... ...
... ... @@ -19,14 +19,16 @@ const char *kUsage =
" /path/to/encoder.onnx \\\n"
" /path/to/decoder.onnx \\\n"
" /path/to/joiner.onnx \\\n"
" /path/to/foo.wav [num_threads]\n"
" /path/to/foo.wav [num_threads [decoding_method]]\n"
"\n\n"
"Default num_threads is 1.\n"
"Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
"Please refer to \n"
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
"for a list of pre-trained models to download.\n";
int32_t main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) {
if (argc < 6 || argc > 8) {
fprintf(stderr, "%s\n", kUsage);
return -1;
}
... ... @@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) {
config.model_config.decoder = argv[3];
config.model_config.joiner = argv[4];
int32_t num_threads = 4;
int32_t num_threads = 1;
if (argc == 7 && atoi(argv[6]) > 0) {
num_threads = atoi(argv[6]);
}
config.model_config.num_threads = num_threads;
config.model_config.debug = 0;
config.decoding_method = "greedy_search";
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
config.feat_config.sample_rate = 16000;
config.feat_config.feature_dim = 80;
... ... @@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) {
SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
SherpaOnnxDisplay *display = CreateDisplay(50);
int32_t segment_id = 0;
const char *wav_filename = argv[5];
FILE *fp = fopen(wav_filename, "rb");
if (!fp) {
... ... @@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) {
SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) {
fprintf(stderr, "%s\n", r->text);
SherpaOnnxPrint(display, segment_id, r->text);
}
if (IsEndpoint(recognizer, stream)) {
if (strlen(r->text)) {
++segment_id;
}
Reset(recognizer, stream);
}
DestroyOnlineRecognizerResult(r);
}
}
... ... @@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) {
SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) {
fprintf(stderr, "%s\n", r->text);
SherpaOnnxPrint(display, segment_id, r->text);
}
DestroyOnlineRecognizerResult(r);
DestroyDisplay(display);
DestoryOnlineStream(stream);
DestroyOnlineRecognizer(recognizer);
fprintf(stderr, "\n");
return 0;
}
... ...
... ... @@ -26,12 +26,17 @@ if [ ! -f ./sherpa-onnx-ffmpeg ]; then
make
fi
../ffmpeg-examples/sherpa-onnx-ffmpeg \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/4.wav
for method in greedy_search modified_beam_search; do
echo "test method: $method"
../ffmpeg-examples/sherpa-onnx-ffmpeg \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \
2 \
$method
done
echo "Decoding a URL"
... ...
... ... @@ -7,7 +7,6 @@
#include "sherpa-onnx/c-api/c-api.h"
/*
* Copyright (c) 2010 Nicolas George
* Copyright (c) 2011 Stefano Sabatini
... ... @@ -43,14 +42,15 @@
#include <unistd.h>
extern "C" {
#include <libavcodec/avcodec.h>
#include <libavformat/avformat.h>
#include <libavfilter/buffersink.h>
#include <libavfilter/buffersrc.h>
#include <libavformat/avformat.h>
#include <libavutil/channel_layout.h>
#include <libavutil/opt.h>
}
static const char *filter_descr = "aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono";
static const char *filter_descr =
"aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono";
static AVFormatContext *fmt_ctx;
static AVCodecContext *dec_ctx;
... ... @@ -59,308 +59,172 @@ AVFilterContext *buffersrc_ctx;
AVFilterGraph *filter_graph;
static int audio_stream_index = -1;
static int open_input_file(const char *filename)
{
const AVCodec *dec;
int ret;
static int open_input_file(const char *filename) {
const AVCodec *dec;
int ret;
if ((ret = avformat_open_input(&fmt_ctx, filename, NULL, NULL)) < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot open input file %s\n", filename);
return ret;
}
if ((ret = avformat_open_input(&fmt_ctx, filename, NULL, NULL)) < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot open input file %s\n", filename);
return ret;
}
if ((ret = avformat_find_stream_info(fmt_ctx, NULL)) < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot find stream information\n");
return ret;
}
if ((ret = avformat_find_stream_info(fmt_ctx, NULL)) < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot find stream information\n");
return ret;
}
/* select the audio stream */
ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot find an audio stream in the input file\n");
return ret;
}
audio_stream_index = ret;
/* create decoding context */
dec_ctx = avcodec_alloc_context3(dec);
if (!dec_ctx)
return AVERROR(ENOMEM);
avcodec_parameters_to_context(dec_ctx, fmt_ctx->streams[audio_stream_index]->codecpar);
/* init the audio decoder */
if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot open audio decoder\n");
return ret;
}
/* select the audio stream */
ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR,
"Cannot find an audio stream in the input file\n");
return ret;
}
audio_stream_index = ret;
/* create decoding context */
dec_ctx = avcodec_alloc_context3(dec);
if (!dec_ctx) return AVERROR(ENOMEM);
avcodec_parameters_to_context(dec_ctx,
fmt_ctx->streams[audio_stream_index]->codecpar);
/* init the audio decoder */
if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot open audio decoder\n");
return ret;
}
return 0;
return 0;
}
static int init_filters(const char *filters_descr)
{
char args[512];
int ret = 0;
const AVFilter *abuffersrc = avfilter_get_by_name("abuffer");
const AVFilter *abuffersink = avfilter_get_by_name("abuffersink");
AVFilterInOut *outputs = avfilter_inout_alloc();
AVFilterInOut *inputs = avfilter_inout_alloc();
static const enum AVSampleFormat out_sample_fmts[] = { AV_SAMPLE_FMT_S16, AV_SAMPLE_FMT_NONE };
static const int out_sample_rates[] = { 16000, -1 };
const AVFilterLink *outlink;
AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base;
filter_graph = avfilter_graph_alloc();
if (!outputs || !inputs || !filter_graph) {
ret = AVERROR(ENOMEM);
goto end;
}
/* buffer audio source: the decoded frames from the decoder will be inserted here. */
if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC)
av_channel_layout_default(&dec_ctx->ch_layout, dec_ctx->ch_layout.nb_channels);
ret = snprintf(args, sizeof(args),
"time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=",
time_base.num, time_base.den, dec_ctx->sample_rate,
av_get_sample_fmt_name(dec_ctx->sample_fmt));
av_channel_layout_describe(&dec_ctx->ch_layout, args + ret, sizeof(args) - ret);
ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in",
args, NULL, filter_graph);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n");
goto end;
}
/* buffer audio sink: to terminate the filter chain. */
ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out",
NULL, NULL, filter_graph);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n");
goto end;
}
ret = av_opt_set_int_list(buffersink_ctx, "sample_fmts", out_sample_fmts, -1,
AV_OPT_SEARCH_CHILDREN);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot set output sample format\n");
goto end;
}
ret = av_opt_set(buffersink_ctx, "ch_layouts", "mono",
AV_OPT_SEARCH_CHILDREN);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n");
goto end;
}
ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates, -1,
AV_OPT_SEARCH_CHILDREN);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n");
goto end;
}
/*
* Set the endpoints for the filter graph. The filter_graph will
* be linked to the graph described by filters_descr.
*/
/*
* The buffer source output must be connected to the input pad of
* the first filter described by filters_descr; since the first
* filter input label is not specified, it is set to "in" by
* default.
*/
outputs->name = av_strdup("in");
outputs->filter_ctx = buffersrc_ctx;
outputs->pad_idx = 0;
outputs->next = NULL;
/*
* The buffer sink input must be connected to the output pad of
* the last filter described by filters_descr; since the last
* filter output label is not specified, it is set to "out" by
* default.
*/
inputs->name = av_strdup("out");
inputs->filter_ctx = buffersink_ctx;
inputs->pad_idx = 0;
inputs->next = NULL;
if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr,
&inputs, &outputs, NULL)) < 0)
goto end;
if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0)
goto end;
/* Print summary of the sink buffer
* Note: args buffer is reused to store channel layout string */
outlink = buffersink_ctx->inputs[0];
av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args));
av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n",
(int)outlink->sample_rate,
(char *)av_x_if_null(av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"),
args);
static int init_filters(const char *filters_descr) {
char args[512];
int ret = 0;
const AVFilter *abuffersrc = avfilter_get_by_name("abuffer");
const AVFilter *abuffersink = avfilter_get_by_name("abuffersink");
AVFilterInOut *outputs = avfilter_inout_alloc();
AVFilterInOut *inputs = avfilter_inout_alloc();
static const enum AVSampleFormat out_sample_fmts[] = {AV_SAMPLE_FMT_S16,
AV_SAMPLE_FMT_NONE};
static const int out_sample_rates[] = {16000, -1};
const AVFilterLink *outlink;
AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base;
filter_graph = avfilter_graph_alloc();
if (!outputs || !inputs || !filter_graph) {
ret = AVERROR(ENOMEM);
goto end;
}
/* buffer audio source: the decoded frames from the decoder will be inserted
* here. */
if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC)
av_channel_layout_default(&dec_ctx->ch_layout,
dec_ctx->ch_layout.nb_channels);
ret = snprintf(args, sizeof(args),
"time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=",
time_base.num, time_base.den, dec_ctx->sample_rate,
av_get_sample_fmt_name(dec_ctx->sample_fmt));
av_channel_layout_describe(&dec_ctx->ch_layout, args + ret,
sizeof(args) - ret);
ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in", args,
NULL, filter_graph);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n");
goto end;
}
/* buffer audio sink: to terminate the filter chain. */
ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out", NULL,
NULL, filter_graph);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n");
goto end;
}
ret = av_opt_set_int_list(buffersink_ctx, "sample_fmts", out_sample_fmts, -1,
AV_OPT_SEARCH_CHILDREN);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot set output sample format\n");
goto end;
}
ret =
av_opt_set(buffersink_ctx, "ch_layouts", "mono", AV_OPT_SEARCH_CHILDREN);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n");
goto end;
}
ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates,
-1, AV_OPT_SEARCH_CHILDREN);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n");
goto end;
}
/*
* Set the endpoints for the filter graph. The filter_graph will
* be linked to the graph described by filters_descr.
*/
/*
* The buffer source output must be connected to the input pad of
* the first filter described by filters_descr; since the first
* filter input label is not specified, it is set to "in" by
* default.
*/
outputs->name = av_strdup("in");
outputs->filter_ctx = buffersrc_ctx;
outputs->pad_idx = 0;
outputs->next = NULL;
/*
* The buffer sink input must be connected to the output pad of
* the last filter described by filters_descr; since the last
* filter output label is not specified, it is set to "out" by
* default.
*/
inputs->name = av_strdup("out");
inputs->filter_ctx = buffersink_ctx;
inputs->pad_idx = 0;
inputs->next = NULL;
if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr, &inputs,
&outputs, NULL)) < 0)
goto end;
if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0) goto end;
/* Print summary of the sink buffer
* Note: args buffer is reused to store channel layout string */
outlink = buffersink_ctx->inputs[0];
av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args));
av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n",
(int)outlink->sample_rate,
(char *)av_x_if_null(
av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"),
args);
end:
avfilter_inout_free(&inputs);
avfilter_inout_free(&outputs);
avfilter_inout_free(&inputs);
avfilter_inout_free(&outputs);
return ret;
return ret;
}
static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream* stream)
{
static void sherpa_decode_frame(const AVFrame *frame,
SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream,
SherpaOnnxDisplay *display,
int32_t *segment_id) {
#define N 3200 // 100s. Sample rate is fixed to 16 kHz
static float samples[N];
static int nb_samples = 0;
const int16_t *p = (int16_t*)frame->data[0];
if (frame->nb_samples + nb_samples > N) {
AcceptWaveform(stream, 16000, samples, nb_samples);
while (IsOnlineStreamReady(recognizer, stream)) {
DecodeOnlineStream(recognizer, stream);
}
if (IsEndpoint(recognizer, stream)) {
SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) {
fprintf(stderr, "%s\n", r->text);
}
DestroyOnlineRecognizerResult(r);
Reset(recognizer, stream);
}
nb_samples = 0;
}
for (int i = 0; i < frame->nb_samples; i++) {
samples[nb_samples++] = p[i] / 32768.;
}
}
static inline char *__av_err2str(int errnum)
{
static char str[AV_ERROR_MAX_STRING_SIZE];
memset(str, 0, sizeof(str));
return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum);
}
int main(int argc, char **argv)
{
int ret;
int num_threads = 4;
AVPacket *packet = av_packet_alloc();
AVFrame *frame = av_frame_alloc();
AVFrame *filt_frame = av_frame_alloc();
const char *kUsage =
"\n"
"Usage:\n"
" ./sherpa-onnx-ffmpeg \\\n"
" /path/to/tokens.txt \\\n"
" /path/to/encoder.onnx\\\n"
" /path/to/decoder.onnx\\\n"
" /path/to/joiner.onnx\\\n"
" /path/to/foo.wav [num_threads]"
"\n\n"
"Please refer to \n"
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
"for a list of pre-trained models to download.\n";
if (!packet || !frame || !filt_frame) {
fprintf(stderr, "Could not allocate frame or packet\n");
exit(1);
}
if (argc < 6 || argc > 7) {
fprintf(stderr, "%s\n", kUsage);
return -1;
}
SherpaOnnxOnlineRecognizerConfig config;
config.model_config.tokens = argv[1];
config.model_config.encoder = argv[2];
config.model_config.decoder = argv[3];
config.model_config.joiner = argv[4];
if (argc == 7 && atoi(argv[6]) > 0) {
num_threads = atoi(argv[6]);
}
config.model_config.num_threads = num_threads;
config.model_config.debug = 0;
config.feat_config.sample_rate = 16000;
config.feat_config.feature_dim = 80;
config.enable_endpoint = 1;
config.rule1_min_trailing_silence = 2.4;
config.rule2_min_trailing_silence = 1.2;
config.rule3_min_utterance_length = 300;
SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
if ((ret = open_input_file(argv[5])) < 0)
exit(1);
if ((ret = init_filters(filter_descr)) < 0)
exit(1);
/* read all packets */
while (1) {
if ((ret = av_read_frame(fmt_ctx, packet)) < 0)
break;
if (packet->stream_index == audio_stream_index) {
ret = avcodec_send_packet(dec_ctx, packet);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Error while sending a packet to the decoder\n");
break;
}
while (ret >= 0) {
ret = avcodec_receive_frame(dec_ctx, frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
break;
} else if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Error while receiving a frame from the decoder\n");
exit(1);
}
if (ret >= 0) {
/* push the audio data from decoded frame into the filtergraph */
if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame, AV_BUFFERSRC_FLAG_KEEP_REF) < 0) {
av_log(NULL, AV_LOG_ERROR, "Error while feeding the audio filtergraph\n");
break;
}
/* pull filtered audio from the filtergraph */
while (1) {
ret = av_buffersink_get_frame(buffersink_ctx, filt_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF)
break;
if (ret < 0)
exit(1);
sherpa_decode_frame(filt_frame, recognizer, stream);
av_frame_unref(filt_frame);
}
av_frame_unref(frame);
}
}
}
av_packet_unref(packet);
}
// add some tail padding
float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate
AcceptWaveform(stream, 16000, tail_paddings, 4800);
InputFinished(stream);
static float samples[N];
static int nb_samples = 0;
const int16_t *p = (int16_t *)frame->data[0];
if (frame->nb_samples + nb_samples > N) {
AcceptWaveform(stream, 16000, samples, nb_samples);
while (IsOnlineStreamReady(recognizer, stream)) {
DecodeOnlineStream(recognizer, stream);
}
... ... @@ -368,25 +232,180 @@ int main(int argc, char **argv)
SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) {
fprintf(stderr, "%s\n", r->text);
SherpaOnnxPrint(display, *segment_id, r->text);
}
if (IsEndpoint(recognizer, stream)) {
if (strlen(r->text)) {
++*segment_id;
}
Reset(recognizer, stream);
}
DestroyOnlineRecognizerResult(r);
nb_samples = 0;
}
DestoryOnlineStream(stream);
DestroyOnlineRecognizer(recognizer);
for (int i = 0; i < frame->nb_samples; i++) {
samples[nb_samples++] = p[i] / 32768.;
}
}
avfilter_graph_free(&filter_graph);
avcodec_free_context(&dec_ctx);
avformat_close_input(&fmt_ctx);
av_packet_free(&packet);
av_frame_free(&frame);
av_frame_free(&filt_frame);
static inline char *__av_err2str(int errnum) {
static char str[AV_ERROR_MAX_STRING_SIZE];
memset(str, 0, sizeof(str));
return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum);
}
if (ret < 0 && ret != AVERROR_EOF) {
fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret));
exit(1);
}
int main(int argc, char **argv) {
int ret;
int num_threads = 1;
AVPacket *packet = av_packet_alloc();
AVFrame *frame = av_frame_alloc();
AVFrame *filt_frame = av_frame_alloc();
const char *kUsage =
"\n"
"Usage:\n"
" ./sherpa-onnx-ffmpeg \\\n"
" /path/to/tokens.txt \\\n"
" /path/to/encoder.onnx\\\n"
" /path/to/decoder.onnx\\\n"
" /path/to/joiner.onnx\\\n"
" /path/to/foo.wav [num_threads [decoding_method]]"
"\n\n"
"Default num_threads is 1.\n"
"Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
"Please refer to \n"
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
"for a list of pre-trained models to download.\n";
if (!packet || !frame || !filt_frame) {
fprintf(stderr, "Could not allocate frame or packet\n");
exit(1);
}
if (argc < 6 || argc > 8) {
fprintf(stderr, "%s\n", kUsage);
return -1;
}
SherpaOnnxOnlineRecognizerConfig config;
config.model_config.tokens = argv[1];
config.model_config.encoder = argv[2];
config.model_config.decoder = argv[3];
config.model_config.joiner = argv[4];
if (argc == 7 && atoi(argv[6]) > 0) {
num_threads = atoi(argv[6]);
}
config.model_config.num_threads = num_threads;
config.model_config.debug = 0;
config.feat_config.sample_rate = 16000;
config.feat_config.feature_dim = 80;
config.decoding_method = "greedy_search";
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
config.enable_endpoint = 1;
config.rule1_min_trailing_silence = 2.4;
config.rule2_min_trailing_silence = 1.2;
config.rule3_min_utterance_length = 300;
SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
SherpaOnnxDisplay *display = CreateDisplay(50);
int32_t segment_id = 0;
if ((ret = open_input_file(argv[5])) < 0) exit(1);
if ((ret = init_filters(filter_descr)) < 0) exit(1);
/* read all packets */
while (1) {
if ((ret = av_read_frame(fmt_ctx, packet)) < 0) break;
if (packet->stream_index == audio_stream_index) {
ret = avcodec_send_packet(dec_ctx, packet);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR,
"Error while sending a packet to the decoder\n");
break;
}
while (ret >= 0) {
ret = avcodec_receive_frame(dec_ctx, frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
break;
} else if (ret < 0) {
av_log(NULL, AV_LOG_ERROR,
"Error while receiving a frame from the decoder\n");
exit(1);
}
return 0;
if (ret >= 0) {
/* push the audio data from decoded frame into the filtergraph */
if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame,
AV_BUFFERSRC_FLAG_KEEP_REF) < 0) {
av_log(NULL, AV_LOG_ERROR,
"Error while feeding the audio filtergraph\n");
break;
}
/* pull filtered audio from the filtergraph */
while (1) {
ret = av_buffersink_get_frame(buffersink_ctx, filt_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) break;
if (ret < 0) exit(1);
sherpa_decode_frame(filt_frame, recognizer, stream, display,
&segment_id);
av_frame_unref(filt_frame);
}
av_frame_unref(frame);
}
}
}
av_packet_unref(packet);
}
// add some tail padding
float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate
AcceptWaveform(stream, 16000, tail_paddings, 4800);
InputFinished(stream);
while (IsOnlineStreamReady(recognizer, stream)) {
DecodeOnlineStream(recognizer, stream);
}
SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) {
SherpaOnnxPrint(display, segment_id, r->text);
}
DestroyOnlineRecognizerResult(r);
DestroyDisplay(display);
DestoryOnlineStream(stream);
DestroyOnlineRecognizer(recognizer);
avfilter_graph_free(&filter_graph);
avcodec_free_context(&dec_ctx);
avformat_close_input(&fmt_ctx);
av_packet_free(&packet);
av_frame_free(&frame);
av_frame_free(&filt_frame);
if (ret < 0 && ret != AVERROR_EOF) {
fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret));
exit(1);
}
fprintf(stderr, "\n");
return 0;
}
... ...
... ... @@ -54,6 +54,20 @@ def get_args():
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument(
"--wave-filename",
type=str,
help="""Path to the wave filename. Must be 16 kHz,
... ... @@ -65,7 +79,6 @@ def get_args():
def main():
sample_rate = 16000
num_threads = 2
args = get_args()
assert_file_exists(args.encoder)
... ... @@ -81,9 +94,10 @@ def main():
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=num_threads,
num_threads=args.num_threads,
sample_rate=sample_rate,
feature_dim=80,
decoding_method=args.decoding_method,
)
with wave.open(args.wave_filename) as f:
assert f.getframerate() == sample_rate, f.getframerate()
... ... @@ -119,7 +133,8 @@ def main():
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration
print(f"num_threads: {num_threads}")
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
... ...
... ... @@ -60,10 +60,10 @@ def get_args():
)
parser.add_argument(
"--wave-filename",
"--decoding-method",
type=str,
help="""Path to the wave filename. Must be 16 kHz,
mono with 16-bit samples""",
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser.parse_args()
... ... @@ -83,17 +83,23 @@ def create_recognizer():
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=1,
sample_rate=16000,
feature_dim=80,
enable_endpoint_detection=True,
rule1_min_trailing_silence=2.4,
rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=300, # it essentially disables this rule
decoding_method=args.decoding_method,
max_feature_vectors=100, # 1 second
)
return recognizer
def main():
print("Started! Please speak")
recognizer = create_recognizer()
print("Started! Please speak")
sample_rate = 16000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
... ... @@ -101,6 +107,7 @@ def main():
last_result = ""
segment_id = 0
display = sherpa_onnx.Display(max_word_per_line=30)
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
... ... @@ -115,7 +122,7 @@ def main():
if result and (last_result != result):
last_result = result
print(f"{segment_id}: {result}")
display.print(segment_id, result)
if is_endpoint:
if result:
... ...
... ... @@ -59,10 +59,10 @@ def get_args():
)
parser.add_argument(
"--wave-filename",
"--decoding-method",
type=str,
help="""Path to the wave filename. Must be 16 kHz,
mono with 16-bit samples""",
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser.parse_args()
... ... @@ -82,9 +82,11 @@ def create_recognizer():
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=4,
num_threads=1,
sample_rate=16000,
feature_dim=80,
decoding_method=args.decoding_method,
max_feature_vectors=100, # 1 second
)
return recognizer
... ... @@ -96,6 +98,7 @@ def main():
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
stream = recognizer.create_stream()
display = sherpa_onnx.Display(max_word_per_line=40)
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
... ... @@ -106,7 +109,7 @@ def main():
result = recognizer.get_result(stream)
if last_result != result:
last_result = result
print(result)
display.print(-1, result)
if __name__ == "__main__":
... ...
... ... @@ -9,6 +9,7 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/display.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
struct SherpaOnnxOnlineRecognizer {
... ... @@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream {
: impl(std::move(p)) {}
};
struct SherpaOnnxDisplay {
std::unique_ptr<sherpa_onnx::Display> impl;
};
SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
const SherpaOnnxOnlineRecognizerConfig *config) {
sherpa_onnx::OnlineRecognizerConfig recognizer_config;
... ... @@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
recognizer_config.model_config.num_threads = config->model_config.num_threads;
recognizer_config.model_config.debug = config->model_config.debug;
recognizer_config.decoding_method = config->decoding_method;
recognizer_config.max_active_paths = config->max_active_paths;
recognizer_config.enable_endpoint = config->enable_endpoint;
recognizer_config.endpoint_config.rule1.min_trailing_silence =
... ... @@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) {
return recognizer->impl->IsEndpoint(stream->impl.get());
}
SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) {
SherpaOnnxDisplay *ans = new SherpaOnnxDisplay;
ans->impl = std::make_unique<sherpa_onnx::Display>(max_word_per_line);
return ans;
}
void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; }
void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) {
display->impl->Print(idx, s);
}
... ...
... ... @@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig {
SherpaOnnxFeatureConfig feat_config;
SherpaOnnxOnlineTransducerModelConfig model_config;
/// Possible values are: greedy_search, modified_beam_search
const char *decoding_method;
/// Used only when decoding_method is modified_beam_search
/// Example value: 4
int32_t max_active_paths;
/// 0 to disable endpoint detection.
/// A non-zero value to enable endpoint detection.
int32_t enable_endpoint;
... ... @@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream);
int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream);
// for displaying results on Linux/macOS.
typedef struct SherpaOnnxDisplay SherpaOnnxDisplay;
/// Create a display object. Must be freed using DestroyDisplay to avoid
/// memory leak.
SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line);
void DestroyDisplay(SherpaOnnxDisplay *display);
/// Print the result.
void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s);
#ifdef __cplusplus
} /* extern "C" */
#endif
... ...
... ... @@ -9,10 +9,11 @@ set(sources
online-lstm-transducer-model.cc
online-recognizer.cc
online-stream.cc
online-transducer-decoder.cc
online-transducer-greedy-search-decoder.cc
online-transducer-model-config.cc
online-transducer-modified-beam-search-decoder.cc
online-transducer-model.cc
online-transducer-modified-beam-search-decoder.cc
online-zipformer-transducer-model.cc
onnx-utils.cc
parse-options.cc
... ...
... ... @@ -12,9 +12,16 @@ namespace sherpa_onnx {
class Display {
public:
explicit Display(int32_t max_word_per_line = 60)
: max_word_per_line_(max_word_per_line) {}
void Print(int32_t segment_id, const std::string &s) {
#ifdef _MSC_VER
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
if (segment_id != -1) {
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
} else {
fprintf(stderr, "%s\n", s.c_str());
}
return;
#endif
if (last_segment_ == segment_id) {
... ... @@ -27,7 +34,9 @@ class Display {
num_previous_lines_ = 0;
}
fprintf(stderr, "\r%d:", segment_id);
if (segment_id != -1) {
fprintf(stderr, "\r%d:", segment_id);
}
int32_t i = 0;
for (size_t n = 0; n < s.size();) {
... ... @@ -69,7 +78,7 @@ class Display {
void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); }
private:
int32_t max_word_per_line_ = 60;
int32_t max_word_per_line_;
int32_t num_previous_lines_ = 0;
int32_t last_segment_ = -1;
};
... ...
... ... @@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const {
os << "FeatureExtractorConfig(";
os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ")";
os << "feature_dim=" << feature_dim << ", ";
os << "max_feature_vectors=" << max_feature_vectors << ")";
return os.str();
}
... ... @@ -40,9 +41,7 @@ class FeatureExtractor::Impl {
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
// cache 100 seconds of feature frames, which is more than enough
// for real needs
opts_.frame_opts.max_feature_vectors = 100 * 100;
opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
opts_.mel_opts.num_bins = config.feature_dim;
... ...
... ... @@ -16,6 +16,7 @@ namespace sherpa_onnx {
struct FeatureExtractorConfig {
float sampling_rate = 16000;
int32_t feature_dim = 80;
int32_t max_feature_vectors = -1;
std::string ToString() const;
... ...
... ... @@ -18,7 +18,7 @@ namespace sherpa_onnx {
struct Hypothesis {
// The predicted tokens so far. Newly predicated tokens are appended.
std::vector<int32_t> ys;
std::vector<int64_t> ys;
// timestamps[i] contains the frame number after subsampling
// on which ys[i] is decoded.
... ... @@ -30,7 +30,7 @@ struct Hypothesis {
int32_t num_trailing_blanks = 0;
Hypothesis() = default;
Hypothesis(const std::vector<int32_t> &ys, double log_prob)
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
: ys(ys), log_prob(log_prob) {}
// If two Hypotheses have the same `Key`, then they contain
... ...
... ... @@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"True to enable endpoint detection. False to disable it.");
po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search.");
po->Register("decoding-mothod", &decoding_method,
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
}
... ... @@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ",";
os << "max_active_paths=" << max_active_paths << ",";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "decoding_method=\"" << decoding_method << "\")";
return os.str();
... ... @@ -187,16 +187,14 @@ class OnlineRecognizer::Impl {
}
void Reset(OnlineStream *s) const {
// reset result, neural network model state, and
// the feature extractor state
// reset result
// we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
s->SetResult(decoder_->GetEmptyResult());
s->GetResult().decoder_out = std::move(decoder_out);
// reset neural network model state
s->SetStates(model_->GetEncoderInitStates());
// reset feature extractor
// Note: We only update counters. The underlying audio samples
// are not discarded.
s->Reset();
}
... ...
... ... @@ -33,21 +33,26 @@ struct OnlineRecognizerConfig {
OnlineTransducerModelConfig model_config;
EndpointConfig endpoint_config;
bool enable_endpoint = true;
int32_t max_active_paths = 4;
std::string decoding_method = "modified_beam_search";
std::string decoding_method = "greedy_search";
// now support modified_beam_search and greedy_search
int32_t max_active_paths = 4; // used only for modified_beam_search
OnlineRecognizerConfig() = default;
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
const OnlineTransducerModelConfig &model_config,
const EndpointConfig &endpoint_config,
bool enable_endpoint)
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths)
: feat_config(feat_config),
model_config(model_config),
endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint) {}
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -22,18 +22,21 @@ class OnlineStream::Impl {
void InputFinished() { feat_extractor_.InputFinished(); }
int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); }
int32_t NumFramesReady() const {
return feat_extractor_.NumFramesReady() - start_frame_index_;
}
bool IsLastFrame(int32_t frame) const {
return feat_extractor_.IsLastFrame(frame);
}
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
return feat_extractor_.GetFrames(frame_index, n);
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
}
void Reset() {
feat_extractor_.Reset();
// we don't reset the feature extractor
start_frame_index_ += num_processed_frames_;
num_processed_frames_ = 0;
}
... ... @@ -41,7 +44,7 @@ class OnlineStream::Impl {
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
const OnlineTransducerDecoderResult &GetResult() const { return result_; }
OnlineTransducerDecoderResult &GetResult() { return result_; }
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
... ... @@ -54,6 +57,7 @@ class OnlineStream::Impl {
private:
FeatureExtractor feat_extractor_;
int32_t num_processed_frames_ = 0; // before subsampling
int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_;
};
... ... @@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
impl_->SetResult(r);
}
const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult();
}
... ...
... ... @@ -63,7 +63,7 @@ class OnlineStream {
int32_t &GetNumProcessedFrames();
void SetResult(const OnlineTransducerDecoderResult &r);
const OnlineTransducerDecoderResult &GetResult() const;
OnlineTransducerDecoderResult &GetResult();
void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates();
... ...
// sherpa-onnx/csrc/online-transducer-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
const OnlineTransducerDecoderResult &other)
: OnlineTransducerDecoderResult() {
*this = other;
}
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
const OnlineTransducerDecoderResult &other) {
if (this == &other) {
return *this;
}
tokens = other.tokens;
num_trailing_blanks = other.num_trailing_blanks;
Ort::AllocatorWithDefaultOptions allocator;
if (other.decoder_out) {
decoder_out = Clone(allocator, &other.decoder_out);
}
hyps = other.hyps;
return *this;
}
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
OnlineTransducerDecoderResult &&other)
: OnlineTransducerDecoderResult() {
*this = std::move(other);
}
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
OnlineTransducerDecoderResult &&other) {
if (this == &other) {
return *this;
}
tokens = std::move(other.tokens);
num_trailing_blanks = other.num_trailing_blanks;
decoder_out = std::move(other.decoder_out);
hyps = std::move(other.hyps);
return *this;
}
} // namespace sherpa_onnx
... ...
... ... @@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult {
/// number of trailing blank frames decoded so far
int32_t num_trailing_blanks = 0;
// Cache decoder_out for endpointing
Ort::Value decoder_out;
// used only in modified beam_search
Hypotheses hyps;
OnlineTransducerDecoderResult()
: tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {}
OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other);
OnlineTransducerDecoderResult &operator=(
const OnlineTransducerDecoderResult &other);
OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other);
OnlineTransducerDecoderResult &operator=(
OnlineTransducerDecoderResult &&other);
};
class OnlineTransducerDecoder {
... ... @@ -53,6 +69,9 @@ class OnlineTransducerDecoder {
*/
virtual void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) = 0;
// used for endpointing. We need to keep decoder_out after reset
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
};
} // namespace sherpa_onnx
... ...
... ... @@ -13,6 +13,43 @@
namespace sherpa_onnx {
static void UseCachedDecoderOut(
const std::vector<OnlineTransducerDecoderResult> &results,
Ort::Value *decoder_out) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
float *dst = decoder_out->GetTensorMutableData<float>();
for (const auto &r : results) {
if (r.decoder_out) {
const float *src = r.decoder_out.GetTensorData<float>();
std::copy(src, src + shape[1], dst);
}
dst += shape[1];
}
}
static void UpdateCachedDecoderOut(
OrtAllocator *allocator, const Ort::Value *decoder_out,
std::vector<OnlineTransducerDecoderResult> *results) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> v_shape{1, shape[1]};
const float *src = decoder_out->GetTensorData<float>();
for (auto &r : *results) {
if (!r.decoder_out) {
r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
v_shape.size());
}
float *dst = r.decoder_out.GetTensorMutableData<float>();
std::copy(src, src + shape[1], dst);
src += shape[1];
}
}
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
... ... @@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
UseCachedDecoderOut(*result, &decoder_out);
for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out =
... ... @@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode(
}
}
if (emitted) {
decoder_input = model_->BuildDecoderInput(*result);
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
decoder_out = model_->RunDecoder(std::move(decoder_input));
}
}
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
}
} // namespace sherpa_onnx
... ...
... ... @@ -13,6 +13,29 @@
namespace sherpa_onnx {
static void UseCachedDecoderOut(
const std::vector<int32_t> &hyps_num_split,
const std::vector<OnlineTransducerDecoderResult> &results,
int32_t context_size, Ort::Value *decoder_out) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
float *dst = decoder_out->GetTensorMutableData<float>();
int32_t batch_size = static_cast<int32_t>(results.size());
for (int32_t i = 0; i != batch_size; ++i) {
int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i];
if (num_hyps > 1 || !results[i].decoder_out) {
dst += num_hyps * shape[1];
continue;
}
const float *src = results[i].decoder_out.GetTensorData<float>();
std::copy(src, src + shape[1], dst);
dst += shape[1];
}
}
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split) {
std::vector<int64_t> cur_encoder_out_shape =
... ... @@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0
OnlineTransducerDecoderResult r;
std::vector<int32_t> blanks(context_size, blank_id);
std::vector<int64_t> blanks(context_size, blank_id);
Hypotheses blank_hyp({{blanks, 0}});
r.hyps = std::move(blank_hyp);
return r;
... ... @@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
if (t == 0) {
UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(),
&decoder_out);
}
Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
... ... @@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
}
for (int32_t b = 0; b != batch_size; ++b) {
(*result)[b].hyps = std::move(cur[b]);
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);
(*result)[b].hyps = std::move(hyps);
(*result)[b].tokens = std::move(best_hyp.ys);
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
}
}
void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
OnlineTransducerDecoderResult *result) {
if (result->tokens.size() == model_->ContextSize()) {
result->decoder_out = Ort::Value{nullptr};
return;
}
Ort::Value decoder_input = model_->BuildDecoderInput({*result});
result->decoder_out = model_->RunDecoder(std::move(decoder_input));
}
} // namespace sherpa_onnx
... ...
... ... @@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder
void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) override;
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
private:
OnlineTransducerModel *model_; // Not owned
int32_t max_active_paths_;
... ...
... ... @@ -21,7 +21,7 @@ static void Handler(int sig) {
}
int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) {
if (argc < 6 || argc > 8) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx-alsa \
... ... @@ -30,7 +30,10 @@ Usage:
/path/to/decoder.onnx \
/path/to/joiner.onnx \
device_name \
[num_threads]
[num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
... ... @@ -79,6 +82,11 @@ as the device_name.
config.model_config.num_threads = atoi(argv[6]);
}
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
config.enable_endpoint = true;
config.endpoint_config.rule1.min_trailing_silence = 2.4;
... ...
... ... @@ -36,7 +36,7 @@ static void Handler(int32_t sig) {
}
int32_t main(int32_t argc, char *argv[]) {
if (argc < 5 || argc > 6) {
if (argc < 5 || argc > 7) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx-microphone \
... ... @@ -44,7 +44,10 @@ Usage:
/path/to/encoder.onnx\
/path/to/decoder.onnx\
/path/to/joiner.onnx\
[num_threads]
[num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
... ... @@ -70,6 +73,11 @@ for a list of pre-trained models to download.
config.model_config.num_threads = atoi(argv[5]);
}
if (argc == 7) {
config.decoding_method = argv[6];
}
config.max_active_paths = 4;
config.enable_endpoint = true;
config.endpoint_config.rule1.min_trailing_silence = 2.4;
... ...
... ... @@ -14,7 +14,7 @@
#include "sherpa-onnx/csrc/wave-reader.h"
int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) {
if (argc < 6 || argc > 8) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx \
... ... @@ -22,7 +22,10 @@ Usage:
/path/to/encoder.onnx \
/path/to/decoder.onnx \
/path/to/joiner.onnx \
/path/to/foo.wav [num_threads]
/path/to/foo.wav [num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
... ... @@ -45,9 +48,15 @@ for a list of pre-trained models to download.
std::string wav_filename = argv[5];
config.model_config.num_threads = 2;
if (argc == 7) {
if (argc == 7 && atoi(argv[6]) > 0) {
config.model_config.num_threads = atoi(argv[6]);
}
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
fprintf(stderr, "%s\n", config.ToString().c_str());
sherpa_onnx::OnlineRecognizer recognizer(config);
... ... @@ -98,6 +107,7 @@ for a list of pre-trained models to download.
1000.;
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
... ...
include_directories(${CMAKE_SOURCE_DIR})
pybind11_add_module(_sherpa_onnx
display.cc
endpoint.cc
features.cc
online-recognizer.cc
online-stream.cc
online-transducer-model-config.cc
sherpa-onnx.cc
endpoint.cc
online-stream.cc
online-recognizer.cc
)
if(APPLE)
... ...
// sherpa-onnx/python/csrc/display.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/csrc/display.h"
namespace sherpa_onnx {
void PybindDisplay(py::module *m) {
using PyClass = Display;
py::class_<PyClass>(*m, "Display")
.def(py::init<int32_t>(), py::arg("max_word_per_line") = 60)
.def("print", &PyClass::Print, py::arg("idx"), py::arg("s"));
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/display.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindDisplay(py::module *m);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
... ...
... ... @@ -11,10 +11,12 @@ namespace sherpa_onnx {
static void PybindFeatureExtractorConfig(py::module *m) {
using PyClass = FeatureExtractorConfig;
py::class_<PyClass>(*m, "FeatureExtractorConfig")
.def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000,
py::arg("feature_dim") = 80)
.def(py::init<float, int32_t, int32_t>(),
py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
py::arg("max_feature_vectors") = -1)
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
.def_readwrite("feature_dim", &PyClass::feature_dim)
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const EndpointConfig &,
bool>(),
bool, const std::string &, int32_t>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("endpoint_config"), py::arg("enable_endpoint"))
py::arg("endpoint_config"), py::arg("enable_endpoint"),
py::arg("decoding_method"), py::arg("max_active_paths"))
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -4,6 +4,7 @@
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
... ... @@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOnlineStream(&m);
PybindEndpoint(&m);
PybindOnlineRecognizer(&m);
PybindDisplay(&m);
}
} // namespace sherpa_onnx
... ...
from _sherpa_onnx import (
EndpointConfig,
FeatureExtractorConfig,
OnlineRecognizerConfig,
OnlineStream,
OnlineTransducerModelConfig,
)
from _sherpa_onnx import Display
from .online_recognizer import OnlineRecognizer
... ...
... ... @@ -32,6 +32,9 @@ class OnlineRecognizer(object):
rule1_min_trailing_silence: int = 2.4,
rule2_min_trailing_silence: int = 1.2,
rule3_min_utterance_length: int = 20,
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
max_feature_vectors: int = -1,
):
"""
Please refer to
... ... @@ -74,6 +77,14 @@ class OnlineRecognizer(object):
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
Valid values are greedy_search, modified_beam_search.
max_active_paths:
Use only when decoding_method is modified_beam_search. It specifies
the maximum number of active paths during beam search.
max_feature_vectors:
Number of feature vectors to cache. -1 means to cache all feature
frames that have been processed.
"""
_assert_file_exists(tokens)
_assert_file_exists(encoder)
... ... @@ -93,6 +104,7 @@ class OnlineRecognizer(object):
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
max_feature_vectors=max_feature_vectors,
)
endpoint_config = EndpointConfig(
... ... @@ -106,6 +118,8 @@ class OnlineRecognizer(object):
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
max_active_paths=max_active_paths,
)
self.recognizer = _Recognizer(recognizer_config)
... ...