Fangjun Kuang
Committed by GitHub

Support spoken language identification with whisper (#694)

正在显示 36 个修改的文件 包含 1173 行增加200 行删除
  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +names=(
  17 +tiny
  18 +base
  19 +small
  20 +medium
  21 +)
  22 +
  23 +# all_language_codes=bo,ml,tt,fa,sl,bg,sn,sr,tl,km,ln,mr,hr,eu,ro,ba,bs,pl,as,nn,sk,ko,oc,ar,uz,pa,tg,mk,kk,hi,ha,uk,is,de,el,ja,yo,be,so,tk,id,sa,ru,yi,en,am,cs,ne,la,sv,su,pt,mi,ca,sd,hy,haw,fi,et,kn,da,lt,it,nl,he,mg,ur,tr,af,br,bn,ta,no,my,si,mt,th,gl,sw,mn,jw,ms,ps,fo,ka,hu,zh,ht,az,fr,lo,sq,gu,cy,lv,es,lb,te,vi
  24 +
  25 +log "Download test waves"
  26 +waves=(
  27 +ar-arabic.wav
  28 +bg-bulgarian.wav
  29 +cs-czech.wav
  30 +da-danish.wav
  31 +de-german.wav
  32 +el-greek.wav
  33 +en-english.wav
  34 +es-spanish.wav
  35 +fa-persian.wav
  36 +fi-finnish.wav
  37 +fr-french.wav
  38 +hi-hindi.wav
  39 +hr-croatian.wav
  40 +id-indonesian.wav
  41 +it-italian.wav
  42 +ja-japanese.wav
  43 +ko-korean.wav
  44 +nl-dutch.wav
  45 +no-norwegian.wav
  46 +po-polish.wav
  47 +pt-portuguese.wav
  48 +ro-romanian.wav
  49 +ru-russian.wav
  50 +sk-slovak.wav
  51 +sv-swedish.wav
  52 +ta-tamil.wav
  53 +tl-tagalog.wav
  54 +tr-turkish.wav
  55 +uk-ukrainian.wav
  56 +zh-chinese.wav
  57 +)
  58 +
  59 +for wav in ${waves[@]}; do
  60 + echo "Downloading $wav"
  61 + curl -SL -O https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/$wav
  62 + ls -lh *.wav
  63 +done
  64 +
  65 +for name in ${names[@]}; do
  66 + log "------------------------------------------------------------"
  67 + log "Run $name"
  68 + log "------------------------------------------------------------"
  69 +
  70 + repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-whisper-$name
  71 + log "Start testing ${repo_url}"
  72 + repo=$(basename $repo_url)
  73 + log "Download pretrained model and test-data from $repo_url"
  74 +
  75 + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  76 + pushd $repo
  77 + git lfs pull --include "*.onnx"
  78 + # git lfs pull --include "*.ort"
  79 + ls -lh *.onnx
  80 + popd
  81 +
  82 + for wav in ${waves[@]}; do
  83 + log "test fp32 onnx"
  84 +
  85 + time $EXE \
  86 + --whisper-encoder=$repo/${name}-encoder.onnx \
  87 + --whisper-decoder=$repo/${name}-decoder.onnx \
  88 + $wav
  89 +
  90 + log "test int8 onnx"
  91 +
  92 + time $EXE \
  93 + --whisper-encoder=$repo/${name}-encoder.int8.onnx \
  94 + --whisper-decoder=$repo/${name}-decoder.int8.onnx \
  95 + $wav
  96 + done
  97 + rm -rf $repo
  98 +done
@@ -82,7 +82,6 @@ jobs: @@ -82,7 +82,6 @@ jobs:
82 env: 82 env:
83 HF_TOKEN: ${{ secrets.HF_TOKEN }} 83 HF_TOKEN: ${{ secrets.HF_TOKEN }}
84 uses: nick-fields/retry@v3 84 uses: nick-fields/retry@v3
85 - shell: bash  
86 with: 85 with:
87 max_attempts: 20 86 max_attempts: 20
88 timeout_seconds: 200 87 timeout_seconds: 200
@@ -21,27 +21,12 @@ jobs: @@ -21,27 +21,12 @@ jobs:
21 fail-fast: false 21 fail-fast: false
22 matrix: 22 matrix:
23 os: [macos-latest] 23 os: [macos-latest]
24 - python-version: ["cp37", "cp38", "cp39", "cp310", "cp311", "cp312"] 24 + python-version: ["cp38", "cp39", "cp310", "cp311", "cp312"]
25 25
26 steps: 26 steps:
27 - uses: actions/checkout@v4 27 - uses: actions/checkout@v4
28 28
29 - # see https://cibuildwheel.readthedocs.io/en/stable/changelog/  
30 - # for a list of versions  
31 - name: Build wheels 29 - name: Build wheels
32 - if: matrix.python-version == 'cp37'  
33 - uses: pypa/cibuildwheel@v2.11.4  
34 - env:  
35 - CIBW_BUILD: "${{ matrix.python-version}}-* "  
36 - CIBW_ENVIRONMENT: SHERPA_ONNX_CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES='arm64'"  
37 - CIBW_ARCHS: "arm64"  
38 - CIBW_BUILD_VERBOSITY: 3  
39 -  
40 - # Don't repair macOS wheels  
41 - CIBW_REPAIR_WHEEL_COMMAND_MACOS: ""  
42 -  
43 - - name: Build wheels  
44 - if: matrix.python-version != 'cp37'  
45 uses: pypa/cibuildwheel@v2.15.0 30 uses: pypa/cibuildwheel@v2.15.0
46 env: 31 env:
47 CIBW_BUILD: "${{ matrix.python-version}}-* " 32 CIBW_BUILD: "${{ matrix.python-version}}-* "
@@ -92,6 +92,14 @@ jobs: @@ -92,6 +92,14 @@ jobs:
92 file build/bin/sherpa-onnx 92 file build/bin/sherpa-onnx
93 readelf -d build/bin/sherpa-onnx 93 readelf -d build/bin/sherpa-onnx
94 94
  95 + - name: Test spoken language identification
  96 + shell: bash
  97 + run: |
  98 + export PATH=$PWD/build/bin:$PATH
  99 + export EXE=sherpa-onnx-offline-language-identification
  100 +
  101 + .github/scripts/test-spoken-language-identification.sh
  102 +
95 - name: Test online CTC 103 - name: Test online CTC
96 shell: bash 104 shell: bash
97 run: | 105 run: |
@@ -116,6 +124,7 @@ jobs: @@ -116,6 +124,7 @@ jobs:
116 124
117 .github/scripts/test-online-paraformer.sh 125 .github/scripts/test-online-paraformer.sh
118 126
  127 +
119 - name: Test offline Whisper 128 - name: Test offline Whisper
120 shell: bash 129 shell: bash
121 run: | 130 run: |
@@ -123,6 +123,15 @@ jobs: @@ -123,6 +123,15 @@ jobs:
123 name: release-${{ matrix.build_type }}-${{ matrix.shared_lib }} 123 name: release-${{ matrix.build_type }}-${{ matrix.shared_lib }}
124 path: build/bin/* 124 path: build/bin/*
125 125
  126 + - name: Test spoken language identification
  127 + if: matrix.build_type != 'Debug'
  128 + shell: bash
  129 + run: |
  130 + export PATH=$PWD/build/bin:$PATH
  131 + export EXE=sherpa-onnx-offline-language-identification
  132 +
  133 + .github/scripts/test-spoken-language-identification.sh
  134 +
126 - name: Test transducer kws 135 - name: Test transducer kws
127 shell: bash 136 shell: bash
128 run: | 137 run: |
@@ -140,6 +149,7 @@ jobs: @@ -140,6 +149,7 @@ jobs:
140 .github/scripts/test-online-ctc.sh 149 .github/scripts/test-online-ctc.sh
141 150
142 - name: Test offline Whisper 151 - name: Test offline Whisper
  152 + if: matrix.build_type != 'Debug'
143 shell: bash 153 shell: bash
144 run: | 154 run: |
145 export PATH=$PWD/build/bin:$PATH 155 export PATH=$PWD/build/bin:$PATH
@@ -102,6 +102,15 @@ jobs: @@ -102,6 +102,15 @@ jobs:
102 otool -L build/bin/sherpa-onnx 102 otool -L build/bin/sherpa-onnx
103 otool -l build/bin/sherpa-onnx 103 otool -l build/bin/sherpa-onnx
104 104
  105 + - name: Test spoken language identification
  106 + if: matrix.build_type != 'Debug'
  107 + shell: bash
  108 + run: |
  109 + export PATH=$PWD/build/bin:$PATH
  110 + export EXE=sherpa-onnx-offline-language-identification
  111 +
  112 + .github/scripts/test-spoken-language-identification.sh
  113 +
105 - name: Test transducer kws 114 - name: Test transducer kws
106 shell: bash 115 shell: bash
107 run: | 116 run: |
@@ -135,6 +144,7 @@ jobs: @@ -135,6 +144,7 @@ jobs:
135 .github/scripts/test-online-paraformer.sh 144 .github/scripts/test-online-paraformer.sh
136 145
137 - name: Test offline Whisper 146 - name: Test offline Whisper
  147 + if: matrix.build_type != 'Debug'
138 shell: bash 148 shell: bash
139 run: | 149 run: |
140 export PATH=$PWD/build/bin:$PATH 150 export PATH=$PWD/build/bin:$PATH
@@ -68,6 +68,14 @@ jobs: @@ -68,6 +68,14 @@ jobs:
68 68
69 ls -lh ./bin/Release/sherpa-onnx.exe 69 ls -lh ./bin/Release/sherpa-onnx.exe
70 70
  71 + - name: Test spoken language identification
  72 + shell: bash
  73 + run: |
  74 + export PATH=$PWD/build/bin/Release:$PATH
  75 + export EXE=sherpa-onnx-offline-language-identification.exe
  76 +
  77 + .github/scripts/test-spoken-language-identification.sh
  78 +
71 - name: Test online CTC 79 - name: Test online CTC
72 shell: bash 80 shell: bash
73 run: | 81 run: |
@@ -68,6 +68,14 @@ jobs: @@ -68,6 +68,14 @@ jobs:
68 68
69 ls -lh ./bin/Release/sherpa-onnx.exe 69 ls -lh ./bin/Release/sherpa-onnx.exe
70 70
  71 + - name: Test spoken language identification
  72 + shell: bash
  73 + run: |
  74 + export PATH=$PWD/build/bin/Release:$PATH
  75 + export EXE=sherpa-onnx-offline-language-identification.exe
  76 +
  77 + .github/scripts/test-spoken-language-identification.sh
  78 +
71 - name: Test online CTC 79 - name: Test online CTC
72 shell: bash 80 shell: bash
73 run: | 81 run: |
@@ -69,6 +69,14 @@ jobs: @@ -69,6 +69,14 @@ jobs:
69 69
70 ls -lh ./bin/Release/sherpa-onnx.exe 70 ls -lh ./bin/Release/sherpa-onnx.exe
71 71
  72 + # - name: Test spoken language identification
  73 + # shell: bash
  74 + # run: |
  75 + # export PATH=$PWD/build/bin/Release:$PATH
  76 + # export EXE=sherpa-onnx-offline-language-identification.exe
  77 + #
  78 + # .github/scripts/test-spoken-language-identification.sh
  79 +
72 - name: Test online CTC 80 - name: Test online CTC
73 shell: bash 81 shell: bash
74 run: | 82 run: |
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.9.13") 4 +set(SHERPA_ONNX_VERSION "1.9.14")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -43,6 +43,50 @@ def enable_alsa(): @@ -43,6 +43,50 @@ def enable_alsa():
43 return build_alsa and is_linux() and (is_arm64() or is_x86()) 43 return build_alsa and is_linux() and (is_arm64() or is_x86())
44 44
45 45
  46 +def get_binaries():
  47 + binaries = [
  48 + "sherpa-onnx",
  49 + "sherpa-onnx-keyword-spotter",
  50 + "sherpa-onnx-microphone",
  51 + "sherpa-onnx-microphone-offline",
  52 + "sherpa-onnx-microphone-offline-speaker-identification",
  53 + "sherpa-onnx-offline",
  54 + "sherpa-onnx-offline-language-identification",
  55 + "sherpa-onnx-offline-tts",
  56 + "sherpa-onnx-offline-tts-play",
  57 + "sherpa-onnx-offline-websocket-server",
  58 + "sherpa-onnx-online-websocket-client",
  59 + "sherpa-onnx-online-websocket-server",
  60 + "sherpa-onnx-vad-microphone",
  61 + "sherpa-onnx-vad-microphone-offline-asr",
  62 + ]
  63 +
  64 + if enable_alsa():
  65 + binaries += [
  66 + "sherpa-onnx-alsa",
  67 + "sherpa-onnx-alsa-offline",
  68 + "sherpa-onnx-alsa-offline-speaker-identification",
  69 + "sherpa-onnx-offline-tts-play-alsa",
  70 + ]
  71 +
  72 + if is_windows():
  73 + binaries += [
  74 + "espeak-ng.dll",
  75 + "kaldi-decoder-core.dll",
  76 + "kaldi-native-fbank-core.dll",
  77 + "onnxruntime.dll",
  78 + "piper_phonemize.dll",
  79 + "sherpa-onnx-c-api.dll",
  80 + "sherpa-onnx-core.dll",
  81 + "sherpa-onnx-fst.lib",
  82 + "sherpa-onnx-kaldifst-core.lib",
  83 + "sherpa-onnx-portaudio.dll",
  84 + "ucd.dll",
  85 + ]
  86 +
  87 + return binaries
  88 +
  89 +
46 try: 90 try:
47 from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 91 from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
48 92
@@ -150,38 +194,7 @@ class BuildExtension(build_ext): @@ -150,38 +194,7 @@ class BuildExtension(build_ext):
150 suffix = ".exe" if is_windows() else "" 194 suffix = ".exe" if is_windows() else ""
151 # Remember to also change setup.py 195 # Remember to also change setup.py
152 196
153 - binaries = ["sherpa-onnx"]  
154 - binaries += ["sherpa-onnx-keyword-spotter"]  
155 - binaries += ["sherpa-onnx-offline"]  
156 - binaries += ["sherpa-onnx-microphone"]  
157 - binaries += ["sherpa-onnx-microphone-offline"]  
158 - binaries += ["sherpa-onnx-microphone-offline-speaker-identification"]  
159 - binaries += ["sherpa-onnx-online-websocket-server"]  
160 - binaries += ["sherpa-onnx-offline-websocket-server"]  
161 - binaries += ["sherpa-onnx-online-websocket-client"]  
162 - binaries += ["sherpa-onnx-vad-microphone"]  
163 - binaries += ["sherpa-onnx-vad-microphone-offline-asr"]  
164 - binaries += ["sherpa-onnx-offline-tts"]  
165 - binaries += ["sherpa-onnx-offline-tts-play"]  
166 -  
167 - if enable_alsa():  
168 - binaries += ["sherpa-onnx-alsa"]  
169 - binaries += ["sherpa-onnx-alsa-offline"]  
170 - binaries += ["sherpa-onnx-offline-tts-play-alsa"]  
171 - binaries += ["sherpa-onnx-alsa-offline-speaker-identification"]  
172 -  
173 - if is_windows():  
174 - binaries += ["kaldi-native-fbank-core.dll"]  
175 - binaries += ["sherpa-onnx-c-api.dll"]  
176 - binaries += ["sherpa-onnx-core.dll"]  
177 - binaries += ["sherpa-onnx-portaudio.dll"]  
178 - binaries += ["onnxruntime.dll"]  
179 - binaries += ["piper_phonemize.dll"]  
180 - binaries += ["espeak-ng.dll"]  
181 - binaries += ["ucd.dll"]  
182 - binaries += ["kaldi-decoder-core.dll"]  
183 - binaries += ["sherpa-onnx-fst.lib"]  
184 - binaries += ["sherpa-onnx-kaldifst-core.lib"] 197 + binaries = get_binaries()
185 198
186 for f in binaries: 199 for f in binaries:
187 suffix = "" if (".dll" in f or ".lib" in f) else suffix 200 suffix = "" if (".dll" in f or ".lib" in f) else suffix
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This script shows how to use Python APIs for spoken languge identification.
  5 +It detects the language spoken in the given wave file.
  6 +
  7 +Usage:
  8 +
  9 +1. Download a whisper multilingual model. We use a tiny model below.
  10 +Please refer to https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  11 +to download more models.
  12 +
  13 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
  14 +tar xvf sherpa-onnx-whisper-tiny.tar.bz2
  15 +rm sherpa-onnx-whisper-tiny.tar.bz2
  16 +
  17 +We only use the int8.onnx models below.
  18 +
  19 +2. Download a test wave.
  20 +
  21 +You can find many wave files for different languages at
  22 +https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
  23 +
  24 +wget https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/de-german.wav
  25 +
  26 +python3 ./python-api-examples/spoken-language-identification.py
  27 + --whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
  28 + --whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
  29 + --num-threads=1 \
  30 + ./de-german.wav
  31 +"""
  32 +
  33 +import argparse
  34 +import logging
  35 +import time
  36 +import wave
  37 +from pathlib import Path
  38 +from typing import Tuple
  39 +
  40 +import numpy as np
  41 +import sherpa_onnx
  42 +
  43 +
  44 +def get_args():
  45 + parser = argparse.ArgumentParser(
  46 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  47 + )
  48 +
  49 + parser.add_argument(
  50 + "--whisper-encoder",
  51 + required=True,
  52 + type=str,
  53 + help="Path to a multilingual whisper encoder model",
  54 + )
  55 +
  56 + parser.add_argument(
  57 + "--whisper-decoder",
  58 + required=True,
  59 + type=str,
  60 + help="Path to a multilingual whisper decoder model",
  61 + )
  62 +
  63 + parser.add_argument(
  64 + "--num-threads",
  65 + type=int,
  66 + default=1,
  67 + help="Number of threads for neural network computation",
  68 + )
  69 +
  70 + parser.add_argument(
  71 + "--debug",
  72 + type=bool,
  73 + default=False,
  74 + help="True to show debug messages",
  75 + )
  76 +
  77 + parser.add_argument(
  78 + "--provider",
  79 + type=str,
  80 + default="cpu",
  81 + help="Valid values: cpu, cuda, coreml",
  82 + )
  83 +
  84 + parser.add_argument(
  85 + "sound_file",
  86 + type=str,
  87 + help="The input sound file to identify. It must be of WAVE"
  88 + "format with a single channel, and each sample has 16-bit, "
  89 + "i.e., int16_t. "
  90 + "The sample rate of the file can be arbitrary and does not need to "
  91 + "be 16 kHz",
  92 + )
  93 +
  94 + return parser.parse_args()
  95 +
  96 +
  97 +def assert_file_exists(filename: str):
  98 + assert Path(filename).is_file(), (
  99 + f"{filename} does not exist!\n"
  100 + "Please refer to "
  101 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html to download it"
  102 + )
  103 +
  104 +
  105 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  106 + """
  107 + Args:
  108 + wave_filename:
  109 + Path to a wave file. It should be single channel and each sample should
  110 + be 16-bit. Its sample rate does not need to be 16kHz.
  111 + Returns:
  112 + Return a tuple containing:
  113 + - A 1-D array of dtype np.float32 containing the samples, which are
  114 + normalized to the range [-1, 1].
  115 + - sample rate of the wave file
  116 + """
  117 +
  118 + with wave.open(wave_filename) as f:
  119 + assert f.getnchannels() == 1, f.getnchannels()
  120 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  121 + num_samples = f.getnframes()
  122 + samples = f.readframes(num_samples)
  123 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  124 + samples_float32 = samples_int16.astype(np.float32)
  125 +
  126 + samples_float32 = samples_float32 / 32768
  127 + return samples_float32, f.getframerate()
  128 +
  129 +
  130 +def main():
  131 + args = get_args()
  132 + assert_file_exists(args.whisper_encoder)
  133 + assert_file_exists(args.whisper_decoder)
  134 + assert args.num_threads > 0, args.num_threads
  135 + config = sherpa_onnx.SpokenLanguageIdentificationConfig(
  136 + whisper=sherpa_onnx.SpokenLanguageIdentificationWhisperConfig(
  137 + encoder=args.whisper_encoder,
  138 + decoder=args.whisper_decoder,
  139 + ),
  140 + num_threads=args.num_threads,
  141 + debug=args.debug,
  142 + provider=args.provider,
  143 + )
  144 + slid = sherpa_onnx.SpokenLanguageIdentification(config)
  145 +
  146 + samples, sample_rate = read_wave(args.sound_file)
  147 +
  148 + start_time = time.time()
  149 + stream = slid.create_stream()
  150 + stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
  151 + lang = slid.compute(stream)
  152 + end_time = time.time()
  153 +
  154 + elapsed_seconds = end_time - start_time
  155 + audio_duration = len(samples) / sample_rate
  156 + real_time_factor = elapsed_seconds / audio_duration
  157 +
  158 + logging.info(f"File: {args.sound_file}")
  159 + logging.info(f"Detected language: {lang}")
  160 + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
  161 + logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
  162 + logging.info(
  163 + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
  164 + )
  165 +
  166 +
  167 +if __name__ == "__main__":
  168 + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
  169 +
  170 + logging.basicConfig(format=formatter, level=logging.INFO)
  171 +
  172 + main()
1 #!/usr/bin/env python3 1 #!/usr/bin/env python3
2 2
3 -import os  
4 import re 3 import re
5 -import sys  
6 from pathlib import Path 4 from pathlib import Path
7 5
8 import setuptools 6 import setuptools
@@ -11,7 +9,7 @@ from cmake.cmake_extension import ( @@ -11,7 +9,7 @@ from cmake.cmake_extension import (
11 BuildExtension, 9 BuildExtension,
12 bdist_wheel, 10 bdist_wheel,
13 cmake_extension, 11 cmake_extension,
14 - enable_alsa, 12 + get_binaries,
15 is_windows, 13 is_windows,
16 ) 14 )
17 15
@@ -42,39 +40,7 @@ def get_binaries_to_install(): @@ -42,39 +40,7 @@ def get_binaries_to_install():
42 bin_dir.mkdir(parents=True, exist_ok=True) 40 bin_dir.mkdir(parents=True, exist_ok=True)
43 suffix = ".exe" if is_windows() else "" 41 suffix = ".exe" if is_windows() else ""
44 42
45 - # Remember to also change cmake/cmake_extension.py  
46 - binaries = ["sherpa-onnx"]  
47 - binaries += ["sherpa-onnx-keyword-spotter"]  
48 - binaries += ["sherpa-onnx-offline"]  
49 - binaries += ["sherpa-onnx-microphone"]  
50 - binaries += ["sherpa-onnx-microphone-offline"]  
51 - binaries += ["sherpa-onnx-microphone-offline-speaker-identification"]  
52 - binaries += ["sherpa-onnx-online-websocket-server"]  
53 - binaries += ["sherpa-onnx-offline-websocket-server"]  
54 - binaries += ["sherpa-onnx-online-websocket-client"]  
55 - binaries += ["sherpa-onnx-vad-microphone"]  
56 - binaries += ["sherpa-onnx-vad-microphone-offline-asr"]  
57 - binaries += ["sherpa-onnx-offline-tts"]  
58 - binaries += ["sherpa-onnx-offline-tts-play"]  
59 -  
60 - if enable_alsa():  
61 - binaries += ["sherpa-onnx-alsa"]  
62 - binaries += ["sherpa-onnx-alsa-offline"]  
63 - binaries += ["sherpa-onnx-offline-tts-play-alsa"]  
64 - binaries += ["sherpa-onnx-alsa-offline-speaker-identification"]  
65 -  
66 - if is_windows():  
67 - binaries += ["kaldi-native-fbank-core.dll"]  
68 - binaries += ["sherpa-onnx-c-api.dll"]  
69 - binaries += ["sherpa-onnx-core.dll"]  
70 - binaries += ["sherpa-onnx-portaudio.dll"]  
71 - binaries += ["onnxruntime.dll"]  
72 - binaries += ["piper_phonemize.dll"]  
73 - binaries += ["espeak-ng.dll"]  
74 - binaries += ["ucd.dll"]  
75 - binaries += ["kaldi-decoder-core.dll"]  
76 - binaries += ["sherpa-onnx-fst.lib"]  
77 - binaries += ["sherpa-onnx-kaldifst-core.lib"] 43 + binaries = get_binaries()
78 44
79 exe = [] 45 exe = []
80 for f in binaries: 46 for f in binaries:
@@ -86,6 +86,8 @@ set(sources @@ -86,6 +86,8 @@ set(sources
86 silero-vad-model-config.cc 86 silero-vad-model-config.cc
87 silero-vad-model.cc 87 silero-vad-model.cc
88 slice.cc 88 slice.cc
  89 + spoken-language-identification-impl.cc
  90 + spoken-language-identification.cc
89 stack.cc 91 stack.cc
90 symbol-table.cc 92 symbol-table.cc
91 text-utils.cc 93 text-utils.cc
@@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
184 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) 186 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
185 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) 187 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
186 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) 188 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
  189 + add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
187 190
188 set(main_exes 191 set(main_exes
189 sherpa-onnx 192 sherpa-onnx
@@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
191 sherpa-onnx-offline 194 sherpa-onnx-offline
192 sherpa-onnx-offline-parallel 195 sherpa-onnx-offline-parallel
193 sherpa-onnx-offline-tts 196 sherpa-onnx-offline-tts
  197 + sherpa-onnx-offline-language-identification
194 ) 198 )
195 199
196 foreach(exe IN LISTS main_exes) 200 foreach(exe IN LISTS main_exes)
@@ -23,7 +23,7 @@ enum class ModelType { @@ -23,7 +23,7 @@ enum class ModelType {
23 kTdnn, 23 kTdnn,
24 kZipformerCtc, 24 kZipformerCtc,
25 kWenetCtc, 25 kWenetCtc,
26 - kUnkown, 26 + kUnknown,
27 }; 27 };
28 28
29 } // namespace 29 } // namespace
@@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
59 "run.sh\n" 59 "run.sh\n"
60 "\n" 60 "\n"
61 "for how to add metadta to model.onnx\n"); 61 "for how to add metadta to model.onnx\n");
62 - return ModelType::kUnkown; 62 + return ModelType::kUnknown;
63 } 63 }
64 64
65 if (model_type.get() == std::string("EncDecCTCModelBPE")) { 65 if (model_type.get() == std::string("EncDecCTCModelBPE")) {
@@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
72 return ModelType::kWenetCtc; 72 return ModelType::kWenetCtc;
73 } else { 73 } else {
74 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 74 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
75 - return ModelType::kUnkown; 75 + return ModelType::kUnknown;
76 } 76 }
77 } 77 }
78 78
79 std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( 79 std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
80 const OfflineModelConfig &config) { 80 const OfflineModelConfig &config) {
81 - ModelType model_type = ModelType::kUnkown; 81 + ModelType model_type = ModelType::kUnknown;
82 82
83 std::string filename; 83 std::string filename;
84 if (!config.nemo_ctc.model.empty()) { 84 if (!config.nemo_ctc.model.empty()) {
@@ -113,7 +113,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -113,7 +113,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
113 case ModelType::kWenetCtc: 113 case ModelType::kWenetCtc:
114 return std::make_unique<OfflineWenetCtcModel>(config); 114 return std::make_unique<OfflineWenetCtcModel>(config);
115 break; 115 break;
116 - case ModelType::kUnkown: 116 + case ModelType::kUnknown:
117 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); 117 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
118 return nullptr; 118 return nullptr;
119 } 119 }
@@ -125,7 +125,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -125,7 +125,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
125 125
126 std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( 126 std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
127 AAssetManager *mgr, const OfflineModelConfig &config) { 127 AAssetManager *mgr, const OfflineModelConfig &config) {
128 - ModelType model_type = ModelType::kUnkown; 128 + ModelType model_type = ModelType::kUnknown;
129 129
130 std::string filename; 130 std::string filename;
131 if (!config.nemo_ctc.model.empty()) { 131 if (!config.nemo_ctc.model.empty()) {
@@ -160,7 +160,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -160,7 +160,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
160 case ModelType::kWenetCtc: 160 case ModelType::kWenetCtc:
161 return std::make_unique<OfflineWenetCtcModel>(mgr, config); 161 return std::make_unique<OfflineWenetCtcModel>(mgr, config);
162 break; 162 break;
163 - case ModelType::kUnkown: 163 + case ModelType::kUnknown:
164 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); 164 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
165 return nullptr; 165 return nullptr;
166 } 166 }
@@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
114 num_frames = max_num_frames - 50; 114 num_frames = max_num_frames - 50;
115 } 115 }
116 116
117 - NormalizeFeatures(f.data(), num_frames, feat_dim); 117 + model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
118 118
119 // note that 1000 is an experience-value. 119 // note that 1000 is an experience-value.
120 // You can replace 1000 by other values, say, 100. 120 // You can replace 1000 by other values, say, 100.
@@ -163,38 +163,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -163,38 +163,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
163 } 163 }
164 164
165 private: 165 private:
166 - static void NormalizeFeatures(float *features, int32_t num_frames,  
167 - int32_t feat_dim) {  
168 - // log_spec = torch.clamp(features, min=1e-10).log10()  
169 - // log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)  
170 - // mel = (log_spec + 4.0) / 4.0  
171 -  
172 - int32_t n = num_frames * feat_dim;  
173 - float max_v = -1e20;  
174 - for (int32_t i = 0; i != n; ++i) {  
175 - float f = features[i];  
176 -  
177 - f = std::max<float>(f, 1e-10);  
178 - f = std::log10(f);  
179 -  
180 - max_v = std::max(f, max_v);  
181 -  
182 - features[i] = f;  
183 - }  
184 -  
185 - max_v -= 8;  
186 -  
187 - for (int32_t i = 0; i != n; ++i) {  
188 - float f = features[i];  
189 - f = std::max(f, max_v);  
190 -  
191 - f = (f + 4) / 4;  
192 -  
193 - features[i] = f;  
194 - }  
195 - }  
196 -  
197 - private:  
198 OfflineRecognizerConfig config_; 166 OfflineRecognizerConfig config_;
199 SymbolTable symbol_table_; 167 SymbolTable symbol_table_;
200 std::unique_ptr<OfflineWhisperModel> model_; 168 std::unique_ptr<OfflineWhisperModel> model_;
@@ -12,56 +12,6 @@ @@ -12,56 +12,6 @@
12 12
13 namespace sherpa_onnx { 13 namespace sherpa_onnx {
14 14
15 -int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(  
16 - Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT  
17 - int64_t token_val = model_->SOT();  
18 - std::array<int64_t, 2> token_shape{1, 1};  
19 -  
20 - auto memory_info =  
21 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
22 -  
23 - Ort::Value tokens = Ort::Value::CreateTensor(  
24 - memory_info, &token_val, 1, token_shape.data(), token_shape.size());  
25 -  
26 - auto self_kv_cache = model_->GetInitialSelfKVCache();  
27 -  
28 - std::array<int64_t, 1> offset_shape{1};  
29 - Ort::Value offset = Ort::Value::CreateTensor<int64_t>(  
30 - model_->Allocator(), offset_shape.data(), offset_shape.size());  
31 - *(offset.GetTensorMutableData<int64_t>()) = 0;  
32 -  
33 - auto decoder_out = model_->ForwardDecoder(  
34 - std::move(tokens), std::move(self_kv_cache.first),  
35 - std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),  
36 - std::move(offset));  
37 -  
38 - cross_k = std::move(std::get<3>(decoder_out));  
39 - cross_v = std::move(std::get<4>(decoder_out));  
40 -  
41 - const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();  
42 - int32_t vocab_size = model_->VocabSize();  
43 - const auto &all_language_ids = model_->GetAllLanguageIDs();  
44 -  
45 - int32_t lang_id = all_language_ids[0];  
46 - float this_logit = p_logits[lang_id];  
47 -  
48 - for (int32_t i = 1; i != all_language_ids.size(); ++i) {  
49 - int32_t id = all_language_ids[i];  
50 - float p = p_logits[id];  
51 -  
52 - if (p > this_logit) {  
53 - this_logit = p;  
54 - lang_id = id;  
55 - }  
56 - }  
57 -#if 1  
58 - SHERPA_ONNX_LOGE("Detected language: %s",  
59 - model_->GetID2Lang().at(lang_id).c_str());  
60 -#endif  
61 -  
62 - return lang_id;  
63 -}  
64 -  
65 std::vector<OfflineWhisperDecoderResult> 15 std::vector<OfflineWhisperDecoderResult>
66 OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, 16 OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
67 Ort::Value cross_v) { 17 Ort::Value cross_v) {
@@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
89 // 0: sot, 1: lang_id, 2: task, 3: no_timestamps 39 // 0: sot, 1: lang_id, 2: task, 3: no_timestamps
90 initial_tokens[1] = lang_id; 40 initial_tokens[1] = lang_id;
91 } else { 41 } else {
92 - int32_t lang_id = DetectLanguage(cross_k, cross_v); 42 + int32_t lang_id = model_->DetectLanguage(cross_k, cross_v);
93 43
94 // 0: sot, 1: lang_id, 2: task, 3: no_timestamps 44 // 0: sot, 1: lang_id, 2: task, 3: no_timestamps
95 initial_tokens[1] = lang_id; 45 initial_tokens[1] = lang_id;
@@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { @@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
22 std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, 22 std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
23 Ort::Value cross_v) override; 23 Ort::Value cross_v) override;
24 24
25 - int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT  
26 - Ort::Value &cross_v) const; // NOLINT  
27 -  
28 private: 25 private:
29 OfflineWhisperModelConfig config_; 26 OfflineWhisperModelConfig config_;
30 OfflineWhisperModel *model_; // not owned 27 OfflineWhisperModel *model_; // not owned
@@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { @@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
35 35
36 po->Register( 36 po->Register(
37 "whisper-tail-paddings", &tail_paddings, 37 "whisper-tail-paddings", &tail_paddings,
38 - "Suggest value: 50 for English models. 300 for multilingual models. " 38 + "Suggested value: 50 for English models. 300 for multilingual models. "
39 "Since we have removed the 30-second constraint, we need to add some " 39 "Since we have removed the 30-second constraint, we need to add some "
40 "tail padding frames " 40 "tail padding frames "
41 - "so that whisper can detect the eot token. Leave it to -1 to use 50 for "  
42 - "English models and 300 for multilingual models."); 41 + "so that whisper can detect the eot token. Leave it to -1 to use 1000.");
43 } 42 }
44 43
45 bool OfflineWhisperModelConfig::Validate() const { 44 bool OfflineWhisperModelConfig::Validate() const {
  45 + if (encoder.empty()) {
  46 + SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
  47 + return false;
  48 + }
  49 +
46 if (!FileExists(encoder)) { 50 if (!FileExists(encoder)) {
47 SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str()); 51 SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
48 return false; 52 return false;
49 } 53 }
50 54
  55 + if (decoder.empty()) {
  56 + SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
  57 + return false;
  58 + }
  59 +
51 if (!FileExists(decoder)) { 60 if (!FileExists(decoder)) {
52 SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str()); 61 SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
53 return false; 62 return false;
@@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl { @@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl {
24 env_(ORT_LOGGING_LEVEL_ERROR), 24 env_(ORT_LOGGING_LEVEL_ERROR),
25 sess_opts_(GetSessionOptions(config)), 25 sess_opts_(GetSessionOptions(config)),
26 allocator_{} { 26 allocator_{} {
  27 + debug_ = config_.debug;
  28 + {
  29 + auto buf = ReadFile(config.whisper.encoder);
  30 + InitEncoder(buf.data(), buf.size());
  31 + }
  32 +
  33 + {
  34 + auto buf = ReadFile(config.whisper.decoder);
  35 + InitDecoder(buf.data(), buf.size());
  36 + }
  37 + }
  38 +
  39 + explicit Impl(const SpokenLanguageIdentificationConfig &config)
  40 + : lid_config_(config),
  41 + env_(ORT_LOGGING_LEVEL_ERROR),
  42 + sess_opts_(GetSessionOptions(config)),
  43 + allocator_{} {
  44 + debug_ = config_.debug;
27 { 45 {
28 auto buf = ReadFile(config.whisper.encoder); 46 auto buf = ReadFile(config.whisper.encoder);
29 InitEncoder(buf.data(), buf.size()); 47 InitEncoder(buf.data(), buf.size());
@@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl { @@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl {
41 env_(ORT_LOGGING_LEVEL_ERROR), 59 env_(ORT_LOGGING_LEVEL_ERROR),
42 sess_opts_(GetSessionOptions(config)), 60 sess_opts_(GetSessionOptions(config)),
43 allocator_{} { 61 allocator_{} {
  62 + debug_ = config_.debug;
44 { 63 {
45 auto buf = ReadFile(mgr, config.whisper.encoder); 64 auto buf = ReadFile(mgr, config.whisper.encoder);
46 InitEncoder(buf.data(), buf.size()); 65 InitEncoder(buf.data(), buf.size());
@@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl { @@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl {
85 std::move(decoder_input[4]), std::move(decoder_input[5])}; 104 std::move(decoder_input[4]), std::move(decoder_input[5])};
86 } 105 }
87 106
  107 + int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
  108 + Ort::Value &cross_v) { // NOLINT
  109 + int64_t token_val = SOT();
  110 + std::array<int64_t, 2> token_shape{1, 1};
  111 +
  112 + auto memory_info =
  113 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  114 +
  115 + Ort::Value tokens = Ort::Value::CreateTensor(
  116 + memory_info, &token_val, 1, token_shape.data(), token_shape.size());
  117 +
  118 + auto self_kv_cache = GetInitialSelfKVCache();
  119 +
  120 + std::array<int64_t, 1> offset_shape{1};
  121 + Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
  122 + Allocator(), offset_shape.data(), offset_shape.size());
  123 + *(offset.GetTensorMutableData<int64_t>()) = 0;
  124 +
  125 + auto decoder_out =
  126 + ForwardDecoder(std::move(tokens), std::move(self_kv_cache.first),
  127 + std::move(self_kv_cache.second), std::move(cross_k),
  128 + std::move(cross_v), std::move(offset));
  129 +
  130 + cross_k = std::move(std::get<3>(decoder_out));
  131 + cross_v = std::move(std::get<4>(decoder_out));
  132 +
  133 + const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
  134 + int32_t vocab_size = VocabSize();
  135 + const auto &all_language_ids = GetAllLanguageIDs();
  136 +
  137 + int32_t lang_id = all_language_ids[0];
  138 + float this_logit = p_logits[lang_id];
  139 +
  140 + for (int32_t i = 1; i != all_language_ids.size(); ++i) {
  141 + int32_t id = all_language_ids[i];
  142 + float p = p_logits[id];
  143 +
  144 + if (p > this_logit) {
  145 + this_logit = p;
  146 + lang_id = id;
  147 + }
  148 + }
  149 +
  150 + if (debug_) {
  151 + SHERPA_ONNX_LOGE("Detected language: %s",
  152 + GetID2Lang().at(lang_id).c_str());
  153 + }
  154 +
  155 + return lang_id;
  156 + }
  157 +
88 std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() { 158 std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
89 std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_}; 159 std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
90 160
@@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl { @@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl {
148 218
149 // get meta data 219 // get meta data
150 Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); 220 Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
151 - if (config_.debug) { 221 + if (debug_) {
152 std::ostringstream os; 222 std::ostringstream os;
153 os << "---encoder---\n"; 223 os << "---encoder---\n";
154 PrintModelMetadata(os, meta_data); 224 PrintModelMetadata(os, meta_data);
@@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl { @@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl {
203 273
204 private: 274 private:
205 OfflineModelConfig config_; 275 OfflineModelConfig config_;
  276 + SpokenLanguageIdentificationConfig lid_config_;
  277 + bool debug_ = false;
206 Ort::Env env_; 278 Ort::Env env_;
207 Ort::SessionOptions sess_opts_; 279 Ort::SessionOptions sess_opts_;
208 Ort::AllocatorWithDefaultOptions allocator_; 280 Ort::AllocatorWithDefaultOptions allocator_;
@@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl { @@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl {
246 OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) 318 OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
247 : impl_(std::make_unique<Impl>(config)) {} 319 : impl_(std::make_unique<Impl>(config)) {}
248 320
  321 +OfflineWhisperModel::OfflineWhisperModel(
  322 + const SpokenLanguageIdentificationConfig &config)
  323 + : impl_(std::make_unique<Impl>(config)) {}
  324 +
249 #if __ANDROID_API__ >= 9 325 #if __ANDROID_API__ >= 9
250 OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr, 326 OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
251 const OfflineModelConfig &config) 327 const OfflineModelConfig &config)
@@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens, @@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
273 std::move(n_layer_cross_v), std::move(offset)); 349 std::move(n_layer_cross_v), std::move(offset));
274 } 350 }
275 351
  352 +int32_t OfflineWhisperModel::DetectLanguage(Ort::Value &cross_k, // NOLINT
  353 + Ort::Value &cross_v) { // NOLINT
  354 + return impl_->DetectLanguage(cross_k, cross_v);
  355 +}
  356 +
276 std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() 357 std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
277 const { 358 const {
278 return impl_->GetInitialSelfKVCache(); 359 return impl_->GetInitialSelfKVCache();
@@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const { @@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const {
318 return impl_->IsMultiLingual(); 399 return impl_->IsMultiLingual();
319 } 400 }
320 401
  402 +void OfflineWhisperModel::NormalizeFeatures(float *features, int32_t num_frames,
  403 + int32_t feat_dim) {
  404 + // log_spec = torch.clamp(features, min=1e-10).log10()
  405 + // log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
  406 + // mel = (log_spec + 4.0) / 4.0
  407 +
  408 + int32_t n = num_frames * feat_dim;
  409 + float max_v = -1e20;
  410 + for (int32_t i = 0; i != n; ++i) {
  411 + float f = features[i];
  412 +
  413 + f = std::max<float>(f, 1e-10);
  414 + f = std::log10(f);
  415 +
  416 + max_v = std::max(f, max_v);
  417 +
  418 + features[i] = f;
  419 + }
  420 +
  421 + max_v -= 8;
  422 +
  423 + for (int32_t i = 0; i != n; ++i) {
  424 + float f = features[i];
  425 + f = std::max(f, max_v);
  426 +
  427 + f = (f + 4) / 4;
  428 +
  429 + features[i] = f;
  430 + }
  431 +}
  432 +
321 } // namespace sherpa_onnx 433 } // namespace sherpa_onnx
@@ -18,6 +18,7 @@ @@ -18,6 +18,7 @@
18 18
19 #include "onnxruntime_cxx_api.h" // NOLINT 19 #include "onnxruntime_cxx_api.h" // NOLINT
20 #include "sherpa-onnx/csrc/offline-model-config.h" 20 #include "sherpa-onnx/csrc/offline-model-config.h"
  21 +#include "sherpa-onnx/csrc/spoken-language-identification.h"
21 22
22 namespace sherpa_onnx { 23 namespace sherpa_onnx {
23 24
@@ -25,6 +26,9 @@ class OfflineWhisperModel { @@ -25,6 +26,9 @@ class OfflineWhisperModel {
25 public: 26 public:
26 explicit OfflineWhisperModel(const OfflineModelConfig &config); 27 explicit OfflineWhisperModel(const OfflineModelConfig &config);
27 28
  29 + explicit OfflineWhisperModel(
  30 + const SpokenLanguageIdentificationConfig &config);
  31 +
28 #if __ANDROID_API__ >= 9 32 #if __ANDROID_API__ >= 9
29 OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config); 33 OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
30 #endif 34 #endif
@@ -72,7 +76,8 @@ class OfflineWhisperModel { @@ -72,7 +76,8 @@ class OfflineWhisperModel {
72 Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, 76 Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
73 Ort::Value n_layer_cross_v, Ort::Value offset) const; 77 Ort::Value n_layer_cross_v, Ort::Value offset) const;
74 78
75 - int32_t DetectLanguage() const; 79 + int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
  80 + Ort::Value &cross_v); // NOLINT
76 81
77 /** Return the initial self kv cache in a pair 82 /** Return the initial self kv cache in a pair
78 * - n_layer_self_k_cache A 4-D tensor of shape 83 * - n_layer_self_k_cache A 4-D tensor of shape
@@ -98,6 +103,9 @@ class OfflineWhisperModel { @@ -98,6 +103,9 @@ class OfflineWhisperModel {
98 int32_t Translate() const; 103 int32_t Translate() const;
99 bool IsMultiLingual() const; 104 bool IsMultiLingual() const;
100 105
  106 + static void NormalizeFeatures(float *features, int32_t num_frames,
  107 + int32_t feat_dim);
  108 +
101 private: 109 private:
102 class Impl; 110 class Impl;
103 std::unique_ptr<Impl> impl_; 111 std::unique_ptr<Impl> impl_;
@@ -28,7 +28,7 @@ enum class ModelType { @@ -28,7 +28,7 @@ enum class ModelType {
28 kLstm, 28 kLstm,
29 kZipformer, 29 kZipformer,
30 kZipformer2, 30 kZipformer2,
31 - kUnkown, 31 + kUnknown,
32 }; 32 };
33 33
34 } // namespace 34 } // namespace
@@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
58 "No model_type in the metadata!\n" 58 "No model_type in the metadata!\n"
59 "Please make sure you are using the latest export-onnx.py from icefall " 59 "Please make sure you are using the latest export-onnx.py from icefall "
60 "to export your transducer models"); 60 "to export your transducer models");
61 - return ModelType::kUnkown; 61 + return ModelType::kUnknown;
62 } 62 }
63 63
64 if (model_type.get() == std::string("conformer")) { 64 if (model_type.get() == std::string("conformer")) {
@@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
71 return ModelType::kZipformer2; 71 return ModelType::kZipformer2;
72 } else { 72 } else {
73 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 73 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
74 - return ModelType::kUnkown; 74 + return ModelType::kUnknown;
75 } 75 }
76 } 76 }
77 77
@@ -93,7 +93,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -93,7 +93,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
93 model_type.c_str()); 93 model_type.c_str());
94 } 94 }
95 } 95 }
96 - ModelType model_type = ModelType::kUnkown; 96 + ModelType model_type = ModelType::kUnknown;
97 97
98 { 98 {
99 auto buffer = ReadFile(config.transducer.encoder); 99 auto buffer = ReadFile(config.transducer.encoder);
@@ -110,7 +110,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -110,7 +110,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
110 return std::make_unique<OnlineZipformerTransducerModel>(config); 110 return std::make_unique<OnlineZipformerTransducerModel>(config);
111 case ModelType::kZipformer2: 111 case ModelType::kZipformer2:
112 return std::make_unique<OnlineZipformer2TransducerModel>(config); 112 return std::make_unique<OnlineZipformer2TransducerModel>(config);
113 - case ModelType::kUnkown: 113 + case ModelType::kUnknown:
114 SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); 114 SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
115 return nullptr; 115 return nullptr;
116 } 116 }
@@ -185,7 +185,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -185,7 +185,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
185 return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); 185 return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
186 case ModelType::kZipformer2: 186 case ModelType::kZipformer2:
187 return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config); 187 return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
188 - case ModelType::kUnkown: 188 + case ModelType::kUnknown:
189 SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); 189 SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
190 return nullptr; 190 return nullptr;
191 } 191 }
@@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions( @@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions(
149 return GetSessionOptionsImpl(config.num_threads, config.provider); 149 return GetSessionOptionsImpl(config.num_threads, config.provider);
150 } 150 }
151 151
  152 +Ort::SessionOptions GetSessionOptions(
  153 + const SpokenLanguageIdentificationConfig &config) {
  154 + return GetSessionOptionsImpl(config.num_threads, config.provider);
  155 +}
  156 +
152 } // namespace sherpa_onnx 157 } // namespace sherpa_onnx
@@ -12,6 +12,7 @@ @@ -12,6 +12,7 @@
12 #include "sherpa-onnx/csrc/online-lm-config.h" 12 #include "sherpa-onnx/csrc/online-lm-config.h"
13 #include "sherpa-onnx/csrc/online-model-config.h" 13 #include "sherpa-onnx/csrc/online-model-config.h"
14 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" 14 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
  15 +#include "sherpa-onnx/csrc/spoken-language-identification.h"
15 #include "sherpa-onnx/csrc/vad-model-config.h" 16 #include "sherpa-onnx/csrc/vad-model-config.h"
16 17
17 namespace sherpa_onnx { 18 namespace sherpa_onnx {
@@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); @@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
30 31
31 Ort::SessionOptions GetSessionOptions( 32 Ort::SessionOptions GetSessionOptions(
32 const SpeakerEmbeddingExtractorConfig &config); 33 const SpeakerEmbeddingExtractorConfig &config);
  34 +
  35 +Ort::SessionOptions GetSessionOptions(
  36 + const SpokenLanguageIdentificationConfig &config);
  37 +
33 } // namespace sherpa_onnx 38 } // namespace sherpa_onnx
34 39
35 #endif // SHERPA_ONNX_CSRC_SESSION_H_ 40 #endif // SHERPA_ONNX_CSRC_SESSION_H_
  1 +// sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
  2 +//
  3 +// Copyright (c) 2022-2024 Xiaomi Corporation
  4 +
  5 +#include <stdio.h>
  6 +
  7 +#include <chrono> // NOLINT
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/parse-options.h"
  12 +#include "sherpa-onnx/csrc/spoken-language-identification.h"
  13 +#include "sherpa-onnx/csrc/wave-reader.h"
  14 +
  15 +int main(int32_t argc, char *argv[]) {
  16 + const char *kUsageMessage = R"usage(
  17 +Spoken language identification with sherpa-onnx.
  18 +
  19 +Usage:
  20 +
  21 +(1) Use a whisper multilingual model
  22 +
  23 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
  24 +tar xvf sherpa-onnx-whisper-tiny.tar.bz2
  25 +rm sherpa-onnx-whisper-tiny.tar.bz2
  26 +
  27 +We only use the int8.onnx models below.
  28 +
  29 +./bin/sherpa-onnx-offline-spoken-language-identification \
  30 + --whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
  31 + --whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
  32 + --num-threads=1 \
  33 + /path/to/foo.wav
  34 +
  35 +foo.wav should be of single channel, 16-bit PCM encoded wave file; its
  36 +sampling rate can be arbitrary and does not need to be 16kHz.
  37 +You can find test waves for different languages at
  38 +https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
  39 +
  40 +Please refer to
  41 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
  42 +Note that only whisper multilingual models are supported. For instance,
  43 +"tiny" is supported but "tiny.en" is not.
  44 +for a list of pre-trained models to download.
  45 +)usage";
  46 +
  47 + sherpa_onnx::ParseOptions po(kUsageMessage);
  48 + sherpa_onnx::SpokenLanguageIdentificationConfig config;
  49 + config.Register(&po);
  50 +
  51 + po.Read(argc, argv);
  52 + if (po.NumArgs() != 1) {
  53 + fprintf(stderr, "Error: Please provide 1 wave file.\n\n");
  54 + po.PrintUsage();
  55 + exit(EXIT_FAILURE);
  56 + }
  57 +
  58 + fprintf(stderr, "%s\n", config.ToString().c_str());
  59 +
  60 + if (!config.Validate()) {
  61 + fprintf(stderr, "Errors in config!\n");
  62 + return -1;
  63 + }
  64 +
  65 + fprintf(stderr, "Creating spoken language identifier ...\n");
  66 + sherpa_onnx::SpokenLanguageIdentification slid(config);
  67 +
  68 + fprintf(stderr, "Started\n");
  69 + const std::string wav_filename = po.GetArg(1);
  70 +
  71 + int32_t sampling_rate = -1;
  72 + bool is_ok = false;
  73 + const std::vector<float> samples =
  74 + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
  75 + if (!is_ok) {
  76 + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
  77 + return -1;
  78 + }
  79 + float duration = samples.size() / static_cast<float>(sampling_rate);
  80 +
  81 + const auto begin = std::chrono::steady_clock::now();
  82 +
  83 + auto s = slid.CreateStream();
  84 + s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
  85 +
  86 + auto language = slid.Compute(s.get());
  87 +
  88 + const auto end = std::chrono::steady_clock::now();
  89 +
  90 + fprintf(stderr, "Done!\n\n");
  91 + fprintf(stderr, "%s\nDetected language: %s\n", wav_filename.c_str(),
  92 + language.c_str());
  93 +
  94 + float elapsed_seconds =
  95 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  96 + .count() /
  97 + 1000.;
  98 +
  99 + fprintf(stderr, "num threads: %d\n", config.num_threads);
  100 +
  101 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  102 + float rtf = elapsed_seconds / duration;
  103 + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
  104 + elapsed_seconds, duration, rtf);
  105 +
  106 + return 0;
  107 +}
@@ -16,7 +16,7 @@ enum class ModelType { @@ -16,7 +16,7 @@ enum class ModelType {
16 kWeSpeaker, 16 kWeSpeaker,
17 k3dSpeaker, 17 k3dSpeaker,
18 kNeMo, 18 kNeMo,
19 - kUnkown, 19 + kUnknown,
20 }; 20 };
21 21
22 } // namespace 22 } // namespace
@@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
47 "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/" 47 "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
48 "add_meta_data.py" 48 "add_meta_data.py"
49 "to add metadata to models from WeSpeaker\n"); 49 "to add metadata to models from WeSpeaker\n");
50 - return ModelType::kUnkown; 50 + return ModelType::kUnknown;
51 } 51 }
52 52
53 if (model_type.get() == std::string("wespeaker")) { 53 if (model_type.get() == std::string("wespeaker")) {
@@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
58 return ModelType::kNeMo; 58 return ModelType::kNeMo;
59 } else { 59 } else {
60 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 60 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
61 - return ModelType::kUnkown; 61 + return ModelType::kUnknown;
62 } 62 }
63 } 63 }
64 64
65 std::unique_ptr<SpeakerEmbeddingExtractorImpl> 65 std::unique_ptr<SpeakerEmbeddingExtractorImpl>
66 SpeakerEmbeddingExtractorImpl::Create( 66 SpeakerEmbeddingExtractorImpl::Create(
67 const SpeakerEmbeddingExtractorConfig &config) { 67 const SpeakerEmbeddingExtractorConfig &config) {
68 - ModelType model_type = ModelType::kUnkown; 68 + ModelType model_type = ModelType::kUnknown;
69 69
70 { 70 {
71 auto buffer = ReadFile(config.model); 71 auto buffer = ReadFile(config.model);
@@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create( @@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create(
80 return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config); 80 return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
81 case ModelType::kNeMo: 81 case ModelType::kNeMo:
82 return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config); 82 return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
83 - case ModelType::kUnkown:  
84 - SHERPA_ONNX_LOGE(  
85 - "Unknown model type in for speaker embedding extractor!"); 83 + case ModelType::kUnknown:
  84 + SHERPA_ONNX_LOGE("Unknown model type for speaker embedding extractor!");
86 return nullptr; 85 return nullptr;
87 } 86 }
88 87
@@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create( @@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create(
94 std::unique_ptr<SpeakerEmbeddingExtractorImpl> 93 std::unique_ptr<SpeakerEmbeddingExtractorImpl>
95 SpeakerEmbeddingExtractorImpl::Create( 94 SpeakerEmbeddingExtractorImpl::Create(
96 AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) { 95 AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) {
97 - ModelType model_type = ModelType::kUnkown; 96 + ModelType model_type = ModelType::kUnknown;
98 97
99 { 98 {
100 auto buffer = ReadFile(mgr, config.model); 99 auto buffer = ReadFile(mgr, config.model);
@@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create( @@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create(
110 config); 109 config);
111 case ModelType::kNeMo: 110 case ModelType::kNeMo:
112 return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config); 111 return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config);
113 - case ModelType::kUnkown: 112 + case ModelType::kUnknown:
114 SHERPA_ONNX_LOGE( 113 SHERPA_ONNX_LOGE(
115 "Unknown model type in for speaker embedding extractor!"); 114 "Unknown model type in for speaker embedding extractor!");
116 return nullptr; 115 return nullptr;
  1 +// sherpa-onnx/csrc/spoken-language-identification-impl.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
  5 +
  6 +#include <memory>
  7 +
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +#include "sherpa-onnx/csrc/onnx-utils.h"
  10 +#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +namespace {
  15 +
  16 +enum class ModelType {
  17 + kWhisper,
  18 + kUnknown,
  19 +};
  20 +
  21 +}
  22 +
  23 +static ModelType GetModelType(char *model_data, size_t model_data_length,
  24 + bool debug) {
  25 + Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
  26 + Ort::SessionOptions sess_opts;
  27 +
  28 + auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
  29 + sess_opts);
  30 +
  31 + Ort::ModelMetadata meta_data = sess->GetModelMetadata();
  32 + if (debug) {
  33 + std::ostringstream os;
  34 + PrintModelMetadata(os, meta_data);
  35 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  36 + }
  37 +
  38 + Ort::AllocatorWithDefaultOptions allocator;
  39 + auto model_type =
  40 + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
  41 + if (!model_type) {
  42 + SHERPA_ONNX_LOGE(
  43 + "No model_type in the metadata!\n"
  44 + "Please make sure you have added metadata to the model.\n\n"
  45 + "For instance, you can use\n"
  46 + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/"
  47 + "export-onnx.py "
  48 + "to add metadata to models from whisper\n");
  49 + return ModelType::kUnknown;
  50 + }
  51 +
  52 + auto model_type_str = std::string(model_type.get());
  53 + if (model_type_str.find("whisper") == 0) {
  54 + return ModelType::kWhisper;
  55 + } else {
  56 + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
  57 + return ModelType::kUnknown;
  58 + }
  59 +}
  60 +
  61 +std::unique_ptr<SpokenLanguageIdentificationImpl>
  62 +SpokenLanguageIdentificationImpl::Create(
  63 + const SpokenLanguageIdentificationConfig &config) {
  64 + ModelType model_type = ModelType::kUnknown;
  65 + {
  66 + if (config.whisper.encoder.empty()) {
  67 + SHERPA_ONNX_LOGE("Only whisper models are supported at present");
  68 + exit(-1);
  69 + }
  70 + auto buffer = ReadFile(config.whisper.encoder);
  71 +
  72 + model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
  73 + }
  74 +
  75 + switch (model_type) {
  76 + case ModelType::kWhisper:
  77 + return std::make_unique<SpokenLanguageIdentificationWhisperImpl>(config);
  78 + case ModelType::kUnknown:
  79 + SHERPA_ONNX_LOGE(
  80 + "Unknown model type for spoken language identification!");
  81 + return nullptr;
  82 + }
  83 +
  84 + // unreachable code
  85 + return nullptr;
  86 +}
  87 +
  88 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/spoken-language-identification-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
  5 +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/spoken-language-identification.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class SpokenLanguageIdentificationImpl {
  15 + public:
  16 + virtual ~SpokenLanguageIdentificationImpl() = default;
  17 +
  18 + static std::unique_ptr<SpokenLanguageIdentificationImpl> Create(
  19 + const SpokenLanguageIdentificationConfig &config);
  20 +
  21 + virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
  22 +
  23 + virtual std::string Compute(OfflineStream *s) const = 0;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
  1 +// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <memory>
  10 +#include <string>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#include "sherpa-onnx/csrc/offline-whisper-model.h"
  15 +#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
  16 +#include "sherpa-onnx/csrc/transpose.h"
  17 +
  18 +namespace sherpa_onnx {
  19 +
  20 +class SpokenLanguageIdentificationWhisperImpl
  21 + : public SpokenLanguageIdentificationImpl {
  22 + public:
  23 + explicit SpokenLanguageIdentificationWhisperImpl(
  24 + const SpokenLanguageIdentificationConfig &config)
  25 + : config_(config), model_(std::make_unique<OfflineWhisperModel>(config)) {
  26 + Check();
  27 + }
  28 +
  29 + std::unique_ptr<OfflineStream> CreateStream() const override {
  30 + return std::make_unique<OfflineStream>(WhisperTag{});
  31 + }
  32 +
  33 + std::string Compute(OfflineStream *s) const override {
  34 + int32_t max_num_frames = 3000;
  35 + auto memory_info =
  36 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  37 +
  38 + int32_t feat_dim = s->FeatureDim();
  39 + std::vector<float> f = s->GetFrames();
  40 + int32_t num_frames = f.size() / feat_dim;
  41 +
  42 + // we use 50 here so that there will be some zero tail paddings
  43 + if (num_frames >= max_num_frames - 50) {
  44 + SHERPA_ONNX_LOGE(
  45 + "Only waves less than 30 seconds are supported. We process only the "
  46 + "first 30 seconds and discard the remaining data");
  47 + num_frames = max_num_frames - 50;
  48 + }
  49 +
  50 + model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
  51 +
  52 + // note that 1000 is an experience-value.
  53 + // You can replace 1000 by other values, say, 100.
  54 + //
  55 + // Since we have removed the 30 seconds constraint, we need
  56 + // tail_padding_frames so that whisper is able to detect the eot token.
  57 + int32_t tail_padding_frames = 1000;
  58 +
  59 + if (config_.whisper.tail_paddings > 0) {
  60 + tail_padding_frames = config_.whisper.tail_paddings;
  61 + }
  62 +
  63 + int32_t actual_frames =
  64 + std::min(num_frames + tail_padding_frames, max_num_frames);
  65 +
  66 + std::array<int64_t, 3> shape{1, actual_frames, feat_dim};
  67 +
  68 + Ort::Value mel = Ort::Value::CreateTensor<float>(
  69 + model_->Allocator(), shape.data(), shape.size());
  70 +
  71 + float *p_mel = mel.GetTensorMutableData<float>();
  72 + std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel);
  73 +
  74 + std::fill_n(p_mel + num_frames * feat_dim,
  75 + (actual_frames - num_frames) * feat_dim, 0);
  76 +
  77 + mel = Transpose12(model_->Allocator(), &mel);
  78 +
  79 + try {
  80 + auto cross_kv = model_->ForwardEncoder(std::move(mel));
  81 + int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second);
  82 + const auto &id2lang = model_->GetID2Lang();
  83 + if (id2lang.count(lang_id)) {
  84 + return id2lang.at(lang_id);
  85 + } else {
  86 + SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.",
  87 + lang_id);
  88 + return "";
  89 + }
  90 + } catch (const Ort::Exception &ex) {
  91 + SHERPA_ONNX_LOGE(
  92 + "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
  93 + "input frames: %d, Current tail "
  94 + "paddings: %d. If you see a lot of such exceptions, please consider "
  95 + "using a larger --whisper-tail-paddings",
  96 + ex.what(), num_frames, tail_padding_frames);
  97 + return "";
  98 + }
  99 + }
  100 +
  101 + private:
  102 + void Check() const {
  103 + if (!model_->IsMultiLingual()) {
  104 + SHERPA_ONNX_LOGE(
  105 + "Only whisper multilingual models can be used for spoken language "
  106 + "identification. Given: %s,%s",
  107 + config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str());
  108 + exit(-1);
  109 + }
  110 + }
  111 +
  112 + private:
  113 + SpokenLanguageIdentificationConfig config_;
  114 + std::unique_ptr<OfflineWhisperModel> model_;
  115 +};
  116 +
  117 +} // namespace sherpa_onnx
  118 +
  119 +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
  1 +// sherpa-onnx/csrc/spoken-language-identification.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/spoken-language-identification.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +void SpokenLanguageIdentificationWhisperConfig::Register(ParseOptions *po) {
  16 + po->Register(
  17 + "whisper-encoder", &encoder,
  18 + "Path to then encoder of a whisper multilingual model. Support only "
  19 + "tiny, base, small, medium, large.");
  20 +
  21 + po->Register(
  22 + "whisper-decoder", &decoder,
  23 + "Path to the decoder of a whisper multilingual model. Support only "
  24 + "tiny, base, small, medium, large.");
  25 +
  26 + po->Register(
  27 + "whisper-tail-paddings", &tail_paddings,
  28 + "Suggested value: 300 for multilingual models. "
  29 + "Since we have removed the 30-second constraint, we need to add some "
  30 + "tail padding frames "
  31 + "so that whisper can detect the eot token. Leave it to -1 to use 1000");
  32 +}
  33 +
  34 +bool SpokenLanguageIdentificationWhisperConfig::Validate() const {
  35 + if (encoder.empty()) {
  36 + SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
  37 + return false;
  38 + }
  39 +
  40 + if (!FileExists(encoder)) {
  41 + SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
  42 + return false;
  43 + }
  44 +
  45 + if (decoder.empty()) {
  46 + SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
  47 + return false;
  48 + }
  49 +
  50 + if (!FileExists(decoder)) {
  51 + SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
  52 + return false;
  53 + }
  54 +
  55 + return true;
  56 +}
  57 +
  58 +std::string SpokenLanguageIdentificationWhisperConfig::ToString() const {
  59 + std::ostringstream os;
  60 +
  61 + os << "SpokenLanguageIdentificationWhisperConfig(";
  62 + os << "encoder=\"" << encoder << "\", ";
  63 + os << "decoder=\"" << decoder << "\", ";
  64 + os << "tail_paddings=" << tail_paddings << ")";
  65 +
  66 + return os.str();
  67 +}
  68 +
  69 +void SpokenLanguageIdentificationConfig::Register(ParseOptions *po) {
  70 + whisper.Register(po);
  71 +
  72 + po->Register("num-threads", &num_threads,
  73 + "Number of threads to run the neural network");
  74 +
  75 + po->Register("debug", &debug,
  76 + "true to print model information while loading it.");
  77 +
  78 + po->Register("provider", &provider,
  79 + "Specify a provider to use: cpu, cuda, coreml");
  80 +}
  81 +
  82 +bool SpokenLanguageIdentificationConfig::Validate() const {
  83 + if (!whisper.Validate()) {
  84 + return false;
  85 + }
  86 +
  87 + return true;
  88 +}
  89 +
  90 +std::string SpokenLanguageIdentificationConfig::ToString() const {
  91 + std::ostringstream os;
  92 +
  93 + os << "SpokenLanguageIdentificationConfig(";
  94 + os << "whisper=\"" << whisper.ToString() << "\", ";
  95 + os << "num_threads=" << num_threads << ", ";
  96 + os << "debug=" << (debug ? "True" : "False") << ", ";
  97 + os << "provider=\"" << provider << "\")";
  98 +
  99 + return os.str();
  100 +}
  101 +
  102 +SpokenLanguageIdentification::SpokenLanguageIdentification(
  103 + const SpokenLanguageIdentificationConfig &config)
  104 + : impl_(SpokenLanguageIdentificationImpl::Create(config)) {}
  105 +
  106 +SpokenLanguageIdentification::~SpokenLanguageIdentification() = default;
  107 +
  108 +std::unique_ptr<OfflineStream> SpokenLanguageIdentification::CreateStream()
  109 + const {
  110 + return impl_->CreateStream();
  111 +}
  112 +
  113 +std::string SpokenLanguageIdentification::Compute(OfflineStream *s) const {
  114 + return impl_->Compute(s);
  115 +}
  116 +
  117 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/spoken-language-identification.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
  5 +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-stream.h"
  11 +#include "sherpa-onnx/csrc/parse-options.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +struct SpokenLanguageIdentificationWhisperConfig {
  16 + // Requires a multi-lingual whisper model.
  17 + // That is, it supports only tiny, base, small, medium, large.
  18 + // Note: It does NOT support tiny.en, base.en, small.en, medium.en
  19 + std::string encoder;
  20 + std::string decoder;
  21 +
  22 + // Number of tail padding frames.
  23 + //
  24 + // Since we remove the 30-second constraint, we need to add some paddings
  25 + // at the end.
  26 + //
  27 + // Recommended values:
  28 + // - 50 for English models
  29 + // - 300 for multilingual models
  30 + int32_t tail_paddings = -1;
  31 +
  32 + SpokenLanguageIdentificationWhisperConfig() = default;
  33 +
  34 + SpokenLanguageIdentificationWhisperConfig(const std::string &encoder,
  35 + const std::string &decoder,
  36 + int32_t tail_paddings)
  37 + : encoder(encoder), decoder(decoder), tail_paddings(tail_paddings) {}
  38 +
  39 + void Register(ParseOptions *po);
  40 + bool Validate() const;
  41 + std::string ToString() const;
  42 +};
  43 +
  44 +struct SpokenLanguageIdentificationConfig {
  45 + SpokenLanguageIdentificationWhisperConfig whisper;
  46 +
  47 + int32_t num_threads = 1;
  48 + bool debug = false;
  49 + std::string provider = "cpu";
  50 +
  51 + SpokenLanguageIdentificationConfig() = default;
  52 +
  53 + SpokenLanguageIdentificationConfig(
  54 + const SpokenLanguageIdentificationWhisperConfig &whisper,
  55 + int32_t num_threads, bool debug, const std::string &provider)
  56 + : whisper(whisper),
  57 + num_threads(num_threads),
  58 + debug(debug),
  59 + provider(provider) {}
  60 +
  61 + void Register(ParseOptions *po);
  62 + bool Validate() const;
  63 + std::string ToString() const;
  64 +};
  65 +
  66 +class SpokenLanguageIdentificationImpl;
  67 +
  68 +class SpokenLanguageIdentification {
  69 + public:
  70 + explicit SpokenLanguageIdentification(
  71 + const SpokenLanguageIdentificationConfig &config);
  72 +
  73 + ~SpokenLanguageIdentification();
  74 +
  75 + // Create a stream to accept audio samples and compute features
  76 + std::unique_ptr<OfflineStream> CreateStream() const;
  77 +
  78 + // Return a string containing the language, e.g., en, zh, de,
  79 + // etc.
  80 + // Note: en is for English, zh is for Chinese, de is for German, etc.
  81 + std::string Compute(OfflineStream *s) const;
  82 +
  83 + private:
  84 + std::unique_ptr<SpokenLanguageIdentificationImpl> impl_;
  85 +};
  86 +
  87 +} // namespace sherpa_onnx
  88 +
  89 +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
@@ -33,6 +33,7 @@ set(srcs @@ -33,6 +33,7 @@ set(srcs
33 silero-vad-model-config.cc 33 silero-vad-model-config.cc
34 speaker-embedding-extractor.cc 34 speaker-embedding-extractor.cc
35 speaker-embedding-manager.cc 35 speaker-embedding-manager.cc
  36 + spoken-language-identification.cc
36 vad-model-config.cc 37 vad-model-config.cc
37 vad-model.cc 38 vad-model.cc
38 voice-activity-detector.cc 39 voice-activity-detector.cc
@@ -22,6 +22,7 @@ @@ -22,6 +22,7 @@
22 #include "sherpa-onnx/python/csrc/online-stream.h" 22 #include "sherpa-onnx/python/csrc/online-stream.h"
23 #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" 23 #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
24 #include "sherpa-onnx/python/csrc/speaker-embedding-manager.h" 24 #include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
  25 +#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
25 #include "sherpa-onnx/python/csrc/vad-model-config.h" 26 #include "sherpa-onnx/python/csrc/vad-model-config.h"
26 #include "sherpa-onnx/python/csrc/vad-model.h" 27 #include "sherpa-onnx/python/csrc/vad-model.h"
27 #include "sherpa-onnx/python/csrc/voice-activity-detector.h" 28 #include "sherpa-onnx/python/csrc/voice-activity-detector.h"
@@ -55,6 +56,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -55,6 +56,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
55 PybindOfflineTts(&m); 56 PybindOfflineTts(&m);
56 PybindSpeakerEmbeddingExtractor(&m); 57 PybindSpeakerEmbeddingExtractor(&m);
57 PybindSpeakerEmbeddingManager(&m); 58 PybindSpeakerEmbeddingManager(&m);
  59 + PybindSpokenLanguageIdentification(&m);
58 60
59 PybindAlsa(&m); 61 PybindAlsa(&m);
60 } 62 }
  1 +// sherpa-onnx/python/csrc/spoken-language-identification.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/spoken-language-identification.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +static void PybindSpokenLanguageIdentificationWhisperConfig(py::module *m) {
  14 + using PyClass = SpokenLanguageIdentificationWhisperConfig;
  15 +
  16 + py::class_<PyClass>(*m, "SpokenLanguageIdentificationWhisperConfig")
  17 + .def(py::init<>())
  18 + .def(py::init<const std::string &, const std::string &, int32_t>(),
  19 + py::arg("encoder"), py::arg("decoder"),
  20 + py::arg("tail_paddings") = -1)
  21 + .def_readwrite("encoder", &PyClass::encoder)
  22 + .def_readwrite("decoder", &PyClass::decoder)
  23 + .def_readwrite("tail_paddings", &PyClass::tail_paddings)
  24 + .def("validate", &PyClass::Validate)
  25 + .def("__str__", &PyClass::ToString);
  26 +}
  27 +
  28 +static void PybindSpokenLanguageIdentificationConfig(py::module *m) {
  29 + PybindSpokenLanguageIdentificationWhisperConfig(m);
  30 +
  31 + using PyClass = SpokenLanguageIdentificationConfig;
  32 +
  33 + py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
  34 + .def(py::init<>())
  35 + .def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
  36 + bool, const std::string>(),
  37 + py::arg("whisper"), py::arg("num_threads") = 1,
  38 + py::arg("debug") = false, py::arg("provider") = "cpu")
  39 + .def_readwrite("whisper", &PyClass::whisper)
  40 + .def_readwrite("num_threads", &PyClass::num_threads)
  41 + .def_readwrite("debug", &PyClass::debug)
  42 + .def_readwrite("provider", &PyClass::provider)
  43 + .def("validate", &PyClass::Validate)
  44 + .def("__str__", &PyClass::ToString);
  45 +}
  46 +
  47 +void PybindSpokenLanguageIdentification(py::module *m) {
  48 + PybindSpokenLanguageIdentificationConfig(m);
  49 +
  50 + using PyClass = SpokenLanguageIdentification;
  51 + py::class_<PyClass>(*m, "SpokenLanguageIdentification")
  52 + .def(py::init<const SpokenLanguageIdentificationConfig &>(),
  53 + py::arg("config"), py::call_guard<py::gil_scoped_release>())
  54 + .def("create_stream", &PyClass::CreateStream,
  55 + py::call_guard<py::gil_scoped_release>())
  56 + .def("compute", &PyClass::Compute,
  57 + py::call_guard<py::gil_scoped_release>());
  58 +}
  59 +
  60 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/spoken-language-identification.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindSpokenLanguageIdentification(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
@@ -13,6 +13,9 @@ from _sherpa_onnx import ( @@ -13,6 +13,9 @@ from _sherpa_onnx import (
13 SpeakerEmbeddingExtractorConfig, 13 SpeakerEmbeddingExtractorConfig,
14 SpeakerEmbeddingManager, 14 SpeakerEmbeddingManager,
15 SpeechSegment, 15 SpeechSegment,
  16 + SpokenLanguageIdentification,
  17 + SpokenLanguageIdentificationConfig,
  18 + SpokenLanguageIdentificationWhisperConfig,
16 VadModel, 19 VadModel,
17 VadModelConfig, 20 VadModelConfig,
18 VoiceActivityDetector, 21 VoiceActivityDetector,