正在显示
33 个修改的文件
包含
605 行增加
和
46 行删除
.github/workflows/export-ced-to-onnx.yaml
0 → 100644
| 1 | +name: export-ced-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-ced-to-onnx-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-ced-to-onnx: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: export ced | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [ubuntu-latest] | ||
| 19 | + python-version: ["3.8"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v5 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Run | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + cd scripts/ced | ||
| 33 | + ./run.sh | ||
| 34 | + | ||
| 35 | + - name: Release | ||
| 36 | + uses: svenstaro/upload-release-action@v2 | ||
| 37 | + with: | ||
| 38 | + file_glob: true | ||
| 39 | + file: ./*.tar.bz2 | ||
| 40 | + overwrite: true | ||
| 41 | + repo_name: k2-fsa/sherpa-onnx | ||
| 42 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 43 | + tag: audio-tagging-models | ||
| 44 | + | ||
| 45 | + - name: Publish to huggingface | ||
| 46 | + env: | ||
| 47 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 48 | + uses: nick-fields/retry@v3 | ||
| 49 | + with: | ||
| 50 | + max_attempts: 20 | ||
| 51 | + timeout_seconds: 200 | ||
| 52 | + shell: bash | ||
| 53 | + command: | | ||
| 54 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 55 | + git config --global user.name "Fangjun Kuang" | ||
| 56 | + | ||
| 57 | + models=( | ||
| 58 | + tiny | ||
| 59 | + mini | ||
| 60 | + small | ||
| 61 | + base | ||
| 62 | + ) | ||
| 63 | + | ||
| 64 | + for m in ${models[@]}; do | ||
| 65 | + rm -rf huggingface | ||
| 66 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 67 | + d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19 | ||
| 68 | + git clone https://huggingface.co/k2-fsa/$d huggingface | ||
| 69 | + mv -v $d/* huggingface | ||
| 70 | + cd huggingface | ||
| 71 | + git lfs track "*.onnx" | ||
| 72 | + git status | ||
| 73 | + git add . | ||
| 74 | + git status | ||
| 75 | + git commit -m "first commit" | ||
| 76 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/$d main | ||
| 77 | + cd .. | ||
| 78 | + done |
| 1 | <resources> | 1 | <resources> |
| 2 | - <string name="app_name">ASR with Next-gen Kaldi</string> | 2 | + <string name="app_name">ASR</string> |
| 3 | <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. | 3 | <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. |
| 4 | \n | 4 | \n |
| 5 | \n\n\n | 5 | \n\n\n |
| 1 | <resources> | 1 | <resources> |
| 2 | - <string name="app_name">ASR with Next-gen Kaldi</string> | 2 | + <string name="app_name">ASR2pass </string> |
| 3 | <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. | 3 | <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. |
| 4 | \n | 4 | \n |
| 5 | \n\n\n | 5 | \n\n\n |
android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt
| 1 | package com.k2fsa.sherpa.onnx | 1 | package com.k2fsa.sherpa.onnx |
| 2 | 2 | ||
| 3 | import android.content.res.AssetManager | 3 | import android.content.res.AssetManager |
| 4 | -import android.util.Log | ||
| 5 | 4 | ||
| 6 | -private val TAG = "sherpa-onnx" | 5 | +const val TAG = "sherpa-onnx" |
| 7 | 6 | ||
| 8 | data class OfflineZipformerAudioTaggingModelConfig( | 7 | data class OfflineZipformerAudioTaggingModelConfig( |
| 9 | - var model: String, | 8 | + var model: String = "", |
| 10 | ) | 9 | ) |
| 11 | 10 | ||
| 12 | data class AudioTaggingModelConfig( | 11 | data class AudioTaggingModelConfig( |
| 13 | - var zipformer: OfflineZipformerAudioTaggingModelConfig, | 12 | + var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(), |
| 13 | + var ced: String = "", | ||
| 14 | var numThreads: Int = 1, | 14 | var numThreads: Int = 1, |
| 15 | var debug: Boolean = false, | 15 | var debug: Boolean = false, |
| 16 | var provider: String = "cpu", | 16 | var provider: String = "cpu", |
| @@ -103,7 +103,7 @@ class AudioTagging( | @@ -103,7 +103,7 @@ class AudioTagging( | ||
| 103 | // | 103 | // |
| 104 | // See also | 104 | // See also |
| 105 | // https://k2-fsa.github.io/sherpa/onnx/audio-tagging/ | 105 | // https://k2-fsa.github.io/sherpa/onnx/audio-tagging/ |
| 106 | -fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? { | 106 | +fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? { |
| 107 | when (type) { | 107 | when (type) { |
| 108 | 0 -> { | 108 | 0 -> { |
| 109 | val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15" | 109 | val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15" |
| @@ -123,7 +123,46 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? { | @@ -123,7 +123,46 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? { | ||
| 123 | return AudioTaggingConfig( | 123 | return AudioTaggingConfig( |
| 124 | model = AudioTaggingModelConfig( | 124 | model = AudioTaggingModelConfig( |
| 125 | zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"), | 125 | zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"), |
| 126 | - numThreads = 1, | 126 | + numThreads = numThreads, |
| 127 | + debug = true, | ||
| 128 | + ), | ||
| 129 | + labels = "$modelDir/class_labels_indices.csv", | ||
| 130 | + topK = 3, | ||
| 131 | + ) | ||
| 132 | + } | ||
| 133 | + | ||
| 134 | + 2 -> { | ||
| 135 | + val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19" | ||
| 136 | + return AudioTaggingConfig( | ||
| 137 | + model = AudioTaggingModelConfig( | ||
| 138 | + ced = "$modelDir/model.int8.onnx", | ||
| 139 | + numThreads = numThreads, | ||
| 140 | + debug = true, | ||
| 141 | + ), | ||
| 142 | + labels = "$modelDir/class_labels_indices.csv", | ||
| 143 | + topK = 3, | ||
| 144 | + ) | ||
| 145 | + } | ||
| 146 | + | ||
| 147 | + 3 -> { | ||
| 148 | + val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19" | ||
| 149 | + return AudioTaggingConfig( | ||
| 150 | + model = AudioTaggingModelConfig( | ||
| 151 | + ced = "$modelDir/model.int8.onnx", | ||
| 152 | + numThreads = numThreads, | ||
| 153 | + debug = true, | ||
| 154 | + ), | ||
| 155 | + labels = "$modelDir/class_labels_indices.csv", | ||
| 156 | + topK = 3, | ||
| 157 | + ) | ||
| 158 | + } | ||
| 159 | + | ||
| 160 | + 4 -> { | ||
| 161 | + val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19" | ||
| 162 | + return AudioTaggingConfig( | ||
| 163 | + model = AudioTaggingModelConfig( | ||
| 164 | + ced = "$modelDir/model.int8.onnx", | ||
| 165 | + numThreads = numThreads, | ||
| 127 | debug = true, | 166 | debug = true, |
| 128 | ), | 167 | ), |
| 129 | labels = "$modelDir/class_labels_indices.csv", | 168 | labels = "$modelDir/class_labels_indices.csv", |
| @@ -131,6 +170,18 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? { | @@ -131,6 +170,18 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? { | ||
| 131 | ) | 170 | ) |
| 132 | } | 171 | } |
| 133 | 172 | ||
| 173 | + 5 -> { | ||
| 174 | + val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19" | ||
| 175 | + return AudioTaggingConfig( | ||
| 176 | + model = AudioTaggingModelConfig( | ||
| 177 | + ced = "$modelDir/model.int8.onnx", | ||
| 178 | + numThreads = numThreads, | ||
| 179 | + debug = true, | ||
| 180 | + ), | ||
| 181 | + labels = "$modelDir/class_labels_indices.csv", | ||
| 182 | + topK = 3, | ||
| 183 | + ) | ||
| 184 | + } | ||
| 134 | } | 185 | } |
| 135 | 186 | ||
| 136 | return null | 187 | return null |
| @@ -3,24 +3,15 @@ | @@ -3,24 +3,15 @@ | ||
| 3 | package com.k2fsa.sherpa.onnx.audio.tagging | 3 | package com.k2fsa.sherpa.onnx.audio.tagging |
| 4 | 4 | ||
| 5 | import android.Manifest | 5 | import android.Manifest |
| 6 | - | ||
| 7 | import android.app.Activity | 6 | import android.app.Activity |
| 8 | import android.content.pm.PackageManager | 7 | import android.content.pm.PackageManager |
| 9 | import android.media.AudioFormat | 8 | import android.media.AudioFormat |
| 10 | import android.media.AudioRecord | 9 | import android.media.AudioRecord |
| 11 | -import androidx.compose.foundation.lazy.items | ||
| 12 | import android.media.MediaRecorder | 10 | import android.media.MediaRecorder |
| 13 | import android.util.Log | 11 | import android.util.Log |
| 14 | import androidx.compose.foundation.ExperimentalFoundationApi | 12 | import androidx.compose.foundation.ExperimentalFoundationApi |
| 15 | -import androidx.compose.foundation.background | ||
| 16 | import androidx.compose.foundation.layout.Arrangement | 13 | import androidx.compose.foundation.layout.Arrangement |
| 17 | import androidx.compose.foundation.layout.Box | 14 | import androidx.compose.foundation.layout.Box |
| 18 | -import androidx.compose.material3.CenterAlignedTopAppBar | ||
| 19 | -import androidx.compose.runtime.Composable | ||
| 20 | -import androidx.compose.material3.Scaffold | ||
| 21 | -import androidx.compose.material3.TopAppBarDefaults | ||
| 22 | -import androidx.compose.material3.MaterialTheme | ||
| 23 | -import androidx.compose.material3.Text | ||
| 24 | import androidx.compose.foundation.layout.Column | 15 | import androidx.compose.foundation.layout.Column |
| 25 | import androidx.compose.foundation.layout.PaddingValues | 16 | import androidx.compose.foundation.layout.PaddingValues |
| 26 | import androidx.compose.foundation.layout.Row | 17 | import androidx.compose.foundation.layout.Row |
| @@ -30,10 +21,17 @@ import androidx.compose.foundation.layout.fillMaxWidth | @@ -30,10 +21,17 @@ import androidx.compose.foundation.layout.fillMaxWidth | ||
| 30 | import androidx.compose.foundation.layout.height | 21 | import androidx.compose.foundation.layout.height |
| 31 | import androidx.compose.foundation.layout.padding | 22 | import androidx.compose.foundation.layout.padding |
| 32 | import androidx.compose.foundation.lazy.LazyColumn | 23 | import androidx.compose.foundation.lazy.LazyColumn |
| 24 | +import androidx.compose.foundation.lazy.items | ||
| 33 | import androidx.compose.material3.Button | 25 | import androidx.compose.material3.Button |
| 26 | +import androidx.compose.material3.CenterAlignedTopAppBar | ||
| 34 | import androidx.compose.material3.ExperimentalMaterial3Api | 27 | import androidx.compose.material3.ExperimentalMaterial3Api |
| 28 | +import androidx.compose.material3.MaterialTheme | ||
| 29 | +import androidx.compose.material3.Scaffold | ||
| 35 | import androidx.compose.material3.Slider | 30 | import androidx.compose.material3.Slider |
| 36 | import androidx.compose.material3.Surface | 31 | import androidx.compose.material3.Surface |
| 32 | +import androidx.compose.material3.Text | ||
| 33 | +import androidx.compose.material3.TopAppBarDefaults | ||
| 34 | +import androidx.compose.runtime.Composable | ||
| 37 | import androidx.compose.runtime.getValue | 35 | import androidx.compose.runtime.getValue |
| 38 | import androidx.compose.runtime.mutableStateListOf | 36 | import androidx.compose.runtime.mutableStateListOf |
| 39 | import androidx.compose.runtime.mutableStateOf | 37 | import androidx.compose.runtime.mutableStateOf |
| @@ -41,7 +39,6 @@ import androidx.compose.runtime.remember | @@ -41,7 +39,6 @@ import androidx.compose.runtime.remember | ||
| 41 | import androidx.compose.runtime.setValue | 39 | import androidx.compose.runtime.setValue |
| 42 | import androidx.compose.ui.Alignment | 40 | import androidx.compose.ui.Alignment |
| 43 | import androidx.compose.ui.Modifier | 41 | import androidx.compose.ui.Modifier |
| 44 | -import androidx.compose.ui.graphics.Color | ||
| 45 | import androidx.compose.ui.platform.LocalContext | 42 | import androidx.compose.ui.platform.LocalContext |
| 46 | import androidx.compose.ui.text.font.FontWeight | 43 | import androidx.compose.ui.text.font.FontWeight |
| 47 | import androidx.compose.ui.text.style.TextAlign | 44 | import androidx.compose.ui.text.style.TextAlign |
| @@ -49,6 +46,7 @@ import androidx.compose.ui.unit.dp | @@ -49,6 +46,7 @@ import androidx.compose.ui.unit.dp | ||
| 49 | import androidx.compose.ui.unit.sp | 46 | import androidx.compose.ui.unit.sp |
| 50 | import androidx.core.app.ActivityCompat | 47 | import androidx.core.app.ActivityCompat |
| 51 | import com.k2fsa.sherpa.onnx.AudioEvent | 48 | import com.k2fsa.sherpa.onnx.AudioEvent |
| 49 | +import com.k2fsa.sherpa.onnx.Tagger | ||
| 52 | import kotlin.concurrent.thread | 50 | import kotlin.concurrent.thread |
| 53 | 51 | ||
| 54 | 52 |
android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/MainActivity.kt
| @@ -13,6 +13,7 @@ import androidx.compose.material3.Surface | @@ -13,6 +13,7 @@ import androidx.compose.material3.Surface | ||
| 13 | import androidx.compose.runtime.Composable | 13 | import androidx.compose.runtime.Composable |
| 14 | import androidx.compose.ui.Modifier | 14 | import androidx.compose.ui.Modifier |
| 15 | import androidx.core.app.ActivityCompat | 15 | import androidx.core.app.ActivityCompat |
| 16 | +import com.k2fsa.sherpa.onnx.Tagger | ||
| 16 | import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme | 17 | import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme |
| 17 | 18 | ||
| 18 | const val TAG = "sherpa-onnx" | 19 | const val TAG = "sherpa-onnx" |
| 1 | -package com.k2fsa.sherpa.onnx.audio.tagging | 1 | +package com.k2fsa.sherpa.onnx |
| 2 | 2 | ||
| 3 | import android.content.res.AssetManager | 3 | import android.content.res.AssetManager |
| 4 | import android.util.Log | 4 | import android.util.Log |
| 5 | -import com.k2fsa.sherpa.onnx.AudioTagging | ||
| 6 | -import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.TAG | ||
| 7 | -import com.k2fsa.sherpa.onnx.getAudioTaggingConfig | 5 | + |
| 8 | 6 | ||
| 9 | object Tagger { | 7 | object Tagger { |
| 10 | private var _tagger: AudioTagging? = null | 8 | private var _tagger: AudioTagging? = null |
| @@ -12,6 +10,7 @@ object Tagger { | @@ -12,6 +10,7 @@ object Tagger { | ||
| 12 | get() { | 10 | get() { |
| 13 | return _tagger!! | 11 | return _tagger!! |
| 14 | } | 12 | } |
| 13 | + | ||
| 15 | fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) { | 14 | fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) { |
| 16 | synchronized(this) { | 15 | synchronized(this) { |
| 17 | if (_tagger != null) { | 16 | if (_tagger != null) { |
| @@ -19,7 +18,7 @@ object Tagger { | @@ -19,7 +18,7 @@ object Tagger { | ||
| 19 | } | 18 | } |
| 20 | 19 | ||
| 21 | Log.i(TAG, "Initializing audio tagger") | 20 | Log.i(TAG, "Initializing audio tagger") |
| 22 | - val config = getAudioTaggingConfig(type = 0, numThreads=numThreads)!! | 21 | + val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!! |
| 23 | _tagger = AudioTagging(assetManager, config) | 22 | _tagger = AudioTagging(assetManager, config) |
| 24 | } | 23 | } |
| 25 | } | 24 | } |
| @@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button | @@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button | ||
| 33 | import androidx.wear.compose.material.MaterialTheme | 33 | import androidx.wear.compose.material.MaterialTheme |
| 34 | import androidx.wear.compose.material.Text | 34 | import androidx.wear.compose.material.Text |
| 35 | import com.k2fsa.sherpa.onnx.AudioEvent | 35 | import com.k2fsa.sherpa.onnx.AudioEvent |
| 36 | -import com.k2fsa.sherpa.onnx.audio.tagging.Tagger | 36 | +import com.k2fsa.sherpa.onnx.Tagger |
| 37 | import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme | 37 | import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme |
| 38 | import kotlin.concurrent.thread | 38 | import kotlin.concurrent.thread |
| 39 | 39 |
| @@ -17,7 +17,7 @@ import androidx.activity.compose.setContent | @@ -17,7 +17,7 @@ import androidx.activity.compose.setContent | ||
| 17 | import androidx.compose.runtime.Composable | 17 | import androidx.compose.runtime.Composable |
| 18 | import androidx.core.app.ActivityCompat | 18 | import androidx.core.app.ActivityCompat |
| 19 | import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen | 19 | import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen |
| 20 | -import com.k2fsa.sherpa.onnx.audio.tagging.Tagger | 20 | +import com.k2fsa.sherpa.onnx.Tagger |
| 21 | 21 | ||
| 22 | const val TAG = "sherpa-onnx" | 22 | const val TAG = "sherpa-onnx" |
| 23 | private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 | 23 | private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 |
| 1 | <resources> | 1 | <resources> |
| 2 | - <string name="app_name">AudioTagging</string> | 2 | + <string name="app_name">Audio Tagging</string> |
| 3 | <!-- | 3 | <!-- |
| 4 | This string is used for square devices and overridden by hello_world in | 4 | This string is used for square devices and overridden by hello_world in |
| 5 | values-round/strings.xml for round devices. | 5 | values-round/strings.xml for round devices. |
| 1 | <resources> | 1 | <resources> |
| 2 | - <string name="app_name">Speaker Identification</string> | 2 | + <string name="app_name">Speaker ID</string> |
| 3 | <string name="start">Start recording</string> | 3 | <string name="start">Start recording</string> |
| 4 | <string name="stop">Stop recording</string> | 4 | <string name="stop">Stop recording</string> |
| 5 | <string name="add">Add speaker</string> | 5 | <string name="add">Add speaker</string> |
| 1 | <resources> | 1 | <resources> |
| 2 | - <string name="app_name">Next-gen Kaldi: TTS</string> | 2 | + <string name="app_name">TTS</string> |
| 3 | <string name="sid_label">Speaker ID</string> | 3 | <string name="sid_label">Speaker ID</string> |
| 4 | <string name="sid_hint">0</string> | 4 | <string name="sid_hint">0</string> |
| 5 | <string name="speed_label">Speech speed (large->fast)</string> | 5 | <string name="speed_label">Speech speed (large->fast)</string> |
| 1 | <resources> | 1 | <resources> |
| 2 | - <string name="app_name">Next-gen Kaldi: SileroVAD</string> | 2 | + <string name="app_name">VAD</string> |
| 3 | 3 | ||
| 4 | <string name="hint">Click the Start button to play Silero VAD with Next-gen Kaldi.</string> | 4 | <string name="hint">Click the Start button to play Silero VAD with Next-gen Kaldi.</string> |
| 5 | <string name="start">Start</string> | 5 | <string name="start">Start</string> |
| 1 | <resources> | 1 | <resources> |
| 2 | - <string name="app_name">ASR with Next-gen Kaldi</string> | 2 | + <string name="app_name">VAD-ASR</string> |
| 3 | <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. | 3 | <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. |
| 4 | \n | 4 | \n |
| 5 | \n\n\n | 5 | \n\n\n |
| @@ -46,7 +46,30 @@ def get_models(): | @@ -46,7 +46,30 @@ def get_models(): | ||
| 46 | ), | 46 | ), |
| 47 | ] | 47 | ] |
| 48 | 48 | ||
| 49 | - return icefall_models | 49 | + ced_models = [ |
| 50 | + AudioTaggingModel( | ||
| 51 | + model_name="sherpa-onnx-ced-tiny-audio-tagging-2024-04-19", | ||
| 52 | + idx=2, | ||
| 53 | + short_name="ced_tiny", | ||
| 54 | + ), | ||
| 55 | + AudioTaggingModel( | ||
| 56 | + model_name="sherpa-onnx-ced-mini-audio-tagging-2024-04-19", | ||
| 57 | + idx=3, | ||
| 58 | + short_name="ced_mini", | ||
| 59 | + ), | ||
| 60 | + AudioTaggingModel( | ||
| 61 | + model_name="sherpa-onnx-ced-small-audio-tagging-2024-04-19", | ||
| 62 | + idx=4, | ||
| 63 | + short_name="ced_small", | ||
| 64 | + ), | ||
| 65 | + AudioTaggingModel( | ||
| 66 | + model_name="sherpa-onnx-ced-base-audio-tagging-2024-04-19", | ||
| 67 | + idx=5, | ||
| 68 | + short_name="ced_base", | ||
| 69 | + ), | ||
| 70 | + ] | ||
| 71 | + | ||
| 72 | + return icefall_models + ced_models | ||
| 50 | 73 | ||
| 51 | 74 | ||
| 52 | def main(): | 75 | def main(): |
scripts/ced/run.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# | ||
| 3 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 4 | + | ||
| 5 | +set -ex | ||
| 6 | + | ||
| 7 | +function install_dependencies() { | ||
| 8 | + pip install -qq torch==2.1.0+cpu torchaudio==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html | ||
| 9 | + pip install -qq onnx onnxruntime==1.17.1 | ||
| 10 | + | ||
| 11 | + pip install -r ./requirements.txt | ||
| 12 | +} | ||
| 13 | + | ||
| 14 | +git clone https://github.com/RicherMans/CED | ||
| 15 | +pushd CED | ||
| 16 | + | ||
| 17 | +install_dependencies | ||
| 18 | + | ||
| 19 | +models=( | ||
| 20 | +tiny | ||
| 21 | +mini | ||
| 22 | +small | ||
| 23 | +base | ||
| 24 | +) | ||
| 25 | + | ||
| 26 | +for m in ${models[@]}; do | ||
| 27 | + python3 ./export_onnx.py -m ced_$m | ||
| 28 | +done | ||
| 29 | + | ||
| 30 | +ls -lh *.onnx | ||
| 31 | + | ||
| 32 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2 | ||
| 33 | + | ||
| 34 | +tar xvf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2 | ||
| 35 | +rm sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2 | ||
| 36 | +src=sherpa-onnx-zipformer-small-audio-tagging-2024-04-15 | ||
| 37 | + | ||
| 38 | +cat >README.md <<EOF | ||
| 39 | +# Introduction | ||
| 40 | + | ||
| 41 | +Models in this repo are converted from | ||
| 42 | +https://github.com/RicherMans/CED | ||
| 43 | +EOF | ||
| 44 | + | ||
| 45 | +for m in ${models[@]}; do | ||
| 46 | + d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19 | ||
| 47 | + | ||
| 48 | + mkdir -p $d | ||
| 49 | + | ||
| 50 | + cp -v README.md $d | ||
| 51 | + cp -v $src/class_labels_indices.csv $d | ||
| 52 | + cp -a $src/test_wavs $d | ||
| 53 | + cp -v ced_$m.onnx $d/model.onnx | ||
| 54 | + cp -v ced_$m.int8.onnx $d/model.int8.onnx | ||
| 55 | + echo "----------$m----------" | ||
| 56 | + ls -lh $d | ||
| 57 | + echo "----------------------" | ||
| 58 | + tar cjvf $d.tar.bz2 $d | ||
| 59 | + mv $d.tar.bz2 ../../.. | ||
| 60 | + mv $d ../../../ | ||
| 61 | +done | ||
| 62 | + | ||
| 63 | +rm -rf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15 | ||
| 64 | + | ||
| 65 | +cd ../../.. | ||
| 66 | + | ||
| 67 | +ls -lh *.tar.bz2 | ||
| 68 | +echo "=======" | ||
| 69 | +ls -lh |
| @@ -1223,6 +1223,7 @@ const SherpaOnnxAudioTagging *SherpaOnnxCreateAudioTagging( | @@ -1223,6 +1223,7 @@ const SherpaOnnxAudioTagging *SherpaOnnxCreateAudioTagging( | ||
| 1223 | const SherpaOnnxAudioTaggingConfig *config) { | 1223 | const SherpaOnnxAudioTaggingConfig *config) { |
| 1224 | sherpa_onnx::AudioTaggingConfig ac; | 1224 | sherpa_onnx::AudioTaggingConfig ac; |
| 1225 | ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, ""); | 1225 | ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, ""); |
| 1226 | + ac.model.ced = SHERPA_ONNX_OR(config->model.ced, ""); | ||
| 1226 | ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); | 1227 | ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); |
| 1227 | ac.model.debug = config->model.debug; | 1228 | ac.model.debug = config->model.debug; |
| 1228 | ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); | 1229 | ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); |
| @@ -1100,6 +1100,7 @@ SHERPA_ONNX_API typedef struct | @@ -1100,6 +1100,7 @@ SHERPA_ONNX_API typedef struct | ||
| 1100 | 1100 | ||
| 1101 | SHERPA_ONNX_API typedef struct SherpaOnnxAudioTaggingModelConfig { | 1101 | SHERPA_ONNX_API typedef struct SherpaOnnxAudioTaggingModelConfig { |
| 1102 | SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer; | 1102 | SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer; |
| 1103 | + const char *ced; | ||
| 1103 | int32_t num_threads; | 1104 | int32_t num_threads; |
| 1104 | int32_t debug; // true to print debug information of the model | 1105 | int32_t debug; // true to print debug information of the model |
| 1105 | const char *provider; | 1106 | const char *provider; |
| @@ -117,6 +117,7 @@ list(APPEND sources | @@ -117,6 +117,7 @@ list(APPEND sources | ||
| 117 | audio-tagging-label-file.cc | 117 | audio-tagging-label-file.cc |
| 118 | audio-tagging-model-config.cc | 118 | audio-tagging-model-config.cc |
| 119 | audio-tagging.cc | 119 | audio-tagging.cc |
| 120 | + offline-ced-model.cc | ||
| 120 | offline-zipformer-audio-tagging-model-config.cc | 121 | offline-zipformer-audio-tagging-model-config.cc |
| 121 | offline-zipformer-audio-tagging-model.cc | 122 | offline-zipformer-audio-tagging-model.cc |
| 122 | ) | 123 | ) |
sherpa-onnx/csrc/audio-tagging-ced-impl.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/audio-tagging-ced-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_ | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <memory> | ||
| 10 | +#include <utility> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#if __ANDROID_API__ >= 9 | ||
| 14 | +#include "android/asset_manager.h" | ||
| 15 | +#include "android/asset_manager_jni.h" | ||
| 16 | +#endif | ||
| 17 | + | ||
| 18 | +#include "sherpa-onnx/csrc/audio-tagging-impl.h" | ||
| 19 | +#include "sherpa-onnx/csrc/audio-tagging-label-file.h" | ||
| 20 | +#include "sherpa-onnx/csrc/audio-tagging.h" | ||
| 21 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 22 | +#include "sherpa-onnx/csrc/math.h" | ||
| 23 | +#include "sherpa-onnx/csrc/offline-ced-model.h" | ||
| 24 | + | ||
| 25 | +namespace sherpa_onnx { | ||
| 26 | + | ||
| 27 | +class AudioTaggingCEDImpl : public AudioTaggingImpl { | ||
| 28 | + public: | ||
| 29 | + explicit AudioTaggingCEDImpl(const AudioTaggingConfig &config) | ||
| 30 | + : config_(config), model_(config.model), labels_(config.labels) { | ||
| 31 | + if (model_.NumEventClasses() != labels_.NumEventClasses()) { | ||
| 32 | + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)", | ||
| 33 | + model_.NumEventClasses(), labels_.NumEventClasses()); | ||
| 34 | + exit(-1); | ||
| 35 | + } | ||
| 36 | + } | ||
| 37 | + | ||
| 38 | +#if __ANDROID_API__ >= 9 | ||
| 39 | + explicit AudioTaggingCEDImpl(AAssetManager *mgr, | ||
| 40 | + const AudioTaggingConfig &config) | ||
| 41 | + : config_(config), | ||
| 42 | + model_(mgr, config.model), | ||
| 43 | + labels_(mgr, config.labels) { | ||
| 44 | + if (model_.NumEventClasses() != labels_.NumEventClasses()) { | ||
| 45 | + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)", | ||
| 46 | + model_.NumEventClasses(), labels_.NumEventClasses()); | ||
| 47 | + exit(-1); | ||
| 48 | + } | ||
| 49 | + } | ||
| 50 | +#endif | ||
| 51 | + | ||
| 52 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 53 | + return std::make_unique<OfflineStream>(CEDTag{}); | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + std::vector<AudioEvent> Compute(OfflineStream *s, | ||
| 57 | + int32_t top_k = -1) const override { | ||
| 58 | + if (top_k < 0) { | ||
| 59 | + top_k = config_.top_k; | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + int32_t num_event_classes = model_.NumEventClasses(); | ||
| 63 | + | ||
| 64 | + if (top_k > num_event_classes) { | ||
| 65 | + top_k = num_event_classes; | ||
| 66 | + } | ||
| 67 | + | ||
| 68 | + auto memory_info = | ||
| 69 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 70 | + | ||
| 71 | + // WARNING(fangjun): It is fixed to 64 for CED models | ||
| 72 | + int32_t feat_dim = 64; | ||
| 73 | + std::vector<float> f = s->GetFrames(); | ||
| 74 | + | ||
| 75 | + int32_t num_frames = f.size() / feat_dim; | ||
| 76 | + assert(feat_dim * num_frames == f.size()); | ||
| 77 | + | ||
| 78 | + std::array<int64_t, 3> shape = {1, num_frames, feat_dim}; | ||
| 79 | + | ||
| 80 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), | ||
| 81 | + shape.data(), shape.size()); | ||
| 82 | + | ||
| 83 | + Ort::Value probs = model_.Forward(std::move(x)); | ||
| 84 | + | ||
| 85 | + const float *p = probs.GetTensorData<float>(); | ||
| 86 | + | ||
| 87 | + std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k); | ||
| 88 | + | ||
| 89 | + std::vector<AudioEvent> ans(top_k); | ||
| 90 | + | ||
| 91 | + int32_t i = 0; | ||
| 92 | + | ||
| 93 | + for (int32_t index : top_k_indexes) { | ||
| 94 | + ans[i].name = labels_.GetEventName(index); | ||
| 95 | + ans[i].index = index; | ||
| 96 | + ans[i].prob = p[index]; | ||
| 97 | + i += 1; | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + return ans; | ||
| 101 | + } | ||
| 102 | + | ||
| 103 | + private: | ||
| 104 | + AudioTaggingConfig config_; | ||
| 105 | + OfflineCEDModel model_; | ||
| 106 | + AudioTaggingLabels labels_; | ||
| 107 | +}; | ||
| 108 | + | ||
| 109 | +} // namespace sherpa_onnx | ||
| 110 | + | ||
| 111 | +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_ |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "android/asset_manager_jni.h" | 11 | #include "android/asset_manager_jni.h" |
| 12 | #endif | 12 | #endif |
| 13 | 13 | ||
| 14 | +#include "sherpa-onnx/csrc/audio-tagging-ced-impl.h" | ||
| 14 | #include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h" | 15 | #include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h" |
| 15 | #include "sherpa-onnx/csrc/macros.h" | 16 | #include "sherpa-onnx/csrc/macros.h" |
| 16 | 17 | ||
| @@ -20,6 +21,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create( | @@ -20,6 +21,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create( | ||
| 20 | const AudioTaggingConfig &config) { | 21 | const AudioTaggingConfig &config) { |
| 21 | if (!config.model.zipformer.model.empty()) { | 22 | if (!config.model.zipformer.model.empty()) { |
| 22 | return std::make_unique<AudioTaggingZipformerImpl>(config); | 23 | return std::make_unique<AudioTaggingZipformerImpl>(config); |
| 24 | + } else if (!config.model.ced.empty()) { | ||
| 25 | + return std::make_unique<AudioTaggingCEDImpl>(config); | ||
| 23 | } | 26 | } |
| 24 | 27 | ||
| 25 | SHERPA_ONNX_LOG( | 28 | SHERPA_ONNX_LOG( |
| @@ -32,6 +35,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create( | @@ -32,6 +35,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create( | ||
| 32 | AAssetManager *mgr, const AudioTaggingConfig &config) { | 35 | AAssetManager *mgr, const AudioTaggingConfig &config) { |
| 33 | if (!config.model.zipformer.model.empty()) { | 36 | if (!config.model.zipformer.model.empty()) { |
| 34 | return std::make_unique<AudioTaggingZipformerImpl>(mgr, config); | 37 | return std::make_unique<AudioTaggingZipformerImpl>(mgr, config); |
| 38 | + } else if (!config.model.ced.empty()) { | ||
| 39 | + return std::make_unique<AudioTaggingCEDImpl>(mgr, config); | ||
| 35 | } | 40 | } |
| 36 | 41 | ||
| 37 | SHERPA_ONNX_LOG( | 42 | SHERPA_ONNX_LOG( |
| @@ -4,11 +4,18 @@ | @@ -4,11 +4,18 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/audio-tagging-model-config.h" | 5 | #include "sherpa-onnx/csrc/audio-tagging-model-config.h" |
| 6 | 6 | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 7 | namespace sherpa_onnx { | 10 | namespace sherpa_onnx { |
| 8 | 11 | ||
| 9 | void AudioTaggingModelConfig::Register(ParseOptions *po) { | 12 | void AudioTaggingModelConfig::Register(ParseOptions *po) { |
| 10 | zipformer.Register(po); | 13 | zipformer.Register(po); |
| 11 | 14 | ||
| 15 | + po->Register("ced-model", &ced, | ||
| 16 | + "Path to CED model. Only need to pass one of --zipformer-model " | ||
| 17 | + "or --ced-model"); | ||
| 18 | + | ||
| 12 | po->Register("num-threads", &num_threads, | 19 | po->Register("num-threads", &num_threads, |
| 13 | "Number of threads to run the neural network"); | 20 | "Number of threads to run the neural network"); |
| 14 | 21 | ||
| @@ -24,6 +31,16 @@ bool AudioTaggingModelConfig::Validate() const { | @@ -24,6 +31,16 @@ bool AudioTaggingModelConfig::Validate() const { | ||
| 24 | return false; | 31 | return false; |
| 25 | } | 32 | } |
| 26 | 33 | ||
| 34 | + if (!ced.empty() && !FileExists(ced)) { | ||
| 35 | + SHERPA_ONNX_LOGE("CED model file %s does not exist", ced.c_str()); | ||
| 36 | + return false; | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + if (zipformer.model.empty() && ced.empty()) { | ||
| 40 | + SHERPA_ONNX_LOGE("Please provide either --zipformer-model or --ced-model"); | ||
| 41 | + return false; | ||
| 42 | + } | ||
| 43 | + | ||
| 27 | return true; | 44 | return true; |
| 28 | } | 45 | } |
| 29 | 46 | ||
| @@ -32,6 +49,7 @@ std::string AudioTaggingModelConfig::ToString() const { | @@ -32,6 +49,7 @@ std::string AudioTaggingModelConfig::ToString() const { | ||
| 32 | 49 | ||
| 33 | os << "AudioTaggingModelConfig("; | 50 | os << "AudioTaggingModelConfig("; |
| 34 | os << "zipformer=" << zipformer.ToString() << ", "; | 51 | os << "zipformer=" << zipformer.ToString() << ", "; |
| 52 | + os << "ced=\"" << ced << "\", "; | ||
| 35 | os << "num_threads=" << num_threads << ", "; | 53 | os << "num_threads=" << num_threads << ", "; |
| 36 | os << "debug=" << (debug ? "True" : "False") << ", "; | 54 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| 37 | os << "provider=\"" << provider << "\")"; | 55 | os << "provider=\"" << provider << "\")"; |
| @@ -13,6 +13,7 @@ namespace sherpa_onnx { | @@ -13,6 +13,7 @@ namespace sherpa_onnx { | ||
| 13 | 13 | ||
| 14 | struct AudioTaggingModelConfig { | 14 | struct AudioTaggingModelConfig { |
| 15 | struct OfflineZipformerAudioTaggingModelConfig zipformer; | 15 | struct OfflineZipformerAudioTaggingModelConfig zipformer; |
| 16 | + std::string ced; | ||
| 16 | 17 | ||
| 17 | int32_t num_threads = 1; | 18 | int32_t num_threads = 1; |
| 18 | bool debug = false; | 19 | bool debug = false; |
| @@ -22,8 +23,10 @@ struct AudioTaggingModelConfig { | @@ -22,8 +23,10 @@ struct AudioTaggingModelConfig { | ||
| 22 | 23 | ||
| 23 | AudioTaggingModelConfig( | 24 | AudioTaggingModelConfig( |
| 24 | const OfflineZipformerAudioTaggingModelConfig &zipformer, | 25 | const OfflineZipformerAudioTaggingModelConfig &zipformer, |
| 25 | - int32_t num_threads, bool debug, const std::string &provider) | 26 | + const std::string &ced, int32_t num_threads, bool debug, |
| 27 | + const std::string &provider) | ||
| 26 | : zipformer(zipformer), | 28 | : zipformer(zipformer), |
| 29 | + ced(ced), | ||
| 27 | num_threads(num_threads), | 30 | num_threads(num_threads), |
| 28 | debug(debug), | 31 | debug(debug), |
| 29 | provider(provider) {} | 32 | provider(provider) {} |
| @@ -4,6 +4,8 @@ | @@ -4,6 +4,8 @@ | ||
| 4 | #ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ | 4 | #ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ |
| 5 | #define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ | 5 | #define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ |
| 6 | 6 | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 7 | #include <memory> | 9 | #include <memory> |
| 8 | #include <utility> | 10 | #include <utility> |
| 9 | #include <vector> | 11 | #include <vector> |
| @@ -72,6 +74,8 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl { | @@ -72,6 +74,8 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl { | ||
| 72 | 74 | ||
| 73 | int32_t num_frames = f.size() / feat_dim; | 75 | int32_t num_frames = f.size() / feat_dim; |
| 74 | 76 | ||
| 77 | + assert(feat_dim * num_frames == f.size()); | ||
| 78 | + | ||
| 75 | std::array<int64_t, 3> shape = {1, num_frames, feat_dim}; | 79 | std::array<int64_t, 3> shape = {1, num_frames, feat_dim}; |
| 76 | 80 | ||
| 77 | Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), | 81 | Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), |
| @@ -24,7 +24,8 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { | @@ -24,7 +24,8 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { | ||
| 24 | "inside the feature extractor"); | 24 | "inside the feature extractor"); |
| 25 | 25 | ||
| 26 | po->Register("feat-dim", &feature_dim, | 26 | po->Register("feat-dim", &feature_dim, |
| 27 | - "Feature dimension. Must match the one expected by the model."); | 27 | + "Feature dimension. Must match the one expected by the model. " |
| 28 | + "Not used by whisper and CED models"); | ||
| 28 | 29 | ||
| 29 | po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins"); | 30 | po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins"); |
| 30 | 31 |
sherpa-onnx/csrc/offline-ced-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-ced-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-ced-model.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 11 | +#include "sherpa-onnx/csrc/session.h" | ||
| 12 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 13 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class OfflineCEDModel::Impl { | ||
| 18 | + public: | ||
| 19 | + explicit Impl(const AudioTaggingModelConfig &config) | ||
| 20 | + : config_(config), | ||
| 21 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 22 | + sess_opts_(GetSessionOptions(config)), | ||
| 23 | + allocator_{} { | ||
| 24 | + auto buf = ReadFile(config_.ced); | ||
| 25 | + Init(buf.data(), buf.size()); | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | +#if __ANDROID_API__ >= 9 | ||
| 29 | + Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config) | ||
| 30 | + : config_(config), | ||
| 31 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 32 | + sess_opts_(GetSessionOptions(config)), | ||
| 33 | + allocator_{} { | ||
| 34 | + auto buf = ReadFile(mgr, config_.ced); | ||
| 35 | + Init(buf.data(), buf.size()); | ||
| 36 | + } | ||
| 37 | +#endif | ||
| 38 | + | ||
| 39 | + Ort::Value Forward(Ort::Value features) { | ||
| 40 | + features = Transpose12(allocator_, &features); | ||
| 41 | + | ||
| 42 | + auto ans = sess_->Run({}, input_names_ptr_.data(), &features, 1, | ||
| 43 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 44 | + return std::move(ans[0]); | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + int32_t NumEventClasses() const { return num_event_classes_; } | ||
| 48 | + | ||
| 49 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 50 | + | ||
| 51 | + private: | ||
| 52 | + void Init(void *model_data, size_t model_data_length) { | ||
| 53 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 54 | + sess_opts_); | ||
| 55 | + | ||
| 56 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 57 | + | ||
| 58 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 59 | + | ||
| 60 | + // get meta data | ||
| 61 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 62 | + if (config_.debug) { | ||
| 63 | + std::ostringstream os; | ||
| 64 | + PrintModelMetadata(os, meta_data); | ||
| 65 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 66 | + } | ||
| 67 | + | ||
| 68 | + // get num_event_classes from the output[0].shape, | ||
| 69 | + // which is (N, num_event_classes) | ||
| 70 | + num_event_classes_ = | ||
| 71 | + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + private: | ||
| 75 | + AudioTaggingModelConfig config_; | ||
| 76 | + Ort::Env env_; | ||
| 77 | + Ort::SessionOptions sess_opts_; | ||
| 78 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 79 | + | ||
| 80 | + std::unique_ptr<Ort::Session> sess_; | ||
| 81 | + | ||
| 82 | + std::vector<std::string> input_names_; | ||
| 83 | + std::vector<const char *> input_names_ptr_; | ||
| 84 | + | ||
| 85 | + std::vector<std::string> output_names_; | ||
| 86 | + std::vector<const char *> output_names_ptr_; | ||
| 87 | + | ||
| 88 | + int32_t num_event_classes_ = 0; | ||
| 89 | +}; | ||
| 90 | + | ||
| 91 | +OfflineCEDModel::OfflineCEDModel(const AudioTaggingModelConfig &config) | ||
| 92 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 93 | + | ||
| 94 | +#if __ANDROID_API__ >= 9 | ||
| 95 | +OfflineCEDModel::OfflineCEDModel(AAssetManager *mgr, | ||
| 96 | + const AudioTaggingModelConfig &config) | ||
| 97 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 98 | +#endif | ||
| 99 | + | ||
| 100 | +OfflineCEDModel::~OfflineCEDModel() = default; | ||
| 101 | + | ||
| 102 | +Ort::Value OfflineCEDModel::Forward(Ort::Value features) const { | ||
| 103 | + return impl_->Forward(std::move(features)); | ||
| 104 | +} | ||
| 105 | + | ||
| 106 | +int32_t OfflineCEDModel::NumEventClasses() const { | ||
| 107 | + return impl_->NumEventClasses(); | ||
| 108 | +} | ||
| 109 | + | ||
| 110 | +OrtAllocator *OfflineCEDModel::Allocator() const { return impl_->Allocator(); } | ||
| 111 | + | ||
| 112 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-ced-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-ced-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_ | ||
| 6 | +#include <memory> | ||
| 7 | +#include <utility> | ||
| 8 | + | ||
| 9 | +#if __ANDROID_API__ >= 9 | ||
| 10 | +#include "android/asset_manager.h" | ||
| 11 | +#include "android/asset_manager_jni.h" | ||
| 12 | +#endif | ||
| 13 | + | ||
| 14 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 15 | +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +/** This class implements the CED model from | ||
| 20 | + * https://github.com/RicherMans/CED/blob/main/export_onnx.py | ||
| 21 | + */ | ||
| 22 | +class OfflineCEDModel { | ||
| 23 | + public: | ||
| 24 | + explicit OfflineCEDModel(const AudioTaggingModelConfig &config); | ||
| 25 | + | ||
| 26 | +#if __ANDROID_API__ >= 9 | ||
| 27 | + OfflineCEDModel(AAssetManager *mgr, const AudioTaggingModelConfig &config); | ||
| 28 | +#endif | ||
| 29 | + | ||
| 30 | + ~OfflineCEDModel(); | ||
| 31 | + | ||
| 32 | + /** Run the forward method of the model. | ||
| 33 | + * | ||
| 34 | + * @param features A tensor of shape (N, T, C). | ||
| 35 | + * | ||
| 36 | + * @return Return a tensor | ||
| 37 | + * - probs: A 2-D tensor of shape (N, num_event_classes). | ||
| 38 | + */ | ||
| 39 | + Ort::Value Forward(Ort::Value features) const; | ||
| 40 | + | ||
| 41 | + /** Return the number of event classes of the model | ||
| 42 | + */ | ||
| 43 | + int32_t NumEventClasses() const; | ||
| 44 | + | ||
| 45 | + /** Return an allocator for allocating memory | ||
| 46 | + */ | ||
| 47 | + OrtAllocator *Allocator() const; | ||
| 48 | + | ||
| 49 | + private: | ||
| 50 | + class Impl; | ||
| 51 | + std::unique_ptr<Impl> impl_; | ||
| 52 | +}; | ||
| 53 | + | ||
| 54 | +} // namespace sherpa_onnx | ||
| 55 | + | ||
| 56 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_ |
| @@ -92,15 +92,32 @@ class OfflineStream::Impl { | @@ -92,15 +92,32 @@ class OfflineStream::Impl { | ||
| 92 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 92 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| 93 | } | 93 | } |
| 94 | 94 | ||
| 95 | - Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph) | ||
| 96 | - : context_graph_(context_graph) { | 95 | + explicit Impl(WhisperTag /*tag*/) { |
| 97 | config_.normalize_samples = true; | 96 | config_.normalize_samples = true; |
| 98 | opts_.frame_opts.samp_freq = 16000; | 97 | opts_.frame_opts.samp_freq = 16000; |
| 99 | - opts_.mel_opts.num_bins = 80; | 98 | + opts_.mel_opts.num_bins = 80; // not used |
| 100 | whisper_fbank_ = | 99 | whisper_fbank_ = |
| 101 | std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts); | 100 | std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts); |
| 102 | } | 101 | } |
| 103 | 102 | ||
| 103 | + explicit Impl(CEDTag /*tag*/) { | ||
| 104 | + // see | ||
| 105 | + // https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py | ||
| 106 | + | ||
| 107 | + opts_.frame_opts.frame_length_ms = 32; | ||
| 108 | + opts_.frame_opts.dither = 0; | ||
| 109 | + opts_.frame_opts.preemph_coeff = 0; | ||
| 110 | + opts_.frame_opts.remove_dc_offset = false; | ||
| 111 | + opts_.frame_opts.window_type = "hann"; | ||
| 112 | + opts_.frame_opts.snip_edges = false; | ||
| 113 | + | ||
| 114 | + opts_.frame_opts.samp_freq = 16000; // fixed to 16000 | ||
| 115 | + opts_.mel_opts.num_bins = 64; | ||
| 116 | + opts_.mel_opts.high_freq = 8000; | ||
| 117 | + | ||
| 118 | + fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | ||
| 119 | + } | ||
| 120 | + | ||
| 104 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { | 121 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { |
| 105 | if (config_.normalize_samples) { | 122 | if (config_.normalize_samples) { |
| 106 | AcceptWaveformImpl(sampling_rate, waveform, n); | 123 | AcceptWaveformImpl(sampling_rate, waveform, n); |
| @@ -233,9 +250,10 @@ OfflineStream::OfflineStream( | @@ -233,9 +250,10 @@ OfflineStream::OfflineStream( | ||
| 233 | ContextGraphPtr context_graph /*= nullptr*/) | 250 | ContextGraphPtr context_graph /*= nullptr*/) |
| 234 | : impl_(std::make_unique<Impl>(config, context_graph)) {} | 251 | : impl_(std::make_unique<Impl>(config, context_graph)) {} |
| 235 | 252 | ||
| 236 | -OfflineStream::OfflineStream(WhisperTag tag, | ||
| 237 | - ContextGraphPtr context_graph /*= {}*/) | ||
| 238 | - : impl_(std::make_unique<Impl>(tag, context_graph)) {} | 253 | +OfflineStream::OfflineStream(WhisperTag tag) |
| 254 | + : impl_(std::make_unique<Impl>(tag)) {} | ||
| 255 | + | ||
| 256 | +OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {} | ||
| 239 | 257 | ||
| 240 | OfflineStream::~OfflineStream() = default; | 258 | OfflineStream::~OfflineStream() = default; |
| 241 | 259 |
| @@ -67,13 +67,15 @@ struct OfflineFeatureExtractorConfig { | @@ -67,13 +67,15 @@ struct OfflineFeatureExtractorConfig { | ||
| 67 | }; | 67 | }; |
| 68 | 68 | ||
| 69 | struct WhisperTag {}; | 69 | struct WhisperTag {}; |
| 70 | +struct CEDTag {}; | ||
| 70 | 71 | ||
| 71 | class OfflineStream { | 72 | class OfflineStream { |
| 72 | public: | 73 | public: |
| 73 | explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, | 74 | explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, |
| 74 | ContextGraphPtr context_graph = {}); | 75 | ContextGraphPtr context_graph = {}); |
| 75 | 76 | ||
| 76 | - explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {}); | 77 | + explicit OfflineStream(WhisperTag tag); |
| 78 | + explicit OfflineStream(CEDTag tag); | ||
| 77 | ~OfflineStream(); | 79 | ~OfflineStream(); |
| 78 | 80 | ||
| 79 | /** | 81 | /** |
| @@ -31,6 +31,12 @@ static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) { | @@ -31,6 +31,12 @@ static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) { | ||
| 31 | ans.model.zipformer.model = p; | 31 | ans.model.zipformer.model = p; |
| 32 | env->ReleaseStringUTFChars(s, p); | 32 | env->ReleaseStringUTFChars(s, p); |
| 33 | 33 | ||
| 34 | + fid = env->GetFieldID(model_cls, "ced", "Ljava/lang/String;"); | ||
| 35 | + s = (jstring)env->GetObjectField(model, fid); | ||
| 36 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 37 | + ans.model.ced = p; | ||
| 38 | + env->ReleaseStringUTFChars(s, p); | ||
| 39 | + | ||
| 34 | fid = env->GetFieldID(model_cls, "numThreads", "I"); | 40 | fid = env->GetFieldID(model_cls, "numThreads", "I"); |
| 35 | ans.model.num_threads = env->GetIntField(model, fid); | 41 | ans.model.num_threads = env->GetIntField(model, fid); |
| 36 | 42 |
| @@ -27,10 +27,11 @@ static void PybindAudioTaggingModelConfig(py::module *m) { | @@ -27,10 +27,11 @@ static void PybindAudioTaggingModelConfig(py::module *m) { | ||
| 27 | 27 | ||
| 28 | py::class_<PyClass>(*m, "AudioTaggingModelConfig") | 28 | py::class_<PyClass>(*m, "AudioTaggingModelConfig") |
| 29 | .def(py::init<>()) | 29 | .def(py::init<>()) |
| 30 | - .def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t, | ||
| 31 | - bool, const std::string &>(), | ||
| 32 | - py::arg("zipformer"), py::arg("num_threads") = 1, | ||
| 33 | - py::arg("debug") = false, py::arg("provider") = "cpu") | 30 | + .def(py::init<const OfflineZipformerAudioTaggingModelConfig &, |
| 31 | + const std::string &, int32_t, bool, const std::string &>(), | ||
| 32 | + py::arg("zipformer"), py::arg("ced") = "", | ||
| 33 | + py::arg("num_threads") = 1, py::arg("debug") = false, | ||
| 34 | + py::arg("provider") = "cpu") | ||
| 34 | .def_readwrite("zipformer", &PyClass::zipformer) | 35 | .def_readwrite("zipformer", &PyClass::zipformer) |
| 35 | .def_readwrite("num_threads", &PyClass::num_threads) | 36 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 36 | .def_readwrite("debug", &PyClass::debug) | 37 | .def_readwrite("debug", &PyClass::debug) |
-
请 注册 或 登录 后发表评论