Fangjun Kuang
Committed by GitHub

Support CED models (#792)

正在显示 33 个修改的文件 包含 605 行增加46 行删除
name: export-ced-to-onnx
on:
workflow_dispatch:
concurrency:
group: export-ced-to-onnx-${{ github.ref }}
cancel-in-progress: true
jobs:
export-ced-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export ced
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.8"]
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Run
shell: bash
run: |
cd scripts/ced
./run.sh
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: audio-tagging-models
- name: Publish to huggingface
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
models=(
tiny
mini
small
base
)
for m in ${models[@]}; do
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19
git clone https://huggingface.co/k2-fsa/$d huggingface
mv -v $d/* huggingface
cd huggingface
git lfs track "*.onnx"
git status
git add .
git status
git commit -m "first commit"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/$d main
cd ..
done
... ...
<resources>
<string name="app_name">ASR with Next-gen Kaldi</string>
<string name="app_name">ASR</string>
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
\n
\n\n\n
... ...
<resources>
<string name="app_name">ASR with Next-gen Kaldi</string>
<string name="app_name">ASR2pass </string>
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
\n
\n\n\n
... ...
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
import android.util.Log
private val TAG = "sherpa-onnx"
const val TAG = "sherpa-onnx"
data class OfflineZipformerAudioTaggingModelConfig(
var model: String,
var model: String = "",
)
data class AudioTaggingModelConfig(
var zipformer: OfflineZipformerAudioTaggingModelConfig,
var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(),
var ced: String = "",
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
... ... @@ -103,7 +103,7 @@ class AudioTagging(
//
// See also
// https://k2-fsa.github.io/sherpa/onnx/audio-tagging/
fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15"
... ... @@ -123,7 +123,46 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
return AudioTaggingConfig(
model = AudioTaggingModelConfig(
zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"),
numThreads = 1,
numThreads = numThreads,
debug = true,
),
labels = "$modelDir/class_labels_indices.csv",
topK = 3,
)
}
2 -> {
val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19"
return AudioTaggingConfig(
model = AudioTaggingModelConfig(
ced = "$modelDir/model.int8.onnx",
numThreads = numThreads,
debug = true,
),
labels = "$modelDir/class_labels_indices.csv",
topK = 3,
)
}
3 -> {
val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19"
return AudioTaggingConfig(
model = AudioTaggingModelConfig(
ced = "$modelDir/model.int8.onnx",
numThreads = numThreads,
debug = true,
),
labels = "$modelDir/class_labels_indices.csv",
topK = 3,
)
}
4 -> {
val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19"
return AudioTaggingConfig(
model = AudioTaggingModelConfig(
ced = "$modelDir/model.int8.onnx",
numThreads = numThreads,
debug = true,
),
labels = "$modelDir/class_labels_indices.csv",
... ... @@ -131,6 +170,18 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
)
}
5 -> {
val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19"
return AudioTaggingConfig(
model = AudioTaggingModelConfig(
ced = "$modelDir/model.int8.onnx",
numThreads = numThreads,
debug = true,
),
labels = "$modelDir/class_labels_indices.csv",
topK = 3,
)
}
}
return null
... ...
... ... @@ -3,24 +3,15 @@
package com.k2fsa.sherpa.onnx.audio.tagging
import android.Manifest
import android.app.Activity
import android.content.pm.PackageManager
import android.media.AudioFormat
import android.media.AudioRecord
import androidx.compose.foundation.lazy.items
import android.media.MediaRecorder
import android.util.Log
import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.runtime.Composable
import androidx.compose.material3.Scaffold
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row
... ... @@ -30,10 +21,17 @@ import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.material3.Button
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Slider
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateListOf
import androidx.compose.runtime.mutableStateOf
... ... @@ -41,7 +39,6 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
... ... @@ -49,6 +46,7 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.AudioEvent
import com.k2fsa.sherpa.onnx.Tagger
import kotlin.concurrent.thread
... ...
... ... @@ -13,6 +13,7 @@ import androidx.compose.material3.Surface
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.Tagger
import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme
const val TAG = "sherpa-onnx"
... ...
package com.k2fsa.sherpa.onnx.audio.tagging
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
import android.util.Log
import com.k2fsa.sherpa.onnx.AudioTagging
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.TAG
import com.k2fsa.sherpa.onnx.getAudioTaggingConfig
object Tagger {
private var _tagger: AudioTagging? = null
... ... @@ -12,6 +10,7 @@ object Tagger {
get() {
return _tagger!!
}
fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) {
synchronized(this) {
if (_tagger != null) {
... ... @@ -19,7 +18,7 @@ object Tagger {
}
Log.i(TAG, "Initializing audio tagger")
val config = getAudioTaggingConfig(type = 0, numThreads=numThreads)!!
val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!!
_tagger = AudioTagging(assetManager, config)
}
}
... ...
... ... @@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button
import androidx.wear.compose.material.MaterialTheme
import androidx.wear.compose.material.Text
import com.k2fsa.sherpa.onnx.AudioEvent
import com.k2fsa.sherpa.onnx.audio.tagging.Tagger
import com.k2fsa.sherpa.onnx.Tagger
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme
import kotlin.concurrent.thread
... ...
... ... @@ -17,7 +17,7 @@ import androidx.activity.compose.setContent
import androidx.compose.runtime.Composable
import androidx.core.app.ActivityCompat
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
import com.k2fsa.sherpa.onnx.audio.tagging.Tagger
import com.k2fsa.sherpa.onnx.Tagger
const val TAG = "sherpa-onnx"
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
... ...
<resources>
<string name="app_name">AudioTagging</string>
<string name="app_name">Audio Tagging</string>
<!--
This string is used for square devices and overridden by hello_world in
values-round/strings.xml for round devices.
... ...
<resources>
<string name="app_name">Speaker Identification</string>
<string name="app_name">Speaker ID</string>
<string name="start">Start recording</string>
<string name="stop">Stop recording</string>
<string name="add">Add speaker</string>
... ...
<resources>
<string name="app_name">SherpaOnnxSpokenLanguageIdentification</string>
<string name="app_name">Language ID</string>
</resources>
\ No newline at end of file
... ...
<resources>
<string name="app_name">Next-gen Kaldi: TTS</string>
<string name="app_name">TTS</string>
<string name="sid_label">Speaker ID</string>
<string name="sid_hint">0</string>
<string name="speed_label">Speech speed (large->fast)</string>
... ...
<resources>
<string name="app_name">Next-gen Kaldi: TTS</string>
<string name="app_name">TTS Engine</string>
</resources>
\ No newline at end of file
... ...
<resources>
<string name="app_name">Next-gen Kaldi: SileroVAD</string>
<string name="app_name">VAD</string>
<string name="hint">Click the Start button to play Silero VAD with Next-gen Kaldi.</string>
<string name="start">Start</string>
... ...
<resources>
<string name="app_name">ASR with Next-gen Kaldi</string>
<string name="app_name">VAD-ASR</string>
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
\n
\n\n\n
... ...
... ... @@ -46,7 +46,30 @@ def get_models():
),
]
return icefall_models
ced_models = [
AudioTaggingModel(
model_name="sherpa-onnx-ced-tiny-audio-tagging-2024-04-19",
idx=2,
short_name="ced_tiny",
),
AudioTaggingModel(
model_name="sherpa-onnx-ced-mini-audio-tagging-2024-04-19",
idx=3,
short_name="ced_mini",
),
AudioTaggingModel(
model_name="sherpa-onnx-ced-small-audio-tagging-2024-04-19",
idx=4,
short_name="ced_small",
),
AudioTaggingModel(
model_name="sherpa-onnx-ced-base-audio-tagging-2024-04-19",
idx=5,
short_name="ced_base",
),
]
return icefall_models + ced_models
def main():
... ...
#!/usr/bin/env bash
#
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
set -ex
function install_dependencies() {
pip install -qq torch==2.1.0+cpu torchaudio==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install -qq onnx onnxruntime==1.17.1
pip install -r ./requirements.txt
}
git clone https://github.com/RicherMans/CED
pushd CED
install_dependencies
models=(
tiny
mini
small
base
)
for m in ${models[@]}; do
python3 ./export_onnx.py -m ced_$m
done
ls -lh *.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
tar xvf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
rm sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
src=sherpa-onnx-zipformer-small-audio-tagging-2024-04-15
cat >README.md <<EOF
# Introduction
Models in this repo are converted from
https://github.com/RicherMans/CED
EOF
for m in ${models[@]}; do
d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19
mkdir -p $d
cp -v README.md $d
cp -v $src/class_labels_indices.csv $d
cp -a $src/test_wavs $d
cp -v ced_$m.onnx $d/model.onnx
cp -v ced_$m.int8.onnx $d/model.int8.onnx
echo "----------$m----------"
ls -lh $d
echo "----------------------"
tar cjvf $d.tar.bz2 $d
mv $d.tar.bz2 ../../..
mv $d ../../../
done
rm -rf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15
cd ../../..
ls -lh *.tar.bz2
echo "======="
ls -lh
... ...
... ... @@ -1223,6 +1223,7 @@ const SherpaOnnxAudioTagging *SherpaOnnxCreateAudioTagging(
const SherpaOnnxAudioTaggingConfig *config) {
sherpa_onnx::AudioTaggingConfig ac;
ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, "");
ac.model.ced = SHERPA_ONNX_OR(config->model.ced, "");
ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
ac.model.debug = config->model.debug;
ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
... ...
... ... @@ -1100,6 +1100,7 @@ SHERPA_ONNX_API typedef struct
SHERPA_ONNX_API typedef struct SherpaOnnxAudioTaggingModelConfig {
SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer;
const char *ced;
int32_t num_threads;
int32_t debug; // true to print debug information of the model
const char *provider;
... ...
... ... @@ -117,6 +117,7 @@ list(APPEND sources
audio-tagging-label-file.cc
audio-tagging-model-config.cc
audio-tagging.cc
offline-ced-model.cc
offline-zipformer-audio-tagging-model-config.cc
offline-zipformer-audio-tagging-model.cc
)
... ...
// sherpa-onnx/csrc/audio-tagging-ced-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
#include <assert.h>
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
#include "sherpa-onnx/csrc/audio-tagging.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/offline-ced-model.h"
namespace sherpa_onnx {
class AudioTaggingCEDImpl : public AudioTaggingImpl {
public:
explicit AudioTaggingCEDImpl(const AudioTaggingConfig &config)
: config_(config), model_(config.model), labels_(config.labels) {
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
model_.NumEventClasses(), labels_.NumEventClasses());
exit(-1);
}
}
#if __ANDROID_API__ >= 9
explicit AudioTaggingCEDImpl(AAssetManager *mgr,
const AudioTaggingConfig &config)
: config_(config),
model_(mgr, config.model),
labels_(mgr, config.labels) {
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
model_.NumEventClasses(), labels_.NumEventClasses());
exit(-1);
}
}
#endif
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(CEDTag{});
}
std::vector<AudioEvent> Compute(OfflineStream *s,
int32_t top_k = -1) const override {
if (top_k < 0) {
top_k = config_.top_k;
}
int32_t num_event_classes = model_.NumEventClasses();
if (top_k > num_event_classes) {
top_k = num_event_classes;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// WARNING(fangjun): It is fixed to 64 for CED models
int32_t feat_dim = 64;
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
assert(feat_dim * num_frames == f.size());
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
Ort::Value probs = model_.Forward(std::move(x));
const float *p = probs.GetTensorData<float>();
std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
std::vector<AudioEvent> ans(top_k);
int32_t i = 0;
for (int32_t index : top_k_indexes) {
ans[i].name = labels_.GetEventName(index);
ans[i].index = index;
ans[i].prob = p[index];
i += 1;
}
return ans;
}
private:
AudioTaggingConfig config_;
OfflineCEDModel model_;
AudioTaggingLabels labels_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
... ...
... ... @@ -11,6 +11,7 @@
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/audio-tagging-ced-impl.h"
#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
#include "sherpa-onnx/csrc/macros.h"
... ... @@ -20,6 +21,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
const AudioTaggingConfig &config) {
if (!config.model.zipformer.model.empty()) {
return std::make_unique<AudioTaggingZipformerImpl>(config);
} else if (!config.model.ced.empty()) {
return std::make_unique<AudioTaggingCEDImpl>(config);
}
SHERPA_ONNX_LOG(
... ... @@ -32,6 +35,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
AAssetManager *mgr, const AudioTaggingConfig &config) {
if (!config.model.zipformer.model.empty()) {
return std::make_unique<AudioTaggingZipformerImpl>(mgr, config);
} else if (!config.model.ced.empty()) {
return std::make_unique<AudioTaggingCEDImpl>(mgr, config);
}
SHERPA_ONNX_LOG(
... ...
... ... @@ -4,11 +4,18 @@
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void AudioTaggingModelConfig::Register(ParseOptions *po) {
zipformer.Register(po);
po->Register("ced-model", &ced,
"Path to CED model. Only need to pass one of --zipformer-model "
"or --ced-model");
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
... ... @@ -24,6 +31,16 @@ bool AudioTaggingModelConfig::Validate() const {
return false;
}
if (!ced.empty() && !FileExists(ced)) {
SHERPA_ONNX_LOGE("CED model file %s does not exist", ced.c_str());
return false;
}
if (zipformer.model.empty() && ced.empty()) {
SHERPA_ONNX_LOGE("Please provide either --zipformer-model or --ced-model");
return false;
}
return true;
}
... ... @@ -32,6 +49,7 @@ std::string AudioTaggingModelConfig::ToString() const {
os << "AudioTaggingModelConfig(";
os << "zipformer=" << zipformer.ToString() << ", ";
os << "ced=\"" << ced << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
... ...
... ... @@ -13,6 +13,7 @@ namespace sherpa_onnx {
struct AudioTaggingModelConfig {
struct OfflineZipformerAudioTaggingModelConfig zipformer;
std::string ced;
int32_t num_threads = 1;
bool debug = false;
... ... @@ -22,8 +23,10 @@ struct AudioTaggingModelConfig {
AudioTaggingModelConfig(
const OfflineZipformerAudioTaggingModelConfig &zipformer,
int32_t num_threads, bool debug, const std::string &provider)
const std::string &ced, int32_t num_threads, bool debug,
const std::string &provider)
: zipformer(zipformer),
ced(ced),
num_threads(num_threads),
debug(debug),
provider(provider) {}
... ...
... ... @@ -4,6 +4,8 @@
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
#include <assert.h>
#include <memory>
#include <utility>
#include <vector>
... ... @@ -72,6 +74,8 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl {
int32_t num_frames = f.size() / feat_dim;
assert(feat_dim * num_frames == f.size());
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
... ...
... ... @@ -24,7 +24,8 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
"inside the feature extractor");
po->Register("feat-dim", &feature_dim,
"Feature dimension. Must match the one expected by the model.");
"Feature dimension. Must match the one expected by the model. "
"Not used by whisper and CED models");
po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins");
... ...
// sherpa-onnx/csrc/offline-ced-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-ced-model.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class OfflineCEDModel::Impl {
public:
explicit Impl(const AudioTaggingModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.ced);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.ced);
Init(buf.data(), buf.size());
}
#endif
Ort::Value Forward(Ort::Value features) {
features = Transpose12(allocator_, &features);
auto ans = sess_->Run({}, input_names_ptr_.data(), &features, 1,
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(ans[0]);
}
int32_t NumEventClasses() const { return num_event_classes_; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
// get num_event_classes from the output[0].shape,
// which is (N, num_event_classes)
num_event_classes_ =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
}
private:
AudioTaggingModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t num_event_classes_ = 0;
};
OfflineCEDModel::OfflineCEDModel(const AudioTaggingModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineCEDModel::OfflineCEDModel(AAssetManager *mgr,
const AudioTaggingModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineCEDModel::~OfflineCEDModel() = default;
Ort::Value OfflineCEDModel::Forward(Ort::Value features) const {
return impl_->Forward(std::move(features));
}
int32_t OfflineCEDModel::NumEventClasses() const {
return impl_->NumEventClasses();
}
OrtAllocator *OfflineCEDModel::Allocator() const { return impl_->Allocator(); }
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-ced-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
#include <memory>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
namespace sherpa_onnx {
/** This class implements the CED model from
* https://github.com/RicherMans/CED/blob/main/export_onnx.py
*/
class OfflineCEDModel {
public:
explicit OfflineCEDModel(const AudioTaggingModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineCEDModel(AAssetManager *mgr, const AudioTaggingModelConfig &config);
#endif
~OfflineCEDModel();
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C).
*
* @return Return a tensor
* - probs: A 2-D tensor of shape (N, num_event_classes).
*/
Ort::Value Forward(Ort::Value features) const;
/** Return the number of event classes of the model
*/
int32_t NumEventClasses() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
... ...
... ... @@ -92,15 +92,32 @@ class OfflineStream::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
: context_graph_(context_graph) {
explicit Impl(WhisperTag /*tag*/) {
config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = 80;
opts_.mel_opts.num_bins = 80; // not used
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
}
explicit Impl(CEDTag /*tag*/) {
// see
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
opts_.frame_opts.frame_length_ms = 32;
opts_.frame_opts.dither = 0;
opts_.frame_opts.preemph_coeff = 0;
opts_.frame_opts.remove_dc_offset = false;
opts_.frame_opts.window_type = "hann";
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
opts_.mel_opts.num_bins = 64;
opts_.mel_opts.high_freq = 8000;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n);
... ... @@ -233,9 +250,10 @@ OfflineStream::OfflineStream(
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag,
ContextGraphPtr context_graph /*= {}*/)
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag)
: impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::~OfflineStream() = default;
... ...
... ... @@ -67,13 +67,15 @@ struct OfflineFeatureExtractorConfig {
};
struct WhisperTag {};
struct CEDTag {};
class OfflineStream {
public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = {});
explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {});
explicit OfflineStream(WhisperTag tag);
explicit OfflineStream(CEDTag tag);
~OfflineStream();
/**
... ...
... ... @@ -31,6 +31,12 @@ static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) {
ans.model.zipformer.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_cls, "ced", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.ced = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);
... ...
... ... @@ -27,10 +27,11 @@ static void PybindAudioTaggingModelConfig(py::module *m) {
py::class_<PyClass>(*m, "AudioTaggingModelConfig")
.def(py::init<>())
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t,
bool, const std::string &>(),
py::arg("zipformer"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &,
const std::string &, int32_t, bool, const std::string &>(),
py::arg("zipformer"), py::arg("ced") = "",
py::arg("num_threads") = 1, py::arg("debug") = false,
py::arg("provider") = "cpu")
.def_readwrite("zipformer", &PyClass::zipformer)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
... ...