Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-05-10 12:15:39 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-05-10 12:15:39 +0800
Commit
17cd3a5f01a62e23aa4bf06adfc6d31425d25a82
17cd3a5f
1 parent
5d8c35e4
Add C++ runtime for non-streaming faster conformer transducer from NeMo. (#854)
显示空白字符变更
内嵌
并排对比
正在显示
31 个修改的文件
包含
1091 行增加
和
151 行删除
.github/scripts/test-offline-transducer.sh
.github/workflows/linux.yaml
.github/workflows/macos.yaml
.gitignore
python-api-examples/offline-nemo-ctc-decode-files.py
python-api-examples/offline-nemo-transducer-decode-files.py
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/features.h
sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
sherpa-onnx/csrc/offline-recognizer-impl.cc
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
sherpa-onnx/csrc/offline-recognizer.h
sherpa-onnx/csrc/offline-stream.cc
sherpa-onnx/csrc/offline-stream.h
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h
sherpa-onnx/csrc/offline-transducer-nemo-model.cc
sherpa-onnx/csrc/offline-transducer-nemo-model.h
sherpa-onnx/csrc/online-recognizer-ctc-impl.h
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
sherpa-onnx/csrc/slice.h
sherpa-onnx/csrc/symbol-table.cc
sherpa-onnx/csrc/symbol-table.h
sherpa-onnx/csrc/utils.cc
sherpa-onnx/python/csrc/offline-recognizer.cc
sherpa-onnx/python/csrc/offline-stream.cc
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
.github/scripts/test-offline-transducer.sh
查看文件 @
17cd3a5
...
...
@@ -13,6 +13,105 @@ echo "PATH: $PATH"
which
$EXE
log
"------------------------------------------------------------------------"
log
"Run Nemo fast conformer hybrid transducer ctc models (transducer branch)"
log
"------------------------------------------------------------------------"
url
=
https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k.tar.bz2
name
=
$(
basename
$url
)
curl -SL -O
$url
tar xvf
$name
rm
$name
repo
=
$(
basename -s .tar.bz2
$name
)
ls -lh
$repo
log
"test
$repo
"
test_wavs
=(
de-german.wav
es-spanish.wav
hr-croatian.wav
po-polish.wav
uk-ukrainian.wav
en-english.wav
fr-french.wav
it-italian.wav
ru-russian.wav
)
for
w
in
${
test_wavs
[@]
}
;
do
time
$EXE
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder.onnx
\
--decoder
=
$repo
/decoder.onnx
\
--joiner
=
$repo
/joiner.onnx
\
--debug
=
1
\
$repo
/test_wavs/
$w
done
rm -rf
$repo
url
=
https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-24500.tar.bz2
name
=
$(
basename
$url
)
curl -SL -O
$url
tar xvf
$name
rm
$name
repo
=
$(
basename -s .tar.bz2
$name
)
ls -lh
$repo
log
"Test
$repo
"
time
$EXE
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder.onnx
\
--decoder
=
$repo
/decoder.onnx
\
--joiner
=
$repo
/joiner.onnx
\
--debug
=
1
\
$repo
/test_wavs/en-english.wav
rm -rf
$repo
url
=
https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-es-1424.tar.bz2
name
=
$(
basename
$url
)
curl -SL -O
$url
tar xvf
$name
rm
$name
repo
=
$(
basename -s .tar.bz2
$name
)
ls -lh
$repo
log
"test
$repo
"
time
$EXE
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder.onnx
\
--decoder
=
$repo
/decoder.onnx
\
--joiner
=
$repo
/joiner.onnx
\
--debug
=
1
\
$repo
/test_wavs/es-spanish.wav
rm -rf
$repo
url
=
https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288.tar.bz2
name
=
$(
basename
$url
)
curl -SL -O
$url
tar xvf
$name
rm
$name
repo
=
$(
basename -s .tar.bz2
$name
)
ls -lh
$repo
log
"Test
$repo
"
time
$EXE
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder.onnx
\
--decoder
=
$repo
/decoder.onnx
\
--joiner
=
$repo
/joiner.onnx
\
--debug
=
1
\
$repo
/test_wavs/en-english.wav
\
$repo
/test_wavs/de-german.wav
\
$repo
/test_wavs/fr-french.wav
\
$repo
/test_wavs/es-spanish.wav
rm -rf
$repo
log
"------------------------------------------------------------"
log
"Run Conformer transducer (English)"
log
"------------------------------------------------------------"
...
...
.github/workflows/linux.yaml
查看文件 @
17cd3a5
...
...
@@ -128,6 +128,14 @@ jobs:
name
:
release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path
:
install/*
-
name
:
Test offline transducer
shell
:
bash
run
:
|
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-transducer.sh
-
name
:
Test spoken language identification (C++ API)
shell
:
bash
run
:
|
...
...
@@ -215,14 +223,6 @@ jobs:
.github/scripts/test-online-paraformer.sh
-
name
:
Test offline transducer
shell
:
bash
run
:
|
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-transducer.sh
-
name
:
Test online transducer
shell
:
bash
run
:
|
...
...
.github/workflows/macos.yaml
查看文件 @
17cd3a5
...
...
@@ -107,6 +107,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
-
name
:
Test offline transducer
shell
:
bash
run
:
|
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-transducer.sh
-
name
:
Test online CTC
shell
:
bash
run
:
|
...
...
@@ -192,14 +200,6 @@ jobs:
.github/scripts/test-offline-ctc.sh
-
name
:
Test offline transducer
shell
:
bash
run
:
|
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-transducer.sh
-
name
:
Test online transducer
shell
:
bash
run
:
|
...
...
.gitignore
查看文件 @
17cd3a5
...
...
@@ -104,3 +104,4 @@ sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
sherpa-onnx-ced-*
node_modules
package-lock.json
sherpa-onnx-nemo-*
...
...
python-api-examples/offline-nemo-ctc-decode-files.py
0 → 100755
查看文件 @
17cd3a5
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming CTC model from NeMo
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
The example model supports 10 languages and it is converted from
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
"""
from
pathlib
import
Path
import
sherpa_onnx
import
soundfile
as
sf
def
create_recognizer
():
model
=
"./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/model.onnx"
tokens
=
"./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt"
test_wav
=
"./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav"
if
not
Path
(
model
)
.
is_file
()
or
not
Path
(
test_wav
)
.
is_file
():
raise
ValueError
(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return
(
sherpa_onnx
.
OfflineRecognizer
.
from_nemo_ctc
(
model
=
model
,
tokens
=
tokens
,
debug
=
True
,
),
test_wav
,
)
def
main
():
recognizer
,
wave_filename
=
create_recognizer
()
audio
,
sample_rate
=
sf
.
read
(
wave_filename
,
dtype
=
"float32"
,
always_2d
=
True
)
audio
=
audio
[:,
0
]
# only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
stream
=
recognizer
.
create_stream
()
stream
.
accept_waveform
(
sample_rate
,
audio
)
recognizer
.
decode_stream
(
stream
)
print
(
wave_filename
)
print
(
stream
.
result
)
if
__name__
==
"__main__"
:
main
()
...
...
python-api-examples/offline-nemo-transducer-decode-files.py
0 → 100755
查看文件 @
17cd3a5
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming transducer model from NeMo
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
The example model supports 10 languages and it is converted from
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc
"""
from
pathlib
import
Path
import
sherpa_onnx
import
soundfile
as
sf
def
create_recognizer
():
encoder
=
"./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/encoder.onnx"
decoder
=
"./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/decoder.onnx"
joiner
=
"./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/joiner.onnx"
tokens
=
"./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt"
test_wav
=
"./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav"
# test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav"
if
not
Path
(
encoder
)
.
is_file
()
or
not
Path
(
test_wav
)
.
is_file
():
raise
ValueError
(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return
(
sherpa_onnx
.
OfflineRecognizer
.
from_transducer
(
encoder
=
encoder
,
decoder
=
decoder
,
joiner
=
joiner
,
tokens
=
tokens
,
model_type
=
"nemo_transducer"
,
debug
=
True
,
),
test_wav
,
)
def
main
():
recognizer
,
wave_filename
=
create_recognizer
()
audio
,
sample_rate
=
sf
.
read
(
wave_filename
,
dtype
=
"float32"
,
always_2d
=
True
)
audio
=
audio
[:,
0
]
# only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
stream
=
recognizer
.
create_stream
()
stream
.
accept_waveform
(
sample_rate
,
audio
)
recognizer
.
decode_stream
(
stream
)
print
(
wave_filename
)
print
(
stream
.
result
)
if
__name__
==
"__main__"
:
main
()
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
17cd3a5
...
...
@@ -40,9 +40,11 @@ set(sources
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
offline-transducer-greedy-search-decoder.cc
offline-transducer-greedy-search-nemo-decoder.cc
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
offline-transducer-nemo-model.cc
offline-wenet-ctc-model-config.cc
offline-wenet-ctc-model.cc
offline-whisper-greedy-search-decoder.cc
...
...
sherpa-onnx/csrc/features.h
查看文件 @
17cd3a5
...
...
@@ -56,6 +56,19 @@ struct FeatureExtractorConfig {
bool
remove_dc_offset
=
true
;
// Subtract mean of wave before FFT.
std
::
string
window_type
=
"povey"
;
// e.g. Hamming window
// For models from NeMo
// This option is not exposed and is set internally when loading models.
// Possible values:
// - per_feature
// - all_features (not implemented yet)
// - fixed_mean (not implemented)
// - fixed_std (not implemented)
// - or just leave it to empty
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
// for details
std
::
string
nemo_normalize_type
;
std
::
string
ToString
()
const
;
void
Register
(
ParseOptions
*
po
);
...
...
sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
查看文件 @
17cd3a5
...
...
@@ -68,7 +68,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
:
config_
(
config
),
model_
(
OnlineTransducerModel
::
Create
(
config
.
model_config
)),
sym_
(
config
.
model_config
.
tokens
)
{
if
(
sym_
.
c
ontains
(
"<unk>"
))
{
if
(
sym_
.
C
ontains
(
"<unk>"
))
{
unk_id_
=
sym_
[
"<unk>"
];
}
...
...
@@ -87,7 +87,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
:
config_
(
config
),
model_
(
OnlineTransducerModel
::
Create
(
mgr
,
config
.
model_config
)),
sym_
(
mgr
,
config
.
model_config
.
tokens
)
{
if
(
sym_
.
c
ontains
(
"<unk>"
))
{
if
(
sym_
.
C
ontains
(
"<unk>"
))
{
unk_id_
=
sym_
[
"<unk>"
];
}
...
...
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
查看文件 @
17cd3a5
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2023
-2024
Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
...
...
sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
查看文件 @
17cd3a5
...
...
@@ -38,7 +38,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
std
::
string
text
;
for
(
int32_t
i
=
0
;
i
!=
src
.
tokens
.
size
();
++
i
)
{
if
(
sym_table
.
c
ontains
(
"SIL"
)
&&
src
.
tokens
[
i
]
==
sym_table
[
"SIL"
])
{
if
(
sym_table
.
C
ontains
(
"SIL"
)
&&
src
.
tokens
[
i
]
==
sym_table
[
"SIL"
])
{
// tdnn models from yesno have a SIL token, we should remove it.
continue
;
}
...
...
@@ -103,9 +103,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
decoder_
=
std
::
make_unique
<
OfflineCtcFstDecoder
>
(
config_
.
ctc_fst_decoder_config
);
}
else
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
if
(
!
symbol_table_
.
contains
(
"<blk>"
)
&&
!
symbol_table_
.
contains
(
"<eps>"
)
&&
!
symbol_table_
.
contains
(
"<blank>"
))
{
if
(
!
symbol_table_
.
Contains
(
"<blk>"
)
&&
!
symbol_table_
.
Contains
(
"<eps>"
)
&&
!
symbol_table_
.
Contains
(
"<blank>"
))
{
SHERPA_ONNX_LOGE
(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> or <blank> and its ID."
);
...
...
@@ -113,12 +113,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
int32_t
blank_id
=
0
;
if
(
symbol_table_
.
c
ontains
(
"<blk>"
))
{
if
(
symbol_table_
.
C
ontains
(
"<blk>"
))
{
blank_id
=
symbol_table_
[
"<blk>"
];
}
else
if
(
symbol_table_
.
c
ontains
(
"<eps>"
))
{
}
else
if
(
symbol_table_
.
C
ontains
(
"<eps>"
))
{
// for tdnn models of the yesno recipe from icefall
blank_id
=
symbol_table_
[
"<eps>"
];
}
else
if
(
symbol_table_
.
c
ontains
(
"<blank>"
))
{
}
else
if
(
symbol_table_
.
C
ontains
(
"<blank>"
))
{
// for Wenet CTC models
blank_id
=
symbol_table_
[
"<blank>"
];
}
...
...
sherpa-onnx/csrc/offline-recognizer-impl.cc
查看文件 @
17cd3a5
...
...
@@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
...
...
@@ -23,6 +24,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
const
auto
&
model_type
=
config
.
model_config
.
model_type
;
if
(
model_type
==
"transducer"
)
{
return
std
::
make_unique
<
OfflineRecognizerTransducerImpl
>
(
config
);
}
else
if
(
model_type
==
"nemo_transducer"
)
{
return
std
::
make_unique
<
OfflineRecognizerTransducerNeMoImpl
>
(
config
);
}
else
if
(
model_type
==
"paraformer"
)
{
return
std
::
make_unique
<
OfflineRecognizerParaformerImpl
>
(
config
);
}
else
if
(
model_type
==
"nemo_ctc"
||
model_type
==
"tdnn"
||
...
...
@@ -122,6 +125,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return
std
::
make_unique
<
OfflineRecognizerParaformerImpl
>
(
config
);
}
if
(
model_type
==
"EncDecHybridRNNTCTCBPEModel"
&&
!
config
.
model_config
.
transducer
.
decoder_filename
.
empty
()
&&
!
config
.
model_config
.
transducer
.
joiner_filename
.
empty
())
{
return
std
::
make_unique
<
OfflineRecognizerTransducerNeMoImpl
>
(
config
);
}
if
(
model_type
==
"EncDecCTCModelBPE"
||
model_type
==
"EncDecHybridRNNTCTCBPEModel"
||
model_type
==
"tdnn"
||
model_type
==
"zipformer2_ctc"
||
model_type
==
"wenet_ctc"
)
{
...
...
@@ -155,6 +164,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
const
auto
&
model_type
=
config
.
model_config
.
model_type
;
if
(
model_type
==
"transducer"
)
{
return
std
::
make_unique
<
OfflineRecognizerTransducerImpl
>
(
mgr
,
config
);
}
else
if
(
model_type
==
"nemo_transducer"
)
{
return
std
::
make_unique
<
OfflineRecognizerTransducerNeMoImpl
>
(
mgr
,
config
);
}
else
if
(
model_type
==
"paraformer"
)
{
return
std
::
make_unique
<
OfflineRecognizerParaformerImpl
>
(
mgr
,
config
);
}
else
if
(
model_type
==
"nemo_ctc"
||
model_type
==
"tdnn"
||
...
...
@@ -254,6 +265,12 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return
std
::
make_unique
<
OfflineRecognizerParaformerImpl
>
(
mgr
,
config
);
}
if
(
model_type
==
"EncDecHybridRNNTCTCBPEModel"
&&
!
config
.
model_config
.
transducer
.
decoder_filename
.
empty
()
&&
!
config
.
model_config
.
transducer
.
joiner_filename
.
empty
())
{
return
std
::
make_unique
<
OfflineRecognizerTransducerNeMoImpl
>
(
mgr
,
config
);
}
if
(
model_type
==
"EncDecCTCModelBPE"
||
model_type
==
"EncDecHybridRNNTCTCBPEModel"
||
model_type
==
"tdnn"
||
model_type
==
"zipformer2_ctc"
||
model_type
==
"wenet_ctc"
)
{
...
...
sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
0 → 100644
查看文件 @
17cd3a5
// sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
//
// Copyright (c) 2022-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#include <fstream>
#include <ios>
#include <memory>
#include <regex> // NOLINT
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
#include "sherpa-onnx/csrc/utils.h"
namespace
sherpa_onnx
{
// defined in ./offline-recognizer-transducer-impl.h
OfflineRecognitionResult
Convert
(
const
OfflineTransducerDecoderResult
&
src
,
const
SymbolTable
&
sym_table
,
int32_t
frame_shift_ms
,
int32_t
subsampling_factor
);
class
OfflineRecognizerTransducerNeMoImpl
:
public
OfflineRecognizerImpl
{
public
:
explicit
OfflineRecognizerTransducerNeMoImpl
(
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
symbol_table_
(
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineTransducerNeMoModel
>
(
config_
.
model_config
))
{
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OfflineTransducerGreedySearchNeMoDecoder
>
(
model_
.
get
(),
config_
.
blank_penalty
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config_
.
decoding_method
.
c_str
());
exit
(
-
1
);
}
PostInit
();
}
#if __ANDROID_API__ >= 9
explicit
OfflineRecognizerTransducerNeMoImpl
(
AAssetManager
*
mgr
,
const
OfflineRecognizerConfig
&
config
)
:
config_
(
config
),
symbol_table_
(
mgr
,
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineTransducerNeMoModel
>
(
mgr
,
config_
.
model_config
))
{
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OfflineTransducerGreedySearchNeMoDecoder
>
(
model_
.
get
(),
config_
.
blank_penalty
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported decoding method: %s"
,
config_
.
decoding_method
.
c_str
());
exit
(
-
1
);
}
PostInit
();
}
#endif
std
::
unique_ptr
<
OfflineStream
>
CreateStream
()
const
override
{
return
std
::
make_unique
<
OfflineStream
>
(
config_
.
feat_config
);
}
void
DecodeStreams
(
OfflineStream
**
ss
,
int32_t
n
)
const
override
{
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
int32_t
feat_dim
=
ss
[
0
]
->
FeatureDim
();
std
::
vector
<
Ort
::
Value
>
features
;
features
.
reserve
(
n
);
std
::
vector
<
std
::
vector
<
float
>>
features_vec
(
n
);
std
::
vector
<
int64_t
>
features_length_vec
(
n
);
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
auto
f
=
ss
[
i
]
->
GetFrames
();
int32_t
num_frames
=
f
.
size
()
/
feat_dim
;
features_length_vec
[
i
]
=
num_frames
;
features_vec
[
i
]
=
std
::
move
(
f
);
std
::
array
<
int64_t
,
2
>
shape
=
{
num_frames
,
feat_dim
};
Ort
::
Value
x
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
features_vec
[
i
].
data
(),
features_vec
[
i
].
size
(),
shape
.
data
(),
shape
.
size
());
features
.
push_back
(
std
::
move
(
x
));
}
std
::
vector
<
const
Ort
::
Value
*>
features_pointer
(
n
);
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
features_pointer
[
i
]
=
&
features
[
i
];
}
std
::
array
<
int64_t
,
1
>
features_length_shape
=
{
n
};
Ort
::
Value
x_length
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
features_length_vec
.
data
(),
n
,
features_length_shape
.
data
(),
features_length_shape
.
size
());
Ort
::
Value
x
=
PadSequence
(
model_
->
Allocator
(),
features_pointer
,
0
);
auto
t
=
model_
->
RunEncoder
(
std
::
move
(
x
),
std
::
move
(
x_length
));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
// t[1] encoder_out_length, int64 tensor, (batch_size,)
Ort
::
Value
encoder_out
=
Transpose12
(
model_
->
Allocator
(),
&
t
[
0
]);
auto
results
=
decoder_
->
Decode
(
std
::
move
(
encoder_out
),
std
::
move
(
t
[
1
]));
int32_t
frame_shift_ms
=
10
;
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
auto
r
=
Convert
(
results
[
i
],
symbol_table_
,
frame_shift_ms
,
model_
->
SubsamplingFactor
());
ss
[
i
]
->
SetResult
(
r
);
}
}
private
:
void
PostInit
()
{
config_
.
feat_config
.
nemo_normalize_type
=
model_
->
FeatureNormalizationMethod
();
config_
.
feat_config
.
low_freq
=
0
;
// config_.feat_config.high_freq = 8000;
config_
.
feat_config
.
is_librosa
=
true
;
config_
.
feat_config
.
remove_dc_offset
=
false
;
// config_.feat_config.window_type = "hann";
config_
.
feat_config
.
dither
=
0
;
config_
.
feat_config
.
nemo_normalize_type
=
model_
->
FeatureNormalizationMethod
();
int32_t
vocab_size
=
model_
->
VocabSize
();
// check the blank ID
if
(
!
symbol_table_
.
Contains
(
"<blk>"
))
{
SHERPA_ONNX_LOGE
(
"tokens.txt does not include the blank token <blk>"
);
exit
(
-
1
);
}
if
(
symbol_table_
[
"<blk>"
]
!=
vocab_size
-
1
)
{
SHERPA_ONNX_LOGE
(
"<blk> is not the last token!"
);
exit
(
-
1
);
}
if
(
symbol_table_
.
NumSymbols
()
!=
vocab_size
)
{
SHERPA_ONNX_LOGE
(
"number of lines in tokens.txt %d != %d (vocab_size)"
,
symbol_table_
.
NumSymbols
(),
vocab_size
);
exit
(
-
1
);
}
}
private
:
OfflineRecognizerConfig
config_
;
SymbolTable
symbol_table_
;
std
::
unique_ptr
<
OfflineTransducerNeMoModel
>
model_
;
std
::
unique_ptr
<
OfflineTransducerDecoder
>
decoder_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
...
...
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
查看文件 @
17cd3a5
...
...
@@ -35,7 +35,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
std
::
string
text
;
for
(
auto
i
:
src
.
tokens
)
{
if
(
!
sym_table
.
c
ontains
(
i
))
{
if
(
!
sym_table
.
C
ontains
(
i
))
{
continue
;
}
...
...
sherpa-onnx/csrc/offline-recognizer.h
查看文件 @
17cd3a5
...
...
@@ -14,6 +14,7 @@
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
...
...
@@ -26,7 +27,7 @@ namespace sherpa_onnx {
struct
OfflineRecognitionResult
;
struct
OfflineRecognizerConfig
{
Offline
FeatureExtractorConfig
feat_config
;
FeatureExtractorConfig
feat_config
;
OfflineModelConfig
model_config
;
OfflineLMConfig
lm_config
;
OfflineCtcFstDecoderConfig
ctc_fst_decoder_config
;
...
...
@@ -44,7 +45,7 @@ struct OfflineRecognizerConfig {
OfflineRecognizerConfig
()
=
default
;
OfflineRecognizerConfig
(
const
Offline
FeatureExtractorConfig
&
feat_config
,
const
FeatureExtractorConfig
&
feat_config
,
const
OfflineModelConfig
&
model_config
,
const
OfflineLMConfig
&
lm_config
,
const
OfflineCtcFstDecoderConfig
&
ctc_fst_decoder_config
,
const
std
::
string
&
decoding_method
,
int32_t
max_active_paths
,
...
...
sherpa-onnx/csrc/offline-stream.cc
查看文件 @
17cd3a5
...
...
@@ -52,42 +52,25 @@ static void ComputeMeanAndInvStd(const float *p, int32_t num_rows,
}
}
void
OfflineFeatureExtractorConfig
::
Register
(
ParseOptions
*
po
)
{
po
->
Register
(
"sample-rate"
,
&
sampling_rate
,
"Sampling rate of the input waveform. "
"Note: You can have a different "
"sample rate for the input waveform. We will do resampling "
"inside the feature extractor"
);
po
->
Register
(
"feat-dim"
,
&
feature_dim
,
"Feature dimension. Must match the one expected by the model."
);
}
std
::
string
OfflineFeatureExtractorConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"OfflineFeatureExtractorConfig("
;
os
<<
"sampling_rate="
<<
sampling_rate
<<
", "
;
os
<<
"feature_dim="
<<
feature_dim
<<
")"
;
return
os
.
str
();
}
class
OfflineStream
::
Impl
{
public
:
explicit
Impl
(
const
Offline
FeatureExtractorConfig
&
config
,
explicit
Impl
(
const
FeatureExtractorConfig
&
config
,
ContextGraphPtr
context_graph
)
:
config_
(
config
),
context_graph_
(
context_graph
)
{
opts_
.
frame_opts
.
dither
=
0
;
opts_
.
frame_opts
.
snip_edges
=
false
;
opts_
.
frame_opts
.
dither
=
config
.
dither
;
opts_
.
frame_opts
.
snip_edges
=
config
.
snip_edges
;
opts_
.
frame_opts
.
samp_freq
=
config
.
sampling_rate
;
opts_
.
frame_opts
.
frame_shift_ms
=
config
.
frame_shift_ms
;
opts_
.
frame_opts
.
frame_length_ms
=
config
.
frame_length_ms
;
opts_
.
frame_opts
.
remove_dc_offset
=
config
.
remove_dc_offset
;
opts_
.
frame_opts
.
window_type
=
config
.
window_type
;
opts_
.
mel_opts
.
num_bins
=
config
.
feature_dim
;
// Please see
// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
// and
// https://github.com/k2-fsa/sherpa-onnx/issues/514
opts_
.
mel_opts
.
high_freq
=
-
400
;
opts_
.
mel_opts
.
high_freq
=
config
.
high_freq
;
opts_
.
mel_opts
.
low_freq
=
config
.
low_freq
;
opts_
.
mel_opts
.
is_librosa
=
config
.
is_librosa
;
fbank_
=
std
::
make_unique
<
knf
::
OnlineFbank
>
(
opts_
);
}
...
...
@@ -237,7 +220,7 @@ class OfflineStream::Impl {
}
private
:
Offline
FeatureExtractorConfig
config_
;
FeatureExtractorConfig
config_
;
std
::
unique_ptr
<
knf
::
OnlineFbank
>
fbank_
;
std
::
unique_ptr
<
knf
::
OnlineWhisperFbank
>
whisper_fbank_
;
knf
::
FbankOptions
opts_
;
...
...
@@ -245,8 +228,7 @@ class OfflineStream::Impl {
ContextGraphPtr
context_graph_
;
};
OfflineStream
::
OfflineStream
(
const
OfflineFeatureExtractorConfig
&
config
/*= {}*/
,
OfflineStream
::
OfflineStream
(
const
FeatureExtractorConfig
&
config
/*= {}*/
,
ContextGraphPtr
context_graph
/*= nullptr*/
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
,
context_graph
))
{}
...
...
sherpa-onnx/csrc/offline-stream.h
查看文件 @
17cd3a5
...
...
@@ -11,6 +11,7 @@
#include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace
sherpa_onnx
{
...
...
@@ -32,46 +33,12 @@ struct OfflineRecognitionResult {
std
::
string
AsJsonString
()
const
;
};
struct
OfflineFeatureExtractorConfig
{
// Sampling rate used by the feature extractor. If it is different from
// the sampling rate of the input waveform, we will do resampling inside.
int32_t
sampling_rate
=
16000
;
// Feature dimension
int32_t
feature_dim
=
80
;
// Set internally by some models, e.g., paraformer and wenet CTC models set
// it to false.
// This parameter is not exposed to users from the commandline
// If true, the feature extractor expects inputs to be normalized to
// the range [-1, 1].
// If false, we will multiply the inputs by 32768
bool
normalize_samples
=
true
;
// For models from NeMo
// This option is not exposed and is set internally when loading models.
// Possible values:
// - per_feature
// - all_features (not implemented yet)
// - fixed_mean (not implemented)
// - fixed_std (not implemented)
// - or just leave it to empty
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
// for details
std
::
string
nemo_normalize_type
;
std
::
string
ToString
()
const
;
void
Register
(
ParseOptions
*
po
);
};
struct
WhisperTag
{};
struct
CEDTag
{};
class
OfflineStream
{
public
:
explicit
OfflineStream
(
const
Offline
FeatureExtractorConfig
&
config
=
{},
explicit
OfflineStream
(
const
FeatureExtractorConfig
&
config
=
{},
ContextGraphPtr
context_graph
=
{});
explicit
OfflineStream
(
WhisperTag
tag
);
...
...
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
查看文件 @
17cd3a5
...
...
@@ -14,7 +14,7 @@ namespace sherpa_onnx {
class
OfflineTransducerGreedySearchDecoder
:
public
OfflineTransducerDecoder
{
public
:
explicit
OfflineTransducerGreedySearchDecoder
(
OfflineTransducerModel
*
model
,
OfflineTransducerGreedySearchDecoder
(
OfflineTransducerModel
*
model
,
float
blank_penalty
)
:
model_
(
model
),
blank_penalty_
(
blank_penalty
)
{}
...
...
sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
0 → 100644
查看文件 @
17cd3a5
// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h"
#include <algorithm>
#include <iterator>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace
sherpa_onnx
{
static
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
BuildDecoderInput
(
int32_t
token
,
OrtAllocator
*
allocator
)
{
std
::
array
<
int64_t
,
2
>
shape
{
1
,
1
};
Ort
::
Value
decoder_input
=
Ort
::
Value
::
CreateTensor
<
int32_t
>
(
allocator
,
shape
.
data
(),
shape
.
size
());
std
::
array
<
int64_t
,
1
>
length_shape
{
1
};
Ort
::
Value
decoder_input_length
=
Ort
::
Value
::
CreateTensor
<
int32_t
>
(
allocator
,
length_shape
.
data
(),
length_shape
.
size
());
int32_t
*
p
=
decoder_input
.
GetTensorMutableData
<
int32_t
>
();
int32_t
*
p_length
=
decoder_input_length
.
GetTensorMutableData
<
int32_t
>
();
p
[
0
]
=
token
;
p_length
[
0
]
=
1
;
return
{
std
::
move
(
decoder_input
),
std
::
move
(
decoder_input_length
)};
}
static
OfflineTransducerDecoderResult
DecodeOne
(
const
float
*
p
,
int32_t
num_rows
,
int32_t
num_cols
,
OfflineTransducerNeMoModel
*
model
,
float
blank_penalty
)
{
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
OfflineTransducerDecoderResult
ans
;
int32_t
vocab_size
=
model
->
VocabSize
();
int32_t
blank_id
=
vocab_size
-
1
;
auto
decoder_input_pair
=
BuildDecoderInput
(
blank_id
,
model
->
Allocator
());
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
decoder_output_pair
=
model
->
RunDecoder
(
std
::
move
(
decoder_input_pair
.
first
),
std
::
move
(
decoder_input_pair
.
second
),
model
->
GetDecoderInitStates
(
1
));
std
::
array
<
int64_t
,
3
>
encoder_shape
{
1
,
num_cols
,
1
};
for
(
int32_t
t
=
0
;
t
!=
num_rows
;
++
t
)
{
Ort
::
Value
cur_encoder_out
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
const_cast
<
float
*>
(
p
)
+
t
*
num_cols
,
num_cols
,
encoder_shape
.
data
(),
encoder_shape
.
size
());
Ort
::
Value
logit
=
model
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
View
(
&
decoder_output_pair
.
first
));
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
if
(
blank_penalty
>
0
)
{
p_logit
[
blank_id
]
-=
blank_penalty
;
}
auto
y
=
static_cast
<
int32_t
>
(
std
::
distance
(
static_cast
<
const
float
*>
(
p_logit
),
std
::
max_element
(
static_cast
<
const
float
*>
(
p_logit
),
static_cast
<
const
float
*>
(
p_logit
)
+
vocab_size
)));
if
(
y
!=
blank_id
)
{
ans
.
tokens
.
push_back
(
y
);
ans
.
timestamps
.
push_back
(
t
);
decoder_input_pair
=
BuildDecoderInput
(
y
,
model
->
Allocator
());
decoder_output_pair
=
model
->
RunDecoder
(
std
::
move
(
decoder_input_pair
.
first
),
std
::
move
(
decoder_input_pair
.
second
),
std
::
move
(
decoder_output_pair
.
second
));
}
// if (y != blank_id)
}
// for (int32_t i = 0; i != num_rows; ++i)
return
ans
;
}
std
::
vector
<
OfflineTransducerDecoderResult
>
OfflineTransducerGreedySearchNeMoDecoder
::
Decode
(
Ort
::
Value
encoder_out
,
Ort
::
Value
encoder_out_length
,
OfflineStream
**
/*ss = nullptr*/
,
int32_t
/*n= 0*/
)
{
auto
shape
=
encoder_out
.
GetTensorTypeAndShapeInfo
().
GetShape
();
int32_t
batch_size
=
static_cast
<
int32_t
>
(
shape
[
0
]);
int32_t
dim1
=
static_cast
<
int32_t
>
(
shape
[
1
]);
int32_t
dim2
=
static_cast
<
int32_t
>
(
shape
[
2
]);
const
int64_t
*
p_length
=
encoder_out_length
.
GetTensorData
<
int64_t
>
();
const
float
*
p
=
encoder_out
.
GetTensorData
<
float
>
();
std
::
vector
<
OfflineTransducerDecoderResult
>
ans
(
batch_size
);
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
const
float
*
this_p
=
p
+
dim1
*
dim2
*
i
;
int32_t
this_len
=
p_length
[
i
];
ans
[
i
]
=
DecodeOne
(
this_p
,
this_len
,
dim2
,
model_
,
blank_penalty_
);
}
return
ans
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h
0 → 100644
查看文件 @
17cd3a5
// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
namespace
sherpa_onnx
{
class
OfflineTransducerGreedySearchNeMoDecoder
:
public
OfflineTransducerDecoder
{
public
:
OfflineTransducerGreedySearchNeMoDecoder
(
OfflineTransducerNeMoModel
*
model
,
float
blank_penalty
)
:
model_
(
model
),
blank_penalty_
(
blank_penalty
)
{}
std
::
vector
<
OfflineTransducerDecoderResult
>
Decode
(
Ort
::
Value
encoder_out
,
Ort
::
Value
encoder_out_length
,
OfflineStream
**
ss
=
nullptr
,
int32_t
n
=
0
)
override
;
private
:
OfflineTransducerNeMoModel
*
model_
;
// Not owned
float
blank_penalty_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
...
...
sherpa-onnx/csrc/offline-transducer-nemo-model.cc
0 → 100644
查看文件 @
17cd3a5
// sherpa-onnx/csrc/offline-transducer-nemo-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h"
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace
sherpa_onnx
{
class
OfflineTransducerNeMoModel
::
Impl
{
public
:
explicit
Impl
(
const
OfflineModelConfig
&
config
)
:
config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_WARNING
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
{
auto
buf
=
ReadFile
(
config
.
transducer
.
encoder_filename
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
config
.
transducer
.
decoder_filename
);
InitDecoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
config
.
transducer
.
joiner_filename
);
InitJoiner
(
buf
.
data
(),
buf
.
size
());
}
}
#if __ANDROID_API__ >= 9
Impl
(
AAssetManager
*
mgr
,
const
OfflineModelConfig
&
config
)
:
config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_WARNING
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
{
auto
buf
=
ReadFile
(
mgr
,
config
.
transducer
.
encoder_filename
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
mgr
,
config
.
transducer
.
decoder_filename
);
InitDecoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
mgr
,
config
.
transducer
.
joiner_filename
);
InitJoiner
(
buf
.
data
(),
buf
.
size
());
}
}
#endif
std
::
vector
<
Ort
::
Value
>
RunEncoder
(
Ort
::
Value
features
,
Ort
::
Value
features_length
)
{
// (B, T, C) -> (B, C, T)
features
=
Transpose12
(
allocator_
,
&
features
);
std
::
array
<
Ort
::
Value
,
2
>
encoder_inputs
=
{
std
::
move
(
features
),
std
::
move
(
features_length
)};
auto
encoder_out
=
encoder_sess_
->
Run
(
{},
encoder_input_names_ptr_
.
data
(),
encoder_inputs
.
data
(),
encoder_inputs
.
size
(),
encoder_output_names_ptr_
.
data
(),
encoder_output_names_ptr_
.
size
());
return
encoder_out
;
}
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
RunDecoder
(
Ort
::
Value
targets
,
Ort
::
Value
targets_length
,
std
::
vector
<
Ort
::
Value
>
states
)
{
std
::
vector
<
Ort
::
Value
>
decoder_inputs
;
decoder_inputs
.
reserve
(
2
+
states
.
size
());
decoder_inputs
.
push_back
(
std
::
move
(
targets
));
decoder_inputs
.
push_back
(
std
::
move
(
targets_length
));
for
(
auto
&
s
:
states
)
{
decoder_inputs
.
push_back
(
std
::
move
(
s
));
}
auto
decoder_out
=
decoder_sess_
->
Run
(
{},
decoder_input_names_ptr_
.
data
(),
decoder_inputs
.
data
(),
decoder_inputs
.
size
(),
decoder_output_names_ptr_
.
data
(),
decoder_output_names_ptr_
.
size
());
std
::
vector
<
Ort
::
Value
>
states_next
;
states_next
.
reserve
(
states
.
size
());
// decoder_out[0]: decoder_output
// decoder_out[1]: decoder_output_length
// decoder_out[2:] states_next
for
(
int32_t
i
=
0
;
i
!=
states
.
size
();
++
i
)
{
states_next
.
push_back
(
std
::
move
(
decoder_out
[
i
+
2
]));
}
// we discard decoder_out[1]
return
{
std
::
move
(
decoder_out
[
0
]),
std
::
move
(
states_next
)};
}
Ort
::
Value
RunJoiner
(
Ort
::
Value
encoder_out
,
Ort
::
Value
decoder_out
)
{
std
::
array
<
Ort
::
Value
,
2
>
joiner_input
=
{
std
::
move
(
encoder_out
),
std
::
move
(
decoder_out
)};
auto
logit
=
joiner_sess_
->
Run
({},
joiner_input_names_ptr_
.
data
(),
joiner_input
.
data
(),
joiner_input
.
size
(),
joiner_output_names_ptr_
.
data
(),
joiner_output_names_ptr_
.
size
());
return
std
::
move
(
logit
[
0
]);
}
std
::
vector
<
Ort
::
Value
>
GetDecoderInitStates
(
int32_t
batch_size
)
const
{
std
::
array
<
int64_t
,
3
>
s0_shape
{
pred_rnn_layers_
,
batch_size
,
pred_hidden_
};
Ort
::
Value
s0
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator_
,
s0_shape
.
data
(),
s0_shape
.
size
());
Fill
<
float
>
(
&
s0
,
0
);
std
::
array
<
int64_t
,
3
>
s1_shape
{
pred_rnn_layers_
,
batch_size
,
pred_hidden_
};
Ort
::
Value
s1
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator_
,
s1_shape
.
data
(),
s1_shape
.
size
());
Fill
<
float
>
(
&
s1
,
0
);
std
::
vector
<
Ort
::
Value
>
states
;
states
.
reserve
(
2
);
states
.
push_back
(
std
::
move
(
s0
));
states
.
push_back
(
std
::
move
(
s1
));
return
states
;
}
int32_t
SubsamplingFactor
()
const
{
return
subsampling_factor_
;
}
int32_t
VocabSize
()
const
{
return
vocab_size_
;
}
OrtAllocator
*
Allocator
()
const
{
return
allocator_
;
}
std
::
string
FeatureNormalizationMethod
()
const
{
return
normalize_type_
;
}
private
:
void
InitEncoder
(
void
*
model_data
,
size_t
model_data_length
)
{
encoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
encoder_sess_
.
get
(),
&
encoder_input_names_
,
&
encoder_input_names_ptr_
);
GetOutputNames
(
encoder_sess_
.
get
(),
&
encoder_output_names_
,
&
encoder_output_names_ptr_
);
// get meta data
Ort
::
ModelMetadata
meta_data
=
encoder_sess_
->
GetModelMetadata
();
if
(
config_
.
debug
)
{
std
::
ostringstream
os
;
os
<<
"---encoder---
\n
"
;
PrintModelMetadata
(
os
,
meta_data
);
SHERPA_ONNX_LOGE
(
"%s
\n
"
,
os
.
str
().
c_str
());
}
Ort
::
AllocatorWithDefaultOptions
allocator
;
// used in the macro below
SHERPA_ONNX_READ_META_DATA
(
vocab_size_
,
"vocab_size"
);
// need to increase by 1 since the blank token is not included in computing
// vocab_size in NeMo.
vocab_size_
+=
1
;
SHERPA_ONNX_READ_META_DATA
(
subsampling_factor_
,
"subsampling_factor"
);
SHERPA_ONNX_READ_META_DATA_STR
(
normalize_type_
,
"normalize_type"
);
SHERPA_ONNX_READ_META_DATA
(
pred_rnn_layers_
,
"pred_rnn_layers"
);
SHERPA_ONNX_READ_META_DATA
(
pred_hidden_
,
"pred_hidden"
);
if
(
normalize_type_
==
"NA"
)
{
normalize_type_
=
""
;
}
}
void
InitDecoder
(
void
*
model_data
,
size_t
model_data_length
)
{
decoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
decoder_sess_
.
get
(),
&
decoder_input_names_
,
&
decoder_input_names_ptr_
);
GetOutputNames
(
decoder_sess_
.
get
(),
&
decoder_output_names_
,
&
decoder_output_names_ptr_
);
}
void
InitJoiner
(
void
*
model_data
,
size_t
model_data_length
)
{
joiner_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
joiner_sess_
.
get
(),
&
joiner_input_names_
,
&
joiner_input_names_ptr_
);
GetOutputNames
(
joiner_sess_
.
get
(),
&
joiner_output_names_
,
&
joiner_output_names_ptr_
);
}
private
:
OfflineModelConfig
config_
;
Ort
::
Env
env_
;
Ort
::
SessionOptions
sess_opts_
;
Ort
::
AllocatorWithDefaultOptions
allocator_
;
std
::
unique_ptr
<
Ort
::
Session
>
encoder_sess_
;
std
::
unique_ptr
<
Ort
::
Session
>
decoder_sess_
;
std
::
unique_ptr
<
Ort
::
Session
>
joiner_sess_
;
std
::
vector
<
std
::
string
>
encoder_input_names_
;
std
::
vector
<
const
char
*>
encoder_input_names_ptr_
;
std
::
vector
<
std
::
string
>
encoder_output_names_
;
std
::
vector
<
const
char
*>
encoder_output_names_ptr_
;
std
::
vector
<
std
::
string
>
decoder_input_names_
;
std
::
vector
<
const
char
*>
decoder_input_names_ptr_
;
std
::
vector
<
std
::
string
>
decoder_output_names_
;
std
::
vector
<
const
char
*>
decoder_output_names_ptr_
;
std
::
vector
<
std
::
string
>
joiner_input_names_
;
std
::
vector
<
const
char
*>
joiner_input_names_ptr_
;
std
::
vector
<
std
::
string
>
joiner_output_names_
;
std
::
vector
<
const
char
*>
joiner_output_names_ptr_
;
int32_t
vocab_size_
=
0
;
int32_t
subsampling_factor_
=
8
;
std
::
string
normalize_type_
;
int32_t
pred_rnn_layers_
=
-
1
;
int32_t
pred_hidden_
=
-
1
;
};
OfflineTransducerNeMoModel
::
OfflineTransducerNeMoModel
(
const
OfflineModelConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
#if __ANDROID_API__ >= 9
OfflineTransducerNeMoModel
::
OfflineTransducerNeMoModel
(
AAssetManager
*
mgr
,
const
OfflineModelConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
mgr
,
config
))
{}
#endif
OfflineTransducerNeMoModel
::~
OfflineTransducerNeMoModel
()
=
default
;
std
::
vector
<
Ort
::
Value
>
OfflineTransducerNeMoModel
::
RunEncoder
(
Ort
::
Value
features
,
Ort
::
Value
features_length
)
const
{
return
impl_
->
RunEncoder
(
std
::
move
(
features
),
std
::
move
(
features_length
));
}
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
OfflineTransducerNeMoModel
::
RunDecoder
(
Ort
::
Value
targets
,
Ort
::
Value
targets_length
,
std
::
vector
<
Ort
::
Value
>
states
)
const
{
return
impl_
->
RunDecoder
(
std
::
move
(
targets
),
std
::
move
(
targets_length
),
std
::
move
(
states
));
}
std
::
vector
<
Ort
::
Value
>
OfflineTransducerNeMoModel
::
GetDecoderInitStates
(
int32_t
batch_size
)
const
{
return
impl_
->
GetDecoderInitStates
(
batch_size
);
}
Ort
::
Value
OfflineTransducerNeMoModel
::
RunJoiner
(
Ort
::
Value
encoder_out
,
Ort
::
Value
decoder_out
)
const
{
return
impl_
->
RunJoiner
(
std
::
move
(
encoder_out
),
std
::
move
(
decoder_out
));
}
int32_t
OfflineTransducerNeMoModel
::
SubsamplingFactor
()
const
{
return
impl_
->
SubsamplingFactor
();
}
int32_t
OfflineTransducerNeMoModel
::
VocabSize
()
const
{
return
impl_
->
VocabSize
();
}
OrtAllocator
*
OfflineTransducerNeMoModel
::
Allocator
()
const
{
return
impl_
->
Allocator
();
}
std
::
string
OfflineTransducerNeMoModel
::
FeatureNormalizationMethod
()
const
{
return
impl_
->
FeatureNormalizationMethod
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-transducer-nemo-model.h
0 → 100644
查看文件 @
17cd3a5
// sherpa-onnx/csrc/offline-transducer-nemo-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace
sherpa_onnx
{
// see
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40
// Its decoder is stateful, not stateless.
class
OfflineTransducerNeMoModel
{
public
:
explicit
OfflineTransducerNeMoModel
(
const
OfflineModelConfig
&
config
);
#if __ANDROID_API__ >= 9
OfflineTransducerNeMoModel
(
AAssetManager
*
mgr
,
const
OfflineModelConfig
&
config
);
#endif
~
OfflineTransducerNeMoModel
();
/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a vector containing:
* - encoder_out: A 3-D tensor of shape (N, T', encoder_dim)
* - encoder_out_length: A 1-D tensor of shape (N,) containing number
* of frames in `encoder_out` before padding.
*/
std
::
vector
<
Ort
::
Value
>
RunEncoder
(
Ort
::
Value
features
,
Ort
::
Value
features_length
)
const
;
/** Run the decoder network.
*
* @param targets A int32 tensor of shape (batch_size, 1)
* @param targets_length A int32 tensor of shape (batch_size,)
* @param states The states for the decoder model.
* @return Return a vector:
* - ans[0] is the decoder_out (a float tensor)
* - ans[1] is the decoder_out_length (a int32 tensor)
* - ans[2:] is the states_next
*/
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
RunDecoder
(
Ort
::
Value
targets
,
Ort
::
Value
targets_length
,
std
::
vector
<
Ort
::
Value
>
states
)
const
;
std
::
vector
<
Ort
::
Value
>
GetDecoderInitStates
(
int32_t
batch_size
)
const
;
/** Run the joint network.
*
* @param encoder_out Output of the encoder network.
* @param decoder_out Output of the decoder network.
* @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits.
*/
Ort
::
Value
RunJoiner
(
Ort
::
Value
encoder_out
,
Ort
::
Value
decoder_out
)
const
;
/** Return the subsampling factor of the model.
*/
int32_t
SubsamplingFactor
()
const
;
int32_t
VocabSize
()
const
;
/** Return an allocator for allocating memory
*/
OrtAllocator
*
Allocator
()
const
;
// Possible values:
// - per_feature
// - all_features (not implemented yet)
// - fixed_mean (not implemented)
// - fixed_std (not implemented)
// - or just leave it to empty
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
// for details
std
::
string
FeatureNormalizationMethod
()
const
;
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_
...
...
sherpa-onnx/csrc/online-recognizer-ctc-impl.h
查看文件 @
17cd3a5
...
...
@@ -223,8 +223,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
private
:
void
InitDecoder
()
{
if
(
!
sym_
.
contains
(
"<blk>"
)
&&
!
sym_
.
contains
(
"<eps>"
)
&&
!
sym_
.
contains
(
"<blank>"
))
{
if
(
!
sym_
.
Contains
(
"<blk>"
)
&&
!
sym_
.
Contains
(
"<eps>"
)
&&
!
sym_
.
Contains
(
"<blank>"
))
{
SHERPA_ONNX_LOGE
(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> or <blank> and its ID."
);
...
...
@@ -232,12 +232,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
}
int32_t
blank_id
=
0
;
if
(
sym_
.
c
ontains
(
"<blk>"
))
{
if
(
sym_
.
C
ontains
(
"<blk>"
))
{
blank_id
=
sym_
[
"<blk>"
];
}
else
if
(
sym_
.
c
ontains
(
"<eps>"
))
{
}
else
if
(
sym_
.
C
ontains
(
"<eps>"
))
{
// for tdnn models of the yesno recipe from icefall
blank_id
=
sym_
[
"<eps>"
];
}
else
if
(
sym_
.
c
ontains
(
"<blank>"
))
{
}
else
if
(
sym_
.
C
ontains
(
"<blank>"
))
{
// for WeNet CTC models
blank_id
=
sym_
[
"<blank>"
];
}
...
...
sherpa-onnx/csrc/online-recognizer-transducer-impl.h
查看文件 @
17cd3a5
...
...
@@ -87,7 +87,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_
(
OnlineTransducerModel
::
Create
(
config
.
model_config
)),
sym_
(
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
if
(
sym_
.
c
ontains
(
"<unk>"
))
{
if
(
sym_
.
C
ontains
(
"<unk>"
))
{
unk_id_
=
sym_
[
"<unk>"
];
}
...
...
@@ -103,19 +103,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
decoder_
=
std
::
make_unique
<
OnlineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
,
unk_id_
,
config_
.
blank_penalty
,
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
,
unk_id_
,
config_
.
blank_penalty
,
config_
.
temperature_scale
);
}
else
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
(),
unk_id_
,
config_
.
blank_penalty
,
model_
.
get
(),
unk_id_
,
config_
.
blank_penalty
,
config_
.
temperature_scale
);
}
else
{
...
...
@@ -132,7 +126,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_
(
OnlineTransducerModel
::
Create
(
mgr
,
config
.
model_config
)),
sym_
(
mgr
,
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
if
(
sym_
.
c
ontains
(
"<unk>"
))
{
if
(
sym_
.
C
ontains
(
"<unk>"
))
{
unk_id_
=
sym_
[
"<unk>"
];
}
...
...
@@ -151,19 +145,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
decoder_
=
std
::
make_unique
<
OnlineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
,
unk_id_
,
config_
.
blank_penalty
,
model_
.
get
(),
lm_
.
get
(),
config_
.
max_active_paths
,
config_
.
lm_config
.
scale
,
unk_id_
,
config_
.
blank_penalty
,
config_
.
temperature_scale
);
}
else
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
(),
unk_id_
,
config_
.
blank_penalty
,
model_
.
get
(),
unk_id_
,
config_
.
blank_penalty
,
config_
.
temperature_scale
);
}
else
{
...
...
sherpa-onnx/csrc/slice.h
查看文件 @
17cd3a5
...
...
@@ -13,7 +13,7 @@ namespace sherpa_onnx {
* It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :]
*
* @param allocator
* @param v A
2
-D tensor. Its data type is T.
* @param v A
3
-D tensor. Its data type is T.
* @param dim0_start Start index of the first dimension..
* @param dim0_end End index of the first dimension..
* @param dim1_start Start index of the second dimension.
...
...
sherpa-onnx/csrc/symbol-table.cc
查看文件 @
17cd3a5
...
...
@@ -100,9 +100,9 @@ int32_t SymbolTable::operator[](const std::string &sym) const {
return
sym2id_
.
at
(
sym
);
}
bool
SymbolTable
::
c
ontains
(
int32_t
id
)
const
{
return
id2sym_
.
count
(
id
)
!=
0
;
}
bool
SymbolTable
::
C
ontains
(
int32_t
id
)
const
{
return
id2sym_
.
count
(
id
)
!=
0
;
}
bool
SymbolTable
::
c
ontains
(
const
std
::
string
&
sym
)
const
{
bool
SymbolTable
::
C
ontains
(
const
std
::
string
&
sym
)
const
{
return
sym2id_
.
count
(
sym
)
!=
0
;
}
...
...
sherpa-onnx/csrc/symbol-table.h
查看文件 @
17cd3a5
...
...
@@ -40,14 +40,16 @@ class SymbolTable {
int32_t
operator
[](
const
std
::
string
&
sym
)
const
;
/// Return true if there is a symbol with the given ID.
bool
c
ontains
(
int32_t
id
)
const
;
bool
C
ontains
(
int32_t
id
)
const
;
/// Return true if there is a given symbol in the symbol table.
bool
c
ontains
(
const
std
::
string
&
sym
)
const
;
bool
C
ontains
(
const
std
::
string
&
sym
)
const
;
// for tokens.txt from Whisper
void
ApplyBase64Decode
();
int32_t
NumSymbols
()
const
{
return
id2sym_
.
size
();
}
private
:
void
Init
(
std
::
istream
&
is
);
...
...
sherpa-onnx/csrc/utils.cc
查看文件 @
17cd3a5
...
...
@@ -49,7 +49,7 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
word
=
word
.
replace
(
0
,
3
,
" "
);
}
}
if
(
symbol_table
.
c
ontains
(
word
))
{
if
(
symbol_table
.
C
ontains
(
word
))
{
int32_t
id
=
symbol_table
[
word
];
tmp_ids
.
push_back
(
id
);
}
else
{
...
...
sherpa-onnx/python/csrc/offline-recognizer.cc
查看文件 @
17cd3a5
...
...
@@ -14,10 +14,10 @@ namespace sherpa_onnx {
static
void
PybindOfflineRecognizerConfig
(
py
::
module
*
m
)
{
using
PyClass
=
OfflineRecognizerConfig
;
py
::
class_
<
PyClass
>
(
*
m
,
"OfflineRecognizerConfig"
)
.
def
(
py
::
init
<
const
OfflineFeatureExtractorConfig
&
,
const
OfflineModelConfig
&
,
const
OfflineLMConfig
&
,
const
OfflineCtcFstDecoderConfig
&
,
const
std
::
string
&
,
int32_t
,
const
std
::
string
&
,
float
,
float
>
(),
.
def
(
py
::
init
<
const
FeatureExtractorConfig
&
,
const
OfflineModelConfig
&
,
const
OfflineLMConfig
&
,
const
OfflineCtcFstDecoderConfig
&
,
const
std
::
string
&
,
int32_t
,
const
std
::
string
&
,
float
,
float
>
(),
py
::
arg
(
"feat_config"
),
py
::
arg
(
"model_config"
),
py
::
arg
(
"lm_config"
)
=
OfflineLMConfig
(),
py
::
arg
(
"ctc_fst_decoder_config"
)
=
OfflineCtcFstDecoderConfig
(),
...
...
sherpa-onnx/python/csrc/offline-stream.cc
查看文件 @
17cd3a5
...
...
@@ -25,6 +25,7 @@ Args:
static
void
PybindOfflineRecognitionResult
(
py
::
module
*
m
)
{
// NOLINT
using
PyClass
=
OfflineRecognitionResult
;
py
::
class_
<
PyClass
>
(
*
m
,
"OfflineRecognitionResult"
)
.
def
(
"__str__"
,
&
PyClass
::
AsJsonString
)
.
def_property_readonly
(
"text"
,
[](
const
PyClass
&
self
)
->
py
::
str
{
...
...
@@ -37,18 +38,7 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
"timestamps"
,
[](
const
PyClass
&
self
)
{
return
self
.
timestamps
;
});
}
static
void
PybindOfflineFeatureExtractorConfig
(
py
::
module
*
m
)
{
using
PyClass
=
OfflineFeatureExtractorConfig
;
py
::
class_
<
PyClass
>
(
*
m
,
"OfflineFeatureExtractorConfig"
)
.
def
(
py
::
init
<
int32_t
,
int32_t
>
(),
py
::
arg
(
"sampling_rate"
)
=
16000
,
py
::
arg
(
"feature_dim"
)
=
80
)
.
def_readwrite
(
"sampling_rate"
,
&
PyClass
::
sampling_rate
)
.
def_readwrite
(
"feature_dim"
,
&
PyClass
::
feature_dim
)
.
def
(
"__str__"
,
&
PyClass
::
ToString
);
}
void
PybindOfflineStream
(
py
::
module
*
m
)
{
PybindOfflineFeatureExtractorConfig
(
m
);
PybindOfflineRecognitionResult
(
m
);
using
PyClass
=
OfflineStream
;
...
...
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
查看文件 @
17cd3a5
...
...
@@ -4,8 +4,8 @@ from pathlib import Path
from
typing
import
List
,
Optional
from
_sherpa_onnx
import
(
FeatureExtractorConfig
,
OfflineCtcFstDecoderConfig
,
OfflineFeatureExtractorConfig
,
OfflineModelConfig
,
OfflineNemoEncDecCtcModelConfig
,
OfflineParaformerModelConfig
,
...
...
@@ -51,6 +51,7 @@ class OfflineRecognizer(object):
blank_penalty
:
float
=
0.0
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
model_type
:
str
=
"transducer"
,
):
"""
Please refer to
...
...
@@ -106,10 +107,10 @@ class OfflineRecognizer(object):
num_threads
=
num_threads
,
debug
=
debug
,
provider
=
provider
,
model_type
=
"transducer"
,
model_type
=
model_type
,
)
feat_config
=
Offline
FeatureExtractorConfig
(
feat_config
=
FeatureExtractorConfig
(
sampling_rate
=
sample_rate
,
feature_dim
=
feature_dim
,
)
...
...
@@ -182,7 +183,7 @@ class OfflineRecognizer(object):
model_type
=
"paraformer"
,
)
feat_config
=
Offline
FeatureExtractorConfig
(
feat_config
=
FeatureExtractorConfig
(
sampling_rate
=
sample_rate
,
feature_dim
=
feature_dim
,
)
...
...
@@ -246,7 +247,7 @@ class OfflineRecognizer(object):
model_type
=
"nemo_ctc"
,
)
feat_config
=
Offline
FeatureExtractorConfig
(
feat_config
=
FeatureExtractorConfig
(
sampling_rate
=
sample_rate
,
feature_dim
=
feature_dim
,
)
...
...
@@ -326,7 +327,7 @@ class OfflineRecognizer(object):
model_type
=
"whisper"
,
)
feat_config
=
Offline
FeatureExtractorConfig
(
feat_config
=
FeatureExtractorConfig
(
sampling_rate
=
16000
,
feature_dim
=
80
,
)
...
...
@@ -389,7 +390,7 @@ class OfflineRecognizer(object):
model_type
=
"tdnn"
,
)
feat_config
=
Offline
FeatureExtractorConfig
(
feat_config
=
FeatureExtractorConfig
(
sampling_rate
=
sample_rate
,
feature_dim
=
feature_dim
,
)
...
...
@@ -453,7 +454,7 @@ class OfflineRecognizer(object):
model_type
=
"wenet_ctc"
,
)
feat_config
=
Offline
FeatureExtractorConfig
(
feat_config
=
FeatureExtractorConfig
(
sampling_rate
=
sample_rate
,
feature_dim
=
feature_dim
,
)
...
...
请
注册
或
登录
后发表评论