Fangjun Kuang
Committed by GitHub

Add HLG decoding for streaming CTC models (#731)

#!/usr/bin/env bash
set -e
set -ex
log() {
# This function is from espnet
... ... @@ -14,6 +14,26 @@ echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run streaming Zipformer2 CTC HLG decoding "
log "------------------------------------------------------------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
repo=$PWD/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
ls -lh $repo
echo "pwd: $PWD"
$EXE \
--zipformer2-ctc-model=$repo/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
--ctc-graph=$repo/HLG.fst \
--tokens=$repo/tokens.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
log "------------------------------------------------------------"
log "Run streaming Zipformer2 CTC "
log "------------------------------------------------------------"
... ...
#!/usr/bin/env bash
set -e
set -ex
log() {
# This function is from espnet
... ... @@ -8,6 +8,23 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "test streaming zipformer2 ctc HLG decoding"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
repo=sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \
--debug 1 \
--tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \
--graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \
--model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav
rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
mkdir -p /tmp/icefall-models
dir=/tmp/icefall-models
... ...
... ... @@ -124,6 +124,14 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: build/bin/*
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-ctc.sh
- name: Test C API
shell: bash
run: |
... ... @@ -149,13 +157,6 @@ jobs:
.github/scripts/test-kws.sh
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-ctc.sh
- name: Test offline Whisper
if: matrix.build_type != 'Debug'
... ...
function(download_kaldi_decoder)
include(FetchContent)
set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz")
set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz")
set(kaldi_decoder_HASH "SHA256=136d96c2f1f8ec44de095205f81a6ce98981cd867fe4ba840f9415a0b58fe601")
set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz")
set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz")
set(kaldi_decoder_HASH "SHA256=f663e58aef31b33cd8086eaa09ff1383628039845f31300b5abef817d8cc2fff")
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
... ... @@ -12,11 +12,11 @@ function(download_kaldi_decoder)
# If you don't have access to the Internet,
# please pre-download kaldi-decoder
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-decoder-0.2.4.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.4.tar.gz
${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.4.tar.gz
/tmp/kaldi-decoder-0.2.4.tar.gz
/star-fj/fangjun/download/github/kaldi-decoder-0.2.4.tar.gz
$ENV{HOME}/Downloads/kaldi-decoder-0.2.5.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.5.tar.gz
${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.5.tar.gz
/tmp/kaldi-decoder-0.2.5.tar.gz
/star-fj/fangjun/download/github/kaldi-decoder-0.2.5.tar.gz
)
foreach(f IN LISTS possible_file_locations)
... ...
#!/usr/bin/env python3
# This file shows how to use a streaming zipformer CTC model and an HLG
# graph for decoding.
#
# We use the following model as an example
#
"""
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \
--tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \
--graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \
--model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav
"""
# (The above model is from https://github.com/k2-fsa/icefall/pull/1557)
import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple
import numpy as np
import sherpa_onnx
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to tokens.txt",
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to the ONNX model",
)
parser.add_argument(
"--graph",
type=str,
required=True,
help="Path to H.fst, HL.fst, or HLG.fst",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)
parser.add_argument(
"--debug",
type=int,
default=0,
help="Valid values: 1, 0",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to decode. It must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
return parser.parse_args()
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
def main():
args = get_args()
print(vars(args))
assert_file_exists(args.tokens)
assert_file_exists(args.graph)
assert_file_exists(args.model)
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
tokens=args.tokens,
model=args.model,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
ctc_graph=args.graph,
)
wave_filename = args.sound_file
assert_file_exists(wave_filename)
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate
print("Started")
start_time = time.time()
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while recognizer.is_ready(s):
recognizer.decode_stream(s)
result = recognizer.get_result(s).lower()
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration
print(f"num_threads: {args.num_threads}")
print(f"Wave duration: {duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print(result)
if __name__ == "__main__":
main()
... ...
... ... @@ -51,6 +51,8 @@ set(sources
offline-zipformer-ctc-model-config.cc
offline-zipformer-ctc-model.cc
online-conformer-transducer-model.cc
online-ctc-fst-decoder-config.cc
online-ctc-fst-decoder.cc
online-ctc-greedy-search-decoder.cc
online-ctc-model.cc
online-lm-config.cc
... ...
... ... @@ -7,6 +7,9 @@
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
std::string OfflineCtcFstDecoderConfig::ToString() const {
... ... @@ -29,4 +32,12 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
"Decoder max active states. Larger->slower; more accurate");
}
bool OfflineCtcFstDecoderConfig::Validate() const {
if (!graph.empty() && !FileExists(graph)) {
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
return false;
}
return true;
}
} // namespace sherpa_onnx
... ...
... ... @@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig {
std::string ToString() const;
void Register(ParseOptions *po);
bool Validate() const;
};
} // namespace sherpa_onnx
... ...
... ... @@ -20,7 +20,7 @@ namespace sherpa_onnx {
// @param filename Path to a StdVectorFst or StdConstFst graph
// @return The caller should free the returned pointer using `delete` to
// avoid memory leak.
static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
// read decoding network FST
std::ifstream is(filename, std::ios::binary);
if (!is.good()) {
... ...
... ... @@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const {
return false;
}
if (!ctc_fst_decoder_config.graph.empty() &&
!ctc_fst_decoder_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in fst_decoder");
return false;
}
return model_config.Validate();
}
... ...
... ... @@ -5,12 +5,16 @@
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
#include <memory>
#include <vector>
#include "kaldi-decoder/csrc/faster-decoder.h"
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
class OnlineStream;
struct OnlineCtcDecoderResult {
/// Number of frames after subsampling we have decoded so far
int32_t frame_offset = 0;
... ... @@ -37,7 +41,13 @@ class OnlineCtcDecoder {
* @param results Input & Output parameters..
*/
virtual void Decode(Ort::Value log_probs,
std::vector<OnlineCtcDecoderResult> *results) = 0;
std::vector<OnlineCtcDecoderResult> *results,
OnlineStream **ss = nullptr, int32_t n = 0) = 0;
virtual std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
const {
return nullptr;
}
};
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
std::string OnlineCtcFstDecoderConfig::ToString() const {
std::ostringstream os;
os << "OnlineCtcFstDecoderConfig(";
os << "graph=\"" << graph << "\", ";
os << "max_active=" << max_active << ")";
return os.str();
}
void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) {
po->Register("ctc-graph", &graph, "Path to H.fst, HL.fst, or HLG.fst");
po->Register("ctc-max-active", &max_active,
"Decoder max active states. Larger->slower; more accurate");
}
bool OnlineCtcFstDecoderConfig::Validate() const {
if (!graph.empty() && !FileExists(graph)) {
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
return false;
}
return true;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineCtcFstDecoderConfig {
// Path to H.fst, HL.fst or HLG.fst
std::string graph;
int32_t max_active = 3000;
OnlineCtcFstDecoderConfig() = default;
OnlineCtcFstDecoderConfig(const std::string &graph, int32_t max_active)
: graph(graph), max_active(max_active) {}
std::string ToString() const;
void Register(ParseOptions *po);
bool Validate() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
... ...
// sherpa-onnx/csrc/online-ctc-fst-decoder.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "fst/fstlib.h"
#include "kaldi-decoder/csrc/decodable-ctc.h"
#include "kaldifst/csrc/fstext-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-stream.h"
namespace sherpa_onnx {
// defined in ./offline-ctc-fst-decoder.cc
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename);
OnlineCtcFstDecoder::OnlineCtcFstDecoder(
const OnlineCtcFstDecoderConfig &config, int32_t blank_id)
: config_(config), fst_(ReadGraph(config.graph)), blank_id_(blank_id) {
options_.max_active = config_.max_active;
}
std::unique_ptr<kaldi_decoder::FasterDecoder>
OnlineCtcFstDecoder::CreateFasterDecoder() const {
return std::make_unique<kaldi_decoder::FasterDecoder>(*fst_, options_);
}
static void DecodeOne(const float *log_probs, int32_t num_rows,
int32_t num_cols, OnlineCtcDecoderResult *result,
OnlineStream *s, int32_t blank_id) {
int32_t &processed_frames = s->GetFasterDecoderProcessedFrames();
kaldi_decoder::DecodableCtc decodable(log_probs, num_rows, num_cols,
processed_frames);
kaldi_decoder::FasterDecoder *decoder = s->GetFasterDecoder();
if (processed_frames == 0) {
decoder->InitDecoding();
}
decoder->AdvanceDecoding(&decodable);
if (decoder->ReachedFinal()) {
fst::VectorFst<fst::LatticeArc> fst_out;
bool ok = decoder->GetBestPath(&fst_out);
if (ok) {
std::vector<int32_t> isymbols_out;
std::vector<int32_t> osymbols_out_unused;
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
&osymbols_out_unused, nullptr);
std::vector<int64_t> tokens;
tokens.reserve(isymbols_out.size());
std::vector<int32_t> timestamps;
timestamps.reserve(isymbols_out.size());
std::ostringstream os;
int32_t prev_id = -1;
int32_t num_trailing_blanks = 0;
int32_t f = 0; // frame number
for (auto i : isymbols_out) {
i -= 1;
if (i == blank_id) {
num_trailing_blanks += 1;
} else {
num_trailing_blanks = 0;
}
if (i != blank_id && i != prev_id) {
tokens.push_back(i);
timestamps.push_back(f);
}
prev_id = i;
f += 1;
}
result->tokens = std::move(tokens);
result->timestamps = std::move(timestamps);
// no need to set frame_offset
}
}
processed_frames += num_rows;
}
void OnlineCtcFstDecoder::Decode(Ort::Value log_probs,
std::vector<OnlineCtcDecoderResult> *results,
OnlineStream **ss, int32_t n) {
std::vector<int64_t> log_probs_shape =
log_probs.GetTensorTypeAndShapeInfo().GetShape();
if (log_probs_shape[0] != results->size()) {
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
static_cast<int32_t>(log_probs_shape[0]),
static_cast<int32_t>(results->size()));
exit(-1);
}
if (log_probs_shape[0] != n) {
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d",
static_cast<int32_t>(log_probs_shape[0]), n);
exit(-1);
}
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
const float *p = log_probs.GetTensorData<float>();
for (int32_t i = 0; i != batch_size; ++i) {
DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
&(*results)[i], ss[i], blank_id_);
}
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-ctc-fst-decoder.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
#include <memory>
#include <vector>
#include "fst/fst.h"
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
namespace sherpa_onnx {
class OnlineCtcFstDecoder : public OnlineCtcDecoder {
public:
OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
int32_t blank_id);
void Decode(Ort::Value log_probs,
std::vector<OnlineCtcDecoderResult> *results,
OnlineStream **ss = nullptr, int32_t n = 0) override;
std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
const override;
private:
OnlineCtcFstDecoderConfig config_;
kaldi_decoder::FasterDecoderOptions options_;
std::unique_ptr<fst::Fst<fst::StdArc>> fst_;
int32_t blank_id_ = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
... ...
... ... @@ -13,7 +13,8 @@
namespace sherpa_onnx {
void OnlineCtcGreedySearchDecoder::Decode(
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results) {
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results,
OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) {
std::vector<int64_t> log_probs_shape =
log_probs.GetTensorTypeAndShapeInfo().GetShape();
... ...
... ... @@ -17,7 +17,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
: blank_id_(blank_id) {}
void Decode(Ort::Value log_probs,
std::vector<OnlineCtcDecoderResult> *results) override;
std::vector<OnlineCtcDecoderResult> *results,
OnlineStream **ss = nullptr, int32_t n = 0) override;
private:
int32_t blank_id_;
... ...
... ... @@ -16,6 +16,7 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-ctc-model.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
... ... @@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetStates(model_->GetInitStates());
stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
return stream;
}
... ... @@ -165,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(std::move(out_states));
decoder_->Decode(std::move(out[0]), &results);
decoder_->Decode(std::move(out[0]), &results, ss, n);
for (int32_t k = 0; k != n; ++k) {
ss[k]->SetCtcResult(results[k]);
... ... @@ -221,7 +223,6 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
private:
void InitDecoder() {
if (config_.decoding_method == "greedy_search") {
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
!sym_.contains("<blank>")) {
SHERPA_ONNX_LOGE(
... ... @@ -241,9 +242,14 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
blank_id = sym_["<blank>"];
}
if (!config_.ctc_fst_decoder_config.graph.empty()) {
decoder_ = std::make_unique<OnlineCtcFstDecoder>(
config_.ctc_fst_decoder_config, blank_id);
} else if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
SHERPA_ONNX_LOGE(
"Unsupported decoding method: %s for streaming CTC models",
config_.decoding_method.c_str());
exit(-1);
}
... ... @@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
std::vector<OnlineCtcDecoderResult> results(1);
results[0] = std::move(s->GetCtcResult());
decoder_->Decode(std::move(out[0]), &results);
decoder_->Decode(std::move(out[0]), &results, &s, 1);
s->SetCtcResult(results[0]);
}
... ...
... ... @@ -19,13 +19,13 @@
namespace sherpa_onnx {
/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<typename T>
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
template <typename T>
std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(precision);
oss << "[ ";
std::string sep = "";
for (const auto& item : vec) {
for (const auto &item : vec) {
oss << sep << item;
sep = ", ";
}
... ... @@ -34,13 +34,13 @@ std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
}
/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<> // explicit specialization for T = std::string
std::string VecToString<std::string>(const std::vector<std::string>& vec,
template <> // explicit specialization for T = std::string
std::string VecToString<std::string>(const std::vector<std::string> &vec,
int32_t) { // ignore 2nd arg
std::ostringstream oss;
oss << "[ ";
std::string sep = "";
for (const auto& item : vec) {
for (const auto &item : vec) {
oss << sep << "\"" << item << "\"";
sep = ", ";
}
... ... @@ -51,15 +51,17 @@ std::string VecToString<std::string>(const std::vector<std::string>& vec,
std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{ ";
os << "\"text\": " << "\"" << text << "\"" << ", ";
os << "\"text\": "
<< "\"" << text << "\""
<< ", ";
os << "\"tokens\": " << VecToString(tokens) << ", ";
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
os << "\"segment\": " << segment << ", ";
os << "\"start_time\": " << std::fixed << std::setprecision(2)
<< start_time << ", ";
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time
<< ", ";
os << "\"is_final\": " << (is_final ? "true" : "false");
os << "}";
return os.str();
... ... @@ -70,6 +72,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
model_config.Register(po);
endpoint_config.Register(po);
lm_config.Register(po);
ctc_fst_decoder_config.Register(po);
po->Register("enable-endpoint", &enable_endpoint,
"True to enable endpoint detection. False to disable it.");
... ... @@ -116,6 +119,12 @@ bool OnlineRecognizerConfig::Validate() const {
return false;
}
if (!ctc_fst_decoder_config.graph.empty() &&
!ctc_fst_decoder_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in ctc_fst_decoder_config");
return false;
}
return model_config.Validate();
}
... ... @@ -127,6 +136,7 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "model_config=" << model_config.ToString() << ", ";
os << "lm_config=" << lm_config.ToString() << ", ";
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "hotwords_score=" << hotwords_score << ", ";
... ...
... ... @@ -16,6 +16,7 @@
#include "sherpa-onnx/csrc/endpoint.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-stream.h"
... ... @@ -80,6 +81,7 @@ struct OnlineRecognizerConfig {
OnlineModelConfig model_config;
OnlineLMConfig lm_config;
EndpointConfig endpoint_config;
OnlineCtcFstDecoderConfig ctc_fst_decoder_config;
bool enable_endpoint = true;
std::string decoding_method = "greedy_search";
... ... @@ -96,19 +98,19 @@ struct OnlineRecognizerConfig {
OnlineRecognizerConfig() = default;
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
const OnlineModelConfig &model_config,
const OnlineLMConfig &lm_config,
OnlineRecognizerConfig(
const FeatureExtractorConfig &feat_config,
const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
const EndpointConfig &endpoint_config,
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths,
const std::string &hotwords_file, float hotwords_score,
float blank_penalty)
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
bool enable_endpoint, const std::string &decoding_method,
int32_t max_active_paths, const std::string &hotwords_file,
float hotwords_score, float blank_penalty)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
endpoint_config(endpoint_config),
ctc_fst_decoder_config(ctc_fst_decoder_config),
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths),
... ...
... ... @@ -104,6 +104,18 @@ class OnlineStream::Impl {
return paraformer_alpha_cache_;
}
void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) {
faster_decoder_ = std::move(decoder);
}
kaldi_decoder::FasterDecoder *GetFasterDecoder() const {
return faster_decoder_.get();
}
int32_t &GetFasterDecoderProcessedFrames() {
return faster_decoder_processed_frames_;
}
private:
FeatureExtractor feat_extractor_;
/// For contextual-biasing
... ... @@ -121,6 +133,8 @@ class OnlineStream::Impl {
std::vector<float> paraformer_encoder_out_cache_;
std::vector<float> paraformer_alpha_cache_;
OnlineParaformerDecoderResult paraformer_result_;
std::unique_ptr<kaldi_decoder::FasterDecoder> faster_decoder_;
int32_t faster_decoder_processed_frames_ = 0;
};
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
... ... @@ -208,6 +222,19 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
void OnlineStream::SetFasterDecoder(
std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) {
impl_->SetFasterDecoder(std::move(decoder));
}
kaldi_decoder::FasterDecoder *OnlineStream::GetFasterDecoder() const {
return impl_->GetFasterDecoder();
}
int32_t &OnlineStream::GetFasterDecoderProcessedFrames() {
return impl_->GetFasterDecoderProcessedFrames();
}
std::vector<float> &OnlineStream::GetParaformerFeatCache() {
return impl_->GetParaformerFeatCache();
}
... ...
... ... @@ -8,6 +8,7 @@
#include <memory>
#include <vector>
#include "kaldi-decoder/csrc/faster-decoder.h"
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/features.h"
... ... @@ -97,6 +98,11 @@ class OnlineStream {
*/
const ContextGraphPtr &GetContextGraph() const;
// for online ctc decoder
void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder);
kaldi_decoder::FasterDecoder *GetFasterDecoder() const;
int32_t &GetFasterDecoderProcessedFrames();
// for streaming paraformer
std::vector<float> &GetParaformerFeatCache();
std::vector<float> &GetParaformerEncoderOutCache();
... ...
... ... @@ -18,6 +18,7 @@ set(srcs
offline-wenet-ctc-model-config.cc
offline-whisper-model-config.cc
offline-zipformer-ctc-model-config.cc
online-ctc-fst-decoder-config.cc
online-lm-config.cc
online-model-config.cc
online-paraformer-model-config.cc
... ...
// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
#include <string>
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
namespace sherpa_onnx {
void PybindOnlineCtcFstDecoderConfig(py::module *m) {
using PyClass = OnlineCtcFstDecoderConfig;
py::class_<PyClass>(*m, "OnlineCtcFstDecoderConfig")
.def(py::init<const std::string &, int32_t>(), py::arg("graph") = "",
py::arg("max_active") = 3000)
.def_readwrite("graph", &PyClass::graph)
.def_readwrite("max_active", &PyClass::max_active)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOnlineCtcFstDecoderConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
... ...
... ... @@ -24,8 +24,7 @@ static void PybindOnlineRecognizerResult(py::module *m) {
"tokens",
[](PyClass &self) -> std::vector<std::string> { return self.tokens; })
.def_property_readonly(
"start_time",
[](PyClass &self) -> float { return self.start_time; })
"start_time", [](PyClass &self) -> float { return self.start_time; })
.def_property_readonly(
"timestamps",
[](PyClass &self) -> std::vector<float> { return self.timestamps; })
... ... @@ -35,17 +34,14 @@ static void PybindOnlineRecognizerResult(py::module *m) {
.def_property_readonly(
"lm_probs",
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
.def_property_readonly(
"context_scores",
.def_property_readonly("context_scores",
[](PyClass &self) -> std::vector<float> {
return self.context_scores;
})
.def_property_readonly(
"segment",
[](PyClass &self) -> int32_t { return self.segment; })
"segment", [](PyClass &self) -> int32_t { return self.segment; })
.def_property_readonly(
"is_final",
[](PyClass &self) -> bool { return self.is_final; })
"is_final", [](PyClass &self) -> bool { return self.is_final; })
.def("as_json_string", &PyClass::AsJsonString,
py::call_guard<py::gil_scoped_release>());
}
... ... @@ -53,12 +49,15 @@ static void PybindOnlineRecognizerResult(py::module *m) {
static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
const OnlineLMConfig &, const EndpointConfig &, bool,
const std::string &, int32_t, const std::string &, float,
float>(),
.def(
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool, const std::string &,
int32_t, const std::string &, float, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
... ... @@ -66,6 +65,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
... ...
... ... @@ -15,6 +15,7 @@
#include "sherpa-onnx/python/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
#include "sherpa-onnx/python/csrc/offline-stream.h"
#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
#include "sherpa-onnx/python/csrc/online-lm-config.h"
#include "sherpa-onnx/python/csrc/online-model-config.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
... ... @@ -36,6 +37,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
m.doc() = "pybind11 binding of sherpa-onnx";
PybindFeatures(&m);
PybindOnlineCtcFstDecoderConfig(&m);
PybindOnlineModelConfig(&m);
PybindOnlineLMConfig(&m);
PybindOnlineStream(&m);
... ...
... ... @@ -16,6 +16,7 @@ from _sherpa_onnx import (
OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig,
OnlineZipformer2CtcModelConfig,
OnlineCtcFstDecoderConfig,
)
... ... @@ -314,6 +315,8 @@ class OnlineRecognizer(object):
rule2_min_trailing_silence: float = 1.2,
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
ctc_graph: str = "",
ctc_max_active: int = 3000,
provider: str = "cpu",
):
"""
... ... @@ -355,6 +358,12 @@ class OnlineRecognizer(object):
is detected.
decoding_method:
The only valid value is greedy_search.
ctc_graph:
If not empty, decoding_method is ignored. It contains the path to
H.fst, HL.fst, or HLG.fst
ctc_max_active:
Used only when ctc_graph is not empty. It specifies the maximum
active paths at a time.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
... ... @@ -384,10 +393,16 @@ class OnlineRecognizer(object):
rule3_min_utterance_length=rule3_min_utterance_length,
)
ctc_fst_decoder_config = OnlineCtcFstDecoderConfig(
graph=ctc_graph,
max_active=ctc_max_active,
)
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
endpoint_config=endpoint_config,
ctc_fst_decoder_config=ctc_fst_decoder_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
)
... ...