Fangjun Kuang
Committed by GitHub

Support CED models (#792)

正在显示 33 个修改的文件 包含 605 行增加46 行删除
  1 +name: export-ced-to-onnx
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: export-ced-to-onnx-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + export-ced-to-onnx:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: export ced
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [ubuntu-latest]
  19 + python-version: ["3.8"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v4
  23 +
  24 + - name: Setup Python ${{ matrix.python-version }}
  25 + uses: actions/setup-python@v5
  26 + with:
  27 + python-version: ${{ matrix.python-version }}
  28 +
  29 + - name: Run
  30 + shell: bash
  31 + run: |
  32 + cd scripts/ced
  33 + ./run.sh
  34 +
  35 + - name: Release
  36 + uses: svenstaro/upload-release-action@v2
  37 + with:
  38 + file_glob: true
  39 + file: ./*.tar.bz2
  40 + overwrite: true
  41 + repo_name: k2-fsa/sherpa-onnx
  42 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  43 + tag: audio-tagging-models
  44 +
  45 + - name: Publish to huggingface
  46 + env:
  47 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  48 + uses: nick-fields/retry@v3
  49 + with:
  50 + max_attempts: 20
  51 + timeout_seconds: 200
  52 + shell: bash
  53 + command: |
  54 + git config --global user.email "csukuangfj@gmail.com"
  55 + git config --global user.name "Fangjun Kuang"
  56 +
  57 + models=(
  58 + tiny
  59 + mini
  60 + small
  61 + base
  62 + )
  63 +
  64 + for m in ${models[@]}; do
  65 + rm -rf huggingface
  66 + export GIT_LFS_SKIP_SMUDGE=1
  67 + d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19
  68 + git clone https://huggingface.co/k2-fsa/$d huggingface
  69 + mv -v $d/* huggingface
  70 + cd huggingface
  71 + git lfs track "*.onnx"
  72 + git status
  73 + git add .
  74 + git status
  75 + git commit -m "first commit"
  76 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/$d main
  77 + cd ..
  78 + done
1 <resources> 1 <resources>
2 - <string name="app_name">ASR with Next-gen Kaldi</string> 2 + <string name="app_name">ASR</string>
3 <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. 3 <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
4 \n 4 \n
5 \n\n\n 5 \n\n\n
1 <resources> 1 <resources>
2 - <string name="app_name">ASR with Next-gen Kaldi</string> 2 + <string name="app_name">ASR2pass </string>
3 <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. 3 <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
4 \n 4 \n
5 \n\n\n 5 \n\n\n
1 package com.k2fsa.sherpa.onnx 1 package com.k2fsa.sherpa.onnx
2 2
3 import android.content.res.AssetManager 3 import android.content.res.AssetManager
4 -import android.util.Log  
5 4
6 -private val TAG = "sherpa-onnx" 5 +const val TAG = "sherpa-onnx"
7 6
8 data class OfflineZipformerAudioTaggingModelConfig( 7 data class OfflineZipformerAudioTaggingModelConfig(
9 - var model: String, 8 + var model: String = "",
10 ) 9 )
11 10
12 data class AudioTaggingModelConfig( 11 data class AudioTaggingModelConfig(
13 - var zipformer: OfflineZipformerAudioTaggingModelConfig, 12 + var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(),
  13 + var ced: String = "",
14 var numThreads: Int = 1, 14 var numThreads: Int = 1,
15 var debug: Boolean = false, 15 var debug: Boolean = false,
16 var provider: String = "cpu", 16 var provider: String = "cpu",
@@ -103,7 +103,7 @@ class AudioTagging( @@ -103,7 +103,7 @@ class AudioTagging(
103 // 103 //
104 // See also 104 // See also
105 // https://k2-fsa.github.io/sherpa/onnx/audio-tagging/ 105 // https://k2-fsa.github.io/sherpa/onnx/audio-tagging/
106 -fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? { 106 +fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? {
107 when (type) { 107 when (type) {
108 0 -> { 108 0 -> {
109 val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15" 109 val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15"
@@ -123,7 +123,46 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? { @@ -123,7 +123,46 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
123 return AudioTaggingConfig( 123 return AudioTaggingConfig(
124 model = AudioTaggingModelConfig( 124 model = AudioTaggingModelConfig(
125 zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"), 125 zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"),
126 - numThreads = 1, 126 + numThreads = numThreads,
  127 + debug = true,
  128 + ),
  129 + labels = "$modelDir/class_labels_indices.csv",
  130 + topK = 3,
  131 + )
  132 + }
  133 +
  134 + 2 -> {
  135 + val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19"
  136 + return AudioTaggingConfig(
  137 + model = AudioTaggingModelConfig(
  138 + ced = "$modelDir/model.int8.onnx",
  139 + numThreads = numThreads,
  140 + debug = true,
  141 + ),
  142 + labels = "$modelDir/class_labels_indices.csv",
  143 + topK = 3,
  144 + )
  145 + }
  146 +
  147 + 3 -> {
  148 + val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19"
  149 + return AudioTaggingConfig(
  150 + model = AudioTaggingModelConfig(
  151 + ced = "$modelDir/model.int8.onnx",
  152 + numThreads = numThreads,
  153 + debug = true,
  154 + ),
  155 + labels = "$modelDir/class_labels_indices.csv",
  156 + topK = 3,
  157 + )
  158 + }
  159 +
  160 + 4 -> {
  161 + val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19"
  162 + return AudioTaggingConfig(
  163 + model = AudioTaggingModelConfig(
  164 + ced = "$modelDir/model.int8.onnx",
  165 + numThreads = numThreads,
127 debug = true, 166 debug = true,
128 ), 167 ),
129 labels = "$modelDir/class_labels_indices.csv", 168 labels = "$modelDir/class_labels_indices.csv",
@@ -131,6 +170,18 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? { @@ -131,6 +170,18 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
131 ) 170 )
132 } 171 }
133 172
  173 + 5 -> {
  174 + val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19"
  175 + return AudioTaggingConfig(
  176 + model = AudioTaggingModelConfig(
  177 + ced = "$modelDir/model.int8.onnx",
  178 + numThreads = numThreads,
  179 + debug = true,
  180 + ),
  181 + labels = "$modelDir/class_labels_indices.csv",
  182 + topK = 3,
  183 + )
  184 + }
134 } 185 }
135 186
136 return null 187 return null
@@ -3,24 +3,15 @@ @@ -3,24 +3,15 @@
3 package com.k2fsa.sherpa.onnx.audio.tagging 3 package com.k2fsa.sherpa.onnx.audio.tagging
4 4
5 import android.Manifest 5 import android.Manifest
6 -  
7 import android.app.Activity 6 import android.app.Activity
8 import android.content.pm.PackageManager 7 import android.content.pm.PackageManager
9 import android.media.AudioFormat 8 import android.media.AudioFormat
10 import android.media.AudioRecord 9 import android.media.AudioRecord
11 -import androidx.compose.foundation.lazy.items  
12 import android.media.MediaRecorder 10 import android.media.MediaRecorder
13 import android.util.Log 11 import android.util.Log
14 import androidx.compose.foundation.ExperimentalFoundationApi 12 import androidx.compose.foundation.ExperimentalFoundationApi
15 -import androidx.compose.foundation.background  
16 import androidx.compose.foundation.layout.Arrangement 13 import androidx.compose.foundation.layout.Arrangement
17 import androidx.compose.foundation.layout.Box 14 import androidx.compose.foundation.layout.Box
18 -import androidx.compose.material3.CenterAlignedTopAppBar  
19 -import androidx.compose.runtime.Composable  
20 -import androidx.compose.material3.Scaffold  
21 -import androidx.compose.material3.TopAppBarDefaults  
22 -import androidx.compose.material3.MaterialTheme  
23 -import androidx.compose.material3.Text  
24 import androidx.compose.foundation.layout.Column 15 import androidx.compose.foundation.layout.Column
25 import androidx.compose.foundation.layout.PaddingValues 16 import androidx.compose.foundation.layout.PaddingValues
26 import androidx.compose.foundation.layout.Row 17 import androidx.compose.foundation.layout.Row
@@ -30,10 +21,17 @@ import androidx.compose.foundation.layout.fillMaxWidth @@ -30,10 +21,17 @@ import androidx.compose.foundation.layout.fillMaxWidth
30 import androidx.compose.foundation.layout.height 21 import androidx.compose.foundation.layout.height
31 import androidx.compose.foundation.layout.padding 22 import androidx.compose.foundation.layout.padding
32 import androidx.compose.foundation.lazy.LazyColumn 23 import androidx.compose.foundation.lazy.LazyColumn
  24 +import androidx.compose.foundation.lazy.items
33 import androidx.compose.material3.Button 25 import androidx.compose.material3.Button
  26 +import androidx.compose.material3.CenterAlignedTopAppBar
34 import androidx.compose.material3.ExperimentalMaterial3Api 27 import androidx.compose.material3.ExperimentalMaterial3Api
  28 +import androidx.compose.material3.MaterialTheme
  29 +import androidx.compose.material3.Scaffold
35 import androidx.compose.material3.Slider 30 import androidx.compose.material3.Slider
36 import androidx.compose.material3.Surface 31 import androidx.compose.material3.Surface
  32 +import androidx.compose.material3.Text
  33 +import androidx.compose.material3.TopAppBarDefaults
  34 +import androidx.compose.runtime.Composable
37 import androidx.compose.runtime.getValue 35 import androidx.compose.runtime.getValue
38 import androidx.compose.runtime.mutableStateListOf 36 import androidx.compose.runtime.mutableStateListOf
39 import androidx.compose.runtime.mutableStateOf 37 import androidx.compose.runtime.mutableStateOf
@@ -41,7 +39,6 @@ import androidx.compose.runtime.remember @@ -41,7 +39,6 @@ import androidx.compose.runtime.remember
41 import androidx.compose.runtime.setValue 39 import androidx.compose.runtime.setValue
42 import androidx.compose.ui.Alignment 40 import androidx.compose.ui.Alignment
43 import androidx.compose.ui.Modifier 41 import androidx.compose.ui.Modifier
44 -import androidx.compose.ui.graphics.Color  
45 import androidx.compose.ui.platform.LocalContext 42 import androidx.compose.ui.platform.LocalContext
46 import androidx.compose.ui.text.font.FontWeight 43 import androidx.compose.ui.text.font.FontWeight
47 import androidx.compose.ui.text.style.TextAlign 44 import androidx.compose.ui.text.style.TextAlign
@@ -49,6 +46,7 @@ import androidx.compose.ui.unit.dp @@ -49,6 +46,7 @@ import androidx.compose.ui.unit.dp
49 import androidx.compose.ui.unit.sp 46 import androidx.compose.ui.unit.sp
50 import androidx.core.app.ActivityCompat 47 import androidx.core.app.ActivityCompat
51 import com.k2fsa.sherpa.onnx.AudioEvent 48 import com.k2fsa.sherpa.onnx.AudioEvent
  49 +import com.k2fsa.sherpa.onnx.Tagger
52 import kotlin.concurrent.thread 50 import kotlin.concurrent.thread
53 51
54 52
@@ -13,6 +13,7 @@ import androidx.compose.material3.Surface @@ -13,6 +13,7 @@ import androidx.compose.material3.Surface
13 import androidx.compose.runtime.Composable 13 import androidx.compose.runtime.Composable
14 import androidx.compose.ui.Modifier 14 import androidx.compose.ui.Modifier
15 import androidx.core.app.ActivityCompat 15 import androidx.core.app.ActivityCompat
  16 +import com.k2fsa.sherpa.onnx.Tagger
16 import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme 17 import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme
17 18
18 const val TAG = "sherpa-onnx" 19 const val TAG = "sherpa-onnx"
1 -package com.k2fsa.sherpa.onnx.audio.tagging 1 +package com.k2fsa.sherpa.onnx
2 2
3 import android.content.res.AssetManager 3 import android.content.res.AssetManager
4 import android.util.Log 4 import android.util.Log
5 -import com.k2fsa.sherpa.onnx.AudioTagging  
6 -import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.TAG  
7 -import com.k2fsa.sherpa.onnx.getAudioTaggingConfig 5 +
8 6
9 object Tagger { 7 object Tagger {
10 private var _tagger: AudioTagging? = null 8 private var _tagger: AudioTagging? = null
@@ -12,6 +10,7 @@ object Tagger { @@ -12,6 +10,7 @@ object Tagger {
12 get() { 10 get() {
13 return _tagger!! 11 return _tagger!!
14 } 12 }
  13 +
15 fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) { 14 fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) {
16 synchronized(this) { 15 synchronized(this) {
17 if (_tagger != null) { 16 if (_tagger != null) {
@@ -19,7 +18,7 @@ object Tagger { @@ -19,7 +18,7 @@ object Tagger {
19 } 18 }
20 19
21 Log.i(TAG, "Initializing audio tagger") 20 Log.i(TAG, "Initializing audio tagger")
22 - val config = getAudioTaggingConfig(type = 0, numThreads=numThreads)!! 21 + val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!!
23 _tagger = AudioTagging(assetManager, config) 22 _tagger = AudioTagging(assetManager, config)
24 } 23 }
25 } 24 }
@@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button @@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button
33 import androidx.wear.compose.material.MaterialTheme 33 import androidx.wear.compose.material.MaterialTheme
34 import androidx.wear.compose.material.Text 34 import androidx.wear.compose.material.Text
35 import com.k2fsa.sherpa.onnx.AudioEvent 35 import com.k2fsa.sherpa.onnx.AudioEvent
36 -import com.k2fsa.sherpa.onnx.audio.tagging.Tagger 36 +import com.k2fsa.sherpa.onnx.Tagger
37 import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme 37 import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme
38 import kotlin.concurrent.thread 38 import kotlin.concurrent.thread
39 39
@@ -17,7 +17,7 @@ import androidx.activity.compose.setContent @@ -17,7 +17,7 @@ import androidx.activity.compose.setContent
17 import androidx.compose.runtime.Composable 17 import androidx.compose.runtime.Composable
18 import androidx.core.app.ActivityCompat 18 import androidx.core.app.ActivityCompat
19 import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen 19 import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
20 -import com.k2fsa.sherpa.onnx.audio.tagging.Tagger 20 +import com.k2fsa.sherpa.onnx.Tagger
21 21
22 const val TAG = "sherpa-onnx" 22 const val TAG = "sherpa-onnx"
23 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 23 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
1 <resources> 1 <resources>
2 - <string name="app_name">AudioTagging</string> 2 + <string name="app_name">Audio Tagging</string>
3 <!-- 3 <!--
4 This string is used for square devices and overridden by hello_world in 4 This string is used for square devices and overridden by hello_world in
5 values-round/strings.xml for round devices. 5 values-round/strings.xml for round devices.
1 <resources> 1 <resources>
2 - <string name="app_name">Speaker Identification</string> 2 + <string name="app_name">Speaker ID</string>
3 <string name="start">Start recording</string> 3 <string name="start">Start recording</string>
4 <string name="stop">Stop recording</string> 4 <string name="stop">Stop recording</string>
5 <string name="add">Add speaker</string> 5 <string name="add">Add speaker</string>
1 <resources> 1 <resources>
2 - <string name="app_name">SherpaOnnxSpokenLanguageIdentification</string> 2 + <string name="app_name">Language ID</string>
3 </resources> 3 </resources>
1 <resources> 1 <resources>
2 - <string name="app_name">Next-gen Kaldi: TTS</string> 2 + <string name="app_name">TTS</string>
3 <string name="sid_label">Speaker ID</string> 3 <string name="sid_label">Speaker ID</string>
4 <string name="sid_hint">0</string> 4 <string name="sid_hint">0</string>
5 <string name="speed_label">Speech speed (large->fast)</string> 5 <string name="speed_label">Speech speed (large->fast)</string>
1 <resources> 1 <resources>
2 - <string name="app_name">Next-gen Kaldi: TTS</string> 2 + <string name="app_name">TTS Engine</string>
3 </resources> 3 </resources>
1 <resources> 1 <resources>
2 - <string name="app_name">Next-gen Kaldi: SileroVAD</string> 2 + <string name="app_name">VAD</string>
3 3
4 <string name="hint">Click the Start button to play Silero VAD with Next-gen Kaldi.</string> 4 <string name="hint">Click the Start button to play Silero VAD with Next-gen Kaldi.</string>
5 <string name="start">Start</string> 5 <string name="start">Start</string>
1 <resources> 1 <resources>
2 - <string name="app_name">ASR with Next-gen Kaldi</string> 2 + <string name="app_name">VAD-ASR</string>
3 <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. 3 <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
4 \n 4 \n
5 \n\n\n 5 \n\n\n
@@ -46,7 +46,30 @@ def get_models(): @@ -46,7 +46,30 @@ def get_models():
46 ), 46 ),
47 ] 47 ]
48 48
49 - return icefall_models 49 + ced_models = [
  50 + AudioTaggingModel(
  51 + model_name="sherpa-onnx-ced-tiny-audio-tagging-2024-04-19",
  52 + idx=2,
  53 + short_name="ced_tiny",
  54 + ),
  55 + AudioTaggingModel(
  56 + model_name="sherpa-onnx-ced-mini-audio-tagging-2024-04-19",
  57 + idx=3,
  58 + short_name="ced_mini",
  59 + ),
  60 + AudioTaggingModel(
  61 + model_name="sherpa-onnx-ced-small-audio-tagging-2024-04-19",
  62 + idx=4,
  63 + short_name="ced_small",
  64 + ),
  65 + AudioTaggingModel(
  66 + model_name="sherpa-onnx-ced-base-audio-tagging-2024-04-19",
  67 + idx=5,
  68 + short_name="ced_base",
  69 + ),
  70 + ]
  71 +
  72 + return icefall_models + ced_models
50 73
51 74
52 def main(): 75 def main():
  1 +#!/usr/bin/env bash
  2 +#
  3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  4 +
  5 +set -ex
  6 +
  7 +function install_dependencies() {
  8 + pip install -qq torch==2.1.0+cpu torchaudio==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
  9 + pip install -qq onnx onnxruntime==1.17.1
  10 +
  11 + pip install -r ./requirements.txt
  12 +}
  13 +
  14 +git clone https://github.com/RicherMans/CED
  15 +pushd CED
  16 +
  17 +install_dependencies
  18 +
  19 +models=(
  20 +tiny
  21 +mini
  22 +small
  23 +base
  24 +)
  25 +
  26 +for m in ${models[@]}; do
  27 + python3 ./export_onnx.py -m ced_$m
  28 +done
  29 +
  30 +ls -lh *.onnx
  31 +
  32 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
  33 +
  34 +tar xvf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
  35 +rm sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
  36 +src=sherpa-onnx-zipformer-small-audio-tagging-2024-04-15
  37 +
  38 +cat >README.md <<EOF
  39 +# Introduction
  40 +
  41 +Models in this repo are converted from
  42 +https://github.com/RicherMans/CED
  43 +EOF
  44 +
  45 +for m in ${models[@]}; do
  46 + d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19
  47 +
  48 + mkdir -p $d
  49 +
  50 + cp -v README.md $d
  51 + cp -v $src/class_labels_indices.csv $d
  52 + cp -a $src/test_wavs $d
  53 + cp -v ced_$m.onnx $d/model.onnx
  54 + cp -v ced_$m.int8.onnx $d/model.int8.onnx
  55 + echo "----------$m----------"
  56 + ls -lh $d
  57 + echo "----------------------"
  58 + tar cjvf $d.tar.bz2 $d
  59 + mv $d.tar.bz2 ../../..
  60 + mv $d ../../../
  61 +done
  62 +
  63 +rm -rf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15
  64 +
  65 +cd ../../..
  66 +
  67 +ls -lh *.tar.bz2
  68 +echo "======="
  69 +ls -lh
@@ -1223,6 +1223,7 @@ const SherpaOnnxAudioTagging *SherpaOnnxCreateAudioTagging( @@ -1223,6 +1223,7 @@ const SherpaOnnxAudioTagging *SherpaOnnxCreateAudioTagging(
1223 const SherpaOnnxAudioTaggingConfig *config) { 1223 const SherpaOnnxAudioTaggingConfig *config) {
1224 sherpa_onnx::AudioTaggingConfig ac; 1224 sherpa_onnx::AudioTaggingConfig ac;
1225 ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, ""); 1225 ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, "");
  1226 + ac.model.ced = SHERPA_ONNX_OR(config->model.ced, "");
1226 ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); 1227 ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
1227 ac.model.debug = config->model.debug; 1228 ac.model.debug = config->model.debug;
1228 ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); 1229 ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
@@ -1100,6 +1100,7 @@ SHERPA_ONNX_API typedef struct @@ -1100,6 +1100,7 @@ SHERPA_ONNX_API typedef struct
1100 1100
1101 SHERPA_ONNX_API typedef struct SherpaOnnxAudioTaggingModelConfig { 1101 SHERPA_ONNX_API typedef struct SherpaOnnxAudioTaggingModelConfig {
1102 SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer; 1102 SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer;
  1103 + const char *ced;
1103 int32_t num_threads; 1104 int32_t num_threads;
1104 int32_t debug; // true to print debug information of the model 1105 int32_t debug; // true to print debug information of the model
1105 const char *provider; 1106 const char *provider;
@@ -117,6 +117,7 @@ list(APPEND sources @@ -117,6 +117,7 @@ list(APPEND sources
117 audio-tagging-label-file.cc 117 audio-tagging-label-file.cc
118 audio-tagging-model-config.cc 118 audio-tagging-model-config.cc
119 audio-tagging.cc 119 audio-tagging.cc
  120 + offline-ced-model.cc
120 offline-zipformer-audio-tagging-model-config.cc 121 offline-zipformer-audio-tagging-model-config.cc
121 offline-zipformer-audio-tagging-model.cc 122 offline-zipformer-audio-tagging-model.cc
122 ) 123 )
  1 +// sherpa-onnx/csrc/audio-tagging-ced-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
  5 +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <memory>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#if __ANDROID_API__ >= 9
  14 +#include "android/asset_manager.h"
  15 +#include "android/asset_manager_jni.h"
  16 +#endif
  17 +
  18 +#include "sherpa-onnx/csrc/audio-tagging-impl.h"
  19 +#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
  20 +#include "sherpa-onnx/csrc/audio-tagging.h"
  21 +#include "sherpa-onnx/csrc/macros.h"
  22 +#include "sherpa-onnx/csrc/math.h"
  23 +#include "sherpa-onnx/csrc/offline-ced-model.h"
  24 +
  25 +namespace sherpa_onnx {
  26 +
  27 +class AudioTaggingCEDImpl : public AudioTaggingImpl {
  28 + public:
  29 + explicit AudioTaggingCEDImpl(const AudioTaggingConfig &config)
  30 + : config_(config), model_(config.model), labels_(config.labels) {
  31 + if (model_.NumEventClasses() != labels_.NumEventClasses()) {
  32 + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
  33 + model_.NumEventClasses(), labels_.NumEventClasses());
  34 + exit(-1);
  35 + }
  36 + }
  37 +
  38 +#if __ANDROID_API__ >= 9
  39 + explicit AudioTaggingCEDImpl(AAssetManager *mgr,
  40 + const AudioTaggingConfig &config)
  41 + : config_(config),
  42 + model_(mgr, config.model),
  43 + labels_(mgr, config.labels) {
  44 + if (model_.NumEventClasses() != labels_.NumEventClasses()) {
  45 + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
  46 + model_.NumEventClasses(), labels_.NumEventClasses());
  47 + exit(-1);
  48 + }
  49 + }
  50 +#endif
  51 +
  52 + std::unique_ptr<OfflineStream> CreateStream() const override {
  53 + return std::make_unique<OfflineStream>(CEDTag{});
  54 + }
  55 +
  56 + std::vector<AudioEvent> Compute(OfflineStream *s,
  57 + int32_t top_k = -1) const override {
  58 + if (top_k < 0) {
  59 + top_k = config_.top_k;
  60 + }
  61 +
  62 + int32_t num_event_classes = model_.NumEventClasses();
  63 +
  64 + if (top_k > num_event_classes) {
  65 + top_k = num_event_classes;
  66 + }
  67 +
  68 + auto memory_info =
  69 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  70 +
  71 + // WARNING(fangjun): It is fixed to 64 for CED models
  72 + int32_t feat_dim = 64;
  73 + std::vector<float> f = s->GetFrames();
  74 +
  75 + int32_t num_frames = f.size() / feat_dim;
  76 + assert(feat_dim * num_frames == f.size());
  77 +
  78 + std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
  79 +
  80 + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
  81 + shape.data(), shape.size());
  82 +
  83 + Ort::Value probs = model_.Forward(std::move(x));
  84 +
  85 + const float *p = probs.GetTensorData<float>();
  86 +
  87 + std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
  88 +
  89 + std::vector<AudioEvent> ans(top_k);
  90 +
  91 + int32_t i = 0;
  92 +
  93 + for (int32_t index : top_k_indexes) {
  94 + ans[i].name = labels_.GetEventName(index);
  95 + ans[i].index = index;
  96 + ans[i].prob = p[index];
  97 + i += 1;
  98 + }
  99 +
  100 + return ans;
  101 + }
  102 +
  103 + private:
  104 + AudioTaggingConfig config_;
  105 + OfflineCEDModel model_;
  106 + AudioTaggingLabels labels_;
  107 +};
  108 +
  109 +} // namespace sherpa_onnx
  110 +
  111 +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "android/asset_manager_jni.h" 11 #include "android/asset_manager_jni.h"
12 #endif 12 #endif
13 13
  14 +#include "sherpa-onnx/csrc/audio-tagging-ced-impl.h"
14 #include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h" 15 #include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
15 #include "sherpa-onnx/csrc/macros.h" 16 #include "sherpa-onnx/csrc/macros.h"
16 17
@@ -20,6 +21,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create( @@ -20,6 +21,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
20 const AudioTaggingConfig &config) { 21 const AudioTaggingConfig &config) {
21 if (!config.model.zipformer.model.empty()) { 22 if (!config.model.zipformer.model.empty()) {
22 return std::make_unique<AudioTaggingZipformerImpl>(config); 23 return std::make_unique<AudioTaggingZipformerImpl>(config);
  24 + } else if (!config.model.ced.empty()) {
  25 + return std::make_unique<AudioTaggingCEDImpl>(config);
23 } 26 }
24 27
25 SHERPA_ONNX_LOG( 28 SHERPA_ONNX_LOG(
@@ -32,6 +35,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create( @@ -32,6 +35,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
32 AAssetManager *mgr, const AudioTaggingConfig &config) { 35 AAssetManager *mgr, const AudioTaggingConfig &config) {
33 if (!config.model.zipformer.model.empty()) { 36 if (!config.model.zipformer.model.empty()) {
34 return std::make_unique<AudioTaggingZipformerImpl>(mgr, config); 37 return std::make_unique<AudioTaggingZipformerImpl>(mgr, config);
  38 + } else if (!config.model.ced.empty()) {
  39 + return std::make_unique<AudioTaggingCEDImpl>(mgr, config);
35 } 40 }
36 41
37 SHERPA_ONNX_LOG( 42 SHERPA_ONNX_LOG(
@@ -4,11 +4,18 @@ @@ -4,11 +4,18 @@
4 4
5 #include "sherpa-onnx/csrc/audio-tagging-model-config.h" 5 #include "sherpa-onnx/csrc/audio-tagging-model-config.h"
6 6
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
7 namespace sherpa_onnx { 10 namespace sherpa_onnx {
8 11
9 void AudioTaggingModelConfig::Register(ParseOptions *po) { 12 void AudioTaggingModelConfig::Register(ParseOptions *po) {
10 zipformer.Register(po); 13 zipformer.Register(po);
11 14
  15 + po->Register("ced-model", &ced,
  16 + "Path to CED model. Only need to pass one of --zipformer-model "
  17 + "or --ced-model");
  18 +
12 po->Register("num-threads", &num_threads, 19 po->Register("num-threads", &num_threads,
13 "Number of threads to run the neural network"); 20 "Number of threads to run the neural network");
14 21
@@ -24,6 +31,16 @@ bool AudioTaggingModelConfig::Validate() const { @@ -24,6 +31,16 @@ bool AudioTaggingModelConfig::Validate() const {
24 return false; 31 return false;
25 } 32 }
26 33
  34 + if (!ced.empty() && !FileExists(ced)) {
  35 + SHERPA_ONNX_LOGE("CED model file %s does not exist", ced.c_str());
  36 + return false;
  37 + }
  38 +
  39 + if (zipformer.model.empty() && ced.empty()) {
  40 + SHERPA_ONNX_LOGE("Please provide either --zipformer-model or --ced-model");
  41 + return false;
  42 + }
  43 +
27 return true; 44 return true;
28 } 45 }
29 46
@@ -32,6 +49,7 @@ std::string AudioTaggingModelConfig::ToString() const { @@ -32,6 +49,7 @@ std::string AudioTaggingModelConfig::ToString() const {
32 49
33 os << "AudioTaggingModelConfig("; 50 os << "AudioTaggingModelConfig(";
34 os << "zipformer=" << zipformer.ToString() << ", "; 51 os << "zipformer=" << zipformer.ToString() << ", ";
  52 + os << "ced=\"" << ced << "\", ";
35 os << "num_threads=" << num_threads << ", "; 53 os << "num_threads=" << num_threads << ", ";
36 os << "debug=" << (debug ? "True" : "False") << ", "; 54 os << "debug=" << (debug ? "True" : "False") << ", ";
37 os << "provider=\"" << provider << "\")"; 55 os << "provider=\"" << provider << "\")";
@@ -13,6 +13,7 @@ namespace sherpa_onnx { @@ -13,6 +13,7 @@ namespace sherpa_onnx {
13 13
14 struct AudioTaggingModelConfig { 14 struct AudioTaggingModelConfig {
15 struct OfflineZipformerAudioTaggingModelConfig zipformer; 15 struct OfflineZipformerAudioTaggingModelConfig zipformer;
  16 + std::string ced;
16 17
17 int32_t num_threads = 1; 18 int32_t num_threads = 1;
18 bool debug = false; 19 bool debug = false;
@@ -22,8 +23,10 @@ struct AudioTaggingModelConfig { @@ -22,8 +23,10 @@ struct AudioTaggingModelConfig {
22 23
23 AudioTaggingModelConfig( 24 AudioTaggingModelConfig(
24 const OfflineZipformerAudioTaggingModelConfig &zipformer, 25 const OfflineZipformerAudioTaggingModelConfig &zipformer,
25 - int32_t num_threads, bool debug, const std::string &provider) 26 + const std::string &ced, int32_t num_threads, bool debug,
  27 + const std::string &provider)
26 : zipformer(zipformer), 28 : zipformer(zipformer),
  29 + ced(ced),
27 num_threads(num_threads), 30 num_threads(num_threads),
28 debug(debug), 31 debug(debug),
29 provider(provider) {} 32 provider(provider) {}
@@ -4,6 +4,8 @@ @@ -4,6 +4,8 @@
4 #ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ 4 #ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
5 #define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ 5 #define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
6 6
  7 +#include <assert.h>
  8 +
7 #include <memory> 9 #include <memory>
8 #include <utility> 10 #include <utility>
9 #include <vector> 11 #include <vector>
@@ -72,6 +74,8 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl { @@ -72,6 +74,8 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl {
72 74
73 int32_t num_frames = f.size() / feat_dim; 75 int32_t num_frames = f.size() / feat_dim;
74 76
  77 + assert(feat_dim * num_frames == f.size());
  78 +
75 std::array<int64_t, 3> shape = {1, num_frames, feat_dim}; 79 std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
76 80
77 Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), 81 Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
@@ -24,7 +24,8 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { @@ -24,7 +24,8 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
24 "inside the feature extractor"); 24 "inside the feature extractor");
25 25
26 po->Register("feat-dim", &feature_dim, 26 po->Register("feat-dim", &feature_dim,
27 - "Feature dimension. Must match the one expected by the model."); 27 + "Feature dimension. Must match the one expected by the model. "
  28 + "Not used by whisper and CED models");
28 29
29 po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins"); 30 po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins");
30 31
  1 +// sherpa-onnx/csrc/offline-ced-model.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-ced-model.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/onnx-utils.h"
  11 +#include "sherpa-onnx/csrc/session.h"
  12 +#include "sherpa-onnx/csrc/text-utils.h"
  13 +#include "sherpa-onnx/csrc/transpose.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OfflineCEDModel::Impl {
  18 + public:
  19 + explicit Impl(const AudioTaggingModelConfig &config)
  20 + : config_(config),
  21 + env_(ORT_LOGGING_LEVEL_ERROR),
  22 + sess_opts_(GetSessionOptions(config)),
  23 + allocator_{} {
  24 + auto buf = ReadFile(config_.ced);
  25 + Init(buf.data(), buf.size());
  26 + }
  27 +
  28 +#if __ANDROID_API__ >= 9
  29 + Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config)
  30 + : config_(config),
  31 + env_(ORT_LOGGING_LEVEL_ERROR),
  32 + sess_opts_(GetSessionOptions(config)),
  33 + allocator_{} {
  34 + auto buf = ReadFile(mgr, config_.ced);
  35 + Init(buf.data(), buf.size());
  36 + }
  37 +#endif
  38 +
  39 + Ort::Value Forward(Ort::Value features) {
  40 + features = Transpose12(allocator_, &features);
  41 +
  42 + auto ans = sess_->Run({}, input_names_ptr_.data(), &features, 1,
  43 + output_names_ptr_.data(), output_names_ptr_.size());
  44 + return std::move(ans[0]);
  45 + }
  46 +
  47 + int32_t NumEventClasses() const { return num_event_classes_; }
  48 +
  49 + OrtAllocator *Allocator() const { return allocator_; }
  50 +
  51 + private:
  52 + void Init(void *model_data, size_t model_data_length) {
  53 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  54 + sess_opts_);
  55 +
  56 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  57 +
  58 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  59 +
  60 + // get meta data
  61 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  62 + if (config_.debug) {
  63 + std::ostringstream os;
  64 + PrintModelMetadata(os, meta_data);
  65 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  66 + }
  67 +
  68 + // get num_event_classes from the output[0].shape,
  69 + // which is (N, num_event_classes)
  70 + num_event_classes_ =
  71 + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
  72 + }
  73 +
  74 + private:
  75 + AudioTaggingModelConfig config_;
  76 + Ort::Env env_;
  77 + Ort::SessionOptions sess_opts_;
  78 + Ort::AllocatorWithDefaultOptions allocator_;
  79 +
  80 + std::unique_ptr<Ort::Session> sess_;
  81 +
  82 + std::vector<std::string> input_names_;
  83 + std::vector<const char *> input_names_ptr_;
  84 +
  85 + std::vector<std::string> output_names_;
  86 + std::vector<const char *> output_names_ptr_;
  87 +
  88 + int32_t num_event_classes_ = 0;
  89 +};
  90 +
  91 +OfflineCEDModel::OfflineCEDModel(const AudioTaggingModelConfig &config)
  92 + : impl_(std::make_unique<Impl>(config)) {}
  93 +
  94 +#if __ANDROID_API__ >= 9
  95 +OfflineCEDModel::OfflineCEDModel(AAssetManager *mgr,
  96 + const AudioTaggingModelConfig &config)
  97 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  98 +#endif
  99 +
  100 +OfflineCEDModel::~OfflineCEDModel() = default;
  101 +
  102 +Ort::Value OfflineCEDModel::Forward(Ort::Value features) const {
  103 + return impl_->Forward(std::move(features));
  104 +}
  105 +
  106 +int32_t OfflineCEDModel::NumEventClasses() const {
  107 + return impl_->NumEventClasses();
  108 +}
  109 +
  110 +OrtAllocator *OfflineCEDModel::Allocator() const { return impl_->Allocator(); }
  111 +
  112 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-ced-model.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
  6 +#include <memory>
  7 +#include <utility>
  8 +
  9 +#if __ANDROID_API__ >= 9
  10 +#include "android/asset_manager.h"
  11 +#include "android/asset_manager_jni.h"
  12 +#endif
  13 +
  14 +#include "onnxruntime_cxx_api.h" // NOLINT
  15 +#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +/** This class implements the CED model from
  20 + * https://github.com/RicherMans/CED/blob/main/export_onnx.py
  21 + */
  22 +class OfflineCEDModel {
  23 + public:
  24 + explicit OfflineCEDModel(const AudioTaggingModelConfig &config);
  25 +
  26 +#if __ANDROID_API__ >= 9
  27 + OfflineCEDModel(AAssetManager *mgr, const AudioTaggingModelConfig &config);
  28 +#endif
  29 +
  30 + ~OfflineCEDModel();
  31 +
  32 + /** Run the forward method of the model.
  33 + *
  34 + * @param features A tensor of shape (N, T, C).
  35 + *
  36 + * @return Return a tensor
  37 + * - probs: A 2-D tensor of shape (N, num_event_classes).
  38 + */
  39 + Ort::Value Forward(Ort::Value features) const;
  40 +
  41 + /** Return the number of event classes of the model
  42 + */
  43 + int32_t NumEventClasses() const;
  44 +
  45 + /** Return an allocator for allocating memory
  46 + */
  47 + OrtAllocator *Allocator() const;
  48 +
  49 + private:
  50 + class Impl;
  51 + std::unique_ptr<Impl> impl_;
  52 +};
  53 +
  54 +} // namespace sherpa_onnx
  55 +
  56 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
@@ -92,15 +92,32 @@ class OfflineStream::Impl { @@ -92,15 +92,32 @@ class OfflineStream::Impl {
92 fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 92 fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
93 } 93 }
94 94
95 - Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)  
96 - : context_graph_(context_graph) { 95 + explicit Impl(WhisperTag /*tag*/) {
97 config_.normalize_samples = true; 96 config_.normalize_samples = true;
98 opts_.frame_opts.samp_freq = 16000; 97 opts_.frame_opts.samp_freq = 16000;
99 - opts_.mel_opts.num_bins = 80; 98 + opts_.mel_opts.num_bins = 80; // not used
100 whisper_fbank_ = 99 whisper_fbank_ =
101 std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts); 100 std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
102 } 101 }
103 102
  103 + explicit Impl(CEDTag /*tag*/) {
  104 + // see
  105 + // https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
  106 +
  107 + opts_.frame_opts.frame_length_ms = 32;
  108 + opts_.frame_opts.dither = 0;
  109 + opts_.frame_opts.preemph_coeff = 0;
  110 + opts_.frame_opts.remove_dc_offset = false;
  111 + opts_.frame_opts.window_type = "hann";
  112 + opts_.frame_opts.snip_edges = false;
  113 +
  114 + opts_.frame_opts.samp_freq = 16000; // fixed to 16000
  115 + opts_.mel_opts.num_bins = 64;
  116 + opts_.mel_opts.high_freq = 8000;
  117 +
  118 + fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
  119 + }
  120 +
