Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2025-02-16 22:45:24 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2025-02-16 22:45:24 +0800
Commit
316424b3825db2ac474042917a9c0d9bf4d8c84f
316424b3
1 parent
2337169e
Add C++ and Python API for FireRedASR AED models (#1867)
隐藏空白字符变更
内嵌
并排对比
正在显示
20 个修改的文件
包含
1019 行增加
和
26 行删除
.gitignore
python-api-examples/offline-fire-red-asr-decode-files.py
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
sherpa-onnx/csrc/offline-fire-red-asr-model.cc
sherpa-onnx/csrc/offline-fire-red-asr-model.h
sherpa-onnx/csrc/offline-model-config.cc
sherpa-onnx/csrc/offline-model-config.h
sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h
sherpa-onnx/csrc/offline-recognizer-impl.cc
sherpa-onnx/python/csrc/CMakeLists.txt
sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc
sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h
sherpa-onnx/python/csrc/offline-model-config.cc
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
.gitignore
查看文件 @
316424b
...
...
@@ -133,3 +133,4 @@ lexicon.txt
us_gold.json
us_silver.json
kokoro-multi-lang-v1_0
sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
...
...
python-api-examples/offline-fire-red-asr-decode-files.py
0 → 100644
查看文件 @
316424b
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming FireRedAsr AED model from
https://github.com/FireRedTeam/FireRedASR
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
For instance,
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
"""
from
pathlib
import
Path
import
sherpa_onnx
import
soundfile
as
sf
def
create_recognizer
():
encoder
=
"./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx"
decoder
=
"./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx"
tokens
=
"./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt"
test_wav
=
"./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/1.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/2.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/8k.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3-sichuan.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/4-tianjin.wav"
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/5-henan.wav"
if
(
not
Path
(
encoder
)
.
is_file
()
or
not
Path
(
decoder
)
.
is_file
()
or
not
Path
(
test_wav
)
.
is_file
()
):
raise
ValueError
(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
"""
)
return
(
sherpa_onnx
.
OfflineRecognizer
.
from_fire_red_asr
(
encoder
=
encoder
,
decoder
=
decoder
,
tokens
=
tokens
,
debug
=
True
,
),
test_wav
,
)
def
main
():
recognizer
,
wave_filename
=
create_recognizer
()
audio
,
sample_rate
=
sf
.
read
(
wave_filename
,
dtype
=
"float32"
,
always_2d
=
True
)
audio
=
audio
[:,
0
]
# only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
stream
=
recognizer
.
create_stream
()
stream
.
accept_waveform
(
sample_rate
,
audio
)
recognizer
.
decode_stream
(
stream
)
print
(
wave_filename
)
print
(
stream
.
result
)
if
__name__
==
"__main__"
:
main
()
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
316424b
...
...
@@ -27,6 +27,9 @@ set(sources
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
offline-ctc-model.cc
offline-fire-red-asr-greedy-search-decoder.cc
offline-fire-red-asr-model-config.cc
offline-fire-red-asr-model.cc
offline-lm-config.cc
offline-lm.cc
offline-model-config.cc
...
...
sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
0 → 100644
查看文件 @
316424b
// sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
#include <cstdint>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace
sherpa_onnx
{
struct
OfflineFireRedAsrDecoderResult
{
/// The decoded token IDs
std
::
vector
<
int32_t
>
tokens
;
};
class
OfflineFireRedAsrDecoder
{
public
:
virtual
~
OfflineFireRedAsrDecoder
()
=
default
;
/** Run beam search given the output from the FireRedAsr encoder model.
*
* @param n_layer_cross_k A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model).
* @param n_layer_cross_v A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model).
*
* @return Return a vector of size `N` containing the decoded results.
*/
virtual
std
::
vector
<
OfflineFireRedAsrDecoderResult
>
Decode
(
Ort
::
Value
n_layer_cross_k
,
Ort
::
Value
n_layer_cross_v
)
=
0
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
...
...
sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
0 → 100644
查看文件 @
316424b
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h"
#include <algorithm>
#include <tuple>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace
sherpa_onnx
{
// Note: this functions works only for batch size == 1 at present
std
::
vector
<
OfflineFireRedAsrDecoderResult
>
OfflineFireRedAsrGreedySearchDecoder
::
Decode
(
Ort
::
Value
cross_k
,
Ort
::
Value
cross_v
)
{
const
auto
&
meta_data
=
model_
->
GetModelMetadata
();
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
// For multilingual models, initial_tokens contains [sot, language, task]
// - language is English by default
// - task is transcribe by default
//
// For non-multilingual models, initial_tokens contains [sot]
std
::
array
<
int64_t
,
2
>
token_shape
=
{
1
,
1
};
int64_t
token
=
meta_data
.
sos_id
;
int32_t
batch_size
=
1
;
Ort
::
Value
tokens
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
&
token
,
1
,
token_shape
.
data
(),
token_shape
.
size
());
std
::
array
<
int64_t
,
1
>
offset_shape
{
1
};
Ort
::
Value
offset
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
model_
->
Allocator
(),
offset_shape
.
data
(),
offset_shape
.
size
());
*
(
offset
.
GetTensorMutableData
<
int64_t
>
())
=
0
;
std
::
vector
<
OfflineFireRedAsrDecoderResult
>
ans
(
1
);
auto
self_kv_cache
=
model_
->
GetInitialSelfKVCache
();
std
::
tuple
<
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
>
decoder_out
=
{
Ort
::
Value
{
nullptr
},
std
::
move
(
self_kv_cache
.
first
),
std
::
move
(
self_kv_cache
.
second
),
std
::
move
(
cross_k
),
std
::
move
(
cross_v
),
std
::
move
(
offset
)};
for
(
int32_t
i
=
0
;
i
<
meta_data
.
max_len
;
++
i
)
{
decoder_out
=
model_
->
ForwardDecoder
(
View
(
&
tokens
),
std
::
move
(
std
::
get
<
1
>
(
decoder_out
)),
std
::
move
(
std
::
get
<
2
>
(
decoder_out
)),
std
::
move
(
std
::
get
<
3
>
(
decoder_out
)),
std
::
move
(
std
::
get
<
4
>
(
decoder_out
)),
std
::
move
(
std
::
get
<
5
>
(
decoder_out
)));
const
auto
&
logits
=
std
::
get
<
0
>
(
decoder_out
);
const
float
*
p_logits
=
logits
.
GetTensorData
<
float
>
();
auto
logits_shape
=
logits
.
GetTensorTypeAndShapeInfo
().
GetShape
();
int32_t
vocab_size
=
logits_shape
[
2
];
int32_t
max_token_id
=
static_cast
<
int32_t
>
(
std
::
distance
(
p_logits
,
std
::
max_element
(
p_logits
,
p_logits
+
vocab_size
)));
if
(
max_token_id
==
meta_data
.
eos_id
)
{
break
;
}
ans
[
0
].
tokens
.
push_back
(
max_token_id
);
token
=
max_token_id
;
// increment offset
*
(
std
::
get
<
5
>
(
decoder_out
).
GetTensorMutableData
<
int64_t
>
())
+=
1
;
}
return
ans
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
0 → 100644
查看文件 @
316424b
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h"
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
namespace
sherpa_onnx
{
class
OfflineFireRedAsrGreedySearchDecoder
:
public
OfflineFireRedAsrDecoder
{
public
:
explicit
OfflineFireRedAsrGreedySearchDecoder
(
OfflineFireRedAsrModel
*
model
)
:
model_
(
model
)
{}
std
::
vector
<
OfflineFireRedAsrDecoderResult
>
Decode
(
Ort
::
Value
cross_k
,
Ort
::
Value
cross_v
)
override
;
private
:
OfflineFireRedAsrModel
*
model_
;
// not owned
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
...
...
sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
0 → 100644
查看文件 @
316424b
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace
sherpa_onnx
{
void
OfflineFireRedAsrModelConfig
::
Register
(
ParseOptions
*
po
)
{
po
->
Register
(
"fire-red-asr-encoder"
,
&
encoder
,
"Path to onnx encoder of FireRedAsr"
);
po
->
Register
(
"fire-red-asr-decoder"
,
&
decoder
,
"Path to onnx decoder of FireRedAsr"
);
}
bool
OfflineFireRedAsrModelConfig
::
Validate
()
const
{
if
(
encoder
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Please provide --fire-red-asr-encoder"
);
return
false
;
}
if
(
!
FileExists
(
encoder
))
{
SHERPA_ONNX_LOGE
(
"FireRedAsr encoder file '%s' does not exist"
,
encoder
.
c_str
());
return
false
;
}
if
(
decoder
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Please provide --fire-red-asr-decoder"
);
return
false
;
}
if
(
!
FileExists
(
decoder
))
{
SHERPA_ONNX_LOGE
(
"FireRedAsr decoder file '%s' does not exist"
,
decoder
.
c_str
());
return
false
;
}
return
true
;
}
std
::
string
OfflineFireRedAsrModelConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"OfflineFireRedAsrModelConfig("
;
os
<<
"encoder=
\"
"
<<
encoder
<<
"
\"
, "
;
os
<<
"decoder=
\"
"
<<
decoder
<<
"
\"
)"
;
return
os
.
str
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
0 → 100644
查看文件 @
316424b
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace
sherpa_onnx
{
// see https://github.com/FireRedTeam/FireRedASR
struct
OfflineFireRedAsrModelConfig
{
std
::
string
encoder
;
std
::
string
decoder
;
OfflineFireRedAsrModelConfig
()
=
default
;
OfflineFireRedAsrModelConfig
(
const
std
::
string
&
encoder
,
const
std
::
string
&
decoder
)
:
encoder
(
encoder
),
decoder
(
decoder
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
std
::
string
ToString
()
const
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
...
...
sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
0 → 100644
查看文件 @
316424b
// sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
#include <string>
#include <unordered_map>
#include <vector>
namespace
sherpa_onnx
{
struct
OfflineFireRedAsrModelMetaData
{
int32_t
sos_id
;
int32_t
eos_id
;
int32_t
max_len
;
int32_t
num_decoder_layers
;
int32_t
num_head
;
int32_t
head_dim
;
std
::
vector
<
float
>
mean
;
std
::
vector
<
float
>
inv_stddev
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
...
...
sherpa-onnx/csrc/offline-fire-red-asr-model.cc
0 → 100644
查看文件 @
316424b
// sherpa-onnx/csrc/offline-fire-red-asr-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
#include <algorithm>
#include <cmath>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace
sherpa_onnx
{
class
OfflineFireRedAsrModel
::
Impl
{
public
:
explicit
Impl
(
const
OfflineModelConfig
&
config
)
:
config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_ERROR
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
{
auto
buf
=
ReadFile
(
config
.
fire_red_asr
.
encoder
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
config
.
fire_red_asr
.
decoder
);
InitDecoder
(
buf
.
data
(),
buf
.
size
());
}
}
template
<
typename
Manager
>
Impl
(
Manager
*
mgr
,
const
OfflineModelConfig
&
config
)
:
config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_ERROR
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
{
auto
buf
=
ReadFile
(
mgr
,
config
.
fire_red_asr
.
encoder
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
mgr
,
config
.
fire_red_asr
.
decoder
);
InitDecoder
(
buf
.
data
(),
buf
.
size
());
}
}
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
ForwardEncoder
(
Ort
::
Value
features
,
Ort
::
Value
features_length
)
{
std
::
array
<
Ort
::
Value
,
2
>
inputs
{
std
::
move
(
features
),
std
::
move
(
features_length
)};
auto
encoder_out
=
encoder_sess_
->
Run
(
{},
encoder_input_names_ptr_
.
data
(),
inputs
.
data
(),
inputs
.
size
(),
encoder_output_names_ptr_
.
data
(),
encoder_output_names_ptr_
.
size
());
return
{
std
::
move
(
encoder_out
[
0
]),
std
::
move
(
encoder_out
[
1
])};
}
std
::
tuple
<
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
>
ForwardDecoder
(
Ort
::
Value
tokens
,
Ort
::
Value
n_layer_self_k_cache
,
Ort
::
Value
n_layer_self_v_cache
,
Ort
::
Value
n_layer_cross_k
,
Ort
::
Value
n_layer_cross_v
,
Ort
::
Value
offset
)
{
std
::
array
<
Ort
::
Value
,
6
>
decoder_input
=
{
std
::
move
(
tokens
),
std
::
move
(
n_layer_self_k_cache
),
std
::
move
(
n_layer_self_v_cache
),
std
::
move
(
n_layer_cross_k
),
std
::
move
(
n_layer_cross_v
),
std
::
move
(
offset
)};
auto
decoder_out
=
decoder_sess_
->
Run
(
{},
decoder_input_names_ptr_
.
data
(),
decoder_input
.
data
(),
decoder_input
.
size
(),
decoder_output_names_ptr_
.
data
(),
decoder_output_names_ptr_
.
size
());
return
std
::
tuple
<
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
>
{
std
::
move
(
decoder_out
[
0
]),
std
::
move
(
decoder_out
[
1
]),
std
::
move
(
decoder_out
[
2
]),
std
::
move
(
decoder_input
[
3
]),
std
::
move
(
decoder_input
[
4
]),
std
::
move
(
decoder_input
[
5
])};
}
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
GetInitialSelfKVCache
()
{
int32_t
batch_size
=
1
;
std
::
array
<
int64_t
,
5
>
shape
{
meta_data_
.
num_decoder_layers
,
batch_size
,
meta_data_
.
max_len
,
meta_data_
.
num_head
,
meta_data_
.
head_dim
};
Ort
::
Value
n_layer_self_k_cache
=
Ort
::
Value
::
CreateTensor
<
float
>
(
Allocator
(),
shape
.
data
(),
shape
.
size
());
Ort
::
Value
n_layer_self_v_cache
=
Ort
::
Value
::
CreateTensor
<
float
>
(
Allocator
(),
shape
.
data
(),
shape
.
size
());
auto
n
=
shape
[
0
]
*
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
]
*
shape
[
4
];
float
*
p_k
=
n_layer_self_k_cache
.
GetTensorMutableData
<
float
>
();
float
*
p_v
=
n_layer_self_v_cache
.
GetTensorMutableData
<
float
>
();
memset
(
p_k
,
0
,
sizeof
(
float
)
*
n
);
memset
(
p_v
,
0
,
sizeof
(
float
)
*
n
);
return
{
std
::
move
(
n_layer_self_k_cache
),
std
::
move
(
n_layer_self_v_cache
)};
}
OrtAllocator
*
Allocator
()
{
return
allocator_
;
}
const
OfflineFireRedAsrModelMetaData
&
GetModelMetadata
()
const
{
return
meta_data_
;
}
private
:
void
InitEncoder
(
void
*
model_data
,
size_t
model_data_length
)
{
encoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
encoder_sess_
.
get
(),
&
encoder_input_names_
,
&
encoder_input_names_ptr_
);
GetOutputNames
(
encoder_sess_
.
get
(),
&
encoder_output_names_
,
&
encoder_output_names_ptr_
);
// get meta data
Ort
::
ModelMetadata
meta_data
=
encoder_sess_
->
GetModelMetadata
();
if
(
config_
.
debug
)
{
std
::
ostringstream
os
;
os
<<
"---encoder---
\n
"
;
PrintModelMetadata
(
os
,
meta_data
);
#if __OHOS__
SHERPA_ONNX_LOGE
(
"%{public}s
\n
"
,
os
.
str
().
c_str
());
#else
SHERPA_ONNX_LOGE
(
"%s
\n
"
,
os
.
str
().
c_str
());
#endif
}
Ort
::
AllocatorWithDefaultOptions
allocator
;
// used in the macro below
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
num_decoder_layers
,
"num_decoder_layers"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
num_head
,
"num_head"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
head_dim
,
"head_dim"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
sos_id
,
"sos"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
eos_id
,
"eos"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
max_len
,
"max_len"
);
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT
(
meta_data_
.
mean
,
"cmvn_mean"
);
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT
(
meta_data_
.
inv_stddev
,
"cmvn_inv_stddev"
);
}
void
InitDecoder
(
void
*
model_data
,
size_t
model_data_length
)
{
decoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
decoder_sess_
.
get
(),
&
decoder_input_names_
,
&
decoder_input_names_ptr_
);
GetOutputNames
(
decoder_sess_
.
get
(),
&
decoder_output_names_
,
&
decoder_output_names_ptr_
);
}
private
:
OfflineModelConfig
config_
;
Ort
::
Env
env_
;
Ort
::
SessionOptions
sess_opts_
;
Ort
::
AllocatorWithDefaultOptions
allocator_
;
std
::
unique_ptr
<
Ort
::
Session
>
encoder_sess_
;
std
::
unique_ptr
<
Ort
::
Session
>
decoder_sess_
;
std
::
vector
<
std
::
string
>
encoder_input_names_
;
std
::
vector
<
const
char
*>
encoder_input_names_ptr_
;
std
::
vector
<
std
::
string
>
encoder_output_names_
;
std
::
vector
<
const
char
*>
encoder_output_names_ptr_
;
std
::
vector
<
std
::
string
>
decoder_input_names_
;
std
::
vector
<
const
char
*>
decoder_input_names_ptr_
;
std
::
vector
<
std
::
string
>
decoder_output_names_
;
std
::
vector
<
const
char
*>
decoder_output_names_ptr_
;
OfflineFireRedAsrModelMetaData
meta_data_
;
};
OfflineFireRedAsrModel
::
OfflineFireRedAsrModel
(
const
OfflineModelConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
template
<
typename
Manager
>
OfflineFireRedAsrModel
::
OfflineFireRedAsrModel
(
Manager
*
mgr
,
const
OfflineModelConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
mgr
,
config
))
{}
OfflineFireRedAsrModel
::~
OfflineFireRedAsrModel
()
=
default
;
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
OfflineFireRedAsrModel
::
ForwardEncoder
(
Ort
::
Value
features
,
Ort
::
Value
features_length
)
const
{
return
impl_
->
ForwardEncoder
(
std
::
move
(
features
),
std
::
move
(
features_length
));
}
std
::
tuple
<
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
>
OfflineFireRedAsrModel
::
ForwardDecoder
(
Ort
::
Value
tokens
,
Ort
::
Value
n_layer_self_k_cache
,
Ort
::
Value
n_layer_self_v_cache
,
Ort
::
Value
n_layer_cross_k
,
Ort
::
Value
n_layer_cross_v
,
Ort
::
Value
offset
)
const
{
return
impl_
->
ForwardDecoder
(
std
::
move
(
tokens
),
std
::
move
(
n_layer_self_k_cache
),
std
::
move
(
n_layer_self_v_cache
),
std
::
move
(
n_layer_cross_k
),
std
::
move
(
n_layer_cross_v
),
std
::
move
(
offset
));
}
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
OfflineFireRedAsrModel
::
GetInitialSelfKVCache
()
const
{
return
impl_
->
GetInitialSelfKVCache
();
}
OrtAllocator
*
OfflineFireRedAsrModel
::
Allocator
()
const
{
return
impl_
->
Allocator
();
}
const
OfflineFireRedAsrModelMetaData
&
OfflineFireRedAsrModel
::
GetModelMetadata
()
const
{
return
impl_
->
GetModelMetadata
();
}
#if __ANDROID_API__ >= 9
template
OfflineFireRedAsrModel
::
OfflineFireRedAsrModel
(
AAssetManager
*
mgr
,
const
OfflineModelConfig
&
config
);
#endif
#if __OHOS__
template
OfflineFireRedAsrModel
::
OfflineFireRedAsrModel
(
NativeResourceManager
*
mgr
,
const
OfflineModelConfig
&
config
);
#endif
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-fire-red-asr-model.h
0 → 100644
查看文件 @
316424b
// sherpa-onnx/csrc/offline-fire-red-asr-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace
sherpa_onnx
{
class
OfflineFireRedAsrModel
{
public
:
explicit
OfflineFireRedAsrModel
(
const
OfflineModelConfig
&
config
);
template
<
typename
Manager
>
OfflineFireRedAsrModel
(
Manager
*
mgr
,
const
OfflineModelConfig
&
config
);
~
OfflineFireRedAsrModel
();
/** Run the encoder model.
*
* @param features A tensor of shape (N, T, C).
* @param features_len A tensor of shape (N,) with dtype int64.
*
* @return Return a pair containing:
* - n_layer_cross_k: A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model)
* - n_layer_cross_v: A 4-D tensor of shape
* (num_decoder_layers, N, T, d_model)
*/
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
ForwardEncoder
(
Ort
::
Value
features
,
Ort
::
Value
features_length
)
const
;
/** Run the decoder model.
*
* @param tokens A int64 tensor of shape (N, num_words)
* @param n_layer_self_k_cache A 5-D tensor of shape
* (num_decoder_layers, N, max_len, num_head, head_dim).
* @param n_layer_self_v_cache A 5-D tensor of shape
* (num_decoder_layers, N, max_len, num_head, head_dim).
* @param n_layer_cross_k A 5-D tensor of shape
* (num_decoder_layers, N, T, d_model).
* @param n_layer_cross_v A 5-D tensor of shape
* (num_decoder_layers, N, T, d_model).
* @param offset A int64 tensor of shape (N,)
*
* @return Return a tuple containing 6 tensors:
*
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
* - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache
* - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache
* - out_n_layer_cross_k Same as n_layer_cross_k
* - out_n_layer_cross_v Same as n_layer_cross_v
* - out_offset Same as offset
*/
std
::
tuple
<
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
,
Ort
::
Value
>
ForwardDecoder
(
Ort
::
Value
tokens
,
Ort
::
Value
n_layer_self_k_cache
,
Ort
::
Value
n_layer_self_v_cache
,
Ort
::
Value
n_layer_cross_k
,
Ort
::
Value
n_layer_cross_v
,
Ort
::
Value
offset
)
const
;
/** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 5-D tensor of shape
* (num_decoder_layers, N, max_len, num_head, head_dim).
* - n_layer_self_v_cache A 5-D tensor of shape
* (num_decoder_layers, N, max_len, num_head, head_dim).
*/
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
GetInitialSelfKVCache
()
const
;
const
OfflineFireRedAsrModelMetaData
&
GetModelMetadata
()
const
;
/** Return an allocator for allocating memory
*/
OrtAllocator
*
Allocator
()
const
;
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
...
...
sherpa-onnx/csrc/offline-model-config.cc
查看文件 @
316424b
...
...
@@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
paraformer
.
Register
(
po
);
nemo_ctc
.
Register
(
po
);
whisper
.
Register
(
po
);
fire_red_asr
.
Register
(
po
);
tdnn
.
Register
(
po
);
zipformer_ctc
.
Register
(
po
);
wenet_ctc
.
Register
(
po
);
...
...
@@ -38,7 +39,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
po
->
Register
(
"model-type"
,
&
model_type
,
"Specify it to reduce model initialization time. "
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
"tdnn, zipformer2_ctc, telespeech_ctc."
"tdnn, zipformer2_ctc, telespeech_ctc
, fire_red_asr
."
"All other values lead to loading the model twice."
);
po
->
Register
(
"modeling-unit"
,
&
modeling_unit
,
"The modeling unit of the model, commonly used units are bpe, "
...
...
@@ -84,6 +85,10 @@ bool OfflineModelConfig::Validate() const {
return
whisper
.
Validate
();
}
if
(
!
fire_red_asr
.
encoder
.
empty
())
{
return
fire_red_asr
.
Validate
();
}
if
(
!
tdnn
.
model
.
empty
())
{
return
tdnn
.
Validate
();
}
...
...
@@ -125,6 +130,7 @@ std::string OfflineModelConfig::ToString() const {
os
<<
"paraformer="
<<
paraformer
.
ToString
()
<<
", "
;
os
<<
"nemo_ctc="
<<
nemo_ctc
.
ToString
()
<<
", "
;
os
<<
"whisper="
<<
whisper
.
ToString
()
<<
", "
;
os
<<
"fire_red_asr="
<<
fire_red_asr
.
ToString
()
<<
", "
;
os
<<
"tdnn="
<<
tdnn
.
ToString
()
<<
", "
;
os
<<
"zipformer_ctc="
<<
zipformer_ctc
.
ToString
()
<<
", "
;
os
<<
"wenet_ctc="
<<
wenet_ctc
.
ToString
()
<<
", "
;
...
...
sherpa-onnx/csrc/offline-model-config.h
查看文件 @
316424b
...
...
@@ -6,6 +6,7 @@
#include <string>
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
...
...
@@ -23,6 +24,7 @@ struct OfflineModelConfig {
OfflineParaformerModelConfig
paraformer
;
OfflineNemoEncDecCtcModelConfig
nemo_ctc
;
OfflineWhisperModelConfig
whisper
;
OfflineFireRedAsrModelConfig
fire_red_asr
;
OfflineTdnnModelConfig
tdnn
;
OfflineZipformerCtcModelConfig
zipformer_ctc
;
OfflineWenetCtcModelConfig
wenet_ctc
;
...
...
@@ -54,6 +56,7 @@ struct OfflineModelConfig {
const
OfflineParaformerModelConfig
&
paraformer
,
const
OfflineNemoEncDecCtcModelConfig
&
nemo_ctc
,
const
OfflineWhisperModelConfig
&
whisper
,
const
OfflineFireRedAsrModelConfig
&
fire_red_asr
,
const
OfflineTdnnModelConfig
&
tdnn
,
const
OfflineZipformerCtcModelConfig
&
zipformer_ctc
,
const
OfflineWenetCtcModelConfig
&
wenet_ctc
,
...
...
@@ -68,6 +71,7 @@ struct OfflineModelConfig {
paraformer
(
paraformer
),
nemo_ctc
(
nemo_ctc
),
whisper
(
whisper
),
fire_red_asr
(
fire_red_asr
),
tdnn
(
tdnn
),
zipformer_ctc
(
zipformer_ctc
),
wenet_ctc
(
wenet_ctc
),
...
...
sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h
0 → 100644
查看文件 @
316424b
// sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h"
#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace
sherpa_onnx
{
static
OfflineRecognitionResult
Convert
(
const
OfflineFireRedAsrDecoderResult
&
src
,
const
SymbolTable
&
sym_table
)
{
OfflineRecognitionResult
r
;
r
.
tokens
.
reserve
(
src
.
tokens
.
size
());
std
::
string
text
;
for
(
auto
i
:
src
.
tokens
)
{
if
(
!
sym_table
.
Contains
(
i
))
{
continue
;
}
const
auto
&
s
=
sym_table
[
i
];
text
+=
s
;
r
.
tokens
.
push_back
(
s
);
}
r
.
text
=
text
;
return
r
;
}
class
OfflineRecognizerFireRedAsrImpl
:
public
OfflineRecognizerImpl
{
public
:
explicit
OfflineRecognizerFireRedAsrImpl
(
const
OfflineRecognizerConfig
&
config
)
:
OfflineRecognizerImpl
(
config
),
config_
(
config
),
symbol_table_
(
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineFireRedAsrModel
>
(
config
.
model_config
))
{
Init
();
}
template
<
typename
Manager
>
OfflineRecognizerFireRedAsrImpl
(
Manager
*
mgr
,
const
OfflineRecognizerConfig
&
config
)
:
OfflineRecognizerImpl
(
mgr
,
config
),
config_
(
config
),
symbol_table_
(
mgr
,
config_
.
model_config
.
tokens
),
model_
(
std
::
make_unique
<
OfflineFireRedAsrModel
>
(
mgr
,
config
.
model_config
))
{
Init
();
}
void
Init
()
{
if
(
config_
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OfflineFireRedAsrGreedySearchDecoder
>
(
model_
.
get
());
}
else
{
SHERPA_ONNX_LOGE
(
"Only greedy_search is supported at present for FireRedAsr. Given %s"
,
config_
.
decoding_method
.
c_str
());
SHERPA_ONNX_EXIT
(
-
1
);
}
const
auto
&
meta_data
=
model_
->
GetModelMetadata
();
config_
.
feat_config
.
normalize_samples
=
false
;
config_
.
feat_config
.
high_freq
=
0
;
config_
.
feat_config
.
snip_edges
=
true
;
}
std
::
unique_ptr
<
OfflineStream
>
CreateStream
()
const
override
{
return
std
::
make_unique
<
OfflineStream
>
(
config_
.
feat_config
);
}
void
DecodeStreams
(
OfflineStream
**
ss
,
int32_t
n
)
const
override
{
// batch decoding is not implemented yet
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
DecodeStream
(
ss
[
i
]);
}
}
OfflineRecognizerConfig
GetConfig
()
const
override
{
return
config_
;
}
private
:
void
DecodeStream
(
OfflineStream
*
s
)
const
{
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
int32_t
feat_dim
=
s
->
FeatureDim
();
std
::
vector
<
float
>
f
=
s
->
GetFrames
();
ApplyCMVN
(
&
f
);
int64_t
num_frames
=
f
.
size
()
/
feat_dim
;
std
::
array
<
int64_t
,
3
>
shape
{
1
,
num_frames
,
feat_dim
};
Ort
::
Value
x
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
f
.
data
(),
f
.
size
(),
shape
.
data
(),
shape
.
size
());
int64_t
len_shape
=
1
;
Ort
::
Value
x_len
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
&
num_frames
,
1
,
&
len_shape
,
1
);
auto
cross_kv
=
model_
->
ForwardEncoder
(
std
::
move
(
x
),
std
::
move
(
x_len
));
auto
results
=
decoder_
->
Decode
(
std
::
move
(
cross_kv
.
first
),
std
::
move
(
cross_kv
.
second
));
auto
r
=
Convert
(
results
[
0
],
symbol_table_
);
r
.
text
=
ApplyInverseTextNormalization
(
std
::
move
(
r
.
text
));
s
->
SetResult
(
r
);
}
void
ApplyCMVN
(
std
::
vector
<
float
>
*
v
)
const
{
const
auto
&
meta_data
=
model_
->
GetModelMetadata
();
const
auto
&
mean
=
meta_data
.
mean
;
const
auto
&
inv_stddev
=
meta_data
.
inv_stddev
;
int32_t
feat_dim
=
static_cast
<
int32_t
>
(
mean
.
size
());
int32_t
num_frames
=
static_cast
<
int32_t
>
(
v
->
size
())
/
feat_dim
;
float
*
p
=
v
->
data
();
for
(
int32_t
i
=
0
;
i
!=
num_frames
;
++
i
)
{
for
(
int32_t
k
=
0
;
k
!=
feat_dim
;
++
k
)
{
p
[
k
]
=
(
p
[
k
]
-
mean
[
k
])
*
inv_stddev
[
k
];
}
p
+=
feat_dim
;
}
}
private
:
OfflineRecognizerConfig
config_
;
SymbolTable
symbol_table_
;
std
::
unique_ptr
<
OfflineFireRedAsrModel
>
model_
;
std
::
unique_ptr
<
OfflineFireRedAsrDecoder
>
decoder_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
...
...
sherpa-onnx/csrc/offline-recognizer-impl.cc
查看文件 @
316424b
...
...
@@ -24,6 +24,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
...
...
@@ -56,6 +57,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return
std
::
make_unique
<
OfflineRecognizerWhisperImpl
>
(
config
);
}
if
(
!
config
.
model_config
.
fire_red_asr
.
encoder
.
empty
())
{
return
std
::
make_unique
<
OfflineRecognizerFireRedAsrImpl
>
(
config
);
}
if
(
!
config
.
model_config
.
moonshine
.
preprocessor
.
empty
())
{
return
std
::
make_unique
<
OfflineRecognizerMoonshineImpl
>
(
config
);
}
...
...
@@ -237,6 +242,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return
std
::
make_unique
<
OfflineRecognizerWhisperImpl
>
(
mgr
,
config
);
}
if
(
!
config
.
model_config
.
fire_red_asr
.
encoder
.
empty
())
{
return
std
::
make_unique
<
OfflineRecognizerFireRedAsrImpl
>
(
mgr
,
config
);
}
if
(
!
config
.
model_config
.
moonshine
.
preprocessor
.
empty
())
{
return
std
::
make_unique
<
OfflineRecognizerMoonshineImpl
>
(
mgr
,
config
);
}
...
...
sherpa-onnx/python/csrc/CMakeLists.txt
查看文件 @
316424b
...
...
@@ -9,6 +9,7 @@ set(srcs
features.cc
keyword-spotter.cc
offline-ctc-fst-decoder-config.cc
offline-fire-red-asr-model-config.cc
offline-lm-config.cc
offline-model-config.cc
offline-moonshine-model-config.cc
...
...
sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc
0 → 100644
查看文件 @
316424b
// sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
namespace
sherpa_onnx
{
void
PybindOfflineFireRedAsrModelConfig
(
py
::
module
*
m
)
{
using
PyClass
=
OfflineFireRedAsrModelConfig
;
py
::
class_
<
PyClass
>
(
*
m
,
"OfflineFireRedAsrModelConfig"
)
.
def
(
py
::
init
<
const
std
::
string
&
,
const
std
::
string
&>
(),
py
::
arg
(
"encoder"
),
py
::
arg
(
"decoder"
))
.
def_readwrite
(
"encoder"
,
&
PyClass
::
encoder
)
.
def_readwrite
(
"decoder"
,
&
PyClass
::
decoder
)
.
def
(
"__str__"
,
&
PyClass
::
ToString
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h
0 → 100644
查看文件 @
316424b
// sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace
sherpa_onnx
{
void
PybindOfflineFireRedAsrModelConfig
(
py
::
module
*
m
);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
...
...
sherpa-onnx/python/csrc/offline-model-config.cc
查看文件 @
316424b
...
...
@@ -8,6 +8,7 @@
#include <vector>
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
...
...
@@ -25,6 +26,7 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineParaformerModelConfig
(
m
);
PybindOfflineNemoEncDecCtcModelConfig
(
m
);
PybindOfflineWhisperModelConfig
(
m
);
PybindOfflineFireRedAsrModelConfig
(
m
);
PybindOfflineTdnnModelConfig
(
m
);
PybindOfflineZipformerCtcModelConfig
(
m
);
PybindOfflineWenetCtcModelConfig
(
m
);
...
...
@@ -33,35 +35,38 @@ void PybindOfflineModelConfig(py::module *m) {
using
PyClass
=
OfflineModelConfig
;
py
::
class_
<
PyClass
>
(
*
m
,
"OfflineModelConfig"
)
.
def
(
py
::
init
<
const
OfflineTransducerModelConfig
&
,
const
OfflineParaformerModelConfig
&
,
const
OfflineNemoEncDecCtcModelConfig
&
,
const
OfflineWhisperModelConfig
&
,
const
OfflineTdnnModelConfig
&
,
const
OfflineZipformerCtcModelConfig
&
,
const
OfflineWenetCtcModelConfig
&
,
const
OfflineSenseVoiceModelConfig
&
,
const
OfflineMoonshineModelConfig
&
,
const
std
::
string
&
,
const
std
::
string
&
,
int32_t
,
bool
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(),
py
::
arg
(
"transducer"
)
=
OfflineTransducerModelConfig
(),
py
::
arg
(
"paraformer"
)
=
OfflineParaformerModelConfig
(),
py
::
arg
(
"nemo_ctc"
)
=
OfflineNemoEncDecCtcModelConfig
(),
py
::
arg
(
"whisper"
)
=
OfflineWhisperModelConfig
(),
py
::
arg
(
"tdnn"
)
=
OfflineTdnnModelConfig
(),
py
::
arg
(
"zipformer_ctc"
)
=
OfflineZipformerCtcModelConfig
(),
py
::
arg
(
"wenet_ctc"
)
=
OfflineWenetCtcModelConfig
(),
py
::
arg
(
"sense_voice"
)
=
OfflineSenseVoiceModelConfig
(),
py
::
arg
(
"moonshine"
)
=
OfflineMoonshineModelConfig
(),
py
::
arg
(
"telespeech_ctc"
)
=
""
,
py
::
arg
(
"tokens"
),
py
::
arg
(
"num_threads"
),
py
::
arg
(
"debug"
)
=
false
,
py
::
arg
(
"provider"
)
=
"cpu"
,
py
::
arg
(
"model_type"
)
=
""
,
py
::
arg
(
"modeling_unit"
)
=
"cjkchar"
,
py
::
arg
(
"bpe_vocab"
)
=
""
)
.
def
(
py
::
init
<
const
OfflineTransducerModelConfig
&
,
const
OfflineParaformerModelConfig
&
,
const
OfflineNemoEncDecCtcModelConfig
&
,
const
OfflineWhisperModelConfig
&
,
const
OfflineFireRedAsrModelConfig
&
,
const
OfflineTdnnModelConfig
&
,
const
OfflineZipformerCtcModelConfig
&
,
const
OfflineWenetCtcModelConfig
&
,
const
OfflineSenseVoiceModelConfig
&
,
const
OfflineMoonshineModelConfig
&
,
const
std
::
string
&
,
const
std
::
string
&
,
int32_t
,
bool
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(),
py
::
arg
(
"transducer"
)
=
OfflineTransducerModelConfig
(),
py
::
arg
(
"paraformer"
)
=
OfflineParaformerModelConfig
(),
py
::
arg
(
"nemo_ctc"
)
=
OfflineNemoEncDecCtcModelConfig
(),
py
::
arg
(
"whisper"
)
=
OfflineWhisperModelConfig
(),
py
::
arg
(
"fire_red_asr"
)
=
OfflineFireRedAsrModelConfig
(),
py
::
arg
(
"tdnn"
)
=
OfflineTdnnModelConfig
(),
py
::
arg
(
"zipformer_ctc"
)
=
OfflineZipformerCtcModelConfig
(),
py
::
arg
(
"wenet_ctc"
)
=
OfflineWenetCtcModelConfig
(),
py
::
arg
(
"sense_voice"
)
=
OfflineSenseVoiceModelConfig
(),
py
::
arg
(
"moonshine"
)
=
OfflineMoonshineModelConfig
(),
py
::
arg
(
"telespeech_ctc"
)
=
""
,
py
::
arg
(
"tokens"
),
py
::
arg
(
"num_threads"
),
py
::
arg
(
"debug"
)
=
false
,
py
::
arg
(
"provider"
)
=
"cpu"
,
py
::
arg
(
"model_type"
)
=
""
,
py
::
arg
(
"modeling_unit"
)
=
"cjkchar"
,
py
::
arg
(
"bpe_vocab"
)
=
""
)
.
def_readwrite
(
"transducer"
,
&
PyClass
::
transducer
)
.
def_readwrite
(
"paraformer"
,
&
PyClass
::
paraformer
)
.
def_readwrite
(
"nemo_ctc"
,
&
PyClass
::
nemo_ctc
)
.
def_readwrite
(
"whisper"
,
&
PyClass
::
whisper
)
.
def_readwrite
(
"fire_red_asr"
,
&
PyClass
::
fire_red_asr
)
.
def_readwrite
(
"tdnn"
,
&
PyClass
::
tdnn
)
.
def_readwrite
(
"zipformer_ctc"
,
&
PyClass
::
zipformer_ctc
)
.
def_readwrite
(
"wenet_ctc"
,
&
PyClass
::
wenet_ctc
)
...
...
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
查看文件 @
316424b
...
...
@@ -6,6 +6,7 @@ from typing import List, Optional
from
_sherpa_onnx
import
(
FeatureExtractorConfig
,
OfflineCtcFstDecoderConfig
,
OfflineFireRedAsrModelConfig
,
OfflineLMConfig
,
OfflineModelConfig
,
OfflineMoonshineModelConfig
,
...
...
@@ -572,6 +573,78 @@ class OfflineRecognizer(object):
return
self
@classmethod
def
from_fire_red_asr
(
cls
,
encoder
:
str
,
decoder
:
str
,
tokens
:
str
,
num_threads
:
int
=
1
,
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
provider
:
str
=
"cpu"
,
rule_fsts
:
str
=
""
,
rule_fars
:
str
=
""
,
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/fire_red_asr/index.html>`_
to download pre-trained models for different kinds of FireRedAsr models,
e.g., xs, large, etc.
Args:
encoder:
Path to the encoder model.
decoder:
Path to the decoder model.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
decoding_method:
Valid values: greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
rule_fsts:
If not empty, it specifies fsts for inverse text normalization.
If there are multiple fsts, they are separated by a comma.
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
"""
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
fire_red_asr
=
OfflineFireRedAsrModelConfig
(
encoder
=
encoder
,
decoder
=
decoder
,
),
tokens
=
tokens
,
num_threads
=
num_threads
,
debug
=
debug
,
provider
=
provider
,
)
feat_config
=
FeatureExtractorConfig
(
sampling_rate
=
16000
,
feature_dim
=
80
,
)
recognizer_config
=
OfflineRecognizerConfig
(
feat_config
=
feat_config
,
model_config
=
model_config
,
decoding_method
=
decoding_method
,
rule_fsts
=
rule_fsts
,
rule_fars
=
rule_fars
,
)
self
.
recognizer
=
_Recognizer
(
recognizer_config
)
self
.
config
=
recognizer_config
return
self
@classmethod
def
from_moonshine
(
cls
,
preprocessor
:
str
,
...
...
请
注册
或
登录
后发表评论