Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-02-26 14:54:01 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-02-26 14:54:01 +0800
Commit
5a8c3a6d10e74702a723f7cb99b2f4fb9e532766
5a8c3a6d
1 parent
475caf22
Use deep copy in Clone() (#66)
隐藏空白字符变更
内嵌
并排对比
正在显示
7 个修改的文件
包含
41 行增加
和
24 行删除
android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt
ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift
sherpa-onnx/csrc/cat.cc
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/onnx-utils.cc
sherpa-onnx/csrc/onnx-utils.h
sherpa-onnx/csrc/unbind.cc
android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt
查看文件 @
5a8c3a6
...
...
@@ -171,9 +171,14 @@ class MainActivity : AppCompatActivity() {
}
private fun initModel() {
// Please change getModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val type = 0
println("Select model type ${type}")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80),
modelConfig = getModelConfig(type =
1
)!!,
modelConfig = getModelConfig(type =
type
)!!,
endpointConfig = getEndpointConfig(),
enableEndpoint = true
)
...
...
ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift
查看文件 @
5a8c3a6
...
...
@@ -63,7 +63,7 @@ class ViewController: UIViewController {
super
.
viewDidLoad
()
// Do any additional setup after loading the view.
resultLabel
.
text
=
"ASR with Next-gen Kaldi
\n\n
Press the Start button to run!"
resultLabel
.
text
=
"ASR with Next-gen Kaldi
\n\n
See https://github.com/k2-fsa/sherpa-onnx
\n\n
Press the Start button to run!"
recordBtn
.
setTitle
(
"Start"
,
for
:
.
normal
)
initRecognizer
()
initRecorder
()
...
...
sherpa-onnx/csrc/cat.cc
查看文件 @
5a8c3a6
...
...
@@ -37,7 +37,7 @@ template <typename T /*=float*/>
Ort
::
Value
Cat
(
OrtAllocator
*
allocator
,
const
std
::
vector
<
const
Ort
::
Value
*>
&
values
,
int32_t
dim
)
{
if
(
values
.
size
()
==
1u
)
{
return
Clone
(
values
[
0
]);
return
Clone
(
allocator
,
values
[
0
]);
}
std
::
vector
<
int64_t
>
v0_shape
=
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
查看文件 @
5a8c3a6
...
...
@@ -100,8 +100,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
for
(
int32_t
t
=
0
;
t
!=
num_frames
;
++
t
)
{
Ort
::
Value
cur_encoder_out
=
GetFrame
(
&
encoder_out
,
t
);
cur_encoder_out
=
Repeat
(
model_
->
Allocator
(),
&
cur_encoder_out
,
batch_size
);
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
&
decoder_out
));
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
model_
->
Allocator
(),
&
decoder_out
));
const
float
*
p_logit
=
logit
.
GetTensorData
<
float
>
();
bool
emitted
=
false
;
...
...
sherpa-onnx/csrc/onnx-utils.cc
查看文件 @
5a8c3a6
...
...
@@ -53,7 +53,7 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
}
}
Ort
::
Value
Clone
(
const
Ort
::
Value
*
v
)
{
Ort
::
Value
Clone
(
OrtAllocator
*
allocator
,
const
Ort
::
Value
*
v
)
{
auto
type_and_shape
=
v
->
GetTensorTypeAndShapeInfo
();
std
::
vector
<
int64_t
>
shape
=
type_and_shape
.
GetShape
();
...
...
@@ -61,21 +61,33 @@ Ort::Value Clone(const Ort::Value *v) {
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
switch
(
type_and_shape
.
GetElementType
())
{
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
:
return
Ort
::
Value
::
CreateTensor
(
memory_info
,
const_cast
<
Ort
::
Value
*>
(
v
)
->
GetTensorMutableData
<
int32_t
>
(),
type_and_shape
.
GetElementCount
(),
shape
.
data
(),
shape
.
size
());
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
:
return
Ort
::
Value
::
CreateTensor
(
memory_info
,
const_cast
<
Ort
::
Value
*>
(
v
)
->
GetTensorMutableData
<
int64_t
>
(),
type_and_shape
.
GetElementCount
(),
shape
.
data
(),
shape
.
size
());
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
:
return
Ort
::
Value
::
CreateTensor
(
memory_info
,
const_cast
<
Ort
::
Value
*>
(
v
)
->
GetTensorMutableData
<
float
>
(),
type_and_shape
.
GetElementCount
(),
shape
.
data
(),
shape
.
size
());
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
:
{
Ort
::
Value
ans
=
Ort
::
Value
::
CreateTensor
<
int32_t
>
(
allocator
,
shape
.
data
(),
shape
.
size
());
const
int32_t
*
start
=
v
->
GetTensorData
<
int32_t
>
();
const
int32_t
*
end
=
start
+
type_and_shape
.
GetElementCount
();
int32_t
*
dst
=
ans
.
GetTensorMutableData
<
int32_t
>
();
std
::
copy
(
start
,
end
,
dst
);
return
ans
;
}
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
:
{
Ort
::
Value
ans
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
allocator
,
shape
.
data
(),
shape
.
size
());
const
int64_t
*
start
=
v
->
GetTensorData
<
int64_t
>
();
const
int64_t
*
end
=
start
+
type_and_shape
.
GetElementCount
();
int64_t
*
dst
=
ans
.
GetTensorMutableData
<
int64_t
>
();
std
::
copy
(
start
,
end
,
dst
);
return
ans
;
}
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
:
{
Ort
::
Value
ans
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
shape
.
data
(),
shape
.
size
());
const
float
*
start
=
v
->
GetTensorData
<
float
>
();
const
float
*
end
=
start
+
type_and_shape
.
GetElementCount
();
float
*
dst
=
ans
.
GetTensorMutableData
<
float
>
();
std
::
copy
(
start
,
end
,
dst
);
return
ans
;
}
default
:
fprintf
(
stderr
,
"Unsupported type: %d
\n
"
,
static_cast
<
int32_t
>
(
type_and_shape
.
GetElementType
()));
...
...
sherpa-onnx/csrc/onnx-utils.h
查看文件 @
5a8c3a6
...
...
@@ -60,8 +60,8 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
void
PrintModelMetadata
(
std
::
ostream
&
os
,
const
Ort
::
ModelMetadata
&
meta_data
);
// NOLINT
// Return a shallow copy of v
Ort
::
Value
Clone
(
const
Ort
::
Value
*
v
);
// Return a deep copy of v
Ort
::
Value
Clone
(
OrtAllocator
*
allocator
,
const
Ort
::
Value
*
v
);
// Print a 1-D tensor to stderr
void
Print1D
(
Ort
::
Value
*
v
);
...
...
sherpa-onnx/csrc/unbind.cc
查看文件 @
5a8c3a6
...
...
@@ -26,7 +26,7 @@ std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
int32_t
n
=
static_cast
<
int32_t
>
(
shape
[
dim
]);
if
(
n
==
1
)
{
std
::
vector
<
Ort
::
Value
>
ans
;
ans
.
push_back
(
Clone
(
value
));
ans
.
push_back
(
Clone
(
allocator
,
value
));
return
ans
;
}
...
...
请
注册
或
登录
后发表评论