104 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { 121 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
105 if (config_.normalize_samples) { 122 if (config_.normalize_samples) {
106 AcceptWaveformImpl(sampling_rate, waveform, n); 123 AcceptWaveformImpl(sampling_rate, waveform, n);
@@ -233,9 +250,10 @@ OfflineStream::OfflineStream( @@ -233,9 +250,10 @@ OfflineStream::OfflineStream(
233 ContextGraphPtr context_graph /*= nullptr*/) 250 ContextGraphPtr context_graph /*= nullptr*/)
234 : impl_(std::make_unique<Impl>(config, context_graph)) {} 251 : impl_(std::make_unique<Impl>(config, context_graph)) {}
235 252
236 -OfflineStream::OfflineStream(WhisperTag tag,  
237 - ContextGraphPtr context_graph /*= {}*/)  
238 - : impl_(std::make_unique<Impl>(tag, context_graph)) {} 253 +OfflineStream::OfflineStream(WhisperTag tag)
  254 + : impl_(std::make_unique<Impl>(tag)) {}
  255 +
  256 +OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
239 257
240 OfflineStream::~OfflineStream() = default; 258 OfflineStream::~OfflineStream() = default;
241 259
@@ -67,13 +67,15 @@ struct OfflineFeatureExtractorConfig { @@ -67,13 +67,15 @@ struct OfflineFeatureExtractorConfig {
67 }; 67 };
68 68
69 struct WhisperTag {}; 69 struct WhisperTag {};
  70 +struct CEDTag {};
70 71
71 class OfflineStream { 72 class OfflineStream {
72 public: 73 public:
73 explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, 74 explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
74 ContextGraphPtr context_graph = {}); 75 ContextGraphPtr context_graph = {});
75 76
76 - explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {}); 77 + explicit OfflineStream(WhisperTag tag);
  78 + explicit OfflineStream(CEDTag tag);
77 ~OfflineStream(); 79 ~OfflineStream();
78 80
79 /** 81 /**
@@ -31,6 +31,12 @@ static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) { @@ -31,6 +31,12 @@ static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) {
31 ans.model.zipformer.model = p; 31 ans.model.zipformer.model = p;
32 env->ReleaseStringUTFChars(s, p); 32 env->ReleaseStringUTFChars(s, p);
33 33
  34 + fid = env->GetFieldID(model_cls, "ced", "Ljava/lang/String;");
  35 + s = (jstring)env->GetObjectField(model, fid);
  36 + p = env->GetStringUTFChars(s, nullptr);
  37 + ans.model.ced = p;
  38 + env->ReleaseStringUTFChars(s, p);
  39 +
34 fid = env->GetFieldID(model_cls, "numThreads", "I"); 40 fid = env->GetFieldID(model_cls, "numThreads", "I");
35 ans.model.num_threads = env->GetIntField(model, fid); 41 ans.model.num_threads = env->GetIntField(model, fid);
36 42
@@ -27,10 +27,11 @@ static void PybindAudioTaggingModelConfig(py::module *m) { @@ -27,10 +27,11 @@ static void PybindAudioTaggingModelConfig(py::module *m) {
27 27
28 py::class_<PyClass>(*m, "AudioTaggingModelConfig") 28 py::class_<PyClass>(*m, "AudioTaggingModelConfig")
29 .def(py::init<>()) 29 .def(py::init<>())
30 - .def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t,  
31 - bool, const std::string &>(),  
32 - py::arg("zipformer"), py::arg("num_threads") = 1,  
33 - py::arg("debug") = false, py::arg("provider") = "cpu") 30 + .def(py::init<const OfflineZipformerAudioTaggingModelConfig &,
  31 + const std::string &, int32_t, bool, const std::string &>(),
  32 + py::arg("zipformer"), py::arg("ced") = "",
  33 + py::arg("num_threads") = 1, py::arg("debug") = false,
  34 + py::arg("provider") = "cpu")
34 .def_readwrite("zipformer", &PyClass::zipformer) 35 .def_readwrite("zipformer", &PyClass::zipformer)
35 .def_readwrite("num_threads", &PyClass::num_threads) 36 .def_readwrite("num_threads", &PyClass::num_threads)
36 .def_readwrite("debug", &PyClass::debug) 37 .def_readwrite("debug", &PyClass::debug)