Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-03-24 22:57:00 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-03-24 22:57:00 +0800
Commit
0d258dd1509e2ac0d7eb5c9db0c21983a28689f8
0d258dd1
1 parent
3cdad9b5
Support spoken language identification with whisper (#694)
隐藏空白字符变更
内嵌
并排对比
正在显示
36 个修改的文件
包含
1173 行增加
和
200 行删除
.github/scripts/test-spoken-language-identification.sh
.github/workflows/build-wheels-linux.yaml
.github/workflows/build-wheels-macos-arm64.yaml
.github/workflows/linux-gpu.yaml
.github/workflows/linux.yaml
.github/workflows/macos.yaml
.github/workflows/windows-x64-cuda.yaml
.github/workflows/windows-x64.yaml
.github/workflows/windows-x86.yaml
CMakeLists.txt
cmake/cmake_extension.py
python-api-examples/spoken-language-identification.py
setup.py
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/offline-ctc-model.cc
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
sherpa-onnx/csrc/offline-whisper-model-config.cc
sherpa-onnx/csrc/offline-whisper-model.cc
sherpa-onnx/csrc/offline-whisper-model.h
sherpa-onnx/csrc/online-transducer-model.cc
sherpa-onnx/csrc/session.cc
sherpa-onnx/csrc/session.h
sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
sherpa-onnx/csrc/spoken-language-identification-impl.cc
sherpa-onnx/csrc/spoken-language-identification-impl.h
sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
sherpa-onnx/csrc/spoken-language-identification.cc
sherpa-onnx/csrc/spoken-language-identification.h
sherpa-onnx/python/csrc/CMakeLists.txt
sherpa-onnx/python/csrc/sherpa-onnx.cc
sherpa-onnx/python/csrc/spoken-language-identification.cc
sherpa-onnx/python/csrc/spoken-language-identification.h
sherpa-onnx/python/sherpa_onnx/__init__.py
.github/scripts/test-spoken-language-identification.sh
0 → 100755
查看文件 @
0d258dd
#!/usr/bin/env bash
set
-e
log
()
{
# This function is from espnet
local
fname
=
${
BASH_SOURCE
[1]##*/
}
echo
-e
"
$(
date
'+%Y-%m-%d %H:%M:%S'
)
(
${
fname
}
:
${
BASH_LINENO
[0]
}
:
${
FUNCNAME
[1]
}
)
$*
"
}
echo
"EXE is
$EXE
"
echo
"PATH:
$PATH
"
which
$EXE
names
=(
tiny
base
small
medium
)
# all_language_codes=bo,ml,tt,fa,sl,bg,sn,sr,tl,km,ln,mr,hr,eu,ro,ba,bs,pl,as,nn,sk,ko,oc,ar,uz,pa,tg,mk,kk,hi,ha,uk,is,de,el,ja,yo,be,so,tk,id,sa,ru,yi,en,am,cs,ne,la,sv,su,pt,mi,ca,sd,hy,haw,fi,et,kn,da,lt,it,nl,he,mg,ur,tr,af,br,bn,ta,no,my,si,mt,th,gl,sw,mn,jw,ms,ps,fo,ka,hu,zh,ht,az,fr,lo,sq,gu,cy,lv,es,lb,te,vi
log
"Download test waves"
waves
=(
ar-arabic.wav
bg
-bulgarian.wav
cs-czech.wav
da-danish.wav
de-german.wav
el-greek.wav
en-english.wav
es-spanish.wav
fa-persian.wav
fi
-finnish.wav
fr-french.wav
hi-hindi.wav
hr-croatian.wav
id-indonesian.wav
it-italian.wav
ja-japanese.wav
ko-korean.wav
nl-dutch.wav
no-norwegian.wav
po-polish.wav
pt-portuguese.wav
ro-romanian.wav
ru-russian.wav
sk-slovak.wav
sv-swedish.wav
ta-tamil.wav
tl-tagalog.wav
tr-turkish.wav
uk-ukrainian.wav
zh-chinese.wav
)
for
wav
in
${
waves
[@]
}
;
do
echo
"Downloading
$wav
"
curl -SL -O https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/
$wav
ls -lh
*
.wav
done
for
name
in
${
names
[@]
}
;
do
log
"------------------------------------------------------------"
log
"Run
$name
"
log
"------------------------------------------------------------"
repo_url
=
https://huggingface.co/csukuangfj/sherpa-onnx-whisper-
$name
log
"Start testing
${
repo_url
}
"
repo
=
$(
basename
$repo_url
)
log
"Download pretrained model and test-data from
$repo_url
"
GIT_LFS_SKIP_SMUDGE
=
1 git clone
$repo_url
pushd
$repo
git lfs pull --include
"*.onnx"
# git lfs pull --include "*.ort"
ls -lh
*
.onnx
popd
for
wav
in
${
waves
[@]
}
;
do
log
"test fp32 onnx"
time
$EXE
\
--whisper-encoder
=
$repo
/
${
name
}
-encoder.onnx
\
--whisper-decoder
=
$repo
/
${
name
}
-decoder.onnx
\
$wav
log
"test int8 onnx"
time
$EXE
\
--whisper-encoder
=
$repo
/
${
name
}
-encoder.int8.onnx
\
--whisper-decoder
=
$repo
/
${
name
}
-decoder.int8.onnx
\
$wav
done
rm -rf
$repo
done
...
...
.github/workflows/build-wheels-linux.yaml
查看文件 @
0d258dd
...
...
@@ -82,7 +82,6 @@ jobs:
env
:
HF_TOKEN
:
${{ secrets.HF_TOKEN }}
uses
:
nick-fields/retry@v3
shell
:
bash
with
:
max_attempts
:
20
timeout_seconds
:
200
...
...
.github/workflows/build-wheels-macos-arm64.yaml
查看文件 @
0d258dd
...
...
@@ -21,27 +21,12 @@ jobs:
fail-fast
:
false
matrix
:
os
:
[
macos-latest
]
python-version
:
[
"
cp3
7"
,
"
cp3
8"
,
"
cp39"
,
"
cp310"
,
"
cp311"
,
"
cp312"
]
python-version
:
[
"
cp38"
,
"
cp39"
,
"
cp310"
,
"
cp311"
,
"
cp312"
]
steps
:
-
uses
:
actions/checkout@v4
# see https://cibuildwheel.readthedocs.io/en/stable/changelog/
# for a list of versions
-
name
:
Build wheels
if
:
matrix.python-version == 'cp37'
uses
:
pypa/cibuildwheel@v2.11.4
env
:
CIBW_BUILD
:
"
${{
matrix.python-version}}-*
"
CIBW_ENVIRONMENT
:
SHERPA_ONNX_CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES='arm64'"
CIBW_ARCHS
:
"
arm64"
CIBW_BUILD_VERBOSITY
:
3
# Don't repair macOS wheels
CIBW_REPAIR_WHEEL_COMMAND_MACOS
:
"
"
-
name
:
Build wheels
if
:
matrix.python-version != 'cp37'
uses
:
pypa/cibuildwheel@v2.15.0
env
:
CIBW_BUILD
:
"
${{
matrix.python-version}}-*
"
...
...
.github/workflows/linux-gpu.yaml
查看文件 @
0d258dd
...
...
@@ -92,6 +92,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
-
name
:
Test spoken language identification
shell
:
bash
run
:
|
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-language-identification
.github/scripts/test-spoken-language-identification.sh
-
name
:
Test online CTC
shell
:
bash
run
:
|
...
...
@@ -116,6 +124,7 @@ jobs:
.github/scripts/test-online-paraformer.sh
-
name
:
Test offline Whisper
shell
:
bash
run
:
|
...
...
.github/workflows/linux.yaml
查看文件 @
0d258dd
...
...
@@ -123,6 +123,15 @@ jobs:
name
:
release-${{ matrix.build_type }}-${{ matrix.shared_lib }}
path
:
build/bin/*
-
name
:
Test spoken language identification
if
:
matrix.build_type != 'Debug'
shell
:
bash
run
:
|
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-language-identification
.github/scripts/test-spoken-language-identification.sh
-
name
:
Test transducer kws
shell
:
bash
run
:
|
...
...
@@ -140,6 +149,7 @@ jobs:
.github/scripts/test-online-ctc.sh
-
name
:
Test offline Whisper
if
:
matrix.build_type != 'Debug'
shell
:
bash
run
:
|
export PATH=$PWD/build/bin:$PATH
...
...
.github/workflows/macos.yaml
查看文件 @
0d258dd
...
...
@@ -102,6 +102,15 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
-
name
:
Test spoken language identification
if
:
matrix.build_type != 'Debug'
shell
:
bash
run
:
|
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-language-identification
.github/scripts/test-spoken-language-identification.sh
-
name
:
Test transducer kws
shell
:
bash
run
:
|
...
...
@@ -135,6 +144,7 @@ jobs:
.github/scripts/test-online-paraformer.sh
-
name
:
Test offline Whisper
if
:
matrix.build_type != 'Debug'
shell
:
bash
run
:
|
export PATH=$PWD/build/bin:$PATH
...
...
.github/workflows/windows-x64-cuda.yaml
查看文件 @
0d258dd
...
...
@@ -68,6 +68,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
-
name
:
Test spoken language identification
shell
:
bash
run
:
|
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-language-identification.exe
.github/scripts/test-spoken-language-identification.sh
-
name
:
Test online CTC
shell
:
bash
run
:
|
...
...
.github/workflows/windows-x64.yaml
查看文件 @
0d258dd
...
...
@@ -68,6 +68,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
-
name
:
Test spoken language identification
shell
:
bash
run
:
|
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-language-identification.exe
.github/scripts/test-spoken-language-identification.sh
-
name
:
Test online CTC
shell
:
bash
run
:
|
...
...
.github/workflows/windows-x86.yaml
查看文件 @
0d258dd
...
...
@@ -69,6 +69,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
# - name: Test spoken language identification
# shell: bash
# run: |
# export PATH=$PWD/build/bin/Release:$PATH
# export EXE=sherpa-onnx-offline-language-identification.exe
#
# .github/scripts/test-spoken-language-identification.sh
-
name
:
Test online CTC
shell
:
bash
run
:
|
...
...
CMakeLists.txt
查看文件 @
0d258dd
cmake_minimum_required
(
VERSION 3.13 FATAL_ERROR
)
project
(
sherpa-onnx
)
set
(
SHERPA_ONNX_VERSION
"1.9.1
3
"
)
set
(
SHERPA_ONNX_VERSION
"1.9.1
4
"
)
# Disable warning about
#
...
...
cmake/cmake_extension.py
查看文件 @
0d258dd
...
...
@@ -43,6 +43,50 @@ def enable_alsa():
return
build_alsa
and
is_linux
()
and
(
is_arm64
()
or
is_x86
())
def
get_binaries
():
binaries
=
[
"sherpa-onnx"
,
"sherpa-onnx-keyword-spotter"
,
"sherpa-onnx-microphone"
,
"sherpa-onnx-microphone-offline"
,
"sherpa-onnx-microphone-offline-speaker-identification"
,
"sherpa-onnx-offline"
,
"sherpa-onnx-offline-language-identification"
,
"sherpa-onnx-offline-tts"
,
"sherpa-onnx-offline-tts-play"
,
"sherpa-onnx-offline-websocket-server"
,
"sherpa-onnx-online-websocket-client"
,
"sherpa-onnx-online-websocket-server"
,
"sherpa-onnx-vad-microphone"
,
"sherpa-onnx-vad-microphone-offline-asr"
,
]
if
enable_alsa
():
binaries
+=
[
"sherpa-onnx-alsa"
,
"sherpa-onnx-alsa-offline"
,
"sherpa-onnx-alsa-offline-speaker-identification"
,
"sherpa-onnx-offline-tts-play-alsa"
,
]
if
is_windows
():
binaries
+=
[
"espeak-ng.dll"
,
"kaldi-decoder-core.dll"
,
"kaldi-native-fbank-core.dll"
,
"onnxruntime.dll"
,
"piper_phonemize.dll"
,
"sherpa-onnx-c-api.dll"
,
"sherpa-onnx-core.dll"
,
"sherpa-onnx-fst.lib"
,
"sherpa-onnx-kaldifst-core.lib"
,
"sherpa-onnx-portaudio.dll"
,
"ucd.dll"
,
]
return
binaries
try
:
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
...
...
@@ -150,38 +194,7 @@ class BuildExtension(build_ext):
suffix
=
".exe"
if
is_windows
()
else
""
# Remember to also change setup.py
binaries
=
[
"sherpa-onnx"
]
binaries
+=
[
"sherpa-onnx-keyword-spotter"
]
binaries
+=
[
"sherpa-onnx-offline"
]
binaries
+=
[
"sherpa-onnx-microphone"
]
binaries
+=
[
"sherpa-onnx-microphone-offline"
]
binaries
+=
[
"sherpa-onnx-microphone-offline-speaker-identification"
]
binaries
+=
[
"sherpa-onnx-online-websocket-server"
]
binaries
+=
[
"sherpa-onnx-offline-websocket-server"
]
binaries
+=
[
"sherpa-onnx-online-websocket-client"
]
binaries
+=
[
"sherpa-onnx-vad-microphone"
]
binaries
+=
[
"sherpa-onnx-vad-microphone-offline-asr"
]
binaries
+=
[
"sherpa-onnx-offline-tts"
]
binaries
+=
[
"sherpa-onnx-offline-tts-play"
]
if
enable_alsa
():
binaries
+=
[
"sherpa-onnx-alsa"
]
binaries
+=
[
"sherpa-onnx-alsa-offline"
]
binaries
+=
[
"sherpa-onnx-offline-tts-play-alsa"
]
binaries
+=
[
"sherpa-onnx-alsa-offline-speaker-identification"
]
if
is_windows
():
binaries
+=
[
"kaldi-native-fbank-core.dll"
]
binaries
+=
[
"sherpa-onnx-c-api.dll"
]
binaries
+=
[
"sherpa-onnx-core.dll"
]
binaries
+=
[
"sherpa-onnx-portaudio.dll"
]
binaries
+=
[
"onnxruntime.dll"
]
binaries
+=
[
"piper_phonemize.dll"
]
binaries
+=
[
"espeak-ng.dll"
]
binaries
+=
[
"ucd.dll"
]
binaries
+=
[
"kaldi-decoder-core.dll"
]
binaries
+=
[
"sherpa-onnx-fst.lib"
]
binaries
+=
[
"sherpa-onnx-kaldifst-core.lib"
]
binaries
=
get_binaries
()
for
f
in
binaries
:
suffix
=
""
if
(
".dll"
in
f
or
".lib"
in
f
)
else
suffix
...
...
python-api-examples/spoken-language-identification.py
0 → 100755
查看文件 @
0d258dd
#!/usr/bin/env python3
"""
This script shows how to use Python APIs for spoken languge identification.
It detects the language spoken in the given wave file.
Usage:
1. Download a whisper multilingual model. We use a tiny model below.
Please refer to https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
to download more models.
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
rm sherpa-onnx-whisper-tiny.tar.bz2
We only use the int8.onnx models below.
2. Download a test wave.
You can find many wave files for different languages at
https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
wget https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/de-german.wav
python3 ./python-api-examples/spoken-language-identification.py
--whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx
\
--whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx
\
--num-threads=1
\
./de-german.wav
"""
import
argparse
import
logging
import
time
import
wave
from
pathlib
import
Path
from
typing
import
Tuple
import
numpy
as
np
import
sherpa_onnx
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
"--whisper-encoder"
,
required
=
True
,
type
=
str
,
help
=
"Path to a multilingual whisper encoder model"
,
)
parser
.
add_argument
(
"--whisper-decoder"
,
required
=
True
,
type
=
str
,
help
=
"Path to a multilingual whisper decoder model"
,
)
parser
.
add_argument
(
"--num-threads"
,
type
=
int
,
default
=
1
,
help
=
"Number of threads for neural network computation"
,
)
parser
.
add_argument
(
"--debug"
,
type
=
bool
,
default
=
False
,
help
=
"True to show debug messages"
,
)
parser
.
add_argument
(
"--provider"
,
type
=
str
,
default
=
"cpu"
,
help
=
"Valid values: cpu, cuda, coreml"
,
)
parser
.
add_argument
(
"sound_file"
,
type
=
str
,
help
=
"The input sound file to identify. It must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz"
,
)
return
parser
.
parse_args
()
def
assert_file_exists
(
filename
:
str
):
assert
Path
(
filename
)
.
is_file
(),
(
f
"{filename} does not exist!
\n
"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html to download it"
)
def
read_wave
(
wave_filename
:
str
)
->
Tuple
[
np
.
ndarray
,
int
]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with
wave
.
open
(
wave_filename
)
as
f
:
assert
f
.
getnchannels
()
==
1
,
f
.
getnchannels
()
assert
f
.
getsampwidth
()
==
2
,
f
.
getsampwidth
()
# it is in bytes
num_samples
=
f
.
getnframes
()
samples
=
f
.
readframes
(
num_samples
)
samples_int16
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
samples_float32
=
samples_int16
.
astype
(
np
.
float32
)
samples_float32
=
samples_float32
/
32768
return
samples_float32
,
f
.
getframerate
()
def
main
():
args
=
get_args
()
assert_file_exists
(
args
.
whisper_encoder
)
assert_file_exists
(
args
.
whisper_decoder
)
assert
args
.
num_threads
>
0
,
args
.
num_threads
config
=
sherpa_onnx
.
SpokenLanguageIdentificationConfig
(
whisper
=
sherpa_onnx
.
SpokenLanguageIdentificationWhisperConfig
(
encoder
=
args
.
whisper_encoder
,
decoder
=
args
.
whisper_decoder
,
),
num_threads
=
args
.
num_threads
,
debug
=
args
.
debug
,
provider
=
args
.
provider
,
)
slid
=
sherpa_onnx
.
SpokenLanguageIdentification
(
config
)
samples
,
sample_rate
=
read_wave
(
args
.
sound_file
)
start_time
=
time
.
time
()
stream
=
slid
.
create_stream
()
stream
.
accept_waveform
(
sample_rate
=
sample_rate
,
waveform
=
samples
)
lang
=
slid
.
compute
(
stream
)
end_time
=
time
.
time
()
elapsed_seconds
=
end_time
-
start_time
audio_duration
=
len
(
samples
)
/
sample_rate
real_time_factor
=
elapsed_seconds
/
audio_duration
logging
.
info
(
f
"File: {args.sound_file}"
)
logging
.
info
(
f
"Detected language: {lang}"
)
logging
.
info
(
f
"Elapsed seconds: {elapsed_seconds:.3f}"
)
logging
.
info
(
f
"Audio duration in seconds: {audio_duration:.3f}"
)
logging
.
info
(
f
"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
)
if
__name__
==
"__main__"
:
formatter
=
"
%(asctime)
s
%(levelname)
s [
%(filename)
s:
%(lineno)
d]
%(message)
s"
logging
.
basicConfig
(
format
=
formatter
,
level
=
logging
.
INFO
)
main
()
...
...
setup.py
查看文件 @
0d258dd
#!/usr/bin/env python3
import
os
import
re
import
sys
from
pathlib
import
Path
import
setuptools
...
...
@@ -11,7 +9,7 @@ from cmake.cmake_extension import (
BuildExtension
,
bdist_wheel
,
cmake_extension
,
enable_alsa
,
get_binaries
,
is_windows
,
)
...
...
@@ -42,39 +40,7 @@ def get_binaries_to_install():
bin_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
suffix
=
".exe"
if
is_windows
()
else
""
# Remember to also change cmake/cmake_extension.py
binaries
=
[
"sherpa-onnx"
]
binaries
+=
[
"sherpa-onnx-keyword-spotter"
]
binaries
+=
[
"sherpa-onnx-offline"
]
binaries
+=
[
"sherpa-onnx-microphone"
]
binaries
+=
[
"sherpa-onnx-microphone-offline"
]
binaries
+=
[
"sherpa-onnx-microphone-offline-speaker-identification"
]
binaries
+=
[
"sherpa-onnx-online-websocket-server"
]
binaries
+=
[
"sherpa-onnx-offline-websocket-server"
]
binaries
+=
[
"sherpa-onnx-online-websocket-client"
]
binaries
+=
[
"sherpa-onnx-vad-microphone"
]
binaries
+=
[
"sherpa-onnx-vad-microphone-offline-asr"
]
binaries
+=
[
"sherpa-onnx-offline-tts"
]
binaries
+=
[
"sherpa-onnx-offline-tts-play"
]
if
enable_alsa
():
binaries
+=
[
"sherpa-onnx-alsa"
]
binaries
+=
[
"sherpa-onnx-alsa-offline"
]
binaries
+=
[
"sherpa-onnx-offline-tts-play-alsa"
]
binaries
+=
[
"sherpa-onnx-alsa-offline-speaker-identification"
]
if
is_windows
():
binaries
+=
[
"kaldi-native-fbank-core.dll"
]
binaries
+=
[
"sherpa-onnx-c-api.dll"
]
binaries
+=
[
"sherpa-onnx-core.dll"
]
binaries
+=
[
"sherpa-onnx-portaudio.dll"
]
binaries
+=
[
"onnxruntime.dll"
]
binaries
+=
[
"piper_phonemize.dll"
]
binaries
+=
[
"espeak-ng.dll"
]
binaries
+=
[
"ucd.dll"
]
binaries
+=
[
"kaldi-decoder-core.dll"
]
binaries
+=
[
"sherpa-onnx-fst.lib"
]
binaries
+=
[
"sherpa-onnx-kaldifst-core.lib"
]
binaries
=
get_binaries
()
exe
=
[]
for
f
in
binaries
:
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
0d258dd
...
...
@@ -86,6 +86,8 @@ set(sources
silero-vad-model-config.cc
silero-vad-model.cc
slice.cc
spoken-language-identification-impl.cc
spoken-language-identification.cc
stack.cc
symbol-table.cc
text-utils.cc
...
...
@@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable
(
sherpa-onnx-offline sherpa-onnx-offline.cc
)
add_executable
(
sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc
)
add_executable
(
sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc
)
add_executable
(
sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc
)
set
(
main_exes
sherpa-onnx
...
...
@@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline
sherpa-onnx-offline-parallel
sherpa-onnx-offline-tts
sherpa-onnx-offline-language-identification
)
foreach
(
exe IN LISTS main_exes
)
...
...
sherpa-onnx/csrc/offline-ctc-model.cc
查看文件 @
0d258dd
...
...
@@ -23,7 +23,7 @@ enum class ModelType {
kTdnn
,
kZipformerCtc
,
kWenetCtc
,
kUnkown
,
kUnk
n
own
,
};
}
// namespace
...
...
@@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"run.sh
\n
"
"
\n
"
"for how to add metadta to model.onnx
\n
"
);
return
ModelType
::
kUnkown
;
return
ModelType
::
kUnk
n
own
;
}
if
(
model_type
.
get
()
==
std
::
string
(
"EncDecCTCModelBPE"
))
{
...
...
@@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return
ModelType
::
kWenetCtc
;
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported model_type: %s"
,
model_type
.
get
());
return
ModelType
::
kUnkown
;
return
ModelType
::
kUnk
n
own
;
}
}
std
::
unique_ptr
<
OfflineCtcModel
>
OfflineCtcModel
::
Create
(
const
OfflineModelConfig
&
config
)
{
ModelType
model_type
=
ModelType
::
kUnkown
;
ModelType
model_type
=
ModelType
::
kUnk
n
own
;
std
::
string
filename
;
if
(
!
config
.
nemo_ctc
.
model
.
empty
())
{
...
...
@@ -113,7 +113,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case
ModelType
:
:
kWenetCtc
:
return
std
::
make_unique
<
OfflineWenetCtcModel
>
(
config
);
break
;
case
ModelType
:
:
kUnkown
:
case
ModelType
:
:
kUnk
n
own
:
SHERPA_ONNX_LOGE
(
"Unknown model type in offline CTC!"
);
return
nullptr
;
}
...
...
@@ -125,7 +125,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
std
::
unique_ptr
<
OfflineCtcModel
>
OfflineCtcModel
::
Create
(
AAssetManager
*
mgr
,
const
OfflineModelConfig
&
config
)
{
ModelType
model_type
=
ModelType
::
kUnkown
;
ModelType
model_type
=
ModelType
::
kUnk
n
own
;
std
::
string
filename
;
if
(
!
config
.
nemo_ctc
.
model
.
empty
())
{
...
...
@@ -160,7 +160,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case
ModelType
:
:
kWenetCtc
:
return
std
::
make_unique
<
OfflineWenetCtcModel
>
(
mgr
,
config
);
break
;
case
ModelType
:
:
kUnkown
:
case
ModelType
:
:
kUnk
n
own
:
SHERPA_ONNX_LOGE
(
"Unknown model type in offline CTC!"
);
return
nullptr
;
}
...
...
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
查看文件 @
0d258dd
...
...
@@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
num_frames
=
max_num_frames
-
50
;
}
NormalizeFeatures
(
f
.
data
(),
num_frames
,
feat_dim
);
model_
->
NormalizeFeatures
(
f
.
data
(),
num_frames
,
feat_dim
);
// note that 1000 is an experience-value.
// You can replace 1000 by other values, say, 100.
...
...
@@ -163,38 +163,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
private
:
static
void
NormalizeFeatures
(
float
*
features
,
int32_t
num_frames
,
int32_t
feat_dim
)
{
// log_spec = torch.clamp(features, min=1e-10).log10()
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
// mel = (log_spec + 4.0) / 4.0
int32_t
n
=
num_frames
*
feat_dim
;
float
max_v
=
-
1e20
;
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
float
f
=
features
[
i
];
f
=
std
::
max
<
float
>
(
f
,
1e-10
);
f
=
std
::
log10
(
f
);
max_v
=
std
::
max
(
f
,
max_v
);
features
[
i
]
=
f
;
}
max_v
-=
8
;
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
float
f
=
features
[
i
];
f
=
std
::
max
(
f
,
max_v
);
f
=
(
f
+
4
)
/
4
;
features
[
i
]
=
f
;
}
}
private
:
OfflineRecognizerConfig
config_
;
SymbolTable
symbol_table_
;
std
::
unique_ptr
<
OfflineWhisperModel
>
model_
;
...
...
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
查看文件 @
0d258dd
...
...
@@ -12,56 +12,6 @@
namespace
sherpa_onnx
{
int32_t
OfflineWhisperGreedySearchDecoder
::
DetectLanguage
(
Ort
::
Value
&
cross_k
,
Ort
::
Value
&
cross_v
)
const
{
// NOLINT
int64_t
token_val
=
model_
->
SOT
();
std
::
array
<
int64_t
,
2
>
token_shape
{
1
,
1
};
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
Ort
::
Value
tokens
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
&
token_val
,
1
,
token_shape
.
data
(),
token_shape
.
size
());
auto
self_kv_cache
=
model_
->
GetInitialSelfKVCache
();
std
::
array
<
int64_t
,
1
>
offset_shape
{
1
};
Ort
::
Value
offset
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
model_
->
Allocator
(),
offset_shape
.
data
(),
offset_shape
.
size
());
*
(
offset
.
GetTensorMutableData
<
int64_t
>
())
=
0
;
auto
decoder_out
=
model_
->
ForwardDecoder
(
std
::
move
(
tokens
),
std
::
move
(
self_kv_cache
.
first
),
std
::
move
(
self_kv_cache
.
second
),
std
::
move
(
cross_k
),
std
::
move
(
cross_v
),
std
::
move
(
offset
));
cross_k
=
std
::
move
(
std
::
get
<
3
>
(
decoder_out
));
cross_v
=
std
::
move
(
std
::
get
<
4
>
(
decoder_out
));
const
float
*
p_logits
=
std
::
get
<
0
>
(
decoder_out
).
GetTensorData
<
float
>
();
int32_t
vocab_size
=
model_
->
VocabSize
();
const
auto
&
all_language_ids
=
model_
->
GetAllLanguageIDs
();
int32_t
lang_id
=
all_language_ids
[
0
];
float
this_logit
=
p_logits
[
lang_id
];
for
(
int32_t
i
=
1
;
i
!=
all_language_ids
.
size
();
++
i
)
{
int32_t
id
=
all_language_ids
[
i
];
float
p
=
p_logits
[
id
];
if
(
p
>
this_logit
)
{
this_logit
=
p
;
lang_id
=
id
;
}
}
#if 1
SHERPA_ONNX_LOGE
(
"Detected language: %s"
,
model_
->
GetID2Lang
().
at
(
lang_id
).
c_str
());
#endif
return
lang_id
;
}
std
::
vector
<
OfflineWhisperDecoderResult
>
OfflineWhisperGreedySearchDecoder
::
Decode
(
Ort
::
Value
cross_k
,
Ort
::
Value
cross_v
)
{
...
...
@@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens
[
1
]
=
lang_id
;
}
else
{
int32_t
lang_id
=
DetectLanguage
(
cross_k
,
cross_v
);
int32_t
lang_id
=
model_
->
DetectLanguage
(
cross_k
,
cross_v
);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens
[
1
]
=
lang_id
;
...
...
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
查看文件 @
0d258dd
...
...
@@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
std
::
vector
<
OfflineWhisperDecoderResult
>
Decode
(
Ort
::
Value
cross_k
,
Ort
::
Value
cross_v
)
override
;
int32_t
DetectLanguage
(
Ort
::
Value
&
cross_k
,
// NOLINT
Ort
::
Value
&
cross_v
)
const
;
// NOLINT
private
:
OfflineWhisperModelConfig
config_
;
OfflineWhisperModel
*
model_
;
// not owned
...
...
sherpa-onnx/csrc/offline-whisper-model-config.cc
查看文件 @
0d258dd
...
...
@@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
po
->
Register
(
"whisper-tail-paddings"
,
&
tail_paddings
,
"Suggest value: 50 for English models. 300 for multilingual models. "
"Suggest
ed
value: 50 for English models. 300 for multilingual models. "
"Since we have removed the 30-second constraint, we need to add some "
"tail padding frames "
"so that whisper can detect the eot token. Leave it to -1 to use 50 for "
"English models and 300 for multilingual models."
);
"so that whisper can detect the eot token. Leave it to -1 to use 1000."
);
}
bool
OfflineWhisperModelConfig
::
Validate
()
const
{
if
(
encoder
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Please provide --whisper-encoder"
);
return
false
;
}
if
(
!
FileExists
(
encoder
))
{
SHERPA_ONNX_LOGE
(
"whisper encoder file %s does not exist"
,
encoder
.
c_str
());
return
false
;
}
if
(
decoder
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Please provide --whisper-decoder"
);
return
false
;
}
if
(
!
FileExists
(
decoder
))
{
SHERPA_ONNX_LOGE
(
"whisper decoder file %s does not exist"
,
decoder
.
c_str
());
return
false
;
...
...
sherpa-onnx/csrc/offline-whisper-model.cc
查看文件 @
0d258dd
...
...
@@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl {
env_
(
ORT_LOGGING_LEVEL_ERROR
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
debug_
=
config_
.
debug
;
{
auto
buf
=
ReadFile
(
config
.
whisper
.
encoder
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
config
.
whisper
.
decoder
);
InitDecoder
(
buf
.
data
(),
buf
.
size
());
}
}
explicit
Impl
(
const
SpokenLanguageIdentificationConfig
&
config
)
:
lid_config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_ERROR
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
debug_
=
config_
.
debug
;
{
auto
buf
=
ReadFile
(
config
.
whisper
.
encoder
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
...
...
@@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl {
env_
(
ORT_LOGGING_LEVEL_ERROR
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
debug_
=
config_
.
debug
;
{
auto
buf
=
ReadFile
(
mgr
,
config
.
whisper
.
encoder
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
...
...
@@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl {
std
::
move
(
decoder_input
[
4
]),
std
::
move
(
decoder_input
[
5
])};
}
int32_t
DetectLanguage
(
Ort
::
Value
&
cross_k
,
// NOLINT
Ort
::
Value
&
cross_v
)
{
// NOLINT
int64_t
token_val
=
SOT
();
std
::
array
<
int64_t
,
2
>
token_shape
{
1
,
1
};
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
Ort
::
Value
tokens
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
&
token_val
,
1
,
token_shape
.
data
(),
token_shape
.
size
());
auto
self_kv_cache
=
GetInitialSelfKVCache
();
std
::
array
<
int64_t
,
1
>
offset_shape
{
1
};
Ort
::
Value
offset
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
Allocator
(),
offset_shape
.
data
(),
offset_shape
.
size
());
*
(
offset
.
GetTensorMutableData
<
int64_t
>
())
=
0
;
auto
decoder_out
=
ForwardDecoder
(
std
::
move
(
tokens
),
std
::
move
(
self_kv_cache
.
first
),
std
::
move
(
self_kv_cache
.
second
),
std
::
move
(
cross_k
),
std
::
move
(
cross_v
),
std
::
move
(
offset
));
cross_k
=
std
::
move
(
std
::
get
<
3
>
(
decoder_out
));
cross_v
=
std
::
move
(
std
::
get
<
4
>
(
decoder_out
));
const
float
*
p_logits
=
std
::
get
<
0
>
(
decoder_out
).
GetTensorData
<
float
>
();
int32_t
vocab_size
=
VocabSize
();
const
auto
&
all_language_ids
=
GetAllLanguageIDs
();
int32_t
lang_id
=
all_language_ids
[
0
];
float
this_logit
=
p_logits
[
lang_id
];
for
(
int32_t
i
=
1
;
i
!=
all_language_ids
.
size
();
++
i
)
{
int32_t
id
=
all_language_ids
[
i
];
float
p
=
p_logits
[
id
];
if
(
p
>
this_logit
)
{
this_logit
=
p
;
lang_id
=
id
;
}
}
if
(
debug_
)
{
SHERPA_ONNX_LOGE
(
"Detected language: %s"
,
GetID2Lang
().
at
(
lang_id
).
c_str
());
}
return
lang_id
;
}
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
GetInitialSelfKVCache
()
{
std
::
array
<
int64_t
,
4
>
shape
{
n_text_layer_
,
1
,
n_text_ctx_
,
n_text_state_
};
...
...
@@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl {
// get meta data
Ort
::
ModelMetadata
meta_data
=
encoder_sess_
->
GetModelMetadata
();
if
(
config_
.
debug
)
{
if
(
debug_
)
{
std
::
ostringstream
os
;
os
<<
"---encoder---
\n
"
;
PrintModelMetadata
(
os
,
meta_data
);
...
...
@@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl {
private
:
OfflineModelConfig
config_
;
SpokenLanguageIdentificationConfig
lid_config_
;
bool
debug_
=
false
;
Ort
::
Env
env_
;
Ort
::
SessionOptions
sess_opts_
;
Ort
::
AllocatorWithDefaultOptions
allocator_
;
...
...
@@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl {
OfflineWhisperModel
::
OfflineWhisperModel
(
const
OfflineModelConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
OfflineWhisperModel
::
OfflineWhisperModel
(
const
SpokenLanguageIdentificationConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
#if __ANDROID_API__ >= 9
OfflineWhisperModel
::
OfflineWhisperModel
(
AAssetManager
*
mgr
,
const
OfflineModelConfig
&
config
)
...
...
@@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
std
::
move
(
n_layer_cross_v
),
std
::
move
(
offset
));
}
int32_t
OfflineWhisperModel
::
DetectLanguage
(
Ort
::
Value
&
cross_k
,
// NOLINT
Ort
::
Value
&
cross_v
)
{
// NOLINT
return
impl_
->
DetectLanguage
(
cross_k
,
cross_v
);
}
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
OfflineWhisperModel
::
GetInitialSelfKVCache
()
const
{
return
impl_
->
GetInitialSelfKVCache
();
...
...
@@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const {
return
impl_
->
IsMultiLingual
();
}
void
OfflineWhisperModel
::
NormalizeFeatures
(
float
*
features
,
int32_t
num_frames
,
int32_t
feat_dim
)
{
// log_spec = torch.clamp(features, min=1e-10).log10()
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
// mel = (log_spec + 4.0) / 4.0
int32_t
n
=
num_frames
*
feat_dim
;
float
max_v
=
-
1e20
;
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
float
f
=
features
[
i
];
f
=
std
::
max
<
float
>
(
f
,
1e-10
);
f
=
std
::
log10
(
f
);
max_v
=
std
::
max
(
f
,
max_v
);
features
[
i
]
=
f
;
}
max_v
-=
8
;
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
float
f
=
features
[
i
];
f
=
std
::
max
(
f
,
max_v
);
f
=
(
f
+
4
)
/
4
;
features
[
i
]
=
f
;
}
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-whisper-model.h
查看文件 @
0d258dd
...
...
@@ -18,6 +18,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
namespace
sherpa_onnx
{
...
...
@@ -25,6 +26,9 @@ class OfflineWhisperModel {
public
:
explicit
OfflineWhisperModel
(
const
OfflineModelConfig
&
config
);
explicit
OfflineWhisperModel
(
const
SpokenLanguageIdentificationConfig
&
config
);
#if __ANDROID_API__ >= 9
OfflineWhisperModel
(
AAssetManager
*
mgr
,
const
OfflineModelConfig
&
config
);
#endif
...
...
@@ -72,7 +76,8 @@ class OfflineWhisperModel {
Ort
::
Value
n_layer_self_v_cache
,
Ort
::
Value
n_layer_cross_k
,
Ort
::
Value
n_layer_cross_v
,
Ort
::
Value
offset
)
const
;
int32_t
DetectLanguage
()
const
;
int32_t
DetectLanguage
(
Ort
::
Value
&
cross_k
,
// NOLINT
Ort
::
Value
&
cross_v
);
// NOLINT
/** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 4-D tensor of shape
...
...
@@ -98,6 +103,9 @@ class OfflineWhisperModel {
int32_t
Translate
()
const
;
bool
IsMultiLingual
()
const
;
static
void
NormalizeFeatures
(
float
*
features
,
int32_t
num_frames
,
int32_t
feat_dim
);
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
...
...
sherpa-onnx/csrc/online-transducer-model.cc
查看文件 @
0d258dd
...
...
@@ -28,7 +28,7 @@ enum class ModelType {
kLstm
,
kZipformer
,
kZipformer2
,
kUnkown
,
kUnk
n
own
,
};
}
// namespace
...
...
@@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"No model_type in the metadata!
\n
"
"Please make sure you are using the latest export-onnx.py from icefall "
"to export your transducer models"
);
return
ModelType
::
kUnkown
;
return
ModelType
::
kUnk
n
own
;
}
if
(
model_type
.
get
()
==
std
::
string
(
"conformer"
))
{
...
...
@@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return
ModelType
::
kZipformer2
;
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported model_type: %s"
,
model_type
.
get
());
return
ModelType
::
kUnkown
;
return
ModelType
::
kUnk
n
own
;
}
}
...
...
@@ -93,7 +93,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
model_type
.
c_str
());
}
}
ModelType
model_type
=
ModelType
::
kUnkown
;
ModelType
model_type
=
ModelType
::
kUnk
n
own
;
{
auto
buffer
=
ReadFile
(
config
.
transducer
.
encoder
);
...
...
@@ -110,7 +110,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
return
std
::
make_unique
<
OnlineZipformerTransducerModel
>
(
config
);
case
ModelType
:
:
kZipformer2
:
return
std
::
make_unique
<
OnlineZipformer2TransducerModel
>
(
config
);
case
ModelType
:
:
kUnkown
:
case
ModelType
:
:
kUnk
n
own
:
SHERPA_ONNX_LOGE
(
"Unknown model type in online transducer!"
);
return
nullptr
;
}
...
...
@@ -185,7 +185,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
return
std
::
make_unique
<
OnlineZipformerTransducerModel
>
(
mgr
,
config
);
case
ModelType
:
:
kZipformer2
:
return
std
::
make_unique
<
OnlineZipformer2TransducerModel
>
(
mgr
,
config
);
case
ModelType
:
:
kUnkown
:
case
ModelType
:
:
kUnk
n
own
:
SHERPA_ONNX_LOGE
(
"Unknown model type in online transducer!"
);
return
nullptr
;
}
...
...
sherpa-onnx/csrc/session.cc
查看文件 @
0d258dd
...
...
@@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions(
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
Ort
::
SessionOptions
GetSessionOptions
(
const
SpokenLanguageIdentificationConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/session.h
查看文件 @
0d258dd
...
...
@@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/vad-model-config.h"
namespace
sherpa_onnx
{
...
...
@@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
Ort
::
SessionOptions
GetSessionOptions
(
const
SpeakerEmbeddingExtractorConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
SpokenLanguageIdentificationConfig
&
config
);
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SESSION_H_
...
...
sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
0 → 100644
查看文件 @
0d258dd
// sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
//
// Copyright (c) 2022-2024 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/wave-reader.h"
int
main
(
int32_t
argc
,
char
*
argv
[])
{
const
char
*
kUsageMessage
=
R"usage(
Spoken language identification with sherpa-onnx.
Usage:
(1) Use a whisper multilingual model
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
rm sherpa-onnx-whisper-tiny.tar.bz2
We only use the int8.onnx models below.
./bin/sherpa-onnx-offline-spoken-language-identification \
--whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
--whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
--num-threads=1 \
/path/to/foo.wav
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
You can find test waves for different languages at
https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
Note that only whisper multilingual models are supported. For instance,
"tiny" is supported but "tiny.en" is not.
for a list of pre-trained models to download.
)usage"
;
sherpa_onnx
::
ParseOptions
po
(
kUsageMessage
);
sherpa_onnx
::
SpokenLanguageIdentificationConfig
config
;
config
.
Register
(
&
po
);
po
.
Read
(
argc
,
argv
);
if
(
po
.
NumArgs
()
!=
1
)
{
fprintf
(
stderr
,
"Error: Please provide 1 wave file.
\n\n
"
);
po
.
PrintUsage
();
exit
(
EXIT_FAILURE
);
}
fprintf
(
stderr
,
"%s
\n
"
,
config
.
ToString
().
c_str
());
if
(
!
config
.
Validate
())
{
fprintf
(
stderr
,
"Errors in config!
\n
"
);
return
-
1
;
}
fprintf
(
stderr
,
"Creating spoken language identifier ...
\n
"
);
sherpa_onnx
::
SpokenLanguageIdentification
slid
(
config
);
fprintf
(
stderr
,
"Started
\n
"
);
const
std
::
string
wav_filename
=
po
.
GetArg
(
1
);
int32_t
sampling_rate
=
-
1
;
bool
is_ok
=
false
;
const
std
::
vector
<
float
>
samples
=
sherpa_onnx
::
ReadWave
(
wav_filename
,
&
sampling_rate
,
&
is_ok
);
if
(
!
is_ok
)
{
fprintf
(
stderr
,
"Failed to read %s
\n
"
,
wav_filename
.
c_str
());
return
-
1
;
}
float
duration
=
samples
.
size
()
/
static_cast
<
float
>
(
sampling_rate
);
const
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
auto
s
=
slid
.
CreateStream
();
s
->
AcceptWaveform
(
sampling_rate
,
samples
.
data
(),
samples
.
size
());
auto
language
=
slid
.
Compute
(
s
.
get
());
const
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
fprintf
(
stderr
,
"Done!
\n\n
"
);
fprintf
(
stderr
,
"%s
\n
Detected language: %s
\n
"
,
wav_filename
.
c_str
(),
language
.
c_str
());
float
elapsed_seconds
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
begin
)
.
count
()
/
1000.
;
fprintf
(
stderr
,
"num threads: %d
\n
"
,
config
.
num_threads
);
fprintf
(
stderr
,
"Elapsed seconds: %.3f s
\n
"
,
elapsed_seconds
);
float
rtf
=
elapsed_seconds
/
duration
;
fprintf
(
stderr
,
"Real time factor (RTF): %.3f / %.3f = %.3f
\n
"
,
elapsed_seconds
,
duration
,
rtf
);
return
0
;
}
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
查看文件 @
0d258dd
...
...
@@ -16,7 +16,7 @@ enum class ModelType {
kWeSpeaker
,
k3dSpeaker
,
kNeMo
,
kUnkown
,
kUnk
n
own
,
};
}
// namespace
...
...
@@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
"add_meta_data.py"
"to add metadata to models from WeSpeaker
\n
"
);
return
ModelType
::
kUnkown
;
return
ModelType
::
kUnk
n
own
;
}
if
(
model_type
.
get
()
==
std
::
string
(
"wespeaker"
))
{
...
...
@@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return
ModelType
::
kNeMo
;
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported model_type: %s"
,
model_type
.
get
());
return
ModelType
::
kUnkown
;
return
ModelType
::
kUnk
n
own
;
}
}
std
::
unique_ptr
<
SpeakerEmbeddingExtractorImpl
>
SpeakerEmbeddingExtractorImpl
::
Create
(
const
SpeakerEmbeddingExtractorConfig
&
config
)
{
ModelType
model_type
=
ModelType
::
kUnkown
;
ModelType
model_type
=
ModelType
::
kUnk
n
own
;
{
auto
buffer
=
ReadFile
(
config
.
model
);
...
...
@@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create(
return
std
::
make_unique
<
SpeakerEmbeddingExtractorGeneralImpl
>
(
config
);
case
ModelType
:
:
kNeMo
:
return
std
::
make_unique
<
SpeakerEmbeddingExtractorNeMoImpl
>
(
config
);
case
ModelType
:
:
kUnkown
:
SHERPA_ONNX_LOGE
(
"Unknown model type in for speaker embedding extractor!"
);
case
ModelType
:
:
kUnknown
:
SHERPA_ONNX_LOGE
(
"Unknown model type for speaker embedding extractor!"
);
return
nullptr
;
}
...
...
@@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create(
std
::
unique_ptr
<
SpeakerEmbeddingExtractorImpl
>
SpeakerEmbeddingExtractorImpl
::
Create
(
AAssetManager
*
mgr
,
const
SpeakerEmbeddingExtractorConfig
&
config
)
{
ModelType
model_type
=
ModelType
::
kUnkown
;
ModelType
model_type
=
ModelType
::
kUnk
n
own
;
{
auto
buffer
=
ReadFile
(
mgr
,
config
.
model
);
...
...
@@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create(
config
);
case
ModelType
:
:
kNeMo
:
return
std
::
make_unique
<
SpeakerEmbeddingExtractorNeMoImpl
>
(
mgr
,
config
);
case
ModelType
:
:
kUnkown
:
case
ModelType
:
:
kUnk
n
own
:
SHERPA_ONNX_LOGE
(
"Unknown model type in for speaker embedding extractor!"
);
return
nullptr
;
...
...
sherpa-onnx/csrc/spoken-language-identification-impl.cc
0 → 100644
查看文件 @
0d258dd
// sherpa-onnx/csrc/spoken-language-identification-impl.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
#include <memory>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"
namespace
sherpa_onnx
{
namespace
{
enum
class
ModelType
{
kWhisper
,
kUnknown
,
};
}
static
ModelType
GetModelType
(
char
*
model_data
,
size_t
model_data_length
,
bool
debug
)
{
Ort
::
Env
env
(
ORT_LOGGING_LEVEL_WARNING
);
Ort
::
SessionOptions
sess_opts
;
auto
sess
=
std
::
make_unique
<
Ort
::
Session
>
(
env
,
model_data
,
model_data_length
,
sess_opts
);
Ort
::
ModelMetadata
meta_data
=
sess
->
GetModelMetadata
();
if
(
debug
)
{
std
::
ostringstream
os
;
PrintModelMetadata
(
os
,
meta_data
);
SHERPA_ONNX_LOGE
(
"%s"
,
os
.
str
().
c_str
());
}
Ort
::
AllocatorWithDefaultOptions
allocator
;
auto
model_type
=
meta_data
.
LookupCustomMetadataMapAllocated
(
"model_type"
,
allocator
);
if
(
!
model_type
)
{
SHERPA_ONNX_LOGE
(
"No model_type in the metadata!
\n
"
"Please make sure you have added metadata to the model.
\n\n
"
"For instance, you can use
\n
"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/"
"export-onnx.py "
"to add metadata to models from whisper
\n
"
);
return
ModelType
::
kUnknown
;
}
auto
model_type_str
=
std
::
string
(
model_type
.
get
());
if
(
model_type_str
.
find
(
"whisper"
)
==
0
)
{
return
ModelType
::
kWhisper
;
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported model_type: %s"
,
model_type
.
get
());
return
ModelType
::
kUnknown
;
}
}
std
::
unique_ptr
<
SpokenLanguageIdentificationImpl
>
SpokenLanguageIdentificationImpl
::
Create
(
const
SpokenLanguageIdentificationConfig
&
config
)
{
ModelType
model_type
=
ModelType
::
kUnknown
;
{
if
(
config
.
whisper
.
encoder
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Only whisper models are supported at present"
);
exit
(
-
1
);
}
auto
buffer
=
ReadFile
(
config
.
whisper
.
encoder
);
model_type
=
GetModelType
(
buffer
.
data
(),
buffer
.
size
(),
config
.
debug
);
}
switch
(
model_type
)
{
case
ModelType
:
:
kWhisper
:
return
std
::
make_unique
<
SpokenLanguageIdentificationWhisperImpl
>
(
config
);
case
ModelType
:
:
kUnknown
:
SHERPA_ONNX_LOGE
(
"Unknown model type for spoken language identification!"
);
return
nullptr
;
}
// unreachable code
return
nullptr
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/spoken-language-identification-impl.h
0 → 100644
查看文件 @
0d258dd
// sherpa-onnx/csrc/spoken-language-identification-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/spoken-language-identification.h"
namespace
sherpa_onnx
{
class
SpokenLanguageIdentificationImpl
{
public
:
virtual
~
SpokenLanguageIdentificationImpl
()
=
default
;
static
std
::
unique_ptr
<
SpokenLanguageIdentificationImpl
>
Create
(
const
SpokenLanguageIdentificationConfig
&
config
);
virtual
std
::
unique_ptr
<
OfflineStream
>
CreateStream
()
const
=
0
;
virtual
std
::
string
Compute
(
OfflineStream
*
s
)
const
=
0
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
...
...
sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
0 → 100644
查看文件 @
0d258dd
// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace
sherpa_onnx
{
class
SpokenLanguageIdentificationWhisperImpl
:
public
SpokenLanguageIdentificationImpl
{
public
:
explicit
SpokenLanguageIdentificationWhisperImpl
(
const
SpokenLanguageIdentificationConfig
&
config
)
:
config_
(
config
),
model_
(
std
::
make_unique
<
OfflineWhisperModel
>
(
config
))
{
Check
();
}
std
::
unique_ptr
<
OfflineStream
>
CreateStream
()
const
override
{
return
std
::
make_unique
<
OfflineStream
>
(
WhisperTag
{});
}
std
::
string
Compute
(
OfflineStream
*
s
)
const
override
{
int32_t
max_num_frames
=
3000
;
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
int32_t
feat_dim
=
s
->
FeatureDim
();
std
::
vector
<
float
>
f
=
s
->
GetFrames
();
int32_t
num_frames
=
f
.
size
()
/
feat_dim
;
// we use 50 here so that there will be some zero tail paddings
if
(
num_frames
>=
max_num_frames
-
50
)
{
SHERPA_ONNX_LOGE
(
"Only waves less than 30 seconds are supported. We process only the "
"first 30 seconds and discard the remaining data"
);
num_frames
=
max_num_frames
-
50
;
}
model_
->
NormalizeFeatures
(
f
.
data
(),
num_frames
,
feat_dim
);
// note that 1000 is an experience-value.
// You can replace 1000 by other values, say, 100.
//
// Since we have removed the 30 seconds constraint, we need
// tail_padding_frames so that whisper is able to detect the eot token.
int32_t
tail_padding_frames
=
1000
;
if
(
config_
.
whisper
.
tail_paddings
>
0
)
{
tail_padding_frames
=
config_
.
whisper
.
tail_paddings
;
}
int32_t
actual_frames
=
std
::
min
(
num_frames
+
tail_padding_frames
,
max_num_frames
);
std
::
array
<
int64_t
,
3
>
shape
{
1
,
actual_frames
,
feat_dim
};
Ort
::
Value
mel
=
Ort
::
Value
::
CreateTensor
<
float
>
(
model_
->
Allocator
(),
shape
.
data
(),
shape
.
size
());
float
*
p_mel
=
mel
.
GetTensorMutableData
<
float
>
();
std
::
copy
(
f
.
data
(),
f
.
data
()
+
num_frames
*
feat_dim
,
p_mel
);
std
::
fill_n
(
p_mel
+
num_frames
*
feat_dim
,
(
actual_frames
-
num_frames
)
*
feat_dim
,
0
);
mel
=
Transpose12
(
model_
->
Allocator
(),
&
mel
);
try
{
auto
cross_kv
=
model_
->
ForwardEncoder
(
std
::
move
(
mel
));
int32_t
lang_id
=
model_
->
DetectLanguage
(
cross_kv
.
first
,
cross_kv
.
second
);
const
auto
&
id2lang
=
model_
->
GetID2Lang
();
if
(
id2lang
.
count
(
lang_id
))
{
return
id2lang
.
at
(
lang_id
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unknown language ID: %d. Return an empty string."
,
lang_id
);
return
""
;
}
}
catch
(
const
Ort
::
Exception
&
ex
)
{
SHERPA_ONNX_LOGE
(
"
\n\n
Caught exception:
\n\n
%s
\n\n
Return an empty result. Number of "
"input frames: %d, Current tail "
"paddings: %d. If you see a lot of such exceptions, please consider "
"using a larger --whisper-tail-paddings"
,
ex
.
what
(),
num_frames
,
tail_padding_frames
);
return
""
;
}
}
private
:
void
Check
()
const
{
if
(
!
model_
->
IsMultiLingual
())
{
SHERPA_ONNX_LOGE
(
"Only whisper multilingual models can be used for spoken language "
"identification. Given: %s,%s"
,
config_
.
whisper
.
encoder
.
c_str
(),
config_
.
whisper
.
decoder
.
c_str
());
exit
(
-
1
);
}
}
private
:
SpokenLanguageIdentificationConfig
config_
;
std
::
unique_ptr
<
OfflineWhisperModel
>
model_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
...
...
sherpa-onnx/csrc/spoken-language-identification.cc
0 → 100644
查看文件 @
0d258dd
// sherpa-onnx/csrc/spoken-language-identification.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include <string>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
namespace
sherpa_onnx
{
void
SpokenLanguageIdentificationWhisperConfig
::
Register
(
ParseOptions
*
po
)
{
po
->
Register
(
"whisper-encoder"
,
&
encoder
,
"Path to then encoder of a whisper multilingual model. Support only "
"tiny, base, small, medium, large."
);
po
->
Register
(
"whisper-decoder"
,
&
decoder
,
"Path to the decoder of a whisper multilingual model. Support only "
"tiny, base, small, medium, large."
);
po
->
Register
(
"whisper-tail-paddings"
,
&
tail_paddings
,
"Suggested value: 300 for multilingual models. "
"Since we have removed the 30-second constraint, we need to add some "
"tail padding frames "
"so that whisper can detect the eot token. Leave it to -1 to use 1000"
);
}
bool
SpokenLanguageIdentificationWhisperConfig
::
Validate
()
const
{
if
(
encoder
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Please provide --whisper-encoder"
);
return
false
;
}
if
(
!
FileExists
(
encoder
))
{
SHERPA_ONNX_LOGE
(
"whisper encoder file %s does not exist"
,
encoder
.
c_str
());
return
false
;
}
if
(
decoder
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Please provide --whisper-decoder"
);
return
false
;
}
if
(
!
FileExists
(
decoder
))
{
SHERPA_ONNX_LOGE
(
"whisper decoder file %s does not exist"
,
decoder
.
c_str
());
return
false
;
}
return
true
;
}
std
::
string
SpokenLanguageIdentificationWhisperConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"SpokenLanguageIdentificationWhisperConfig("
;
os
<<
"encoder=
\"
"
<<
encoder
<<
"
\"
, "
;
os
<<
"decoder=
\"
"
<<
decoder
<<
"
\"
, "
;
os
<<
"tail_paddings="
<<
tail_paddings
<<
")"
;
return
os
.
str
();
}
void
SpokenLanguageIdentificationConfig
::
Register
(
ParseOptions
*
po
)
{
whisper
.
Register
(
po
);
po
->
Register
(
"num-threads"
,
&
num_threads
,
"Number of threads to run the neural network"
);
po
->
Register
(
"debug"
,
&
debug
,
"true to print model information while loading it."
);
po
->
Register
(
"provider"
,
&
provider
,
"Specify a provider to use: cpu, cuda, coreml"
);
}
bool
SpokenLanguageIdentificationConfig
::
Validate
()
const
{
if
(
!
whisper
.
Validate
())
{
return
false
;
}
return
true
;
}
std
::
string
SpokenLanguageIdentificationConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"SpokenLanguageIdentificationConfig("
;
os
<<
"whisper=
\"
"
<<
whisper
.
ToString
()
<<
"
\"
, "
;
os
<<
"num_threads="
<<
num_threads
<<
", "
;
os
<<
"debug="
<<
(
debug
?
"True"
:
"False"
)
<<
", "
;
os
<<
"provider=
\"
"
<<
provider
<<
"
\"
)"
;
return
os
.
str
();
}
SpokenLanguageIdentification
::
SpokenLanguageIdentification
(
const
SpokenLanguageIdentificationConfig
&
config
)
:
impl_
(
SpokenLanguageIdentificationImpl
::
Create
(
config
))
{}
SpokenLanguageIdentification
::~
SpokenLanguageIdentification
()
=
default
;
std
::
unique_ptr
<
OfflineStream
>
SpokenLanguageIdentification
::
CreateStream
()
const
{
return
impl_
->
CreateStream
();
}
std
::
string
SpokenLanguageIdentification
::
Compute
(
OfflineStream
*
s
)
const
{
return
impl_
->
Compute
(
s
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/spoken-language-identification.h
0 → 100644
查看文件 @
0d258dd
// sherpa-onnx/csrc/spoken-language-identification.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/offline-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace
sherpa_onnx
{
struct
SpokenLanguageIdentificationWhisperConfig
{
// Requires a multi-lingual whisper model.
// That is, it supports only tiny, base, small, medium, large.
// Note: It does NOT support tiny.en, base.en, small.en, medium.en
std
::
string
encoder
;
std
::
string
decoder
;
// Number of tail padding frames.
//
// Since we remove the 30-second constraint, we need to add some paddings
// at the end.
//
// Recommended values:
// - 50 for English models
// - 300 for multilingual models
int32_t
tail_paddings
=
-
1
;
SpokenLanguageIdentificationWhisperConfig
()
=
default
;
SpokenLanguageIdentificationWhisperConfig
(
const
std
::
string
&
encoder
,
const
std
::
string
&
decoder
,
int32_t
tail_paddings
)
:
encoder
(
encoder
),
decoder
(
decoder
),
tail_paddings
(
tail_paddings
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
std
::
string
ToString
()
const
;
};
struct
SpokenLanguageIdentificationConfig
{
SpokenLanguageIdentificationWhisperConfig
whisper
;
int32_t
num_threads
=
1
;
bool
debug
=
false
;
std
::
string
provider
=
"cpu"
;
SpokenLanguageIdentificationConfig
()
=
default
;
SpokenLanguageIdentificationConfig
(
const
SpokenLanguageIdentificationWhisperConfig
&
whisper
,
int32_t
num_threads
,
bool
debug
,
const
std
::
string
&
provider
)
:
whisper
(
whisper
),
num_threads
(
num_threads
),
debug
(
debug
),
provider
(
provider
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
std
::
string
ToString
()
const
;
};
class
SpokenLanguageIdentificationImpl
;
class
SpokenLanguageIdentification
{
public
:
explicit
SpokenLanguageIdentification
(
const
SpokenLanguageIdentificationConfig
&
config
);
~
SpokenLanguageIdentification
();
// Create a stream to accept audio samples and compute features
std
::
unique_ptr
<
OfflineStream
>
CreateStream
()
const
;
// Return a string containing the language, e.g., en, zh, de,
// etc.
// Note: en is for English, zh is for Chinese, de is for German, etc.
std
::
string
Compute
(
OfflineStream
*
s
)
const
;
private
:
std
::
unique_ptr
<
SpokenLanguageIdentificationImpl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
...
...
sherpa-onnx/python/csrc/CMakeLists.txt
查看文件 @
0d258dd
...
...
@@ -33,6 +33,7 @@ set(srcs
silero-vad-model-config.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
spoken-language-identification.cc
vad-model-config.cc
vad-model.cc
voice-activity-detector.cc
...
...
sherpa-onnx/python/csrc/sherpa-onnx.cc
查看文件 @
0d258dd
...
...
@@ -22,6 +22,7 @@
#include "sherpa-onnx/python/csrc/online-stream.h"
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
#include "sherpa-onnx/python/csrc/vad-model-config.h"
#include "sherpa-onnx/python/csrc/vad-model.h"
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
...
...
@@ -55,6 +56,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOfflineTts
(
&
m
);
PybindSpeakerEmbeddingExtractor
(
&
m
);
PybindSpeakerEmbeddingManager
(
&
m
);
PybindSpokenLanguageIdentification
(
&
m
);
PybindAlsa
(
&
m
);
}
...
...
sherpa-onnx/python/csrc/spoken-language-identification.cc
0 → 100644
查看文件 @
0d258dd
// sherpa-onnx/python/csrc/spoken-language-identification.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
#include <string>
#include "sherpa-onnx/csrc/spoken-language-identification.h"
namespace
sherpa_onnx
{
static
void
PybindSpokenLanguageIdentificationWhisperConfig
(
py
::
module
*
m
)
{
using
PyClass
=
SpokenLanguageIdentificationWhisperConfig
;
py
::
class_
<
PyClass
>
(
*
m
,
"SpokenLanguageIdentificationWhisperConfig"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
const
std
::
string
&
,
const
std
::
string
&
,
int32_t
>
(),
py
::
arg
(
"encoder"
),
py
::
arg
(
"decoder"
),
py
::
arg
(
"tail_paddings"
)
=
-
1
)
.
def_readwrite
(
"encoder"
,
&
PyClass
::
encoder
)
.
def_readwrite
(
"decoder"
,
&
PyClass
::
decoder
)
.
def_readwrite
(
"tail_paddings"
,
&
PyClass
::
tail_paddings
)
.
def
(
"validate"
,
&
PyClass
::
Validate
)
.
def
(
"__str__"
,
&
PyClass
::
ToString
);
}
static
void
PybindSpokenLanguageIdentificationConfig
(
py
::
module
*
m
)
{
PybindSpokenLanguageIdentificationWhisperConfig
(
m
);
using
PyClass
=
SpokenLanguageIdentificationConfig
;
py
::
class_
<
PyClass
>
(
*
m
,
"SpokenLanguageIdentificationConfig"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
const
SpokenLanguageIdentificationWhisperConfig
&
,
int32_t
,
bool
,
const
std
::
string
>
(),
py
::
arg
(
"whisper"
),
py
::
arg
(
"num_threads"
)
=
1
,
py
::
arg
(
"debug"
)
=
false
,
py
::
arg
(
"provider"
)
=
"cpu"
)
.
def_readwrite
(
"whisper"
,
&
PyClass
::
whisper
)
.
def_readwrite
(
"num_threads"
,
&
PyClass
::
num_threads
)
.
def_readwrite
(
"debug"
,
&
PyClass
::
debug
)
.
def_readwrite
(
"provider"
,
&
PyClass
::
provider
)
.
def
(
"validate"
,
&
PyClass
::
Validate
)
.
def
(
"__str__"
,
&
PyClass
::
ToString
);
}
void
PybindSpokenLanguageIdentification
(
py
::
module
*
m
)
{
PybindSpokenLanguageIdentificationConfig
(
m
);
using
PyClass
=
SpokenLanguageIdentification
;
py
::
class_
<
PyClass
>
(
*
m
,
"SpokenLanguageIdentification"
)
.
def
(
py
::
init
<
const
SpokenLanguageIdentificationConfig
&>
(),
py
::
arg
(
"config"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"create_stream"
,
&
PyClass
::
CreateStream
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"compute"
,
&
PyClass
::
Compute
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/python/csrc/spoken-language-identification.h
0 → 100644
查看文件 @
0d258dd
// sherpa-onnx/python/csrc/spoken-language-identification.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
#define SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace
sherpa_onnx
{
void
PybindSpokenLanguageIdentification
(
py
::
module
*
m
);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
...
...
sherpa-onnx/python/sherpa_onnx/__init__.py
查看文件 @
0d258dd
...
...
@@ -13,6 +13,9 @@ from _sherpa_onnx import (
SpeakerEmbeddingExtractorConfig
,
SpeakerEmbeddingManager
,
SpeechSegment
,
SpokenLanguageIdentification
,
SpokenLanguageIdentificationConfig
,
SpokenLanguageIdentificationWhisperConfig
,
VadModel
,
VadModelConfig
,
VoiceActivityDetector
,
...
...
请
注册
或
登录
后发表评论