Committed by
GitHub
Support spoken language identification with whisper (#694)
正在显示
36 个修改的文件
包含
1173 行增加
和
200 行删除
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -e | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +echo "EXE is $EXE" | ||
| 12 | +echo "PATH: $PATH" | ||
| 13 | + | ||
| 14 | +which $EXE | ||
| 15 | + | ||
| 16 | +names=( | ||
| 17 | +tiny | ||
| 18 | +base | ||
| 19 | +small | ||
| 20 | +medium | ||
| 21 | +) | ||
| 22 | + | ||
| 23 | +# all_language_codes=bo,ml,tt,fa,sl,bg,sn,sr,tl,km,ln,mr,hr,eu,ro,ba,bs,pl,as,nn,sk,ko,oc,ar,uz,pa,tg,mk,kk,hi,ha,uk,is,de,el,ja,yo,be,so,tk,id,sa,ru,yi,en,am,cs,ne,la,sv,su,pt,mi,ca,sd,hy,haw,fi,et,kn,da,lt,it,nl,he,mg,ur,tr,af,br,bn,ta,no,my,si,mt,th,gl,sw,mn,jw,ms,ps,fo,ka,hu,zh,ht,az,fr,lo,sq,gu,cy,lv,es,lb,te,vi | ||
| 24 | + | ||
| 25 | +log "Download test waves" | ||
| 26 | +waves=( | ||
| 27 | +ar-arabic.wav | ||
| 28 | +bg-bulgarian.wav | ||
| 29 | +cs-czech.wav | ||
| 30 | +da-danish.wav | ||
| 31 | +de-german.wav | ||
| 32 | +el-greek.wav | ||
| 33 | +en-english.wav | ||
| 34 | +es-spanish.wav | ||
| 35 | +fa-persian.wav | ||
| 36 | +fi-finnish.wav | ||
| 37 | +fr-french.wav | ||
| 38 | +hi-hindi.wav | ||
| 39 | +hr-croatian.wav | ||
| 40 | +id-indonesian.wav | ||
| 41 | +it-italian.wav | ||
| 42 | +ja-japanese.wav | ||
| 43 | +ko-korean.wav | ||
| 44 | +nl-dutch.wav | ||
| 45 | +no-norwegian.wav | ||
| 46 | +po-polish.wav | ||
| 47 | +pt-portuguese.wav | ||
| 48 | +ro-romanian.wav | ||
| 49 | +ru-russian.wav | ||
| 50 | +sk-slovak.wav | ||
| 51 | +sv-swedish.wav | ||
| 52 | +ta-tamil.wav | ||
| 53 | +tl-tagalog.wav | ||
| 54 | +tr-turkish.wav | ||
| 55 | +uk-ukrainian.wav | ||
| 56 | +zh-chinese.wav | ||
| 57 | +) | ||
| 58 | + | ||
| 59 | +for wav in ${waves[@]}; do | ||
| 60 | + echo "Downloading $wav" | ||
| 61 | + curl -SL -O https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/$wav | ||
| 62 | + ls -lh *.wav | ||
| 63 | +done | ||
| 64 | + | ||
| 65 | +for name in ${names[@]}; do | ||
| 66 | + log "------------------------------------------------------------" | ||
| 67 | + log "Run $name" | ||
| 68 | + log "------------------------------------------------------------" | ||
| 69 | + | ||
| 70 | + repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-whisper-$name | ||
| 71 | + log "Start testing ${repo_url}" | ||
| 72 | + repo=$(basename $repo_url) | ||
| 73 | + log "Download pretrained model and test-data from $repo_url" | ||
| 74 | + | ||
| 75 | + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 76 | + pushd $repo | ||
| 77 | + git lfs pull --include "*.onnx" | ||
| 78 | + # git lfs pull --include "*.ort" | ||
| 79 | + ls -lh *.onnx | ||
| 80 | + popd | ||
| 81 | + | ||
| 82 | + for wav in ${waves[@]}; do | ||
| 83 | + log "test fp32 onnx" | ||
| 84 | + | ||
| 85 | + time $EXE \ | ||
| 86 | + --whisper-encoder=$repo/${name}-encoder.onnx \ | ||
| 87 | + --whisper-decoder=$repo/${name}-decoder.onnx \ | ||
| 88 | + $wav | ||
| 89 | + | ||
| 90 | + log "test int8 onnx" | ||
| 91 | + | ||
| 92 | + time $EXE \ | ||
| 93 | + --whisper-encoder=$repo/${name}-encoder.int8.onnx \ | ||
| 94 | + --whisper-decoder=$repo/${name}-decoder.int8.onnx \ | ||
| 95 | + $wav | ||
| 96 | + done | ||
| 97 | + rm -rf $repo | ||
| 98 | +done |
| @@ -82,7 +82,6 @@ jobs: | @@ -82,7 +82,6 @@ jobs: | ||
| 82 | env: | 82 | env: |
| 83 | HF_TOKEN: ${{ secrets.HF_TOKEN }} | 83 | HF_TOKEN: ${{ secrets.HF_TOKEN }} |
| 84 | uses: nick-fields/retry@v3 | 84 | uses: nick-fields/retry@v3 |
| 85 | - shell: bash | ||
| 86 | with: | 85 | with: |
| 87 | max_attempts: 20 | 86 | max_attempts: 20 |
| 88 | timeout_seconds: 200 | 87 | timeout_seconds: 200 |
| @@ -21,27 +21,12 @@ jobs: | @@ -21,27 +21,12 @@ jobs: | ||
| 21 | fail-fast: false | 21 | fail-fast: false |
| 22 | matrix: | 22 | matrix: |
| 23 | os: [macos-latest] | 23 | os: [macos-latest] |
| 24 | - python-version: ["cp37", "cp38", "cp39", "cp310", "cp311", "cp312"] | 24 | + python-version: ["cp38", "cp39", "cp310", "cp311", "cp312"] |
| 25 | 25 | ||
| 26 | steps: | 26 | steps: |
| 27 | - uses: actions/checkout@v4 | 27 | - uses: actions/checkout@v4 |
| 28 | 28 | ||
| 29 | - # see https://cibuildwheel.readthedocs.io/en/stable/changelog/ | ||
| 30 | - # for a list of versions | ||
| 31 | - name: Build wheels | 29 | - name: Build wheels |
| 32 | - if: matrix.python-version == 'cp37' | ||
| 33 | - uses: pypa/cibuildwheel@v2.11.4 | ||
| 34 | - env: | ||
| 35 | - CIBW_BUILD: "${{ matrix.python-version}}-* " | ||
| 36 | - CIBW_ENVIRONMENT: SHERPA_ONNX_CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES='arm64'" | ||
| 37 | - CIBW_ARCHS: "arm64" | ||
| 38 | - CIBW_BUILD_VERBOSITY: 3 | ||
| 39 | - | ||
| 40 | - # Don't repair macOS wheels | ||
| 41 | - CIBW_REPAIR_WHEEL_COMMAND_MACOS: "" | ||
| 42 | - | ||
| 43 | - - name: Build wheels | ||
| 44 | - if: matrix.python-version != 'cp37' | ||
| 45 | uses: pypa/cibuildwheel@v2.15.0 | 30 | uses: pypa/cibuildwheel@v2.15.0 |
| 46 | env: | 31 | env: |
| 47 | CIBW_BUILD: "${{ matrix.python-version}}-* " | 32 | CIBW_BUILD: "${{ matrix.python-version}}-* " |
| @@ -92,6 +92,14 @@ jobs: | @@ -92,6 +92,14 @@ jobs: | ||
| 92 | file build/bin/sherpa-onnx | 92 | file build/bin/sherpa-onnx |
| 93 | readelf -d build/bin/sherpa-onnx | 93 | readelf -d build/bin/sherpa-onnx |
| 94 | 94 | ||
| 95 | + - name: Test spoken language identification | ||
| 96 | + shell: bash | ||
| 97 | + run: | | ||
| 98 | + export PATH=$PWD/build/bin:$PATH | ||
| 99 | + export EXE=sherpa-onnx-offline-language-identification | ||
| 100 | + | ||
| 101 | + .github/scripts/test-spoken-language-identification.sh | ||
| 102 | + | ||
| 95 | - name: Test online CTC | 103 | - name: Test online CTC |
| 96 | shell: bash | 104 | shell: bash |
| 97 | run: | | 105 | run: | |
| @@ -116,6 +124,7 @@ jobs: | @@ -116,6 +124,7 @@ jobs: | ||
| 116 | 124 | ||
| 117 | .github/scripts/test-online-paraformer.sh | 125 | .github/scripts/test-online-paraformer.sh |
| 118 | 126 | ||
| 127 | + | ||
| 119 | - name: Test offline Whisper | 128 | - name: Test offline Whisper |
| 120 | shell: bash | 129 | shell: bash |
| 121 | run: | | 130 | run: | |
| @@ -123,6 +123,15 @@ jobs: | @@ -123,6 +123,15 @@ jobs: | ||
| 123 | name: release-${{ matrix.build_type }}-${{ matrix.shared_lib }} | 123 | name: release-${{ matrix.build_type }}-${{ matrix.shared_lib }} |
| 124 | path: build/bin/* | 124 | path: build/bin/* |
| 125 | 125 | ||
| 126 | + - name: Test spoken language identification | ||
| 127 | + if: matrix.build_type != 'Debug' | ||
| 128 | + shell: bash | ||
| 129 | + run: | | ||
| 130 | + export PATH=$PWD/build/bin:$PATH | ||
| 131 | + export EXE=sherpa-onnx-offline-language-identification | ||
| 132 | + | ||
| 133 | + .github/scripts/test-spoken-language-identification.sh | ||
| 134 | + | ||
| 126 | - name: Test transducer kws | 135 | - name: Test transducer kws |
| 127 | shell: bash | 136 | shell: bash |
| 128 | run: | | 137 | run: | |
| @@ -140,6 +149,7 @@ jobs: | @@ -140,6 +149,7 @@ jobs: | ||
| 140 | .github/scripts/test-online-ctc.sh | 149 | .github/scripts/test-online-ctc.sh |
| 141 | 150 | ||
| 142 | - name: Test offline Whisper | 151 | - name: Test offline Whisper |
| 152 | + if: matrix.build_type != 'Debug' | ||
| 143 | shell: bash | 153 | shell: bash |
| 144 | run: | | 154 | run: | |
| 145 | export PATH=$PWD/build/bin:$PATH | 155 | export PATH=$PWD/build/bin:$PATH |
| @@ -102,6 +102,15 @@ jobs: | @@ -102,6 +102,15 @@ jobs: | ||
| 102 | otool -L build/bin/sherpa-onnx | 102 | otool -L build/bin/sherpa-onnx |
| 103 | otool -l build/bin/sherpa-onnx | 103 | otool -l build/bin/sherpa-onnx |
| 104 | 104 | ||
| 105 | + - name: Test spoken language identification | ||
| 106 | + if: matrix.build_type != 'Debug' | ||
| 107 | + shell: bash | ||
| 108 | + run: | | ||
| 109 | + export PATH=$PWD/build/bin:$PATH | ||
| 110 | + export EXE=sherpa-onnx-offline-language-identification | ||
| 111 | + | ||
| 112 | + .github/scripts/test-spoken-language-identification.sh | ||
| 113 | + | ||
| 105 | - name: Test transducer kws | 114 | - name: Test transducer kws |
| 106 | shell: bash | 115 | shell: bash |
| 107 | run: | | 116 | run: | |
| @@ -135,6 +144,7 @@ jobs: | @@ -135,6 +144,7 @@ jobs: | ||
| 135 | .github/scripts/test-online-paraformer.sh | 144 | .github/scripts/test-online-paraformer.sh |
| 136 | 145 | ||
| 137 | - name: Test offline Whisper | 146 | - name: Test offline Whisper |
| 147 | + if: matrix.build_type != 'Debug' | ||
| 138 | shell: bash | 148 | shell: bash |
| 139 | run: | | 149 | run: | |
| 140 | export PATH=$PWD/build/bin:$PATH | 150 | export PATH=$PWD/build/bin:$PATH |
| @@ -68,6 +68,14 @@ jobs: | @@ -68,6 +68,14 @@ jobs: | ||
| 68 | 68 | ||
| 69 | ls -lh ./bin/Release/sherpa-onnx.exe | 69 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 70 | 70 | ||
| 71 | + - name: Test spoken language identification | ||
| 72 | + shell: bash | ||
| 73 | + run: | | ||
| 74 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 75 | + export EXE=sherpa-onnx-offline-language-identification.exe | ||
| 76 | + | ||
| 77 | + .github/scripts/test-spoken-language-identification.sh | ||
| 78 | + | ||
| 71 | - name: Test online CTC | 79 | - name: Test online CTC |
| 72 | shell: bash | 80 | shell: bash |
| 73 | run: | | 81 | run: | |
| @@ -68,6 +68,14 @@ jobs: | @@ -68,6 +68,14 @@ jobs: | ||
| 68 | 68 | ||
| 69 | ls -lh ./bin/Release/sherpa-onnx.exe | 69 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 70 | 70 | ||
| 71 | + - name: Test spoken language identification | ||
| 72 | + shell: bash | ||
| 73 | + run: | | ||
| 74 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 75 | + export EXE=sherpa-onnx-offline-language-identification.exe | ||
| 76 | + | ||
| 77 | + .github/scripts/test-spoken-language-identification.sh | ||
| 78 | + | ||
| 71 | - name: Test online CTC | 79 | - name: Test online CTC |
| 72 | shell: bash | 80 | shell: bash |
| 73 | run: | | 81 | run: | |
| @@ -69,6 +69,14 @@ jobs: | @@ -69,6 +69,14 @@ jobs: | ||
| 69 | 69 | ||
| 70 | ls -lh ./bin/Release/sherpa-onnx.exe | 70 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 71 | 71 | ||
| 72 | + # - name: Test spoken language identification | ||
| 73 | + # shell: bash | ||
| 74 | + # run: | | ||
| 75 | + # export PATH=$PWD/build/bin/Release:$PATH | ||
| 76 | + # export EXE=sherpa-onnx-offline-language-identification.exe | ||
| 77 | + # | ||
| 78 | + # .github/scripts/test-spoken-language-identification.sh | ||
| 79 | + | ||
| 72 | - name: Test online CTC | 80 | - name: Test online CTC |
| 73 | shell: bash | 81 | shell: bash |
| 74 | run: | | 82 | run: | |
| @@ -43,6 +43,50 @@ def enable_alsa(): | @@ -43,6 +43,50 @@ def enable_alsa(): | ||
| 43 | return build_alsa and is_linux() and (is_arm64() or is_x86()) | 43 | return build_alsa and is_linux() and (is_arm64() or is_x86()) |
| 44 | 44 | ||
| 45 | 45 | ||
| 46 | +def get_binaries(): | ||
| 47 | + binaries = [ | ||
| 48 | + "sherpa-onnx", | ||
| 49 | + "sherpa-onnx-keyword-spotter", | ||
| 50 | + "sherpa-onnx-microphone", | ||
| 51 | + "sherpa-onnx-microphone-offline", | ||
| 52 | + "sherpa-onnx-microphone-offline-speaker-identification", | ||
| 53 | + "sherpa-onnx-offline", | ||
| 54 | + "sherpa-onnx-offline-language-identification", | ||
| 55 | + "sherpa-onnx-offline-tts", | ||
| 56 | + "sherpa-onnx-offline-tts-play", | ||
| 57 | + "sherpa-onnx-offline-websocket-server", | ||
| 58 | + "sherpa-onnx-online-websocket-client", | ||
| 59 | + "sherpa-onnx-online-websocket-server", | ||
| 60 | + "sherpa-onnx-vad-microphone", | ||
| 61 | + "sherpa-onnx-vad-microphone-offline-asr", | ||
| 62 | + ] | ||
| 63 | + | ||
| 64 | + if enable_alsa(): | ||
| 65 | + binaries += [ | ||
| 66 | + "sherpa-onnx-alsa", | ||
| 67 | + "sherpa-onnx-alsa-offline", | ||
| 68 | + "sherpa-onnx-alsa-offline-speaker-identification", | ||
| 69 | + "sherpa-onnx-offline-tts-play-alsa", | ||
| 70 | + ] | ||
| 71 | + | ||
| 72 | + if is_windows(): | ||
| 73 | + binaries += [ | ||
| 74 | + "espeak-ng.dll", | ||
| 75 | + "kaldi-decoder-core.dll", | ||
| 76 | + "kaldi-native-fbank-core.dll", | ||
| 77 | + "onnxruntime.dll", | ||
| 78 | + "piper_phonemize.dll", | ||
| 79 | + "sherpa-onnx-c-api.dll", | ||
| 80 | + "sherpa-onnx-core.dll", | ||
| 81 | + "sherpa-onnx-fst.lib", | ||
| 82 | + "sherpa-onnx-kaldifst-core.lib", | ||
| 83 | + "sherpa-onnx-portaudio.dll", | ||
| 84 | + "ucd.dll", | ||
| 85 | + ] | ||
| 86 | + | ||
| 87 | + return binaries | ||
| 88 | + | ||
| 89 | + | ||
| 46 | try: | 90 | try: |
| 47 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel | 91 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel |
| 48 | 92 | ||
| @@ -150,38 +194,7 @@ class BuildExtension(build_ext): | @@ -150,38 +194,7 @@ class BuildExtension(build_ext): | ||
| 150 | suffix = ".exe" if is_windows() else "" | 194 | suffix = ".exe" if is_windows() else "" |
| 151 | # Remember to also change setup.py | 195 | # Remember to also change setup.py |
| 152 | 196 | ||
| 153 | - binaries = ["sherpa-onnx"] | ||
| 154 | - binaries += ["sherpa-onnx-keyword-spotter"] | ||
| 155 | - binaries += ["sherpa-onnx-offline"] | ||
| 156 | - binaries += ["sherpa-onnx-microphone"] | ||
| 157 | - binaries += ["sherpa-onnx-microphone-offline"] | ||
| 158 | - binaries += ["sherpa-onnx-microphone-offline-speaker-identification"] | ||
| 159 | - binaries += ["sherpa-onnx-online-websocket-server"] | ||
| 160 | - binaries += ["sherpa-onnx-offline-websocket-server"] | ||
| 161 | - binaries += ["sherpa-onnx-online-websocket-client"] | ||
| 162 | - binaries += ["sherpa-onnx-vad-microphone"] | ||
| 163 | - binaries += ["sherpa-onnx-vad-microphone-offline-asr"] | ||
| 164 | - binaries += ["sherpa-onnx-offline-tts"] | ||
| 165 | - binaries += ["sherpa-onnx-offline-tts-play"] | ||
| 166 | - | ||
| 167 | - if enable_alsa(): | ||
| 168 | - binaries += ["sherpa-onnx-alsa"] | ||
| 169 | - binaries += ["sherpa-onnx-alsa-offline"] | ||
| 170 | - binaries += ["sherpa-onnx-offline-tts-play-alsa"] | ||
| 171 | - binaries += ["sherpa-onnx-alsa-offline-speaker-identification"] | ||
| 172 | - | ||
| 173 | - if is_windows(): | ||
| 174 | - binaries += ["kaldi-native-fbank-core.dll"] | ||
| 175 | - binaries += ["sherpa-onnx-c-api.dll"] | ||
| 176 | - binaries += ["sherpa-onnx-core.dll"] | ||
| 177 | - binaries += ["sherpa-onnx-portaudio.dll"] | ||
| 178 | - binaries += ["onnxruntime.dll"] | ||
| 179 | - binaries += ["piper_phonemize.dll"] | ||
| 180 | - binaries += ["espeak-ng.dll"] | ||
| 181 | - binaries += ["ucd.dll"] | ||
| 182 | - binaries += ["kaldi-decoder-core.dll"] | ||
| 183 | - binaries += ["sherpa-onnx-fst.lib"] | ||
| 184 | - binaries += ["sherpa-onnx-kaldifst-core.lib"] | 197 | + binaries = get_binaries() |
| 185 | 198 | ||
| 186 | for f in binaries: | 199 | for f in binaries: |
| 187 | suffix = "" if (".dll" in f or ".lib" in f) else suffix | 200 | suffix = "" if (".dll" in f or ".lib" in f) else suffix |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This script shows how to use Python APIs for spoken languge identification. | ||
| 5 | +It detects the language spoken in the given wave file. | ||
| 6 | + | ||
| 7 | +Usage: | ||
| 8 | + | ||
| 9 | +1. Download a whisper multilingual model. We use a tiny model below. | ||
| 10 | +Please refer to https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 11 | +to download more models. | ||
| 12 | + | ||
| 13 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 | ||
| 14 | +tar xvf sherpa-onnx-whisper-tiny.tar.bz2 | ||
| 15 | +rm sherpa-onnx-whisper-tiny.tar.bz2 | ||
| 16 | + | ||
| 17 | +We only use the int8.onnx models below. | ||
| 18 | + | ||
| 19 | +2. Download a test wave. | ||
| 20 | + | ||
| 21 | +You can find many wave files for different languages at | ||
| 22 | +https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs | ||
| 23 | + | ||
| 24 | +wget https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/de-german.wav | ||
| 25 | + | ||
| 26 | +python3 ./python-api-examples/spoken-language-identification.py | ||
| 27 | + --whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \ | ||
| 28 | + --whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \ | ||
| 29 | + --num-threads=1 \ | ||
| 30 | + ./de-german.wav | ||
| 31 | +""" | ||
| 32 | + | ||
| 33 | +import argparse | ||
| 34 | +import logging | ||
| 35 | +import time | ||
| 36 | +import wave | ||
| 37 | +from pathlib import Path | ||
| 38 | +from typing import Tuple | ||
| 39 | + | ||
| 40 | +import numpy as np | ||
| 41 | +import sherpa_onnx | ||
| 42 | + | ||
| 43 | + | ||
| 44 | +def get_args(): | ||
| 45 | + parser = argparse.ArgumentParser( | ||
| 46 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 47 | + ) | ||
| 48 | + | ||
| 49 | + parser.add_argument( | ||
| 50 | + "--whisper-encoder", | ||
| 51 | + required=True, | ||
| 52 | + type=str, | ||
| 53 | + help="Path to a multilingual whisper encoder model", | ||
| 54 | + ) | ||
| 55 | + | ||
| 56 | + parser.add_argument( | ||
| 57 | + "--whisper-decoder", | ||
| 58 | + required=True, | ||
| 59 | + type=str, | ||
| 60 | + help="Path to a multilingual whisper decoder model", | ||
| 61 | + ) | ||
| 62 | + | ||
| 63 | + parser.add_argument( | ||
| 64 | + "--num-threads", | ||
| 65 | + type=int, | ||
| 66 | + default=1, | ||
| 67 | + help="Number of threads for neural network computation", | ||
| 68 | + ) | ||
| 69 | + | ||
| 70 | + parser.add_argument( | ||
| 71 | + "--debug", | ||
| 72 | + type=bool, | ||
| 73 | + default=False, | ||
| 74 | + help="True to show debug messages", | ||
| 75 | + ) | ||
| 76 | + | ||
| 77 | + parser.add_argument( | ||
| 78 | + "--provider", | ||
| 79 | + type=str, | ||
| 80 | + default="cpu", | ||
| 81 | + help="Valid values: cpu, cuda, coreml", | ||
| 82 | + ) | ||
| 83 | + | ||
| 84 | + parser.add_argument( | ||
| 85 | + "sound_file", | ||
| 86 | + type=str, | ||
| 87 | + help="The input sound file to identify. It must be of WAVE" | ||
| 88 | + "format with a single channel, and each sample has 16-bit, " | ||
| 89 | + "i.e., int16_t. " | ||
| 90 | + "The sample rate of the file can be arbitrary and does not need to " | ||
| 91 | + "be 16 kHz", | ||
| 92 | + ) | ||
| 93 | + | ||
| 94 | + return parser.parse_args() | ||
| 95 | + | ||
| 96 | + | ||
| 97 | +def assert_file_exists(filename: str): | ||
| 98 | + assert Path(filename).is_file(), ( | ||
| 99 | + f"{filename} does not exist!\n" | ||
| 100 | + "Please refer to " | ||
| 101 | + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html to download it" | ||
| 102 | + ) | ||
| 103 | + | ||
| 104 | + | ||
| 105 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 106 | + """ | ||
| 107 | + Args: | ||
| 108 | + wave_filename: | ||
| 109 | + Path to a wave file. It should be single channel and each sample should | ||
| 110 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 111 | + Returns: | ||
| 112 | + Return a tuple containing: | ||
| 113 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 114 | + normalized to the range [-1, 1]. | ||
| 115 | + - sample rate of the wave file | ||
| 116 | + """ | ||
| 117 | + | ||
| 118 | + with wave.open(wave_filename) as f: | ||
| 119 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 120 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 121 | + num_samples = f.getnframes() | ||
| 122 | + samples = f.readframes(num_samples) | ||
| 123 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 124 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 125 | + | ||
| 126 | + samples_float32 = samples_float32 / 32768 | ||
| 127 | + return samples_float32, f.getframerate() | ||
| 128 | + | ||
| 129 | + | ||
| 130 | +def main(): | ||
| 131 | + args = get_args() | ||
| 132 | + assert_file_exists(args.whisper_encoder) | ||
| 133 | + assert_file_exists(args.whisper_decoder) | ||
| 134 | + assert args.num_threads > 0, args.num_threads | ||
| 135 | + config = sherpa_onnx.SpokenLanguageIdentificationConfig( | ||
| 136 | + whisper=sherpa_onnx.SpokenLanguageIdentificationWhisperConfig( | ||
| 137 | + encoder=args.whisper_encoder, | ||
| 138 | + decoder=args.whisper_decoder, | ||
| 139 | + ), | ||
| 140 | + num_threads=args.num_threads, | ||
| 141 | + debug=args.debug, | ||
| 142 | + provider=args.provider, | ||
| 143 | + ) | ||
| 144 | + slid = sherpa_onnx.SpokenLanguageIdentification(config) | ||
| 145 | + | ||
| 146 | + samples, sample_rate = read_wave(args.sound_file) | ||
| 147 | + | ||
| 148 | + start_time = time.time() | ||
| 149 | + stream = slid.create_stream() | ||
| 150 | + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) | ||
| 151 | + lang = slid.compute(stream) | ||
| 152 | + end_time = time.time() | ||
| 153 | + | ||
| 154 | + elapsed_seconds = end_time - start_time | ||
| 155 | + audio_duration = len(samples) / sample_rate | ||
| 156 | + real_time_factor = elapsed_seconds / audio_duration | ||
| 157 | + | ||
| 158 | + logging.info(f"File: {args.sound_file}") | ||
| 159 | + logging.info(f"Detected language: {lang}") | ||
| 160 | + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}") | ||
| 161 | + logging.info(f"Audio duration in seconds: {audio_duration:.3f}") | ||
| 162 | + logging.info( | ||
| 163 | + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" | ||
| 164 | + ) | ||
| 165 | + | ||
| 166 | + | ||
| 167 | +if __name__ == "__main__": | ||
| 168 | + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
| 169 | + | ||
| 170 | + logging.basicConfig(format=formatter, level=logging.INFO) | ||
| 171 | + | ||
| 172 | + main() |
| 1 | #!/usr/bin/env python3 | 1 | #!/usr/bin/env python3 |
| 2 | 2 | ||
| 3 | -import os | ||
| 4 | import re | 3 | import re |
| 5 | -import sys | ||
| 6 | from pathlib import Path | 4 | from pathlib import Path |
| 7 | 5 | ||
| 8 | import setuptools | 6 | import setuptools |
| @@ -11,7 +9,7 @@ from cmake.cmake_extension import ( | @@ -11,7 +9,7 @@ from cmake.cmake_extension import ( | ||
| 11 | BuildExtension, | 9 | BuildExtension, |
| 12 | bdist_wheel, | 10 | bdist_wheel, |
| 13 | cmake_extension, | 11 | cmake_extension, |
| 14 | - enable_alsa, | 12 | + get_binaries, |
| 15 | is_windows, | 13 | is_windows, |
| 16 | ) | 14 | ) |
| 17 | 15 | ||
| @@ -42,39 +40,7 @@ def get_binaries_to_install(): | @@ -42,39 +40,7 @@ def get_binaries_to_install(): | ||
| 42 | bin_dir.mkdir(parents=True, exist_ok=True) | 40 | bin_dir.mkdir(parents=True, exist_ok=True) |
| 43 | suffix = ".exe" if is_windows() else "" | 41 | suffix = ".exe" if is_windows() else "" |
| 44 | 42 | ||
| 45 | - # Remember to also change cmake/cmake_extension.py | ||
| 46 | - binaries = ["sherpa-onnx"] | ||
| 47 | - binaries += ["sherpa-onnx-keyword-spotter"] | ||
| 48 | - binaries += ["sherpa-onnx-offline"] | ||
| 49 | - binaries += ["sherpa-onnx-microphone"] | ||
| 50 | - binaries += ["sherpa-onnx-microphone-offline"] | ||
| 51 | - binaries += ["sherpa-onnx-microphone-offline-speaker-identification"] | ||
| 52 | - binaries += ["sherpa-onnx-online-websocket-server"] | ||
| 53 | - binaries += ["sherpa-onnx-offline-websocket-server"] | ||
| 54 | - binaries += ["sherpa-onnx-online-websocket-client"] | ||
| 55 | - binaries += ["sherpa-onnx-vad-microphone"] | ||
| 56 | - binaries += ["sherpa-onnx-vad-microphone-offline-asr"] | ||
| 57 | - binaries += ["sherpa-onnx-offline-tts"] | ||
| 58 | - binaries += ["sherpa-onnx-offline-tts-play"] | ||
| 59 | - | ||
| 60 | - if enable_alsa(): | ||
| 61 | - binaries += ["sherpa-onnx-alsa"] | ||
| 62 | - binaries += ["sherpa-onnx-alsa-offline"] | ||
| 63 | - binaries += ["sherpa-onnx-offline-tts-play-alsa"] | ||
| 64 | - binaries += ["sherpa-onnx-alsa-offline-speaker-identification"] | ||
| 65 | - | ||
| 66 | - if is_windows(): | ||
| 67 | - binaries += ["kaldi-native-fbank-core.dll"] | ||
| 68 | - binaries += ["sherpa-onnx-c-api.dll"] | ||
| 69 | - binaries += ["sherpa-onnx-core.dll"] | ||
| 70 | - binaries += ["sherpa-onnx-portaudio.dll"] | ||
| 71 | - binaries += ["onnxruntime.dll"] | ||
| 72 | - binaries += ["piper_phonemize.dll"] | ||
| 73 | - binaries += ["espeak-ng.dll"] | ||
| 74 | - binaries += ["ucd.dll"] | ||
| 75 | - binaries += ["kaldi-decoder-core.dll"] | ||
| 76 | - binaries += ["sherpa-onnx-fst.lib"] | ||
| 77 | - binaries += ["sherpa-onnx-kaldifst-core.lib"] | 43 | + binaries = get_binaries() |
| 78 | 44 | ||
| 79 | exe = [] | 45 | exe = [] |
| 80 | for f in binaries: | 46 | for f in binaries: |
| @@ -86,6 +86,8 @@ set(sources | @@ -86,6 +86,8 @@ set(sources | ||
| 86 | silero-vad-model-config.cc | 86 | silero-vad-model-config.cc |
| 87 | silero-vad-model.cc | 87 | silero-vad-model.cc |
| 88 | slice.cc | 88 | slice.cc |
| 89 | + spoken-language-identification-impl.cc | ||
| 90 | + spoken-language-identification.cc | ||
| 89 | stack.cc | 91 | stack.cc |
| 90 | symbol-table.cc | 92 | symbol-table.cc |
| 91 | text-utils.cc | 93 | text-utils.cc |
| @@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 184 | add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) | 186 | add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) |
| 185 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) | 187 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) |
| 186 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) | 188 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) |
| 189 | + add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) | ||
| 187 | 190 | ||
| 188 | set(main_exes | 191 | set(main_exes |
| 189 | sherpa-onnx | 192 | sherpa-onnx |
| @@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 191 | sherpa-onnx-offline | 194 | sherpa-onnx-offline |
| 192 | sherpa-onnx-offline-parallel | 195 | sherpa-onnx-offline-parallel |
| 193 | sherpa-onnx-offline-tts | 196 | sherpa-onnx-offline-tts |
| 197 | + sherpa-onnx-offline-language-identification | ||
| 194 | ) | 198 | ) |
| 195 | 199 | ||
| 196 | foreach(exe IN LISTS main_exes) | 200 | foreach(exe IN LISTS main_exes) |
| @@ -23,7 +23,7 @@ enum class ModelType { | @@ -23,7 +23,7 @@ enum class ModelType { | ||
| 23 | kTdnn, | 23 | kTdnn, |
| 24 | kZipformerCtc, | 24 | kZipformerCtc, |
| 25 | kWenetCtc, | 25 | kWenetCtc, |
| 26 | - kUnkown, | 26 | + kUnknown, |
| 27 | }; | 27 | }; |
| 28 | 28 | ||
| 29 | } // namespace | 29 | } // namespace |
| @@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 59 | "run.sh\n" | 59 | "run.sh\n" |
| 60 | "\n" | 60 | "\n" |
| 61 | "for how to add metadta to model.onnx\n"); | 61 | "for how to add metadta to model.onnx\n"); |
| 62 | - return ModelType::kUnkown; | 62 | + return ModelType::kUnknown; |
| 63 | } | 63 | } |
| 64 | 64 | ||
| 65 | if (model_type.get() == std::string("EncDecCTCModelBPE")) { | 65 | if (model_type.get() == std::string("EncDecCTCModelBPE")) { |
| @@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 72 | return ModelType::kWenetCtc; | 72 | return ModelType::kWenetCtc; |
| 73 | } else { | 73 | } else { |
| 74 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 74 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); |
| 75 | - return ModelType::kUnkown; | 75 | + return ModelType::kUnknown; |
| 76 | } | 76 | } |
| 77 | } | 77 | } |
| 78 | 78 | ||
| 79 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | 79 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( |
| 80 | const OfflineModelConfig &config) { | 80 | const OfflineModelConfig &config) { |
| 81 | - ModelType model_type = ModelType::kUnkown; | 81 | + ModelType model_type = ModelType::kUnknown; |
| 82 | 82 | ||
| 83 | std::string filename; | 83 | std::string filename; |
| 84 | if (!config.nemo_ctc.model.empty()) { | 84 | if (!config.nemo_ctc.model.empty()) { |
| @@ -113,7 +113,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -113,7 +113,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 113 | case ModelType::kWenetCtc: | 113 | case ModelType::kWenetCtc: |
| 114 | return std::make_unique<OfflineWenetCtcModel>(config); | 114 | return std::make_unique<OfflineWenetCtcModel>(config); |
| 115 | break; | 115 | break; |
| 116 | - case ModelType::kUnkown: | 116 | + case ModelType::kUnknown: |
| 117 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); | 117 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); |
| 118 | return nullptr; | 118 | return nullptr; |
| 119 | } | 119 | } |
| @@ -125,7 +125,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -125,7 +125,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 125 | 125 | ||
| 126 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | 126 | std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( |
| 127 | AAssetManager *mgr, const OfflineModelConfig &config) { | 127 | AAssetManager *mgr, const OfflineModelConfig &config) { |
| 128 | - ModelType model_type = ModelType::kUnkown; | 128 | + ModelType model_type = ModelType::kUnknown; |
| 129 | 129 | ||
| 130 | std::string filename; | 130 | std::string filename; |
| 131 | if (!config.nemo_ctc.model.empty()) { | 131 | if (!config.nemo_ctc.model.empty()) { |
| @@ -160,7 +160,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -160,7 +160,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 160 | case ModelType::kWenetCtc: | 160 | case ModelType::kWenetCtc: |
| 161 | return std::make_unique<OfflineWenetCtcModel>(mgr, config); | 161 | return std::make_unique<OfflineWenetCtcModel>(mgr, config); |
| 162 | break; | 162 | break; |
| 163 | - case ModelType::kUnkown: | 163 | + case ModelType::kUnknown: |
| 164 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); | 164 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); |
| 165 | return nullptr; | 165 | return nullptr; |
| 166 | } | 166 | } |
| @@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 114 | num_frames = max_num_frames - 50; | 114 | num_frames = max_num_frames - 50; |
| 115 | } | 115 | } |
| 116 | 116 | ||
| 117 | - NormalizeFeatures(f.data(), num_frames, feat_dim); | 117 | + model_->NormalizeFeatures(f.data(), num_frames, feat_dim); |
| 118 | 118 | ||
| 119 | // note that 1000 is an experience-value. | 119 | // note that 1000 is an experience-value. |
| 120 | // You can replace 1000 by other values, say, 100. | 120 | // You can replace 1000 by other values, say, 100. |
| @@ -163,38 +163,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -163,38 +163,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 163 | } | 163 | } |
| 164 | 164 | ||
| 165 | private: | 165 | private: |
| 166 | - static void NormalizeFeatures(float *features, int32_t num_frames, | ||
| 167 | - int32_t feat_dim) { | ||
| 168 | - // log_spec = torch.clamp(features, min=1e-10).log10() | ||
| 169 | - // log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | ||
| 170 | - // mel = (log_spec + 4.0) / 4.0 | ||
| 171 | - | ||
| 172 | - int32_t n = num_frames * feat_dim; | ||
| 173 | - float max_v = -1e20; | ||
| 174 | - for (int32_t i = 0; i != n; ++i) { | ||
| 175 | - float f = features[i]; | ||
| 176 | - | ||
| 177 | - f = std::max<float>(f, 1e-10); | ||
| 178 | - f = std::log10(f); | ||
| 179 | - | ||
| 180 | - max_v = std::max(f, max_v); | ||
| 181 | - | ||
| 182 | - features[i] = f; | ||
| 183 | - } | ||
| 184 | - | ||
| 185 | - max_v -= 8; | ||
| 186 | - | ||
| 187 | - for (int32_t i = 0; i != n; ++i) { | ||
| 188 | - float f = features[i]; | ||
| 189 | - f = std::max(f, max_v); | ||
| 190 | - | ||
| 191 | - f = (f + 4) / 4; | ||
| 192 | - | ||
| 193 | - features[i] = f; | ||
| 194 | - } | ||
| 195 | - } | ||
| 196 | - | ||
| 197 | - private: | ||
| 198 | OfflineRecognizerConfig config_; | 166 | OfflineRecognizerConfig config_; |
| 199 | SymbolTable symbol_table_; | 167 | SymbolTable symbol_table_; |
| 200 | std::unique_ptr<OfflineWhisperModel> model_; | 168 | std::unique_ptr<OfflineWhisperModel> model_; |
| @@ -12,56 +12,6 @@ | @@ -12,56 +12,6 @@ | ||
| 12 | 12 | ||
| 13 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 14 | 14 | ||
| 15 | -int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage( | ||
| 16 | - Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT | ||
| 17 | - int64_t token_val = model_->SOT(); | ||
| 18 | - std::array<int64_t, 2> token_shape{1, 1}; | ||
| 19 | - | ||
| 20 | - auto memory_info = | ||
| 21 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 22 | - | ||
| 23 | - Ort::Value tokens = Ort::Value::CreateTensor( | ||
| 24 | - memory_info, &token_val, 1, token_shape.data(), token_shape.size()); | ||
| 25 | - | ||
| 26 | - auto self_kv_cache = model_->GetInitialSelfKVCache(); | ||
| 27 | - | ||
| 28 | - std::array<int64_t, 1> offset_shape{1}; | ||
| 29 | - Ort::Value offset = Ort::Value::CreateTensor<int64_t>( | ||
| 30 | - model_->Allocator(), offset_shape.data(), offset_shape.size()); | ||
| 31 | - *(offset.GetTensorMutableData<int64_t>()) = 0; | ||
| 32 | - | ||
| 33 | - auto decoder_out = model_->ForwardDecoder( | ||
| 34 | - std::move(tokens), std::move(self_kv_cache.first), | ||
| 35 | - std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), | ||
| 36 | - std::move(offset)); | ||
| 37 | - | ||
| 38 | - cross_k = std::move(std::get<3>(decoder_out)); | ||
| 39 | - cross_v = std::move(std::get<4>(decoder_out)); | ||
| 40 | - | ||
| 41 | - const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>(); | ||
| 42 | - int32_t vocab_size = model_->VocabSize(); | ||
| 43 | - const auto &all_language_ids = model_->GetAllLanguageIDs(); | ||
| 44 | - | ||
| 45 | - int32_t lang_id = all_language_ids[0]; | ||
| 46 | - float this_logit = p_logits[lang_id]; | ||
| 47 | - | ||
| 48 | - for (int32_t i = 1; i != all_language_ids.size(); ++i) { | ||
| 49 | - int32_t id = all_language_ids[i]; | ||
| 50 | - float p = p_logits[id]; | ||
| 51 | - | ||
| 52 | - if (p > this_logit) { | ||
| 53 | - this_logit = p; | ||
| 54 | - lang_id = id; | ||
| 55 | - } | ||
| 56 | - } | ||
| 57 | -#if 1 | ||
| 58 | - SHERPA_ONNX_LOGE("Detected language: %s", | ||
| 59 | - model_->GetID2Lang().at(lang_id).c_str()); | ||
| 60 | -#endif | ||
| 61 | - | ||
| 62 | - return lang_id; | ||
| 63 | -} | ||
| 64 | - | ||
| 65 | std::vector<OfflineWhisperDecoderResult> | 15 | std::vector<OfflineWhisperDecoderResult> |
| 66 | OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | 16 | OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, |
| 67 | Ort::Value cross_v) { | 17 | Ort::Value cross_v) { |
| @@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | @@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 89 | // 0: sot, 1: lang_id, 2: task, 3: no_timestamps | 39 | // 0: sot, 1: lang_id, 2: task, 3: no_timestamps |
| 90 | initial_tokens[1] = lang_id; | 40 | initial_tokens[1] = lang_id; |
| 91 | } else { | 41 | } else { |
| 92 | - int32_t lang_id = DetectLanguage(cross_k, cross_v); | 42 | + int32_t lang_id = model_->DetectLanguage(cross_k, cross_v); |
| 93 | 43 | ||
| 94 | // 0: sot, 1: lang_id, 2: task, 3: no_timestamps | 44 | // 0: sot, 1: lang_id, 2: task, 3: no_timestamps |
| 95 | initial_tokens[1] = lang_id; | 45 | initial_tokens[1] = lang_id; |
| @@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { | @@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { | ||
| 22 | std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, | 22 | std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, |
| 23 | Ort::Value cross_v) override; | 23 | Ort::Value cross_v) override; |
| 24 | 24 | ||
| 25 | - int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT | ||
| 26 | - Ort::Value &cross_v) const; // NOLINT | ||
| 27 | - | ||
| 28 | private: | 25 | private: |
| 29 | OfflineWhisperModelConfig config_; | 26 | OfflineWhisperModelConfig config_; |
| 30 | OfflineWhisperModel *model_; // not owned | 27 | OfflineWhisperModel *model_; // not owned |
| @@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { | @@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { | ||
| 35 | 35 | ||
| 36 | po->Register( | 36 | po->Register( |
| 37 | "whisper-tail-paddings", &tail_paddings, | 37 | "whisper-tail-paddings", &tail_paddings, |
| 38 | - "Suggest value: 50 for English models. 300 for multilingual models. " | 38 | + "Suggested value: 50 for English models. 300 for multilingual models. " |
| 39 | "Since we have removed the 30-second constraint, we need to add some " | 39 | "Since we have removed the 30-second constraint, we need to add some " |
| 40 | "tail padding frames " | 40 | "tail padding frames " |
| 41 | - "so that whisper can detect the eot token. Leave it to -1 to use 50 for " | ||
| 42 | - "English models and 300 for multilingual models."); | 41 | + "so that whisper can detect the eot token. Leave it to -1 to use 1000."); |
| 43 | } | 42 | } |
| 44 | 43 | ||
| 45 | bool OfflineWhisperModelConfig::Validate() const { | 44 | bool OfflineWhisperModelConfig::Validate() const { |
| 45 | + if (encoder.empty()) { | ||
| 46 | + SHERPA_ONNX_LOGE("Please provide --whisper-encoder"); | ||
| 47 | + return false; | ||
| 48 | + } | ||
| 49 | + | ||
| 46 | if (!FileExists(encoder)) { | 50 | if (!FileExists(encoder)) { |
| 47 | SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str()); | 51 | SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str()); |
| 48 | return false; | 52 | return false; |
| 49 | } | 53 | } |
| 50 | 54 | ||
| 55 | + if (decoder.empty()) { | ||
| 56 | + SHERPA_ONNX_LOGE("Please provide --whisper-decoder"); | ||
| 57 | + return false; | ||
| 58 | + } | ||
| 59 | + | ||
| 51 | if (!FileExists(decoder)) { | 60 | if (!FileExists(decoder)) { |
| 52 | SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str()); | 61 | SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str()); |
| 53 | return false; | 62 | return false; |
| @@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl { | @@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl { | ||
| 24 | env_(ORT_LOGGING_LEVEL_ERROR), | 24 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 25 | sess_opts_(GetSessionOptions(config)), | 25 | sess_opts_(GetSessionOptions(config)), |
| 26 | allocator_{} { | 26 | allocator_{} { |
| 27 | + debug_ = config_.debug; | ||
| 28 | + { | ||
| 29 | + auto buf = ReadFile(config.whisper.encoder); | ||
| 30 | + InitEncoder(buf.data(), buf.size()); | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + { | ||
| 34 | + auto buf = ReadFile(config.whisper.decoder); | ||
| 35 | + InitDecoder(buf.data(), buf.size()); | ||
| 36 | + } | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + explicit Impl(const SpokenLanguageIdentificationConfig &config) | ||
| 40 | + : lid_config_(config), | ||
| 41 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 42 | + sess_opts_(GetSessionOptions(config)), | ||
| 43 | + allocator_{} { | ||
| 44 | + debug_ = config_.debug; | ||
| 27 | { | 45 | { |
| 28 | auto buf = ReadFile(config.whisper.encoder); | 46 | auto buf = ReadFile(config.whisper.encoder); |
| 29 | InitEncoder(buf.data(), buf.size()); | 47 | InitEncoder(buf.data(), buf.size()); |
| @@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl { | @@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl { | ||
| 41 | env_(ORT_LOGGING_LEVEL_ERROR), | 59 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 42 | sess_opts_(GetSessionOptions(config)), | 60 | sess_opts_(GetSessionOptions(config)), |
| 43 | allocator_{} { | 61 | allocator_{} { |
| 62 | + debug_ = config_.debug; | ||
| 44 | { | 63 | { |
| 45 | auto buf = ReadFile(mgr, config.whisper.encoder); | 64 | auto buf = ReadFile(mgr, config.whisper.encoder); |
| 46 | InitEncoder(buf.data(), buf.size()); | 65 | InitEncoder(buf.data(), buf.size()); |
| @@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl { | @@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl { | ||
| 85 | std::move(decoder_input[4]), std::move(decoder_input[5])}; | 104 | std::move(decoder_input[4]), std::move(decoder_input[5])}; |
| 86 | } | 105 | } |
| 87 | 106 | ||
| 107 | + int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT | ||
| 108 | + Ort::Value &cross_v) { // NOLINT | ||
| 109 | + int64_t token_val = SOT(); | ||
| 110 | + std::array<int64_t, 2> token_shape{1, 1}; | ||
| 111 | + | ||
| 112 | + auto memory_info = | ||
| 113 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 114 | + | ||
| 115 | + Ort::Value tokens = Ort::Value::CreateTensor( | ||
| 116 | + memory_info, &token_val, 1, token_shape.data(), token_shape.size()); | ||
| 117 | + | ||
| 118 | + auto self_kv_cache = GetInitialSelfKVCache(); | ||
| 119 | + | ||
| 120 | + std::array<int64_t, 1> offset_shape{1}; | ||
| 121 | + Ort::Value offset = Ort::Value::CreateTensor<int64_t>( | ||
| 122 | + Allocator(), offset_shape.data(), offset_shape.size()); | ||
| 123 | + *(offset.GetTensorMutableData<int64_t>()) = 0; | ||
| 124 | + | ||
| 125 | + auto decoder_out = | ||
| 126 | + ForwardDecoder(std::move(tokens), std::move(self_kv_cache.first), | ||
| 127 | + std::move(self_kv_cache.second), std::move(cross_k), | ||
| 128 | + std::move(cross_v), std::move(offset)); | ||
| 129 | + | ||
| 130 | + cross_k = std::move(std::get<3>(decoder_out)); | ||
| 131 | + cross_v = std::move(std::get<4>(decoder_out)); | ||
| 132 | + | ||
| 133 | + const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>(); | ||
| 134 | + int32_t vocab_size = VocabSize(); | ||
| 135 | + const auto &all_language_ids = GetAllLanguageIDs(); | ||
| 136 | + | ||
| 137 | + int32_t lang_id = all_language_ids[0]; | ||
| 138 | + float this_logit = p_logits[lang_id]; | ||
| 139 | + | ||
| 140 | + for (int32_t i = 1; i != all_language_ids.size(); ++i) { | ||
| 141 | + int32_t id = all_language_ids[i]; | ||
| 142 | + float p = p_logits[id]; | ||
| 143 | + | ||
| 144 | + if (p > this_logit) { | ||
| 145 | + this_logit = p; | ||
| 146 | + lang_id = id; | ||
| 147 | + } | ||
| 148 | + } | ||
| 149 | + | ||
| 150 | + if (debug_) { | ||
| 151 | + SHERPA_ONNX_LOGE("Detected language: %s", | ||
| 152 | + GetID2Lang().at(lang_id).c_str()); | ||
| 153 | + } | ||
| 154 | + | ||
| 155 | + return lang_id; | ||
| 156 | + } | ||
| 157 | + | ||
| 88 | std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() { | 158 | std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() { |
| 89 | std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_}; | 159 | std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_}; |
| 90 | 160 | ||
| @@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl { | @@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl { | ||
| 148 | 218 | ||
| 149 | // get meta data | 219 | // get meta data |
| 150 | Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | 220 | Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); |
| 151 | - if (config_.debug) { | 221 | + if (debug_) { |
| 152 | std::ostringstream os; | 222 | std::ostringstream os; |
| 153 | os << "---encoder---\n"; | 223 | os << "---encoder---\n"; |
| 154 | PrintModelMetadata(os, meta_data); | 224 | PrintModelMetadata(os, meta_data); |
| @@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl { | @@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl { | ||
| 203 | 273 | ||
| 204 | private: | 274 | private: |
| 205 | OfflineModelConfig config_; | 275 | OfflineModelConfig config_; |
| 276 | + SpokenLanguageIdentificationConfig lid_config_; | ||
| 277 | + bool debug_ = false; | ||
| 206 | Ort::Env env_; | 278 | Ort::Env env_; |
| 207 | Ort::SessionOptions sess_opts_; | 279 | Ort::SessionOptions sess_opts_; |
| 208 | Ort::AllocatorWithDefaultOptions allocator_; | 280 | Ort::AllocatorWithDefaultOptions allocator_; |
| @@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl { | @@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl { | ||
| 246 | OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) | 318 | OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) |
| 247 | : impl_(std::make_unique<Impl>(config)) {} | 319 | : impl_(std::make_unique<Impl>(config)) {} |
| 248 | 320 | ||
| 321 | +OfflineWhisperModel::OfflineWhisperModel( | ||
| 322 | + const SpokenLanguageIdentificationConfig &config) | ||
| 323 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 324 | + | ||
| 249 | #if __ANDROID_API__ >= 9 | 325 | #if __ANDROID_API__ >= 9 |
| 250 | OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr, | 326 | OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr, |
| 251 | const OfflineModelConfig &config) | 327 | const OfflineModelConfig &config) |
| @@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens, | @@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens, | ||
| 273 | std::move(n_layer_cross_v), std::move(offset)); | 349 | std::move(n_layer_cross_v), std::move(offset)); |
| 274 | } | 350 | } |
| 275 | 351 | ||
| 352 | +int32_t OfflineWhisperModel::DetectLanguage(Ort::Value &cross_k, // NOLINT | ||
| 353 | + Ort::Value &cross_v) { // NOLINT | ||
| 354 | + return impl_->DetectLanguage(cross_k, cross_v); | ||
| 355 | +} | ||
| 356 | + | ||
| 276 | std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() | 357 | std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() |
| 277 | const { | 358 | const { |
| 278 | return impl_->GetInitialSelfKVCache(); | 359 | return impl_->GetInitialSelfKVCache(); |
| @@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const { | @@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const { | ||
| 318 | return impl_->IsMultiLingual(); | 399 | return impl_->IsMultiLingual(); |
| 319 | } | 400 | } |
| 320 | 401 | ||
| 402 | +void OfflineWhisperModel::NormalizeFeatures(float *features, int32_t num_frames, | ||
| 403 | + int32_t feat_dim) { | ||
| 404 | + // log_spec = torch.clamp(features, min=1e-10).log10() | ||
| 405 | + // log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | ||
| 406 | + // mel = (log_spec + 4.0) / 4.0 | ||
| 407 | + | ||
| 408 | + int32_t n = num_frames * feat_dim; | ||
| 409 | + float max_v = -1e20; | ||
| 410 | + for (int32_t i = 0; i != n; ++i) { | ||
| 411 | + float f = features[i]; | ||
| 412 | + | ||
| 413 | + f = std::max<float>(f, 1e-10); | ||
| 414 | + f = std::log10(f); | ||
| 415 | + | ||
| 416 | + max_v = std::max(f, max_v); | ||
| 417 | + | ||
| 418 | + features[i] = f; | ||
| 419 | + } | ||
| 420 | + | ||
| 421 | + max_v -= 8; | ||
| 422 | + | ||
| 423 | + for (int32_t i = 0; i != n; ++i) { | ||
| 424 | + float f = features[i]; | ||
| 425 | + f = std::max(f, max_v); | ||
| 426 | + | ||
| 427 | + f = (f + 4) / 4; | ||
| 428 | + | ||
| 429 | + features[i] = f; | ||
| 430 | + } | ||
| 431 | +} | ||
| 432 | + | ||
| 321 | } // namespace sherpa_onnx | 433 | } // namespace sherpa_onnx |
| @@ -18,6 +18,7 @@ | @@ -18,6 +18,7 @@ | ||
| 18 | 18 | ||
| 19 | #include "onnxruntime_cxx_api.h" // NOLINT | 19 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 20 | #include "sherpa-onnx/csrc/offline-model-config.h" | 20 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 21 | +#include "sherpa-onnx/csrc/spoken-language-identification.h" | ||
| 21 | 22 | ||
| 22 | namespace sherpa_onnx { | 23 | namespace sherpa_onnx { |
| 23 | 24 | ||
| @@ -25,6 +26,9 @@ class OfflineWhisperModel { | @@ -25,6 +26,9 @@ class OfflineWhisperModel { | ||
| 25 | public: | 26 | public: |
| 26 | explicit OfflineWhisperModel(const OfflineModelConfig &config); | 27 | explicit OfflineWhisperModel(const OfflineModelConfig &config); |
| 27 | 28 | ||
| 29 | + explicit OfflineWhisperModel( | ||
| 30 | + const SpokenLanguageIdentificationConfig &config); | ||
| 31 | + | ||
| 28 | #if __ANDROID_API__ >= 9 | 32 | #if __ANDROID_API__ >= 9 |
| 29 | OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config); | 33 | OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config); |
| 30 | #endif | 34 | #endif |
| @@ -72,7 +76,8 @@ class OfflineWhisperModel { | @@ -72,7 +76,8 @@ class OfflineWhisperModel { | ||
| 72 | Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, | 76 | Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, |
| 73 | Ort::Value n_layer_cross_v, Ort::Value offset) const; | 77 | Ort::Value n_layer_cross_v, Ort::Value offset) const; |
| 74 | 78 | ||
| 75 | - int32_t DetectLanguage() const; | 79 | + int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT |
| 80 | + Ort::Value &cross_v); // NOLINT | ||
| 76 | 81 | ||
| 77 | /** Return the initial self kv cache in a pair | 82 | /** Return the initial self kv cache in a pair |
| 78 | * - n_layer_self_k_cache A 4-D tensor of shape | 83 | * - n_layer_self_k_cache A 4-D tensor of shape |
| @@ -98,6 +103,9 @@ class OfflineWhisperModel { | @@ -98,6 +103,9 @@ class OfflineWhisperModel { | ||
| 98 | int32_t Translate() const; | 103 | int32_t Translate() const; |
| 99 | bool IsMultiLingual() const; | 104 | bool IsMultiLingual() const; |
| 100 | 105 | ||
| 106 | + static void NormalizeFeatures(float *features, int32_t num_frames, | ||
| 107 | + int32_t feat_dim); | ||
| 108 | + | ||
| 101 | private: | 109 | private: |
| 102 | class Impl; | 110 | class Impl; |
| 103 | std::unique_ptr<Impl> impl_; | 111 | std::unique_ptr<Impl> impl_; |
| @@ -28,7 +28,7 @@ enum class ModelType { | @@ -28,7 +28,7 @@ enum class ModelType { | ||
| 28 | kLstm, | 28 | kLstm, |
| 29 | kZipformer, | 29 | kZipformer, |
| 30 | kZipformer2, | 30 | kZipformer2, |
| 31 | - kUnkown, | 31 | + kUnknown, |
| 32 | }; | 32 | }; |
| 33 | 33 | ||
| 34 | } // namespace | 34 | } // namespace |
| @@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 58 | "No model_type in the metadata!\n" | 58 | "No model_type in the metadata!\n" |
| 59 | "Please make sure you are using the latest export-onnx.py from icefall " | 59 | "Please make sure you are using the latest export-onnx.py from icefall " |
| 60 | "to export your transducer models"); | 60 | "to export your transducer models"); |
| 61 | - return ModelType::kUnkown; | 61 | + return ModelType::kUnknown; |
| 62 | } | 62 | } |
| 63 | 63 | ||
| 64 | if (model_type.get() == std::string("conformer")) { | 64 | if (model_type.get() == std::string("conformer")) { |
| @@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 71 | return ModelType::kZipformer2; | 71 | return ModelType::kZipformer2; |
| 72 | } else { | 72 | } else { |
| 73 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 73 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); |
| 74 | - return ModelType::kUnkown; | 74 | + return ModelType::kUnknown; |
| 75 | } | 75 | } |
| 76 | } | 76 | } |
| 77 | 77 | ||
| @@ -93,7 +93,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -93,7 +93,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 93 | model_type.c_str()); | 93 | model_type.c_str()); |
| 94 | } | 94 | } |
| 95 | } | 95 | } |
| 96 | - ModelType model_type = ModelType::kUnkown; | 96 | + ModelType model_type = ModelType::kUnknown; |
| 97 | 97 | ||
| 98 | { | 98 | { |
| 99 | auto buffer = ReadFile(config.transducer.encoder); | 99 | auto buffer = ReadFile(config.transducer.encoder); |
| @@ -110,7 +110,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -110,7 +110,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 110 | return std::make_unique<OnlineZipformerTransducerModel>(config); | 110 | return std::make_unique<OnlineZipformerTransducerModel>(config); |
| 111 | case ModelType::kZipformer2: | 111 | case ModelType::kZipformer2: |
| 112 | return std::make_unique<OnlineZipformer2TransducerModel>(config); | 112 | return std::make_unique<OnlineZipformer2TransducerModel>(config); |
| 113 | - case ModelType::kUnkown: | 113 | + case ModelType::kUnknown: |
| 114 | SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); | 114 | SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); |
| 115 | return nullptr; | 115 | return nullptr; |
| 116 | } | 116 | } |
| @@ -185,7 +185,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -185,7 +185,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 185 | return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); | 185 | return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); |
| 186 | case ModelType::kZipformer2: | 186 | case ModelType::kZipformer2: |
| 187 | return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config); | 187 | return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config); |
| 188 | - case ModelType::kUnkown: | 188 | + case ModelType::kUnknown: |
| 189 | SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); | 189 | SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); |
| 190 | return nullptr; | 190 | return nullptr; |
| 191 | } | 191 | } |
| @@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions( | @@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions( | ||
| 149 | return GetSessionOptionsImpl(config.num_threads, config.provider); | 149 | return GetSessionOptionsImpl(config.num_threads, config.provider); |
| 150 | } | 150 | } |
| 151 | 151 | ||
| 152 | +Ort::SessionOptions GetSessionOptions( | ||
| 153 | + const SpokenLanguageIdentificationConfig &config) { | ||
| 154 | + return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 155 | +} | ||
| 156 | + | ||
| 152 | } // namespace sherpa_onnx | 157 | } // namespace sherpa_onnx |
| @@ -12,6 +12,7 @@ | @@ -12,6 +12,7 @@ | ||
| 12 | #include "sherpa-onnx/csrc/online-lm-config.h" | 12 | #include "sherpa-onnx/csrc/online-lm-config.h" |
| 13 | #include "sherpa-onnx/csrc/online-model-config.h" | 13 | #include "sherpa-onnx/csrc/online-model-config.h" |
| 14 | #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | 14 | #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" |
| 15 | +#include "sherpa-onnx/csrc/spoken-language-identification.h" | ||
| 15 | #include "sherpa-onnx/csrc/vad-model-config.h" | 16 | #include "sherpa-onnx/csrc/vad-model-config.h" |
| 16 | 17 | ||
| 17 | namespace sherpa_onnx { | 18 | namespace sherpa_onnx { |
| @@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); | @@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); | ||
| 30 | 31 | ||
| 31 | Ort::SessionOptions GetSessionOptions( | 32 | Ort::SessionOptions GetSessionOptions( |
| 32 | const SpeakerEmbeddingExtractorConfig &config); | 33 | const SpeakerEmbeddingExtractorConfig &config); |
| 34 | + | ||
| 35 | +Ort::SessionOptions GetSessionOptions( | ||
| 36 | + const SpokenLanguageIdentificationConfig &config); | ||
| 37 | + | ||
| 33 | } // namespace sherpa_onnx | 38 | } // namespace sherpa_onnx |
| 34 | 39 | ||
| 35 | #endif // SHERPA_ONNX_CSRC_SESSION_H_ | 40 | #endif // SHERPA_ONNX_CSRC_SESSION_H_ |
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include <stdio.h> | ||
| 6 | + | ||
| 7 | +#include <chrono> // NOLINT | ||
| 8 | +#include <string> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 12 | +#include "sherpa-onnx/csrc/spoken-language-identification.h" | ||
| 13 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 14 | + | ||
| 15 | +int main(int32_t argc, char *argv[]) { | ||
| 16 | + const char *kUsageMessage = R"usage( | ||
| 17 | +Spoken language identification with sherpa-onnx. | ||
| 18 | + | ||
| 19 | +Usage: | ||
| 20 | + | ||
| 21 | +(1) Use a whisper multilingual model | ||
| 22 | + | ||
| 23 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 | ||
| 24 | +tar xvf sherpa-onnx-whisper-tiny.tar.bz2 | ||
| 25 | +rm sherpa-onnx-whisper-tiny.tar.bz2 | ||
| 26 | + | ||
| 27 | +We only use the int8.onnx models below. | ||
| 28 | + | ||
| 29 | +./bin/sherpa-onnx-offline-spoken-language-identification \ | ||
| 30 | + --whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \ | ||
| 31 | + --whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \ | ||
| 32 | + --num-threads=1 \ | ||
| 33 | + /path/to/foo.wav | ||
| 34 | + | ||
| 35 | +foo.wav should be of single channel, 16-bit PCM encoded wave file; its | ||
| 36 | +sampling rate can be arbitrary and does not need to be 16kHz. | ||
| 37 | +You can find test waves for different languages at | ||
| 38 | +https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs | ||
| 39 | + | ||
| 40 | +Please refer to | ||
| 41 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html | ||
| 42 | +Note that only whisper multilingual models are supported. For instance, | ||
| 43 | +"tiny" is supported but "tiny.en" is not. | ||
| 44 | +for a list of pre-trained models to download. | ||
| 45 | +)usage"; | ||
| 46 | + | ||
| 47 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 48 | + sherpa_onnx::SpokenLanguageIdentificationConfig config; | ||
| 49 | + config.Register(&po); | ||
| 50 | + | ||
| 51 | + po.Read(argc, argv); | ||
| 52 | + if (po.NumArgs() != 1) { | ||
| 53 | + fprintf(stderr, "Error: Please provide 1 wave file.\n\n"); | ||
| 54 | + po.PrintUsage(); | ||
| 55 | + exit(EXIT_FAILURE); | ||
| 56 | + } | ||
| 57 | + | ||
| 58 | + fprintf(stderr, "%s\n", config.ToString().c_str()); | ||
| 59 | + | ||
| 60 | + if (!config.Validate()) { | ||
| 61 | + fprintf(stderr, "Errors in config!\n"); | ||
| 62 | + return -1; | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + fprintf(stderr, "Creating spoken language identifier ...\n"); | ||
| 66 | + sherpa_onnx::SpokenLanguageIdentification slid(config); | ||
| 67 | + | ||
| 68 | + fprintf(stderr, "Started\n"); | ||
| 69 | + const std::string wav_filename = po.GetArg(1); | ||
| 70 | + | ||
| 71 | + int32_t sampling_rate = -1; | ||
| 72 | + bool is_ok = false; | ||
| 73 | + const std::vector<float> samples = | ||
| 74 | + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||
| 75 | + if (!is_ok) { | ||
| 76 | + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||
| 77 | + return -1; | ||
| 78 | + } | ||
| 79 | + float duration = samples.size() / static_cast<float>(sampling_rate); | ||
| 80 | + | ||
| 81 | + const auto begin = std::chrono::steady_clock::now(); | ||
| 82 | + | ||
| 83 | + auto s = slid.CreateStream(); | ||
| 84 | + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | ||
| 85 | + | ||
| 86 | + auto language = slid.Compute(s.get()); | ||
| 87 | + | ||
| 88 | + const auto end = std::chrono::steady_clock::now(); | ||
| 89 | + | ||
| 90 | + fprintf(stderr, "Done!\n\n"); | ||
| 91 | + fprintf(stderr, "%s\nDetected language: %s\n", wav_filename.c_str(), | ||
| 92 | + language.c_str()); | ||
| 93 | + | ||
| 94 | + float elapsed_seconds = | ||
| 95 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 96 | + .count() / | ||
| 97 | + 1000.; | ||
| 98 | + | ||
| 99 | + fprintf(stderr, "num threads: %d\n", config.num_threads); | ||
| 100 | + | ||
| 101 | + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 102 | + float rtf = elapsed_seconds / duration; | ||
| 103 | + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||
| 104 | + elapsed_seconds, duration, rtf); | ||
| 105 | + | ||
| 106 | + return 0; | ||
| 107 | +} |
| @@ -16,7 +16,7 @@ enum class ModelType { | @@ -16,7 +16,7 @@ enum class ModelType { | ||
| 16 | kWeSpeaker, | 16 | kWeSpeaker, |
| 17 | k3dSpeaker, | 17 | k3dSpeaker, |
| 18 | kNeMo, | 18 | kNeMo, |
| 19 | - kUnkown, | 19 | + kUnknown, |
| 20 | }; | 20 | }; |
| 21 | 21 | ||
| 22 | } // namespace | 22 | } // namespace |
| @@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 47 | "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/" | 47 | "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/" |
| 48 | "add_meta_data.py" | 48 | "add_meta_data.py" |
| 49 | "to add metadata to models from WeSpeaker\n"); | 49 | "to add metadata to models from WeSpeaker\n"); |
| 50 | - return ModelType::kUnkown; | 50 | + return ModelType::kUnknown; |
| 51 | } | 51 | } |
| 52 | 52 | ||
| 53 | if (model_type.get() == std::string("wespeaker")) { | 53 | if (model_type.get() == std::string("wespeaker")) { |
| @@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 58 | return ModelType::kNeMo; | 58 | return ModelType::kNeMo; |
| 59 | } else { | 59 | } else { |
| 60 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 60 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); |
| 61 | - return ModelType::kUnkown; | 61 | + return ModelType::kUnknown; |
| 62 | } | 62 | } |
| 63 | } | 63 | } |
| 64 | 64 | ||
| 65 | std::unique_ptr<SpeakerEmbeddingExtractorImpl> | 65 | std::unique_ptr<SpeakerEmbeddingExtractorImpl> |
| 66 | SpeakerEmbeddingExtractorImpl::Create( | 66 | SpeakerEmbeddingExtractorImpl::Create( |
| 67 | const SpeakerEmbeddingExtractorConfig &config) { | 67 | const SpeakerEmbeddingExtractorConfig &config) { |
| 68 | - ModelType model_type = ModelType::kUnkown; | 68 | + ModelType model_type = ModelType::kUnknown; |
| 69 | 69 | ||
| 70 | { | 70 | { |
| 71 | auto buffer = ReadFile(config.model); | 71 | auto buffer = ReadFile(config.model); |
| @@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create( | @@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create( | ||
| 80 | return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config); | 80 | return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config); |
| 81 | case ModelType::kNeMo: | 81 | case ModelType::kNeMo: |
| 82 | return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config); | 82 | return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config); |
| 83 | - case ModelType::kUnkown: | ||
| 84 | - SHERPA_ONNX_LOGE( | ||
| 85 | - "Unknown model type in for speaker embedding extractor!"); | 83 | + case ModelType::kUnknown: |
| 84 | + SHERPA_ONNX_LOGE("Unknown model type for speaker embedding extractor!"); | ||
| 86 | return nullptr; | 85 | return nullptr; |
| 87 | } | 86 | } |
| 88 | 87 | ||
| @@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create( | @@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create( | ||
| 94 | std::unique_ptr<SpeakerEmbeddingExtractorImpl> | 93 | std::unique_ptr<SpeakerEmbeddingExtractorImpl> |
| 95 | SpeakerEmbeddingExtractorImpl::Create( | 94 | SpeakerEmbeddingExtractorImpl::Create( |
| 96 | AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) { | 95 | AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) { |
| 97 | - ModelType model_type = ModelType::kUnkown; | 96 | + ModelType model_type = ModelType::kUnknown; |
| 98 | 97 | ||
| 99 | { | 98 | { |
| 100 | auto buffer = ReadFile(mgr, config.model); | 99 | auto buffer = ReadFile(mgr, config.model); |
| @@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create( | @@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create( | ||
| 110 | config); | 109 | config); |
| 111 | case ModelType::kNeMo: | 110 | case ModelType::kNeMo: |
| 112 | return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config); | 111 | return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config); |
| 113 | - case ModelType::kUnkown: | 112 | + case ModelType::kUnknown: |
| 114 | SHERPA_ONNX_LOGE( | 113 | SHERPA_ONNX_LOGE( |
| 115 | "Unknown model type in for speaker embedding extractor!"); | 114 | "Unknown model type in for speaker embedding extractor!"); |
| 116 | return nullptr; | 115 | return nullptr; |
| 1 | +// sherpa-onnx/csrc/spoken-language-identification-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/spoken-language-identification-impl.h" | ||
| 5 | + | ||
| 6 | +#include <memory> | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +namespace { | ||
| 15 | + | ||
| 16 | +enum class ModelType { | ||
| 17 | + kWhisper, | ||
| 18 | + kUnknown, | ||
| 19 | +}; | ||
| 20 | + | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 24 | + bool debug) { | ||
| 25 | + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); | ||
| 26 | + Ort::SessionOptions sess_opts; | ||
| 27 | + | ||
| 28 | + auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length, | ||
| 29 | + sess_opts); | ||
| 30 | + | ||
| 31 | + Ort::ModelMetadata meta_data = sess->GetModelMetadata(); | ||
| 32 | + if (debug) { | ||
| 33 | + std::ostringstream os; | ||
| 34 | + PrintModelMetadata(os, meta_data); | ||
| 35 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 36 | + } | ||
| 37 | + | ||
| 38 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 39 | + auto model_type = | ||
| 40 | + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 41 | + if (!model_type) { | ||
| 42 | + SHERPA_ONNX_LOGE( | ||
| 43 | + "No model_type in the metadata!\n" | ||
| 44 | + "Please make sure you have added metadata to the model.\n\n" | ||
| 45 | + "For instance, you can use\n" | ||
| 46 | + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/" | ||
| 47 | + "export-onnx.py " | ||
| 48 | + "to add metadata to models from whisper\n"); | ||
| 49 | + return ModelType::kUnknown; | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + auto model_type_str = std::string(model_type.get()); | ||
| 53 | + if (model_type_str.find("whisper") == 0) { | ||
| 54 | + return ModelType::kWhisper; | ||
| 55 | + } else { | ||
| 56 | + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | ||
| 57 | + return ModelType::kUnknown; | ||
| 58 | + } | ||
| 59 | +} | ||
| 60 | + | ||
| 61 | +std::unique_ptr<SpokenLanguageIdentificationImpl> | ||
| 62 | +SpokenLanguageIdentificationImpl::Create( | ||
| 63 | + const SpokenLanguageIdentificationConfig &config) { | ||
| 64 | + ModelType model_type = ModelType::kUnknown; | ||
| 65 | + { | ||
| 66 | + if (config.whisper.encoder.empty()) { | ||
| 67 | + SHERPA_ONNX_LOGE("Only whisper models are supported at present"); | ||
| 68 | + exit(-1); | ||
| 69 | + } | ||
| 70 | + auto buffer = ReadFile(config.whisper.encoder); | ||
| 71 | + | ||
| 72 | + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + switch (model_type) { | ||
| 76 | + case ModelType::kWhisper: | ||
| 77 | + return std::make_unique<SpokenLanguageIdentificationWhisperImpl>(config); | ||
| 78 | + case ModelType::kUnknown: | ||
| 79 | + SHERPA_ONNX_LOGE( | ||
| 80 | + "Unknown model type for spoken language identification!"); | ||
| 81 | + return nullptr; | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + // unreachable code | ||
| 85 | + return nullptr; | ||
| 86 | +} | ||
| 87 | + | ||
| 88 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/spoken-language-identification-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/spoken-language-identification.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class SpokenLanguageIdentificationImpl { | ||
| 15 | + public: | ||
| 16 | + virtual ~SpokenLanguageIdentificationImpl() = default; | ||
| 17 | + | ||
| 18 | + static std::unique_ptr<SpokenLanguageIdentificationImpl> Create( | ||
| 19 | + const SpokenLanguageIdentificationConfig &config); | ||
| 20 | + | ||
| 21 | + virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; | ||
| 22 | + | ||
| 23 | + virtual std::string Compute(OfflineStream *s) const = 0; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <memory> | ||
| 10 | +#include <string> | ||
| 11 | +#include <utility> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +#include "sherpa-onnx/csrc/offline-whisper-model.h" | ||
| 15 | +#include "sherpa-onnx/csrc/spoken-language-identification-impl.h" | ||
| 16 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 17 | + | ||
| 18 | +namespace sherpa_onnx { | ||
| 19 | + | ||
| 20 | +class SpokenLanguageIdentificationWhisperImpl | ||
| 21 | + : public SpokenLanguageIdentificationImpl { | ||
| 22 | + public: | ||
| 23 | + explicit SpokenLanguageIdentificationWhisperImpl( | ||
| 24 | + const SpokenLanguageIdentificationConfig &config) | ||
| 25 | + : config_(config), model_(std::make_unique<OfflineWhisperModel>(config)) { | ||
| 26 | + Check(); | ||
| 27 | + } | ||
| 28 | + | ||
| 29 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 30 | + return std::make_unique<OfflineStream>(WhisperTag{}); | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + std::string Compute(OfflineStream *s) const override { | ||
| 34 | + int32_t max_num_frames = 3000; | ||
| 35 | + auto memory_info = | ||
| 36 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 37 | + | ||
| 38 | + int32_t feat_dim = s->FeatureDim(); | ||
| 39 | + std::vector<float> f = s->GetFrames(); | ||
| 40 | + int32_t num_frames = f.size() / feat_dim; | ||
| 41 | + | ||
| 42 | + // we use 50 here so that there will be some zero tail paddings | ||
| 43 | + if (num_frames >= max_num_frames - 50) { | ||
| 44 | + SHERPA_ONNX_LOGE( | ||
| 45 | + "Only waves less than 30 seconds are supported. We process only the " | ||
| 46 | + "first 30 seconds and discard the remaining data"); | ||
| 47 | + num_frames = max_num_frames - 50; | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + model_->NormalizeFeatures(f.data(), num_frames, feat_dim); | ||
| 51 | + | ||
| 52 | + // note that 1000 is an experience-value. | ||
| 53 | + // You can replace 1000 by other values, say, 100. | ||
| 54 | + // | ||
| 55 | + // Since we have removed the 30 seconds constraint, we need | ||
| 56 | + // tail_padding_frames so that whisper is able to detect the eot token. | ||
| 57 | + int32_t tail_padding_frames = 1000; | ||
| 58 | + | ||
| 59 | + if (config_.whisper.tail_paddings > 0) { | ||
| 60 | + tail_padding_frames = config_.whisper.tail_paddings; | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + int32_t actual_frames = | ||
| 64 | + std::min(num_frames + tail_padding_frames, max_num_frames); | ||
| 65 | + | ||
| 66 | + std::array<int64_t, 3> shape{1, actual_frames, feat_dim}; | ||
| 67 | + | ||
| 68 | + Ort::Value mel = Ort::Value::CreateTensor<float>( | ||
| 69 | + model_->Allocator(), shape.data(), shape.size()); | ||
| 70 | + | ||
| 71 | + float *p_mel = mel.GetTensorMutableData<float>(); | ||
| 72 | + std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel); | ||
| 73 | + | ||
| 74 | + std::fill_n(p_mel + num_frames * feat_dim, | ||
| 75 | + (actual_frames - num_frames) * feat_dim, 0); | ||
| 76 | + | ||
| 77 | + mel = Transpose12(model_->Allocator(), &mel); | ||
| 78 | + | ||
| 79 | + try { | ||
| 80 | + auto cross_kv = model_->ForwardEncoder(std::move(mel)); | ||
| 81 | + int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second); | ||
| 82 | + const auto &id2lang = model_->GetID2Lang(); | ||
| 83 | + if (id2lang.count(lang_id)) { | ||
| 84 | + return id2lang.at(lang_id); | ||
| 85 | + } else { | ||
| 86 | + SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.", | ||
| 87 | + lang_id); | ||
| 88 | + return ""; | ||
| 89 | + } | ||
| 90 | + } catch (const Ort::Exception &ex) { | ||
| 91 | + SHERPA_ONNX_LOGE( | ||
| 92 | + "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of " | ||
| 93 | + "input frames: %d, Current tail " | ||
| 94 | + "paddings: %d. If you see a lot of such exceptions, please consider " | ||
| 95 | + "using a larger --whisper-tail-paddings", | ||
| 96 | + ex.what(), num_frames, tail_padding_frames); | ||
| 97 | + return ""; | ||
| 98 | + } | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + private: | ||
| 102 | + void Check() const { | ||
| 103 | + if (!model_->IsMultiLingual()) { | ||
| 104 | + SHERPA_ONNX_LOGE( | ||
| 105 | + "Only whisper multilingual models can be used for spoken language " | ||
| 106 | + "identification. Given: %s,%s", | ||
| 107 | + config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str()); | ||
| 108 | + exit(-1); | ||
| 109 | + } | ||
| 110 | + } | ||
| 111 | + | ||
| 112 | + private: | ||
| 113 | + SpokenLanguageIdentificationConfig config_; | ||
| 114 | + std::unique_ptr<OfflineWhisperModel> model_; | ||
| 115 | +}; | ||
| 116 | + | ||
| 117 | +} // namespace sherpa_onnx | ||
| 118 | + | ||
| 119 | +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/spoken-language-identification.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/spoken-language-identification.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | +#include "sherpa-onnx/csrc/spoken-language-identification-impl.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +void SpokenLanguageIdentificationWhisperConfig::Register(ParseOptions *po) { | ||
| 16 | + po->Register( | ||
| 17 | + "whisper-encoder", &encoder, | ||
| 18 | + "Path to then encoder of a whisper multilingual model. Support only " | ||
| 19 | + "tiny, base, small, medium, large."); | ||
| 20 | + | ||
| 21 | + po->Register( | ||
| 22 | + "whisper-decoder", &decoder, | ||
| 23 | + "Path to the decoder of a whisper multilingual model. Support only " | ||
| 24 | + "tiny, base, small, medium, large."); | ||
| 25 | + | ||
| 26 | + po->Register( | ||
| 27 | + "whisper-tail-paddings", &tail_paddings, | ||
| 28 | + "Suggested value: 300 for multilingual models. " | ||
| 29 | + "Since we have removed the 30-second constraint, we need to add some " | ||
| 30 | + "tail padding frames " | ||
| 31 | + "so that whisper can detect the eot token. Leave it to -1 to use 1000"); | ||
| 32 | +} | ||
| 33 | + | ||
| 34 | +bool SpokenLanguageIdentificationWhisperConfig::Validate() const { | ||
| 35 | + if (encoder.empty()) { | ||
| 36 | + SHERPA_ONNX_LOGE("Please provide --whisper-encoder"); | ||
| 37 | + return false; | ||
| 38 | + } | ||
| 39 | + | ||
| 40 | + if (!FileExists(encoder)) { | ||
| 41 | + SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str()); | ||
| 42 | + return false; | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + if (decoder.empty()) { | ||
| 46 | + SHERPA_ONNX_LOGE("Please provide --whisper-decoder"); | ||
| 47 | + return false; | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + if (!FileExists(decoder)) { | ||
| 51 | + SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str()); | ||
| 52 | + return false; | ||
| 53 | + } | ||
| 54 | + | ||
| 55 | + return true; | ||
| 56 | +} | ||
| 57 | + | ||
| 58 | +std::string SpokenLanguageIdentificationWhisperConfig::ToString() const { | ||
| 59 | + std::ostringstream os; | ||
| 60 | + | ||
| 61 | + os << "SpokenLanguageIdentificationWhisperConfig("; | ||
| 62 | + os << "encoder=\"" << encoder << "\", "; | ||
| 63 | + os << "decoder=\"" << decoder << "\", "; | ||
| 64 | + os << "tail_paddings=" << tail_paddings << ")"; | ||
| 65 | + | ||
| 66 | + return os.str(); | ||
| 67 | +} | ||
| 68 | + | ||
| 69 | +void SpokenLanguageIdentificationConfig::Register(ParseOptions *po) { | ||
| 70 | + whisper.Register(po); | ||
| 71 | + | ||
| 72 | + po->Register("num-threads", &num_threads, | ||
| 73 | + "Number of threads to run the neural network"); | ||
| 74 | + | ||
| 75 | + po->Register("debug", &debug, | ||
| 76 | + "true to print model information while loading it."); | ||
| 77 | + | ||
| 78 | + po->Register("provider", &provider, | ||
| 79 | + "Specify a provider to use: cpu, cuda, coreml"); | ||
| 80 | +} | ||
| 81 | + | ||
| 82 | +bool SpokenLanguageIdentificationConfig::Validate() const { | ||
| 83 | + if (!whisper.Validate()) { | ||
| 84 | + return false; | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + return true; | ||
| 88 | +} | ||
| 89 | + | ||
| 90 | +std::string SpokenLanguageIdentificationConfig::ToString() const { | ||
| 91 | + std::ostringstream os; | ||
| 92 | + | ||
| 93 | + os << "SpokenLanguageIdentificationConfig("; | ||
| 94 | + os << "whisper=\"" << whisper.ToString() << "\", "; | ||
| 95 | + os << "num_threads=" << num_threads << ", "; | ||
| 96 | + os << "debug=" << (debug ? "True" : "False") << ", "; | ||
| 97 | + os << "provider=\"" << provider << "\")"; | ||
| 98 | + | ||
| 99 | + return os.str(); | ||
| 100 | +} | ||
| 101 | + | ||
| 102 | +SpokenLanguageIdentification::SpokenLanguageIdentification( | ||
| 103 | + const SpokenLanguageIdentificationConfig &config) | ||
| 104 | + : impl_(SpokenLanguageIdentificationImpl::Create(config)) {} | ||
| 105 | + | ||
| 106 | +SpokenLanguageIdentification::~SpokenLanguageIdentification() = default; | ||
| 107 | + | ||
| 108 | +std::unique_ptr<OfflineStream> SpokenLanguageIdentification::CreateStream() | ||
| 109 | + const { | ||
| 110 | + return impl_->CreateStream(); | ||
| 111 | +} | ||
| 112 | + | ||
| 113 | +std::string SpokenLanguageIdentification::Compute(OfflineStream *s) const { | ||
| 114 | + return impl_->Compute(s); | ||
| 115 | +} | ||
| 116 | + | ||
| 117 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/spoken-language-identification.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-stream.h" | ||
| 11 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +struct SpokenLanguageIdentificationWhisperConfig { | ||
| 16 | + // Requires a multi-lingual whisper model. | ||
| 17 | + // That is, it supports only tiny, base, small, medium, large. | ||
| 18 | + // Note: It does NOT support tiny.en, base.en, small.en, medium.en | ||
| 19 | + std::string encoder; | ||
| 20 | + std::string decoder; | ||
| 21 | + | ||
| 22 | + // Number of tail padding frames. | ||
| 23 | + // | ||
| 24 | + // Since we remove the 30-second constraint, we need to add some paddings | ||
| 25 | + // at the end. | ||
| 26 | + // | ||
| 27 | + // Recommended values: | ||
| 28 | + // - 50 for English models | ||
| 29 | + // - 300 for multilingual models | ||
| 30 | + int32_t tail_paddings = -1; | ||
| 31 | + | ||
| 32 | + SpokenLanguageIdentificationWhisperConfig() = default; | ||
| 33 | + | ||
| 34 | + SpokenLanguageIdentificationWhisperConfig(const std::string &encoder, | ||
| 35 | + const std::string &decoder, | ||
| 36 | + int32_t tail_paddings) | ||
| 37 | + : encoder(encoder), decoder(decoder), tail_paddings(tail_paddings) {} | ||
| 38 | + | ||
| 39 | + void Register(ParseOptions *po); | ||
| 40 | + bool Validate() const; | ||
| 41 | + std::string ToString() const; | ||
| 42 | +}; | ||
| 43 | + | ||
| 44 | +struct SpokenLanguageIdentificationConfig { | ||
| 45 | + SpokenLanguageIdentificationWhisperConfig whisper; | ||
| 46 | + | ||
| 47 | + int32_t num_threads = 1; | ||
| 48 | + bool debug = false; | ||
| 49 | + std::string provider = "cpu"; | ||
| 50 | + | ||
| 51 | + SpokenLanguageIdentificationConfig() = default; | ||
| 52 | + | ||
| 53 | + SpokenLanguageIdentificationConfig( | ||
| 54 | + const SpokenLanguageIdentificationWhisperConfig &whisper, | ||
| 55 | + int32_t num_threads, bool debug, const std::string &provider) | ||
| 56 | + : whisper(whisper), | ||
| 57 | + num_threads(num_threads), | ||
| 58 | + debug(debug), | ||
| 59 | + provider(provider) {} | ||
| 60 | + | ||
| 61 | + void Register(ParseOptions *po); | ||
| 62 | + bool Validate() const; | ||
| 63 | + std::string ToString() const; | ||
| 64 | +}; | ||
| 65 | + | ||
| 66 | +class SpokenLanguageIdentificationImpl; | ||
| 67 | + | ||
| 68 | +class SpokenLanguageIdentification { | ||
| 69 | + public: | ||
| 70 | + explicit SpokenLanguageIdentification( | ||
| 71 | + const SpokenLanguageIdentificationConfig &config); | ||
| 72 | + | ||
| 73 | + ~SpokenLanguageIdentification(); | ||
| 74 | + | ||
| 75 | + // Create a stream to accept audio samples and compute features | ||
| 76 | + std::unique_ptr<OfflineStream> CreateStream() const; | ||
| 77 | + | ||
| 78 | + // Return a string containing the language, e.g., en, zh, de, | ||
| 79 | + // etc. | ||
| 80 | + // Note: en is for English, zh is for Chinese, de is for German, etc. | ||
| 81 | + std::string Compute(OfflineStream *s) const; | ||
| 82 | + | ||
| 83 | + private: | ||
| 84 | + std::unique_ptr<SpokenLanguageIdentificationImpl> impl_; | ||
| 85 | +}; | ||
| 86 | + | ||
| 87 | +} // namespace sherpa_onnx | ||
| 88 | + | ||
| 89 | +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ |
| @@ -33,6 +33,7 @@ set(srcs | @@ -33,6 +33,7 @@ set(srcs | ||
| 33 | silero-vad-model-config.cc | 33 | silero-vad-model-config.cc |
| 34 | speaker-embedding-extractor.cc | 34 | speaker-embedding-extractor.cc |
| 35 | speaker-embedding-manager.cc | 35 | speaker-embedding-manager.cc |
| 36 | + spoken-language-identification.cc | ||
| 36 | vad-model-config.cc | 37 | vad-model-config.cc |
| 37 | vad-model.cc | 38 | vad-model.cc |
| 38 | voice-activity-detector.cc | 39 | voice-activity-detector.cc |
| @@ -22,6 +22,7 @@ | @@ -22,6 +22,7 @@ | ||
| 22 | #include "sherpa-onnx/python/csrc/online-stream.h" | 22 | #include "sherpa-onnx/python/csrc/online-stream.h" |
| 23 | #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" | 23 | #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" |
| 24 | #include "sherpa-onnx/python/csrc/speaker-embedding-manager.h" | 24 | #include "sherpa-onnx/python/csrc/speaker-embedding-manager.h" |
| 25 | +#include "sherpa-onnx/python/csrc/spoken-language-identification.h" | ||
| 25 | #include "sherpa-onnx/python/csrc/vad-model-config.h" | 26 | #include "sherpa-onnx/python/csrc/vad-model-config.h" |
| 26 | #include "sherpa-onnx/python/csrc/vad-model.h" | 27 | #include "sherpa-onnx/python/csrc/vad-model.h" |
| 27 | #include "sherpa-onnx/python/csrc/voice-activity-detector.h" | 28 | #include "sherpa-onnx/python/csrc/voice-activity-detector.h" |
| @@ -55,6 +56,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -55,6 +56,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 55 | PybindOfflineTts(&m); | 56 | PybindOfflineTts(&m); |
| 56 | PybindSpeakerEmbeddingExtractor(&m); | 57 | PybindSpeakerEmbeddingExtractor(&m); |
| 57 | PybindSpeakerEmbeddingManager(&m); | 58 | PybindSpeakerEmbeddingManager(&m); |
| 59 | + PybindSpokenLanguageIdentification(&m); | ||
| 58 | 60 | ||
| 59 | PybindAlsa(&m); | 61 | PybindAlsa(&m); |
| 60 | } | 62 | } |
| 1 | +// sherpa-onnx/python/csrc/spoken-language-identification.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/spoken-language-identification.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/spoken-language-identification.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +static void PybindSpokenLanguageIdentificationWhisperConfig(py::module *m) { | ||
| 14 | + using PyClass = SpokenLanguageIdentificationWhisperConfig; | ||
| 15 | + | ||
| 16 | + py::class_<PyClass>(*m, "SpokenLanguageIdentificationWhisperConfig") | ||
| 17 | + .def(py::init<>()) | ||
| 18 | + .def(py::init<const std::string &, const std::string &, int32_t>(), | ||
| 19 | + py::arg("encoder"), py::arg("decoder"), | ||
| 20 | + py::arg("tail_paddings") = -1) | ||
| 21 | + .def_readwrite("encoder", &PyClass::encoder) | ||
| 22 | + .def_readwrite("decoder", &PyClass::decoder) | ||
| 23 | + .def_readwrite("tail_paddings", &PyClass::tail_paddings) | ||
| 24 | + .def("validate", &PyClass::Validate) | ||
| 25 | + .def("__str__", &PyClass::ToString); | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +static void PybindSpokenLanguageIdentificationConfig(py::module *m) { | ||
| 29 | + PybindSpokenLanguageIdentificationWhisperConfig(m); | ||
| 30 | + | ||
| 31 | + using PyClass = SpokenLanguageIdentificationConfig; | ||
| 32 | + | ||
| 33 | + py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig") | ||
| 34 | + .def(py::init<>()) | ||
| 35 | + .def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t, | ||
| 36 | + bool, const std::string>(), | ||
| 37 | + py::arg("whisper"), py::arg("num_threads") = 1, | ||
| 38 | + py::arg("debug") = false, py::arg("provider") = "cpu") | ||
| 39 | + .def_readwrite("whisper", &PyClass::whisper) | ||
| 40 | + .def_readwrite("num_threads", &PyClass::num_threads) | ||
| 41 | + .def_readwrite("debug", &PyClass::debug) | ||
| 42 | + .def_readwrite("provider", &PyClass::provider) | ||
| 43 | + .def("validate", &PyClass::Validate) | ||
| 44 | + .def("__str__", &PyClass::ToString); | ||
| 45 | +} | ||
| 46 | + | ||
| 47 | +void PybindSpokenLanguageIdentification(py::module *m) { | ||
| 48 | + PybindSpokenLanguageIdentificationConfig(m); | ||
| 49 | + | ||
| 50 | + using PyClass = SpokenLanguageIdentification; | ||
| 51 | + py::class_<PyClass>(*m, "SpokenLanguageIdentification") | ||
| 52 | + .def(py::init<const SpokenLanguageIdentificationConfig &>(), | ||
| 53 | + py::arg("config"), py::call_guard<py::gil_scoped_release>()) | ||
| 54 | + .def("create_stream", &PyClass::CreateStream, | ||
| 55 | + py::call_guard<py::gil_scoped_release>()) | ||
| 56 | + .def("compute", &PyClass::Compute, | ||
| 57 | + py::call_guard<py::gil_scoped_release>()); | ||
| 58 | +} | ||
| 59 | + | ||
| 60 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/spoken-language-identification.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindSpokenLanguageIdentification(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ |
| @@ -13,6 +13,9 @@ from _sherpa_onnx import ( | @@ -13,6 +13,9 @@ from _sherpa_onnx import ( | ||
| 13 | SpeakerEmbeddingExtractorConfig, | 13 | SpeakerEmbeddingExtractorConfig, |
| 14 | SpeakerEmbeddingManager, | 14 | SpeakerEmbeddingManager, |
| 15 | SpeechSegment, | 15 | SpeechSegment, |
| 16 | + SpokenLanguageIdentification, | ||
| 17 | + SpokenLanguageIdentificationConfig, | ||
| 18 | + SpokenLanguageIdentificationWhisperConfig, | ||
| 16 | VadModel, | 19 | VadModel, |
| 17 | VadModelConfig, | 20 | VadModelConfig, |
| 18 | VoiceActivityDetector, | 21 | VoiceActivityDetector, |
-
请 注册 或 登录 后发表评论