Fangjun Kuang
Committed by GitHub

Add non-streaming ASR (#92)

正在显示 48 个修改的文件 包含 1526 行增加150 行删除
... ... @@ -33,18 +33,20 @@ fun main() {
config = config,
)
var samples = WaveReader.readWave(
var objArray = WaveReader.readWave(
assetManager = AssetManager(),
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
)
var samples : FloatArray = objArray[0] as FloatArray
var sampleRate : Int = objArray[1] as Int
model.acceptWaveform(samples!!, sampleRate=16000)
model.acceptWaveform(samples, sampleRate=sampleRate)
while (model.isReady()) {
model.decode()
}
var tail_paddings = FloatArray(8000) // 0.5 seconds
model.acceptWaveform(tail_paddings, sampleRate=16000)
var tail_paddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
model.acceptWaveform(tail_paddings, sampleRate=sampleRate)
model.inputFinished()
while (model.isReady()) {
model.decode()
... ...
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run Conformer transducer (English)"
log "------------------------------------------------------------"
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-conformer-en-2023-03-18
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
cd test_wavs
popd
waves=(
$repo/test_wavs/0.wav
$repo/test_wavs/1.wav
$repo/test_wavs/2.wav
)
for wave in ${waves[@]}; do
time $EXE \
$repo/tokens.txt \
$repo/encoder-epoch-99-avg-1.onnx \
$repo/decoder-epoch-99-avg-1.onnx \
$repo/joiner-epoch-99-avg-1.onnx \
$wave \
2
done
if command -v sox &> /dev/null; then
echo "test 8kHz"
sox $repo/test_wavs/0.wav -r 8000 8k.wav
time $EXE \
$repo/tokens.txt \
$repo/encoder-epoch-99-avg-1.onnx \
$repo/decoder-epoch-99-avg-1.onnx \
$repo/joiner-epoch-99-avg-1.onnx \
8k.wav \
2
fi
rm -rf $repo
... ...
... ... @@ -40,7 +40,7 @@ for wave in ${waves[@]}; do
$repo/decoder-epoch-99-avg-1.onnx \
$repo/joiner-epoch-99-avg-1.onnx \
$wave \
4
2
done
rm -rf $repo
... ... @@ -72,7 +72,7 @@ for wave in ${waves[@]}; do
$repo/decoder-epoch-11-avg-1.onnx \
$repo/joiner-epoch-11-avg-1.onnx \
$wave \
4
2
done
rm -rf $repo
... ... @@ -104,7 +104,7 @@ for wave in ${waves[@]}; do
$repo/decoder-epoch-99-avg-1.onnx \
$repo/joiner-epoch-99-avg-1.onnx \
$wave \
4
2
done
rm -rf $repo
... ... @@ -138,7 +138,7 @@ for wave in ${waves[@]}; do
$repo/decoder-epoch-99-avg-1.onnx \
$repo/joiner-epoch-99-avg-1.onnx \
$wave \
4
2
done
# Decode a URL
... ... @@ -149,7 +149,7 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then
$repo/decoder-epoch-99-avg-1.onnx \
$repo/joiner-epoch-99-avg-1.onnx \
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \
4
2
fi
rm -rf $repo
... ...
... ... @@ -7,11 +7,11 @@ on:
paths:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
- 'sherpa-onnx/c-api/*'
- 'ffmpeg-examples/**'
- 'c-api-examples/**'
pull_request:
branches:
... ... @@ -19,11 +19,11 @@ on:
paths:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
- 'sherpa-onnx/c-api/*'
- 'ffmpeg-examples/**'
concurrency:
group: linux-${{ github.ref }}
... ... @@ -39,35 +39,26 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
build_type: [Release, Debug]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install ffmpeg
- name: Install sox
shell: bash
run: |
sudo apt-get install -y software-properties-common
sudo add-apt-repository ppa:savoury1/ffmpeg4
sudo add-apt-repository ppa:savoury1/ffmpeg5
sudo apt-get install -y libavdevice-dev libavutil-dev ffmpeg
pkg-config --modversion libavutil
ffmpeg -version
- name: Show ffmpeg version
shell: bash
run: |
pkg-config --modversion libavutil
ffmpeg -version
sudo apt-get update
sudo apt-get install -y sox
sox -h
- name: Configure CMake
shell: bash
run: |
mkdir build
cd build
cmake -D CMAKE_BUILD_TYPE=Release ..
cmake -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} ..
- name: Build sherpa-onnx for ubuntu
shell: bash
... ... @@ -78,21 +69,19 @@ jobs:
ls -lh lib
ls -lh bin
cd ../ffmpeg-examples
make
- name: Display dependencies of sherpa-onnx for linux
shell: bash
run: |
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
- name: Test sherpa-onnx-ffmpeg
- name: Test offline transducer
shell: bash
run: |
export PATH=$PWD/ffmpeg-examples:$PATH
export EXE=sherpa-onnx-ffmpeg
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-online-transducer.sh
.github/scripts/test-offline-transducer.sh
- name: Test online transducer
shell: bash
... ...
... ... @@ -7,6 +7,7 @@ on:
paths:
- '.github/workflows/macos.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -16,6 +17,7 @@ on:
paths:
- '.github/workflows/macos.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -34,18 +36,25 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest]
build_type: [Release, Debug]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install sox
shell: bash
run: |
brew install sox
sox -h
- name: Configure CMake
shell: bash
run: |
mkdir build
cd build
cmake -D CMAKE_BUILD_TYPE=Release ..
cmake -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} ..
- name: Build sherpa-onnx for macos
shell: bash
... ... @@ -64,6 +73,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test offline transducer
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-transducer.sh
- name: Test online transducer
shell: bash
run: |
... ...
... ... @@ -39,3 +39,5 @@ tags
run-decode-file-python.sh
android/SherpaOnnx/app/src/main/assets/
*.ncnn.*
run-sherpa-onnx-offline.sh
sherpa-onnx-conformer-en-2023-03-18
... ...
... ... @@ -13,7 +13,7 @@ endif()
option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF)
option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON)
option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" OFF)
option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
... ...
... ... @@ -121,7 +121,7 @@ class MainActivity : AppCompatActivity() {
val ret = audioRecord?.read(buffer, 0, buffer.size)
if (ret != null && ret > 0) {
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
model.acceptWaveform(samples, sampleRate=16000)
model.acceptWaveform(samples, sampleRate=sampleRateInHz)
while (model.isReady()) {
model.decode()
}
... ... @@ -180,7 +180,7 @@ class MainActivity : AppCompatActivity() {
val type = 0
println("Select model type ${type}")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getModelConfig(type = type)!!,
endpointConfig = getEndpointConfig(),
enableEndpoint = true,
... ...
... ... @@ -8,7 +8,7 @@ class WaveReader {
// No resampling is made.
external fun readWave(
assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f
): FloatArray?
): Array<Any>
init {
System.loadLibrary("sherpa-onnx-jni")
... ...
function(download_kaldi_native_fbank)
include(FetchContent)
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.13.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.13.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=1f4d228f9fe3e3e9f92a74a7eecd2489071a03982e4ba6d7c70fc5fa7444df57")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.14.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.14.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=6a66638a111d3ce21fe6f29cbf9ab3dbcae2331c77391bf825927df5cbf2babe")
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
... ... @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.13.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.13.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.13.tar.gz
/tmp/kaldi-native-fbank-1.13.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.13.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.14.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.14.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.14.tar.gz
/tmp/kaldi-native-fbank-1.14.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.14.tar.gz
)
foreach(f IN LISTS possible_file_locations)
... ...
... ... @@ -91,7 +91,6 @@ def create_recognizer():
rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=300, # it essentially disables this rule
decoding_method=args.decoding_method,
max_feature_vectors=100, # 1 second
)
return recognizer
... ...
... ... @@ -86,7 +86,6 @@ def create_recognizer():
sample_rate=16000,
feature_dim=80,
decoding_method=args.decoding_method,
max_feature_vectors=100, # 1 second
)
return recognizer
... ...
... ... @@ -6,6 +6,11 @@ set(sources
features.cc
file-utils.cc
hypothesis.cc
offline-stream.cc
offline-transducer-greedy-search-decoder.cc
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-recognizer.cc
online-lstm-transducer-model.cc
online-recognizer.cc
online-stream.cc
... ... @@ -56,10 +61,13 @@ if(SHERPA_ONNX_ENABLE_CHECK)
endif()
add_executable(sherpa-onnx sherpa-onnx.cc)
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
target_link_libraries(sherpa-onnx sherpa-onnx-core)
target_link_libraries(sherpa-onnx-offline sherpa-onnx-core)
if(NOT WIN32)
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
endif()
if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
... ... @@ -68,7 +76,13 @@ else()
install(TARGETS sherpa-onnx-core DESTINATION lib)
endif()
install(TARGETS sherpa-onnx DESTINATION bin)
install(
TARGETS
sherpa-onnx
sherpa-onnx-offline
DESTINATION
bin
)
if(SHERPA_ONNX_HAS_ALSA)
add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc)
... ...
... ... @@ -19,7 +19,9 @@ namespace sherpa_onnx {
void FeatureExtractorConfig::Register(ParseOptions *po) {
po->Register("sample-rate", &sampling_rate,
"Sampling rate of the input waveform. Must match the one "
"expected by the model.");
"expected by the model. Note: You can have a different "
"sample rate for the input waveform. We will do resampling "
"inside the feature extractor");
po->Register("feat-dim", &feature_dim,
"Feature dimension. Must match the one expected by the model.");
... ... @@ -30,8 +32,7 @@ std::string FeatureExtractorConfig::ToString() const {
os << "FeatureExtractorConfig(";
os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ", ";
os << "max_feature_vectors=" << max_feature_vectors << ")";
os << "feature_dim=" << feature_dim << ")";
return os.str();
}
... ... @@ -43,8 +44,6 @@ class FeatureExtractor::Impl {
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
opts_.mel_opts.num_bins = config.feature_dim;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
... ... @@ -95,7 +94,7 @@ class FeatureExtractor::Impl {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
}
void InputFinished() {
void InputFinished() const {
std::lock_guard<std::mutex> lock(mutex_);
fbank_->InputFinished();
}
... ... @@ -110,12 +109,21 @@ class FeatureExtractor::Impl {
return fbank_->IsLastFrame(frame);
}
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
if (frame_index + n > NumFramesReady()) {
fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
std::vector<float> GetFrames(int32_t frame_index, int32_t n) {
std::lock_guard<std::mutex> lock(mutex_);
if (frame_index + n > fbank_->NumFramesReady()) {
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n,
fbank_->NumFramesReady());
exit(-1);
}
std::lock_guard<std::mutex> lock(mutex_);
int32_t discard_num = frame_index - last_frame_index_;
if (discard_num < 0) {
SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d",
last_frame_index_, frame_index);
exit(-1);
}
fbank_->Pop(discard_num);
int32_t feature_dim = fbank_->Dim();
std::vector<float> features(feature_dim * n);
... ... @@ -128,12 +136,9 @@ class FeatureExtractor::Impl {
p += feature_dim;
}
return features;
}
last_frame_index_ = frame_index;
void Reset() {
std::lock_guard<std::mutex> lock(mutex_);
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
return features;
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
... ... @@ -143,6 +148,7 @@ class FeatureExtractor::Impl {
knf::FbankOptions opts_;
mutable std::mutex mutex_;
std::unique_ptr<LinearResample> resampler_;
int32_t last_frame_index_ = 0;
};
FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
... ... @@ -151,11 +157,11 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
FeatureExtractor::~FeatureExtractor() = default;
void FeatureExtractor::AcceptWaveform(int32_t sampling_rate,
const float *waveform, int32_t n) {
const float *waveform, int32_t n) const {
impl_->AcceptWaveform(sampling_rate, waveform, n);
}
void FeatureExtractor::InputFinished() { impl_->InputFinished(); }
void FeatureExtractor::InputFinished() const { impl_->InputFinished(); }
int32_t FeatureExtractor::NumFramesReady() const {
return impl_->NumFramesReady();
... ... @@ -170,8 +176,6 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
return impl_->GetFrames(frame_index, n);
}
void FeatureExtractor::Reset() { impl_->Reset(); }
int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); }
} // namespace sherpa_onnx
... ...
... ... @@ -14,9 +14,12 @@
namespace sherpa_onnx {
struct FeatureExtractorConfig {
// Sampling rate used by the feature extractor. If it is different from
// the sampling rate of the input waveform, we will do resampling inside.
int32_t sampling_rate = 16000;
// Feature dimension
int32_t feature_dim = 80;
int32_t max_feature_vectors = -1;
std::string ToString() const;
... ... @@ -36,7 +39,8 @@ class FeatureExtractor {
the range [-1, 1].
@param n Number of entries in waveform
*/
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
void AcceptWaveform(int32_t sampling_rate, const float *waveform,
int32_t n) const;
/**
* InputFinished() tells the class you won't be providing any
... ... @@ -44,7 +48,7 @@ class FeatureExtractor {
* of features, in the case where snip-edges == false; it also
* affects the return value of IsLastFrame().
*/
void InputFinished();
void InputFinished() const;
int32_t NumFramesReady() const;
... ... @@ -62,8 +66,6 @@ class FeatureExtractor {
*/
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
void Reset();
/// Return feature dim of this extractor
int32_t FeatureDim() const;
... ...
// sherpa-onnx/csrc/offline-recognizer.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include <memory>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-model.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace sherpa_onnx {
static OfflineRecognitionResult Convert(
const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table,
int32_t frame_shift_ms, int32_t subsampling_factor) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.timestamps.size());
std::string text;
for (auto i : src.tokens) {
auto sym = sym_table[i];
text.append(sym);
r.tokens.push_back(std::move(sym));
}
r.text = std::move(text);
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (auto t : src.timestamps) {
float time = frame_shift_s * t;
r.timestamps.push_back(time);
}
return r;
}
void OfflineRecognizerConfig::Register(ParseOptions *po) {
feat_config.Register(po);
model_config.Register(po);
po->Register("decoding-method", &decoding_method,
"decoding method,"
"Valid values: greedy_search.");
}
bool OfflineRecognizerConfig::Validate() const {
return model_config.Validate();
}
std::string OfflineRecognizerConfig::ToString() const {
std::ostringstream os;
os << "OfflineRecognizerConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "decoding_method=\"" << decoding_method << "\")";
return os.str();
}
class OfflineRecognizer::Impl {
public:
explicit Impl(const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
} else if (config_.decoding_method == "modified_beam_search") {
SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented");
exit(-1);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
exit(-1);
}
}
std::unique_ptr<OfflineStream> CreateStream() const {
return std::make_unique<OfflineStream>(config_.feat_config);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = ss[0]->FeatureDim();
std::vector<Ort::Value> features;
features.reserve(n);
std::vector<std::vector<float>> features_vec(n);
std::vector<int64_t> features_length_vec(n);
for (int32_t i = 0; i != n; ++i) {
auto f = ss[i]->GetFrames();
int32_t num_frames = f.size() / feat_dim;
features_length_vec[i] = num_frames;
features_vec[i] = std::move(f);
std::array<int64_t, 2> shape = {num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(
memory_info, features_vec[i].data(), features_vec[i].size(),
shape.data(), shape.size());
features.push_back(std::move(x));
}
std::vector<const Ort::Value *> features_pointer(n);
for (int32_t i = 0; i != n; ++i) {
features_pointer[i] = &features[i];
}
std::array<int64_t, 1> features_length_shape = {n};
Ort::Value x_length = Ort::Value::CreateTensor(
memory_info, features_length_vec.data(), n,
features_length_shape.data(), features_length_shape.size());
Ort::Value x = PadSequence(model_->Allocator(), features_pointer,
-23.025850929940457f);
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
int32_t frame_shift_ms = 10;
for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
ss[i]->SetResult(r);
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineTransducerModel> model_;
std::unique_ptr<OfflineTransducerDecoder> decoder_;
};
OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineRecognizer::~OfflineRecognizer() = default;
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
return impl_->CreateStream();
}
void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-recognizer.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-stream.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineRecognitionResult {
// Recognition results.
// For English, it consists of space separated words.
// For Chinese, it consists of Chinese words without spaces.
std::string text;
// Decoded results at the token level.
// For instance, for BPE-based models it consists of a list of BPE tokens.
std::vector<std::string> tokens;
/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;
};
struct OfflineRecognizerConfig {
OfflineFeatureExtractorConfig feat_config;
OfflineTransducerModelConfig model_config;
std::string decoding_method = "greedy_search";
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
OfflineRecognizerConfig() = default;
OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
const OfflineTransducerModelConfig &model_config,
const std::string &decoding_method)
: feat_config(feat_config),
model_config(model_config),
decoding_method(decoding_method) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
class OfflineRecognizer {
public:
~OfflineRecognizer();
explicit OfflineRecognizer(const OfflineRecognizerConfig &config);
/// Create a stream for decoding.
std::unique_ptr<OfflineStream> CreateStream() const;
/** Decode a single stream
*
* @param s The stream to decode.
*/
void DecodeStream(OfflineStream *s) const {
OfflineStream *ss[1] = {s};
DecodeStreams(ss, 1);
}
/** Decode a list of streams.
*
* @param ss Pointer to an array of streams.
* @param n Size of the input array.
*/
void DecodeStreams(OfflineStream **ss, int32_t n) const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_
... ...
// sherpa-onnx/csrc/offline-stream.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-stream.h"
#include <assert.h>
#include <algorithm>
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/resample.h"
namespace sherpa_onnx {
void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
po->Register("sample-rate", &sampling_rate,
"Sampling rate of the input waveform. Must match the one "
"expected by the model. Note: You can have a different "
"sample rate for the input waveform. We will do resampling "
"inside the feature extractor");
po->Register("feat-dim", &feature_dim,
"Feature dimension. Must match the one expected by the model.");
}
std::string OfflineFeatureExtractorConfig::ToString() const {
std::ostringstream os;
os << "OfflineFeatureExtractorConfig(";
os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ")";
return os.str();
}
class OfflineStream::Impl {
public:
explicit Impl(const OfflineFeatureExtractorConfig &config) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.mel_opts.num_bins = config.feature_dim;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (sampling_rate != opts_.frame_opts.samp_freq) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
float min_freq =
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
auto resampler = std::make_unique<LinearResample>(
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
fbank_->InputFinished();
return;
}
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
std::vector<float> GetFrames() const {
int32_t n = fbank_->NumFramesReady();
assert(n > 0 && "Please first call AcceptWaveform()");
int32_t feature_dim = FeatureDim();
std::vector<float> features(n * feature_dim);
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f = fbank_->GetFrame(i);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
return features;
}
void SetResult(const OfflineRecognitionResult &r) { r_ = r; }
const OfflineRecognitionResult &GetResult() const { return r_; }
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_;
OfflineRecognitionResult r_;
};
OfflineStream::OfflineStream(
const OfflineFeatureExtractorConfig &config /*= {}*/)
: impl_(std::make_unique<Impl>(config)) {}
OfflineStream::~OfflineStream() = default;
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
int32_t n) const {
impl_->AcceptWaveform(sampling_rate, waveform, n);
}
int32_t OfflineStream::FeatureDim() const { return impl_->FeatureDim(); }
std::vector<float> OfflineStream::GetFrames() const {
return impl_->GetFrames();
}
void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
impl_->SetResult(r);
}
const OfflineRecognitionResult &OfflineStream::GetResult() const {
return impl_->GetResult();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-stream.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_
#define SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_
#include <stdint.h>
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineRecognitionResult;
struct OfflineFeatureExtractorConfig {
// Sampling rate used by the feature extractor. If it is different from
// the sampling rate of the input waveform, we will do resampling inside.
int32_t sampling_rate = 16000;
// Feature dimension
int32_t feature_dim = 80;
std::string ToString() const;
void Register(ParseOptions *po);
};
class OfflineStream {
public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {});
~OfflineStream();
/**
@param sampling_rate The sampling_rate of the input waveform. If it does
not equal to config.sampling_rate, we will do
resampling inside.
@param waveform Pointer to a 1-D array of size n. It must be normalized to
the range [-1, 1].
@param n Number of entries in waveform
Caution: You can only invoke this function once so you have to input
all the samples at once
*/
void AcceptWaveform(int32_t sampling_rate, const float *waveform,
int32_t n) const;
/// Return feature dim of this extractor
int32_t FeatureDim() const;
// Get all the feature frames of this stream in a 1-D array, which is
// flattened from a 2-D array of shape (num_frames, feat_dim).
std::vector<float> GetFrames() const;
/** Set the recognition result for this stream. */
void SetResult(const OfflineRecognitionResult &r);
/** Get the recognition result of this stream */
const OfflineRecognitionResult &GetResult() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_
... ...
// sherpa-onnx/csrc/offline-transducer-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
struct OfflineTransducerDecoderResult {
/// The decoded token IDs
std::vector<int64_t> tokens;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
/// Note: The index is after subsampling
std::vector<int32_t> timestamps;
};
class OfflineTransducerDecoder {
public:
virtual ~OfflineTransducerDecoder() = default;
/** Run transducer beam search given the output from the encoder model.
*
* @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
* @param encoder_out_length A 1-D tensor of shape (N,) containing number
* of valid frames in encoder_out before padding.
*
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
#include <algorithm>
#include <iterator>
#include <utility>
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/packed-sequence.h"
#include "sherpa-onnx/csrc/slice.h"
namespace sherpa_onnx {
std::vector<OfflineTransducerDecoderResult>
OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
Ort::Value encoder_out_length) {
PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length);
int32_t batch_size =
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
int32_t vocab_size = model_->VocabSize();
int32_t context_size = model_->ContextSize();
std::vector<OfflineTransducerDecoderResult> ans(batch_size);
for (auto &r : ans) {
// 0 is the ID of the blank token
r.tokens.resize(context_size, 0);
}
auto decoder_input = model_->BuildDecoderInput(ans, ans.size());
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
int32_t start = 0;
int32_t t = 0;
for (auto n : packed_encoder_out.batch_sizes) {
Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n);
Ort::Value cur_decoder_out = Slice(model_->Allocator(), &decoder_out, 0, n);
start += n;
Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
std::move(cur_decoder_out));
const float *p_logit = logit.GetTensorData<float>();
bool emitted = false;
for (int32_t i = 0; i != n; ++i) {
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
p_logit += vocab_size;
if (y != 0) {
ans[i].tokens.push_back(y);
ans[i].timestamps.push_back(t);
emitted = true;
}
}
if (emitted) {
Ort::Value decoder_input = model_->BuildDecoderInput(ans, n);
decoder_out = model_->RunDecoder(std::move(decoder_input));
}
++t;
}
for (auto &r : ans) {
r.tokens = {r.tokens.begin() + context_size, r.tokens.end()};
}
std::vector<OfflineTransducerDecoderResult> unsorted_ans(batch_size);
for (int32_t i = 0; i != batch_size; ++i) {
unsorted_ans[packed_encoder_out.sorted_indexes[i]] = std::move(ans[i]);
}
return unsorted_ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-model.h"
namespace sherpa_onnx {
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
public:
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model)
: model_(model) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
private:
OfflineTransducerModel *model_; // Not owned
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-transducer-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
#include <sstream>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineTransducerModelConfig::Register(ParseOptions *po) {
po->Register("encoder", &encoder_filename, "Path to encoder.onnx");
po->Register("decoder", &decoder_filename, "Path to decoder.onnx");
po->Register("joiner", &joiner_filename, "Path to joiner.onnx");
po->Register("tokens", &tokens, "Path to tokens.txt");
po->Register("num_threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
}
bool OfflineTransducerModelConfig::Validate() const {
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
return false;
}
if (!FileExists(encoder_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
return false;
}
if (!FileExists(decoder_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str());
return false;
}
if (!FileExists(joiner_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str());
return false;
}
if (num_threads < 1) {
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
return false;
}
return true;
}
std::string OfflineTransducerModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineTransducerModelConfig(";
os << "encoder_filename=\"" << encoder_filename << "\", ";
os << "decoder_filename=\"" << decoder_filename << "\", ";
os << "joiner_filename=\"" << joiner_filename << "\", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-transducer-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineTransducerModelConfig {
std::string encoder_filename;
std::string decoder_filename;
std::string joiner_filename;
std::string tokens;
int32_t num_threads = 2;
bool debug = false;
OfflineTransducerModelConfig() = default;
OfflineTransducerModelConfig(const std::string &encoder_filename,
const std::string &decoder_filename,
const std::string &joiner_filename,
const std::string &tokens, int32_t num_threads,
bool debug)
: encoder_filename(encoder_filename),
decoder_filename(decoder_filename),
joiner_filename(joiner_filename),
tokens(tokens),
num_threads(num_threads),
debug(debug) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-transducer-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-transducer-model.h"
#include <algorithm>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
class OfflineTransducerModel::Impl {
public:
explicit Impl(const OfflineTransducerModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_{},
allocator_{} {
sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads);
{
auto buf = ReadFile(config.encoder_filename);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.decoder_filename);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
std::move(features_length)};
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
return {std::move(encoder_out[0]), std::move(encoder_out[1])};
}
Ort::Value RunDecoder(Ort::Value decoder_input) {
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), &decoder_input, 1,
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
return std::move(decoder_out[0]);
}
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(),
joiner_input.data(), joiner_input.size(),
joiner_output_names_ptr_.data(),
joiner_output_names_ptr_.size());
return std::move(logit[0]);
}
int32_t VocabSize() const { return vocab_size_; }
int32_t ContextSize() const { return context_size_; }
int32_t SubsamplingFactor() const { return 4; }
OrtAllocator *Allocator() const { return allocator_; }
Ort::Value BuildDecoderInput(
const std::vector<OfflineTransducerDecoderResult> &results,
int32_t end_index) const {
assert(end_index <= results.size());
int32_t batch_size = end_index;
int32_t context_size = ContextSize();
std::array<int64_t, 2> shape{batch_size, context_size};
Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
Allocator(), shape.data(), shape.size());
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
for (int32_t i = 0; i != batch_size; ++i) {
const auto &r = results[i];
const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size;
const int64_t *end = r.tokens.data() + r.tokens.size();
std::copy(begin, end, p);
p += context_size;
}
return decoder_input;
}
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
}
void InitDecoder(void *model_data, size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---decoder---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
}
void InitJoiner(void *model_data, size_t model_data_length) {
joiner_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_);
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
&joiner_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---joiner---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
}
private:
OfflineTransducerModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::unique_ptr<Ort::Session> joiner_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<std::string> joiner_input_names_;
std::vector<const char *> joiner_input_names_ptr_;
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;
int32_t vocab_size_ = 0; // initialized in InitDecoder
int32_t context_size_ = 0; // initialized in InitDecoder
};
OfflineTransducerModel::OfflineTransducerModel(
const OfflineTransducerModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineTransducerModel::~OfflineTransducerModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineTransducerModel::RunEncoder(
Ort::Value features, Ort::Value features_length) {
return impl_->RunEncoder(std::move(features), std::move(features_length));
}
Ort::Value OfflineTransducerModel::RunDecoder(Ort::Value decoder_input) {
return impl_->RunDecoder(std::move(decoder_input));
}
Ort::Value OfflineTransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
}
int32_t OfflineTransducerModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OfflineTransducerModel::ContextSize() const {
return impl_->ContextSize();
}
int32_t OfflineTransducerModel::SubsamplingFactor() const {
return impl_->SubsamplingFactor();
}
OrtAllocator *OfflineTransducerModel::Allocator() const {
return impl_->Allocator();
}
Ort::Value OfflineTransducerModel::BuildDecoderInput(
const std::vector<OfflineTransducerDecoderResult> &results,
int32_t end_index) const {
return impl_->BuildDecoderInput(results, end_index);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-transducer-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
namespace sherpa_onnx {
struct OfflineTransducerDecoderResult;
class OfflineTransducerModel {
public:
explicit OfflineTransducerModel(const OfflineTransducerModelConfig &config);
~OfflineTransducerModel();
/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
*
* @return Return a pair containing:
* - encoder_out: A 3-D tensor of shape (N, T', encoder_dim)
* - encoder_out_length: A 1-D tensor of shape (N,) containing number
* of frames in `encoder_out` before padding.
*/
std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features,
Ort::Value features_length);
/** Run the decoder network.
*
* Caution: We assume there are no recurrent connections in the decoder and
* the decoder is stateless. See
* https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py
* for an example
*
* @param decoder_input It is usually of shape (N, context_size)
* @return Return a tensor of shape (N, decoder_dim).
*/
Ort::Value RunDecoder(Ort::Value decoder_input);
/** Run the joint network.
*
* @param encoder_out Output of the encoder network. A tensor of shape
* (N, joiner_dim).
* @param decoder_out Output of the decoder network. A tensor of shape
* (N, joiner_dim).
* @return Return a tensor of shape (N, vocab_size). In icefall, the last
* last layer of the joint network is `nn.Linear`,
* not `nn.LogSoftmax`.
*/
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out);
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const;
/** Return the context_size of the decoder model.
*/
int32_t ContextSize() const;
/** Return the subsampling factor of the model.
*/
int32_t SubsamplingFactor() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
/** Build decoder_input from the current results.
*
* @param results Current decoded results.
* @param end_index We only use results[0:end_index] to build
* the decoder_input.
* @return Return a tensor of shape (results.size(), ContextSize())
*/
Ort::Value BuildDecoderInput(
const std::vector<OfflineTransducerDecoderResult> &results,
int32_t end_index) const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_
... ...
... ... @@ -95,7 +95,7 @@ void OnlineLstmTransducerModel::InitEncoder(void *model_data,
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
... ... @@ -123,7 +123,7 @@ void OnlineLstmTransducerModel::InitDecoder(void *model_data,
std::ostringstream os;
os << "---decoder---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
... ... @@ -148,7 +148,7 @@ void OnlineLstmTransducerModel::InitJoiner(void *model_data,
std::ostringstream os;
os << "---joiner---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
}
... ... @@ -228,9 +228,6 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
std::pair<Ort::Value, std::vector<Ort::Value>>
OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> states) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<Ort::Value, 3> encoder_inputs = {
std::move(features), std::move(states[0]), std::move(states[1])};
... ...
... ... @@ -20,7 +20,7 @@ class OnlineStream::Impl {
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
}
void InputFinished() { feat_extractor_.InputFinished(); }
void InputFinished() const { feat_extractor_.InputFinished(); }
int32_t NumFramesReady() const {
return feat_extractor_.NumFramesReady() - start_frame_index_;
... ... @@ -68,11 +68,11 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
OnlineStream::~OnlineStream() = default;
void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
int32_t n) {
int32_t n) const {
impl_->AcceptWaveform(sampling_rate, waveform, n);
}
void OnlineStream::InputFinished() { impl_->InputFinished(); }
void OnlineStream::InputFinished() const { impl_->InputFinished(); }
int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); }
... ...
... ... @@ -27,7 +27,8 @@ class OnlineStream {
the range [-1, 1].
@param n Number of entries in waveform
*/
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
void AcceptWaveform(int32_t sampling_rate, const float *waveform,
int32_t n) const;
/**
* InputFinished() tells the class you won't be providing any
... ... @@ -35,7 +36,7 @@ class OnlineStream {
* of features, in the case where snip-edges == false; it also
* affects the return value of IsLastFrame().
*/
void InputFinished();
void InputFinished() const;
int32_t NumFramesReady() const;
... ...
... ... @@ -248,14 +248,21 @@ int32_t main(int32_t argc, char *argv[]) {
std::string wave_filename = po.GetArg(1);
bool is_ok = false;
int32_t actual_sample_rate = -1;
std::vector<float> samples =
sherpa_onnx::ReadWave(wave_filename, sample_rate, &is_ok);
sherpa_onnx::ReadWave(wave_filename, &actual_sample_rate, &is_ok);
if (!is_ok) {
SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str());
return -1;
}
if (actual_sample_rate != sample_rate) {
SHERPA_ONNX_LOGE("Expected sample rate: %d, given %d", sample_rate,
actual_sample_rate);
return -1;
}
asio::io_context io_conn; // for network connections
Client c(io_conn, server_ip, server_port, samples, samples_per_message,
seconds_per_message);
... ...
... ... @@ -97,7 +97,7 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
... ... @@ -123,8 +123,8 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
print(num_encoder_layers_, "num_encoder_layers");
print(cnn_module_kernels_, "cnn_module_kernels");
print(left_context_len_, "left_context_len");
fprintf(stderr, "T: %d\n", T_);
fprintf(stderr, "decode_chunk_len_: %d\n", decode_chunk_len_);
SHERPA_ONNX_LOGE("T: %d", T_);
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
}
}
... ... @@ -145,7 +145,7 @@ void OnlineZipformerTransducerModel::InitDecoder(void *model_data,
std::ostringstream os;
os << "---decoder---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
... ... @@ -170,7 +170,7 @@ void OnlineZipformerTransducerModel::InitJoiner(void *model_data,
std::ostringstream os;
os << "---joiner---\n";
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
}
... ... @@ -435,9 +435,6 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() {
std::pair<Ort::Value, std::vector<Ort::Value>>
OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> states) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<Ort::Value> encoder_inputs;
encoder_inputs.reserve(1 + states.size());
... ...
... ... @@ -41,7 +41,7 @@ PackedSequence PackPaddedSequence(OrtAllocator *allocator,
std::vector<int64_t> l_shape = length->GetTensorTypeAndShapeInfo().GetShape();
assert(v_shape.size() == 3);
assert(l_shape.size() == 3);
assert(l_shape.size() == 1);
assert(v_shape[0] == l_shape[0]);
std::vector<int32_t> indexes(v_shape[0]);
... ...
... ... @@ -13,7 +13,26 @@ namespace sherpa_onnx {
struct PackedSequence {
std::vector<int32_t> sorted_indexes;
std::vector<int32_t> batch_sizes;
// data is a 2-D tensor of shape (sum(batch_sizes), channels)
Ort::Value data{nullptr};
// Return a shallow copy of data[start:start+size, :]
Ort::Value Get(int32_t start, int32_t size) {
auto shape = data.GetTensorTypeAndShapeInfo().GetShape();
std::array<int64_t, 2> ans_shape{size, shape[1]};
float *p = data.GetTensorMutableData<float>();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// a shallow copy
return Ort::Value::CreateTensor(memory_info, p + start * shape[1],
size * shape[1], ans_shape.data(),
ans_shape.size());
}
};
/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only
... ...
... ... @@ -46,7 +46,7 @@ I Gcd(I m, I n) {
// this function is copied from kaldi/src/base/kaldi-math.h
if (m == 0 || n == 0) {
if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors.
fprintf(stderr, "Undefined GCD since m = 0, n = 0.");
fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n");
exit(-1);
}
return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m));
... ...
... ... @@ -95,6 +95,10 @@ as the device_name.
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
sherpa_onnx::OnlineRecognizer recognizer(config);
int32_t expected_sample_rate = config.feat_config.sampling_rate;
... ...
... ... @@ -86,6 +86,11 @@ for a list of pre-trained models to download.
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
sherpa_onnx::OnlineRecognizer recognizer(config);
auto s = recognizer.CreateStream();
... ...
// sherpa-onnx/csrc/sherpa-onnx-offline.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-stream.h"
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-model.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 8) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx-offline \
/path/to/tokens.txt \
/path/to/encoder.onnx \
/path/to/decoder.onnx \
/path/to/joiner.onnx \
/path/to/foo.wav [num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search.
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)usage";
fprintf(stderr, "%s\n", usage);
return 0;
}
sherpa_onnx::OfflineRecognizerConfig config;
config.model_config.tokens = argv[1];
config.model_config.debug = false;
config.model_config.encoder_filename = argv[2];
config.model_config.decoder_filename = argv[3];
config.model_config.joiner_filename = argv[4];
std::string wav_filename = argv[5];
config.model_config.num_threads = 2;
if (argc == 7 && atoi(argv[6]) > 0) {
config.model_config.num_threads = atoi(argv[6]);
}
if (argc == 8) {
config.decoding_method = argv[7];
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
int32_t sampling_rate = -1;
bool is_ok = false;
std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
return -1;
}
fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);
float duration = samples.size() / static_cast<float>(sampling_rate);
sherpa_onnx::OfflineRecognizer recognizer(config);
auto s = recognizer.CreateStream();
auto begin = std::chrono::steady_clock::now();
fprintf(stderr, "Started\n");
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
recognizer.DecodeStream(s.get());
fprintf(stderr, "Done!\n");
fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
s->GetResult().text.c_str());
auto end = std::chrono::steady_clock::now();
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
return 0;
}
... ...
... ... @@ -26,6 +26,8 @@ Usage:
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
... ... @@ -59,20 +61,26 @@ for a list of pre-trained models to download.
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
sherpa_onnx::OnlineRecognizer recognizer(config);
int32_t expected_sampling_rate = config.feat_config.sampling_rate;
int32_t sampling_rate = -1;
bool is_ok = false;
std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
return -1;
}
fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);
float duration = samples.size() / static_cast<float>(expected_sampling_rate);
float duration = samples.size() / static_cast<float>(sampling_rate);
fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
fprintf(stderr, "wav duration (s): %.3f\n", duration);
... ... @@ -81,12 +89,13 @@ for a list of pre-trained models to download.
fprintf(stderr, "Started\n");
auto s = recognizer.CreateStream();
s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
std::vector<float> tail_paddings(static_cast<int>(0.2 * sampling_rate));
// Note: We can call AcceptWaveform() multiple times.
s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size());
std::vector<float> tail_paddings(
static_cast<int>(0.2 * expected_sampling_rate));
s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
tail_paddings.size());
// Call InputFinished() to indicate that no audio samples are available
s->InputFinished();
while (recognizer.IsReady(s.get())) {
... ...
... ... @@ -30,4 +30,23 @@ TEST(Slice, Slice3D) {
// TODO(fangjun): Check that the results are correct
}
TEST(Slice, Slice2D) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 2> shape{5, 8};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
std::iota(p, p + shape[0] * shape[1], 0);
auto v1 = Slice(allocator, &v, 1, 3);
auto v2 = Slice(allocator, &v, 0, 2);
Print2D(&v);
Print2D(&v1);
Print2D(&v2);
// TODO(fangjun): Check that the results are correct
}
} // namespace sherpa_onnx
... ...
... ... @@ -24,7 +24,7 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
assert(0 <= dim1_start);
assert(dim1_start < dim1_end);
assert(dim1_end < shape[1]);
assert(dim1_end <= shape[1]);
const T *src = v->GetTensorData<T>();
... ... @@ -46,8 +46,35 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
return ans;
}
template <typename T /*= float*/>
Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
int32_t dim0_start, int32_t dim0_end) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
assert(shape.size() == 2);
assert(0 <= dim0_start);
assert(dim0_start < dim0_end);
assert(dim0_end <= shape[0]);
const T *src = v->GetTensorData<T>();
std::array<int64_t, 2> ans_shape{dim0_end - dim0_start, shape[1]};
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
ans_shape.size());
const T *start = v->GetTensorData<T>() + dim0_start * shape[1];
const T *end = v->GetTensorData<T>() + dim0_end * shape[1];
T *dst = ans.GetTensorMutableData<T>();
std::copy(start, end, dst);
return ans;
}
template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v,
int32_t dim0_start, int32_t dim0_end,
int32_t dim1_start, int32_t dim1_end);
template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v,
int32_t dim0_start, int32_t dim0_end);
} // namespace sherpa_onnx
... ...
... ... @@ -8,12 +8,12 @@
namespace sherpa_onnx {
/** Get a deep copy by slicing v.
/** Get a deep copy by slicing a 3-D tensor v.
*
* It returns v[dim0_start:dim0_end, dim1_start:dim1_end]
* It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :]
*
* @param allocator
* @param v A 3-D tensor. Its data type is T.
* @param v A 2-D tensor. Its data type is T.
* @param dim0_start Start index of the first dimension..
* @param dim0_end End index of the first dimension..
* @param dim1_start Start index of the second dimension.
... ... @@ -26,6 +26,23 @@ template <typename T = float>
Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
int32_t dim0_start, int32_t dim0_end, int32_t dim1_start,
int32_t dim1_end);
/** Get a deep copy by slicing a 2-D tensor v.
*
* It returns v[dim0_start:dim0_end, :]
*
* @param allocator
* @param v A 2-D tensor. Its data type is T.
* @param dim0_start Start index of the first dimension..
* @param dim0_end End index of the first dimension..
*
* @return Return a 2-D tensor of shape
* (dim0_end-dim0_start, v.shape[1])
*/
template <typename T = float>
Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
int32_t dim0_start, int32_t dim0_end);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SLICE_H_
... ...
... ... @@ -6,10 +6,11 @@
#include <cassert>
#include <fstream>
#include <iostream>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
namespace {
// see http://soundfile.sapp.org/doc/WaveFormat/
... ... @@ -20,26 +21,34 @@ struct WaveHeader {
bool Validate() const {
// F F I R
if (chunk_id != 0x46464952) {
SHERPA_ONNX_LOGE("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id);
return false;
}
// E V A W
if (format != 0x45564157) {
SHERPA_ONNX_LOGE("Expected format WAVE. Given: 0x%08x\n", format);
return false;
}
if (subchunk1_id != 0x20746d66) {
SHERPA_ONNX_LOGE("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n",
subchunk1_id);
return false;
}
if (subchunk1_size != 16) { // 16 for PCM
SHERPA_ONNX_LOGE("Expected subchunk1_size 16. Given: %d\n",
subchunk1_size);
return false;
}
if (audio_format != 1) { // 1 for PCM
SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n", audio_format);
return false;
}
if (num_channels != 1) { // we support only single channel for now
SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n", num_channels);
return false;
}
if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
... ... @@ -51,6 +60,8 @@ struct WaveHeader {
}
if (bits_per_sample != 16) { // we support only 16 bits per sample
SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n",
bits_per_sample);
return false;
}
... ... @@ -62,7 +73,7 @@ struct WaveHeader {
// and
// https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
void SeekToDataChunk(std::istream &is) {
// a t a d
// a t a d
while (is && subchunk2_id != 0x61746164) {
// const char *p = reinterpret_cast<const char *>(&subchunk2_id);
// printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
... ... @@ -91,7 +102,7 @@ static_assert(sizeof(WaveHeader) == 44, "");
// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
bool *is_ok) {
WaveHeader header;
is.read(reinterpret_cast<char *>(&header), sizeof(header));
... ... @@ -111,10 +122,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
return {};
}
if (expected_sample_rate != header.sample_rate) {
*is_ok = false;
return {};
}
*sampling_rate = header.sample_rate;
// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains two bytes, so it is divided by 2 here
... ... @@ -137,15 +145,15 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
} // namespace
std::vector<float> ReadWave(const std::string &filename,
float expected_sample_rate, bool *is_ok) {
std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
bool *is_ok) {
std::ifstream is(filename, std::ifstream::binary);
return ReadWave(is, expected_sample_rate, is_ok);
return ReadWave(is, sampling_rate, is_ok);
}
std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
bool *is_ok) {
auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok);
auto samples = ReadWaveImpl(is, sampling_rate, is_ok);
return samples;
}
... ...
... ... @@ -13,17 +13,17 @@ namespace sherpa_onnx {
/** Read a wave file with expected sample rate.
@param filename Path to a wave file. It MUST be single channel, PCM encoded.
@param expected_sample_rate Expected sample rate of the wave file. If the
sample rate don't match, it throws an exception.
@param filename Path to a wave file. It MUST be single channel, 16-bit
PCM encoded.
@param sampling_rate On return, it contains the sampling rate of the file.
@param is_ok On return it is true if the reading succeeded; false otherwise.
@return Return wave samples normalized to the range [-1, 1).
*/
std::vector<float> ReadWave(const std::string &filename,
float expected_sample_rate, bool *is_ok);
std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
bool *is_ok);
std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
bool *is_ok);
} // namespace sherpa_onnx
... ...
... ... @@ -11,6 +11,7 @@
#include "jni.h" // NOLINT
#include <strstream>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
... ... @@ -43,14 +44,18 @@ class SherpaOnnx {
stream_(recognizer_.CreateStream()) {
}
void AcceptWaveform(int32_t sample_rate, const float *samples,
int32_t n) const {
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
if (input_sample_rate_ == -1) {
input_sample_rate_ = sample_rate;
}
stream_->AcceptWaveform(sample_rate, samples, n);
}
void InputFinished() const {
std::vector<float> tail_padding(16000 * 0.32, 0);
stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size());
std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0);
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
tail_padding.size());
stream_->InputFinished();
}
... ... @@ -70,6 +75,7 @@ class SherpaOnnx {
private:
sherpa_onnx::OnlineRecognizer recognizer_;
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
int32_t input_sample_rate_ = -1;
};
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
... ... @@ -276,17 +282,24 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
return env->NewStringUTF(text.c_str());
}
// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
static jobject NewInteger(JNIEnv *env, int32_t value) {
jclass cls = env->FindClass("java/lang/Integer");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
return env->NewObject(cls, constructor, value);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jfloatArray JNICALL
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename,
jfloat expected_sample_rate) {
JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) {
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return nullptr;
exit(-1);
}
std::vector<char> buffer = sherpa_onnx::ReadFile(mgr, p_filename);
... ... @@ -297,16 +310,25 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
#endif
bool is_ok = false;
int32_t sampling_rate = -1;
std::vector<float> samples =
sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok);
sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok);
env->ReleaseStringUTFChars(filename, p_filename);
if (!is_ok) {
return nullptr;
SHERPA_ONNX_LOGE("Failed to read %s", p_filename);
exit(-1);
}
jfloatArray ans = env->NewFloatArray(samples.size());
env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
return ans;
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, ans);
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate));
return obj_arr;
}
... ...
... ... @@ -11,12 +11,10 @@ namespace sherpa_onnx {
static void PybindFeatureExtractorConfig(py::module *m) {
using PyClass = FeatureExtractorConfig;
py::class_<PyClass>(*m, "FeatureExtractorConfig")
.def(py::init<int32_t, int32_t, int32_t>(),
py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
py::arg("max_feature_vectors") = -1)
.def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,
py::arg("feature_dim") = 80)
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
.def_readwrite("feature_dim", &PyClass::feature_dim)
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -34,7 +34,6 @@ class OnlineRecognizer(object):
rule3_min_utterance_length: int = 20,
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
max_feature_vectors: int = -1,
):
"""
Please refer to
... ... @@ -82,9 +81,6 @@ class OnlineRecognizer(object):
max_active_paths:
Use only when decoding_method is modified_beam_search. It specifies
the maximum number of active paths during beam search.
max_feature_vectors:
Number of feature vectors to cache. -1 means to cache all feature
frames that have been processed.
"""
_assert_file_exists(tokens)
_assert_file_exists(encoder)
... ... @@ -104,7 +100,6 @@ class OnlineRecognizer(object):
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
max_feature_vectors=max_feature_vectors,
)
endpoint_config = EndpointConfig(
... ...
... ... @@ -8,18 +8,18 @@
import unittest
import sherpa_onnx
import _sherpa_onnx
class TestFeatureExtractorConfig(unittest.TestCase):
def test_default_constructor(self):
config = sherpa_onnx.FeatureExtractorConfig()
config = _sherpa_onnx.FeatureExtractorConfig()
assert config.sampling_rate == 16000, config.sampling_rate
assert config.feature_dim == 80, config.feature_dim
print(config)
def test_constructor(self):
config = sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40)
config = _sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40)
assert config.sampling_rate == 8000, config.sampling_rate
assert config.feature_dim == 40, config.feature_dim
print(config)
... ...
... ... @@ -8,21 +8,23 @@
import unittest
import sherpa_onnx
import _sherpa_onnx
class TestOnlineTransducerModelConfig(unittest.TestCase):
def test_constructor(self):
config = sherpa_onnx.OnlineTransducerModelConfig(
config = _sherpa_onnx.OnlineTransducerModelConfig(
encoder_filename="encoder.onnx",
decoder_filename="decoder.onnx",
joiner_filename="joiner.onnx",
tokens="tokens.txt",
num_threads=8,
debug=True,
)
assert config.encoder_filename == "encoder.onnx", config.encoder_filename
assert config.decoder_filename == "decoder.onnx", config.decoder_filename
assert config.joiner_filename == "joiner.onnx", config.joiner_filename
assert config.tokens == "tokens.txt", config.tokens
assert config.num_threads == 8, config.num_threads
assert config.debug is True, config.debug
print(config)
... ...