Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-08-06 17:43:24 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-08-06 17:43:24 +0800
Commit
375c055ff8bf233f9793df538618b87ff04f6880
375c055f
1 parent
1414e4dc
Fix style issues for online punctuation source files (#1225)
显示空白字符变更
内嵌
并排对比
正在显示
10 个修改的文件
包含
70 行增加
和
56 行删除
cmake/cmake_extension.py
sherpa-onnx/csrc/online-cnn-bilstm-model.cc
sherpa-onnx/csrc/online-cnn-bilstm-model.h
sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h
sherpa-onnx/csrc/online-punctuation-impl.cc
sherpa-onnx/csrc/online-punctuation-model-config.cc
sherpa-onnx/csrc/online-punctuation-model-config.h
sherpa-onnx/csrc/online-punctuation.cc
sherpa-onnx/csrc/online-punctuation.h
sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc
cmake/cmake_extension.py
查看文件 @
375c055
...
...
@@ -58,6 +58,7 @@ def get_binaries():
"sherpa-onnx-offline-tts"
,
"sherpa-onnx-offline-tts-play"
,
"sherpa-onnx-offline-websocket-server"
,
"sherpa-onnx-online-punctuation"
,
"sherpa-onnx-online-websocket-client"
,
"sherpa-onnx-online-websocket-server"
,
"sherpa-onnx-vad-microphone"
,
...
...
sherpa-onnx/csrc/online-cnn-bilstm-model.cc
查看文件 @
375c055
...
...
@@ -35,8 +35,11 @@ class OnlineCNNBiLSTMModel::Impl {
}
#endif
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
Forward
(
Ort
::
Value
token_ids
,
Ort
::
Value
valid_ids
,
Ort
::
Value
label_lens
)
{
std
::
array
<
Ort
::
Value
,
3
>
inputs
=
{
std
::
move
(
token_ids
),
std
::
move
(
valid_ids
),
std
::
move
(
label_lens
)};
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
Forward
(
Ort
::
Value
token_ids
,
Ort
::
Value
valid_ids
,
Ort
::
Value
label_lens
)
{
std
::
array
<
Ort
::
Value
,
3
>
inputs
=
{
std
::
move
(
token_ids
),
std
::
move
(
valid_ids
),
std
::
move
(
label_lens
)};
auto
ans
=
sess_
->
Run
({},
input_names_ptr_
.
data
(),
inputs
.
data
(),
inputs
.
size
(),
...
...
@@ -117,18 +120,18 @@ OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
OnlineCNNBiLSTMModel
::~
OnlineCNNBiLSTMModel
()
=
default
;
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
OnlineCNNBiLSTMModel
::
Forward
(
Ort
::
Value
token_ids
,
Ort
::
Value
valid_ids
,
Ort
::
Value
label_lens
)
const
{
return
impl_
->
Forward
(
std
::
move
(
token_ids
),
std
::
move
(
valid_ids
),
std
::
move
(
label_lens
));
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
OnlineCNNBiLSTMModel
::
Forward
(
Ort
::
Value
token_ids
,
Ort
::
Value
valid_ids
,
Ort
::
Value
label_lens
)
const
{
return
impl_
->
Forward
(
std
::
move
(
token_ids
),
std
::
move
(
valid_ids
),
std
::
move
(
label_lens
));
}
OrtAllocator
*
OnlineCNNBiLSTMModel
::
Allocator
()
const
{
return
impl_
->
Allocator
();
}
const
OnlineCNNBiLSTMModelMetaData
&
OnlineCNNBiLSTMModel
::
GetModelMetadata
()
const
{
const
OnlineCNNBiLSTMModelMetaData
&
OnlineCNNBiLSTMModel
::
GetModelMetadata
()
const
{
return
impl_
->
GetModelMetadata
();
}
...
...
sherpa-onnx/csrc/online-cnn-bilstm-model.h
查看文件 @
375c055
...
...
@@ -23,8 +23,7 @@ namespace sherpa_onnx {
*/
class
OnlineCNNBiLSTMModel
{
public
:
explicit
OnlineCNNBiLSTMModel
(
const
OnlinePunctuationModelConfig
&
config
);
explicit
OnlineCNNBiLSTMModel
(
const
OnlinePunctuationModelConfig
&
config
);
#if __ANDROID_API__ >= 9
OnlineCNNBiLSTMModel
(
AAssetManager
*
mgr
,
...
...
@@ -43,7 +42,9 @@ class OnlineCNNBiLSTMModel {
* - case_logits: A 2-D tensor of shape (T', num_cases).
* - punct_logits: A 2-D tensor of shape (T', num_puncts).
*/
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
Forward
(
Ort
::
Value
token_ids
,
Ort
::
Value
valid_ids
,
Ort
::
Value
label_lens
)
const
;
std
::
pair
<
Ort
::
Value
,
Ort
::
Value
>
Forward
(
Ort
::
Value
token_ids
,
Ort
::
Value
valid_ids
,
Ort
::
Value
label_lens
)
const
;
/** Return an allocator for allocating memory
*/
...
...
sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h
查看文件 @
375c055
...
...
@@ -7,27 +7,28 @@
#include <math.h>
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <algorithm>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include <chrono> // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"
#include "sherpa-onnx/csrc/online-punctuation-impl.h"
#include "sherpa-onnx/csrc/online-punctuation.h"
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
#include <chrono> // NOLINT
namespace
sherpa_onnx
{
...
...
@@ -35,8 +36,7 @@ static const int32_t kMaxSeqLen = 200;
class
OnlinePunctuationCNNBiLSTMImpl
:
public
OnlinePunctuationImpl
{
public
:
explicit
OnlinePunctuationCNNBiLSTMImpl
(
const
OnlinePunctuationConfig
&
config
)
explicit
OnlinePunctuationCNNBiLSTMImpl
(
const
OnlinePunctuationConfig
&
config
)
:
config_
(
config
),
model_
(
config
.
model
)
{
if
(
!
config_
.
model
.
bpe_vocab
.
empty
())
{
bpe_encoder_
=
std
::
make_unique
<
ssentencepiece
::
Ssentencepiece
>
(
...
...
@@ -75,34 +75,43 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
int32_t
n
=
label_len_list
.
size
();
std
::
array
<
int64_t
,
2
>
token_ids_shape
=
{
n
,
kMaxSeqLen
};
Ort
::
Value
token_ids
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
tokens_list
.
data
(),
tokens_list
.
size
(),
Ort
::
Value
token_ids
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
tokens_list
.
data
(),
tokens_list
.
size
(),
token_ids_shape
.
data
(),
token_ids_shape
.
size
());
std
::
array
<
int64_t
,
2
>
valid_ids_shape
=
{
n
,
kMaxSeqLen
};
Ort
::
Value
valid_ids
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
valids_list
.
data
(),
valids_list
.
size
(),
Ort
::
Value
valid_ids
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
valids_list
.
data
(),
valids_list
.
size
(),
valid_ids_shape
.
data
(),
valid_ids_shape
.
size
());
std
::
array
<
int64_t
,
1
>
label_len_shape
=
{
n
};
Ort
::
Value
label_len
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
label_len_list
.
data
(),
label_len_list
.
size
(),
Ort
::
Value
label_len
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
label_len_list
.
data
(),
label_len_list
.
size
(),
label_len_shape
.
data
(),
label_len_shape
.
size
());
auto
pair
=
model_
.
Forward
(
std
::
move
(
token_ids
),
std
::
move
(
valid_ids
),
std
::
move
(
label_len
));
auto
pair
=
model_
.
Forward
(
std
::
move
(
token_ids
),
std
::
move
(
valid_ids
),
std
::
move
(
label_len
));
std
::
vector
<
int32_t
>
case_pred
;
std
::
vector
<
int32_t
>
punct_pred
;
const
float
*
active_case_logits
=
pair
.
first
.
GetTensorData
<
float
>
();
const
float
*
active_punct_logits
=
pair
.
second
.
GetTensorData
<
float
>
();
std
::
vector
<
int64_t
>
case_logits_shape
=
pair
.
first
.
GetTensorTypeAndShapeInfo
().
GetShape
();
const
float
*
active_case_logits
=
pair
.
first
.
GetTensorData
<
float
>
();
const
float
*
active_punct_logits
=
pair
.
second
.
GetTensorData
<
float
>
();
std
::
vector
<
int64_t
>
case_logits_shape
=
pair
.
first
.
GetTensorTypeAndShapeInfo
().
GetShape
();
for
(
int32_t
i
=
0
;
i
<
case_logits_shape
[
0
];
++
i
)
{
const
float
*
p_cur_case
=
active_case_logits
+
i
*
meta_data
.
num_cases
;
const
float
*
p_cur_case
=
active_case_logits
+
i
*
meta_data
.
num_cases
;
auto
index_case
=
static_cast
<
int32_t
>
(
std
::
distance
(
p_cur_case
,
std
::
max_element
(
p_cur_case
,
p_cur_case
+
meta_data
.
num_cases
)));
p_cur_case
,
std
::
max_element
(
p_cur_case
,
p_cur_case
+
meta_data
.
num_cases
)));
case_pred
.
push_back
(
index_case
);
const
float
*
p_cur_punct
=
active_punct_logits
+
i
*
meta_data
.
num_punctuations
;
const
float
*
p_cur_punct
=
active_punct_logits
+
i
*
meta_data
.
num_punctuations
;
auto
index_punct
=
static_cast
<
int32_t
>
(
std
::
distance
(
p_cur_punct
,
std
::
max_element
(
p_cur_punct
,
p_cur_punct
+
meta_data
.
num_punctuations
)));
p_cur_punct
,
std
::
max_element
(
p_cur_punct
,
p_cur_punct
+
meta_data
.
num_punctuations
)));
punct_pred
.
push_back
(
index_punct
);
}
...
...
@@ -112,10 +121,10 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
}
private
:
void
EncodeSentences
(
const
std
::
string
&
text
,
std
::
vector
<
int32_t
>&
tokens_list
,
std
::
vector
<
int32_t
>&
valids_list
,
std
::
vector
<
int32_t
>&
label_len_list
)
const
{
void
EncodeSentences
(
const
std
::
string
&
text
,
std
::
vector
<
int32_t
>
&
tokens_list
,
// NOLINT
std
::
vector
<
int32_t
>
&
valids_list
,
// NOLINT
std
::
vector
<
int32_t
>
&
label_len_list
)
const
{
// NOLINT
std
::
vector
<
int32_t
>
tokens
;
std
::
vector
<
int32_t
>
valids
;
int32_t
label_len
=
0
;
...
...
@@ -184,9 +193,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
}
}
std
::
string
DecodeSentences
(
const
std
::
string
&
raw_text
,
const
std
::
vector
<
int32_t
>&
case_pred
,
const
std
::
vector
<
int32_t
>&
punct_pred
)
const
{
std
::
string
DecodeSentences
(
const
std
::
string
&
raw_text
,
const
std
::
vector
<
int32_t
>
&
case_pred
,
const
std
::
vector
<
int32_t
>
&
punct_pred
)
const
{
std
::
string
result_text
;
std
::
istringstream
iss
(
raw_text
);
std
::
vector
<
std
::
string
>
words
;
...
...
@@ -205,7 +214,8 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
switch
(
case_pred
[
i
])
{
case
1
:
// upper
{
std
::
transform
(
words
[
i
].
begin
(),
words
[
i
].
end
(),
words
[
i
].
begin
(),
[](
auto
c
){
return
std
::
toupper
(
c
);
});
std
::
transform
(
words
[
i
].
begin
(),
words
[
i
].
end
(),
words
[
i
].
begin
(),
[](
auto
c
)
{
return
std
::
toupper
(
c
);
});
result_text
+=
words
[
i
];
break
;
}
...
...
@@ -217,14 +227,14 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
}
case
3
:
// mix case
{
// TODO:
// Need to add a map containing supported mix case words so that we can fetch the predicted word from the map
// e.g. mcdonald's -> McDonald's
// TODO(frankyoujian):
// Need to add a map containing supported mix case words so that we
// can fetch the predicted word from the map e.g. mcdonald's ->
// McDonald's
result_text
+=
words
[
i
];
break
;
}
default
:
{
default
:
{
result_text
+=
words
[
i
];
break
;
}
...
...
sherpa-onnx/csrc/online-punctuation-impl.cc
查看文件 @
375c055
...
...
@@ -20,7 +20,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
return
std
::
make_unique
<
OnlinePunctuationCNNBiLSTMImpl
>
(
config
);
}
SHERPA_ONNX_LOGE
(
"Please specify a punctuation model and bpe vocab! Return a null pointer"
);
SHERPA_ONNX_LOGE
(
"Please specify a punctuation model and bpe vocab! Return a null "
"pointer"
);
return
nullptr
;
}
...
...
@@ -31,7 +33,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
return
std
::
make_unique
<
OnlinePunctuationCNNBiLSTMImpl
>
(
mgr
,
config
);
}
SHERPA_ONNX_LOGE
(
"Please specify a punctuation model and bpe vocab! Return a null pointer"
);
SHERPA_ONNX_LOGE
(
"Please specify a punctuation model and bpe vocab! Return a null "
"pointer"
);
return
nullptr
;
}
#endif
...
...
sherpa-onnx/csrc/online-punctuation-model-config.cc
查看文件 @
375c055
...
...
@@ -13,8 +13,7 @@ void OnlinePunctuationModelConfig::Register(ParseOptions *po) {
po
->
Register
(
"cnn-bilstm"
,
&
cnn_bilstm
,
"Path to the light-weight CNN-BiLSTM model"
);
po
->
Register
(
"bpe-vocab"
,
&
bpe_vocab
,
"Path to the bpe vocab file"
);
po
->
Register
(
"bpe-vocab"
,
&
bpe_vocab
,
"Path to the bpe vocab file"
);
po
->
Register
(
"num-threads"
,
&
num_threads
,
"Number of threads to run the neural network"
);
...
...
@@ -33,8 +32,7 @@ bool OnlinePunctuationModelConfig::Validate() const {
}
if
(
!
FileExists
(
cnn_bilstm
))
{
SHERPA_ONNX_LOGE
(
"--cnn-bilstm '%s' does not exist"
,
cnn_bilstm
.
c_str
());
SHERPA_ONNX_LOGE
(
"--cnn-bilstm '%s' does not exist"
,
cnn_bilstm
.
c_str
());
return
false
;
}
...
...
@@ -44,8 +42,7 @@ bool OnlinePunctuationModelConfig::Validate() const {
}
if
(
!
FileExists
(
bpe_vocab
))
{
SHERPA_ONNX_LOGE
(
"--bpe-vocab '%s' does not exist"
,
bpe_vocab
.
c_str
());
SHERPA_ONNX_LOGE
(
"--bpe-vocab '%s' does not exist"
,
bpe_vocab
.
c_str
());
return
false
;
}
...
...
sherpa-onnx/csrc/online-punctuation-model-config.h
查看文件 @
375c055
sherpa-onnx/csrc/online-punctuation.cc
查看文件 @
375c055
...
...
@@ -14,9 +14,7 @@
namespace
sherpa_onnx
{
void
OnlinePunctuationConfig
::
Register
(
ParseOptions
*
po
)
{
model
.
Register
(
po
);
}
void
OnlinePunctuationConfig
::
Register
(
ParseOptions
*
po
)
{
model
.
Register
(
po
);
}
bool
OnlinePunctuationConfig
::
Validate
()
const
{
if
(
!
model
.
Validate
())
{
...
...
@@ -46,7 +44,8 @@ OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr,
OnlinePunctuation
::~
OnlinePunctuation
()
=
default
;
std
::
string
OnlinePunctuation
::
AddPunctuationWithCase
(
const
std
::
string
&
text
)
const
{
std
::
string
OnlinePunctuation
::
AddPunctuationWithCase
(
const
std
::
string
&
text
)
const
{
return
impl_
->
AddPunctuationWithCase
(
text
);
}
...
...
sherpa-onnx/csrc/online-punctuation.h
查看文件 @
375c055
...
...
@@ -40,8 +40,7 @@ class OnlinePunctuation {
explicit
OnlinePunctuation
(
const
OnlinePunctuationConfig
&
config
);
#if __ANDROID_API__ >= 9
OnlinePunctuation
(
AAssetManager
*
mgr
,
const
OnlinePunctuationConfig
&
config
);
OnlinePunctuation
(
AAssetManager
*
mgr
,
const
OnlinePunctuationConfig
&
config
);
#endif
~
OnlinePunctuation
();
...
...
sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc
查看文件 @
375c055
...
...
@@ -3,9 +3,9 @@
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
#include <stdio.h>
#include <iostream>
#include <chrono> // NOLINT
#include <iostream>
#include "sherpa-onnx/csrc/online-punctuation.h"
#include "sherpa-onnx/csrc/parse-options.h"
...
...
请
注册
或
登录
后发表评论