Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
PF Luo
2023-03-01 15:32:54 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-03-01 15:32:54 +0800
Commit
5326d0f81f82ba7785f75bc3866c26b97611dba3
5326d0f8
1 parent
e0b76655
add modified beam search (#69)
隐藏空白字符变更
内嵌
并排对比
正在显示
19 个修改的文件
包含
614 行增加
和
87 行删除
.gitignore
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/hypothesis.cc
sherpa-onnx/csrc/hypothesis.h
sherpa-onnx/csrc/math.h
sherpa-onnx/csrc/online-lstm-transducer-model.cc
sherpa-onnx/csrc/online-lstm-transducer-model.h
sherpa-onnx/csrc/online-recognizer.cc
sherpa-onnx/csrc/online-recognizer.h
sherpa-onnx/csrc/online-transducer-decoder.h
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
sherpa-onnx/csrc/online-transducer-model.cc
sherpa-onnx/csrc/online-transducer-model.h
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h
sherpa-onnx/csrc/online-zipformer-transducer-model.cc
sherpa-onnx/csrc/online-zipformer-transducer-model.h
sherpa-onnx/csrc/onnx-utils.cc
sherpa-onnx/csrc/onnx-utils.h
.gitignore
查看文件 @
5326d0f
...
...
@@ -34,3 +34,4 @@ decode-file
tokens.txt
*.onnx
log.txt
tags
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
5326d0f
...
...
@@ -5,11 +5,13 @@ set(sources
endpoint.cc
features.cc
file-utils.cc
hypothesis.cc
online-lstm-transducer-model.cc
online-recognizer.cc
online-stream.cc
online-transducer-greedy-search-decoder.cc
online-transducer-model-config.cc
online-transducer-modified-beam-search-decoder.cc
online-transducer-model.cc
online-zipformer-transducer-model.cc
onnx-utils.cc
...
...
sherpa-onnx/csrc/hypothesis.cc
0 → 100644
查看文件 @
5326d0f
/**
* Copyright (c) 2023 Xiaomi Corporation
*
*/
#include "sherpa-onnx/csrc/hypothesis.h"
#include <algorithm>
#include <utility>
namespace
sherpa_onnx
{
void
Hypotheses
::
Add
(
Hypothesis
hyp
)
{
auto
key
=
hyp
.
Key
();
auto
it
=
hyps_dict_
.
find
(
key
);
if
(
it
==
hyps_dict_
.
end
())
{
hyps_dict_
[
key
]
=
std
::
move
(
hyp
);
}
else
{
it
->
second
.
log_prob
=
LogAdd
<
double
>
()(
it
->
second
.
log_prob
,
hyp
.
log_prob
);
}
}
Hypothesis
Hypotheses
::
GetMostProbable
(
bool
length_norm
)
const
{
if
(
length_norm
==
false
)
{
return
std
::
max_element
(
hyps_dict_
.
begin
(),
hyps_dict_
.
end
(),
[](
const
auto
&
left
,
auto
&
right
)
->
bool
{
return
left
.
second
.
log_prob
<
right
.
second
.
log_prob
;
})
->
second
;
}
else
{
// for length_norm is true
return
std
::
max_element
(
hyps_dict_
.
begin
(),
hyps_dict_
.
end
(),
[](
const
auto
&
left
,
const
auto
&
right
)
->
bool
{
return
left
.
second
.
log_prob
/
left
.
second
.
ys
.
size
()
<
right
.
second
.
log_prob
/
right
.
second
.
ys
.
size
();
})
->
second
;
}
}
std
::
vector
<
Hypothesis
>
Hypotheses
::
GetTopK
(
int32_t
k
,
bool
length_norm
)
const
{
k
=
std
::
max
(
k
,
1
);
k
=
std
::
min
(
k
,
Size
());
std
::
vector
<
Hypothesis
>
all_hyps
=
Vec
();
if
(
length_norm
==
false
)
{
std
::
partial_sort
(
all_hyps
.
begin
(),
all_hyps
.
begin
()
+
k
,
all_hyps
.
end
(),
[](
const
auto
&
a
,
const
auto
&
b
)
{
return
a
.
log_prob
>
b
.
log_prob
;
});
}
else
{
// for length_norm is true
std
::
partial_sort
(
all_hyps
.
begin
(),
all_hyps
.
begin
()
+
k
,
all_hyps
.
end
(),
[](
const
auto
&
a
,
const
auto
&
b
)
{
return
a
.
log_prob
/
a
.
ys
.
size
()
>
b
.
log_prob
/
b
.
ys
.
size
();
});
}
return
{
all_hyps
.
begin
(),
all_hyps
.
begin
()
+
k
};
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/hypothesis.h
0 → 100644
查看文件 @
5326d0f
/**
* Copyright (c) 2023 Xiaomi Corporation
*
*/
#ifndef SHERPA_ONNX_CSRC_HYPOTHESIS_H_
#define SHERPA_ONNX_CSRC_HYPOTHESIS_H_
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/math.h"
namespace
sherpa_onnx
{
struct
Hypothesis
{
// The predicted tokens so far. Newly predicated tokens are appended.
std
::
vector
<
int32_t
>
ys
;
// timestamps[i] contains the frame number after subsampling
// on which ys[i] is decoded.
std
::
vector
<
int32_t
>
timestamps
;
// The total score of ys in log space.
double
log_prob
=
0
;
int32_t
num_trailing_blanks
=
0
;
Hypothesis
()
=
default
;
Hypothesis
(
const
std
::
vector
<
int32_t
>
&
ys
,
double
log_prob
)
:
ys
(
ys
),
log_prob
(
log_prob
)
{}
// If two Hypotheses have the same `Key`, then they contain
// the same token sequence.
std
::
string
Key
()
const
{
// TODO(fangjun): Use a hash function?
std
::
ostringstream
os
;
std
::
string
sep
=
"-"
;
for
(
auto
i
:
ys
)
{
os
<<
i
<<
sep
;
sep
=
"-"
;
}
return
os
.
str
();
}
// For debugging
std
::
string
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"("
<<
Key
()
<<
", "
<<
log_prob
<<
")"
;
return
os
.
str
();
}
};
class
Hypotheses
{
public
:
Hypotheses
()
=
default
;
explicit
Hypotheses
(
std
::
vector
<
Hypothesis
>
hyps
)
{
for
(
auto
&
h
:
hyps
)
{
hyps_dict_
[
h
.
Key
()]
=
std
::
move
(
h
);
}
}
explicit
Hypotheses
(
std
::
unordered_map
<
std
::
string
,
Hypothesis
>
hyps_dict
)
:
hyps_dict_
(
std
::
move
(
hyps_dict
))
{}
// Add hyp to this object. If it already exists, its log_prob
// is updated with the given hyp using log-sum-exp.
void
Add
(
Hypothesis
hyp
);
// Get the hyp that has the largest log_prob.
// If length_norm is true, hyp's log_prob is divided by
// len(hyp.ys) before comparison.
Hypothesis
GetMostProbable
(
bool
length_norm
)
const
;
// Get the k hyps that have the largest log_prob.
// If length_norm is true, hyp's log_prob is divided by
// len(hyp.ys) before comparison.
std
::
vector
<
Hypothesis
>
GetTopK
(
int32_t
k
,
bool
length_norm
)
const
;
int32_t
Size
()
const
{
return
hyps_dict_
.
size
();
}
std
::
string
ToString
()
const
{
std
::
ostringstream
os
;
for
(
const
auto
&
p
:
hyps_dict_
)
{
os
<<
p
.
second
.
ToString
()
<<
"
\n
"
;
}
return
os
.
str
();
}
const
auto
begin
()
const
{
return
hyps_dict_
.
begin
();
}
const
auto
end
()
const
{
return
hyps_dict_
.
end
();
}
void
Clear
()
{
hyps_dict_
.
clear
();
}
private
:
// Return a list of hyps contained in this object.
std
::
vector
<
Hypothesis
>
Vec
()
const
{
std
::
vector
<
Hypothesis
>
ans
;
ans
.
reserve
(
hyps_dict_
.
size
());
for
(
const
auto
&
p
:
hyps_dict_
)
{
ans
.
push_back
(
p
.
second
);
}
return
ans
;
}
private
:
using
Map
=
std
::
unordered_map
<
std
::
string
,
Hypothesis
>
;
Map
hyps_dict_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_HYPOTHESIS_H_
...
...
sherpa-onnx/csrc/math.h
0 → 100644
查看文件 @
5326d0f
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Daniel Povey)
* Copyright (c) 2023 (Pingfeng Luo)
*
*/
// This file is copied from k2/csrc/utils.h
#ifndef SHERPA_ONNX_CSRC_MATH_H_
#define SHERPA_ONNX_CSRC_MATH_H_
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <vector>
namespace
sherpa_onnx
{
// logf(FLT_EPSILON)
#define SHERPA_ONNX_MIN_LOG_DIFF_FLOAT -15.9423847198486328125f
// log(DBL_EPSILON)
#define SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE \
-36.0436533891171535515240975655615329742431640625
template
<
typename
T
>
struct
LogAdd
;
template
<>
struct
LogAdd
<
double
>
{
double
operator
()(
double
x
,
double
y
)
const
{
double
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
}
else
{
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE
)
{
double
res
;
res
=
x
+
log1p
(
exp
(
diff
));
return
res
;
}
return
x
;
// return the larger one.
}
};
template
<>
struct
LogAdd
<
float
>
{
float
operator
()(
float
x
,
float
y
)
const
{
float
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
}
else
{
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE
)
{
float
res
;
res
=
x
+
log1pf
(
expf
(
diff
));
return
res
;
}
return
x
;
// return the larger one.
}
};
template
<
class
T
>
void
LogSoftmax
(
T
*
input
,
int32_t
input_len
)
{
assert
(
input
);
T
m
=
*
std
::
max_element
(
input
,
input
+
input_len
);
T
sum
=
0
.
0
;
for
(
int32_t
i
=
0
;
i
<
input_len
;
i
++
)
{
sum
+=
exp
(
input
[
i
]
-
m
);
}
T
offset
=
m
+
log
(
sum
);
for
(
int32_t
i
=
0
;
i
<
input_len
;
i
++
)
{
input
[
i
]
-=
offset
;
}
}
template
<
class
T
>
std
::
vector
<
int32_t
>
TopkIndex
(
const
T
*
vec
,
int32_t
size
,
int32_t
topk
)
{
std
::
vector
<
int32_t
>
vec_index
(
size
);
std
::
iota
(
vec_index
.
begin
(),
vec_index
.
end
(),
0
);
std
::
sort
(
vec_index
.
begin
(),
vec_index
.
end
(),
[
vec
](
int32_t
index_1
,
int32_t
index_2
)
{
return
vec
[
index_1
]
>
vec
[
index_2
];
});
int32_t
k_num
=
std
::
min
<
int32_t
>
(
size
,
topk
);
std
::
vector
<
int32_t
>
index
(
vec_index
.
begin
(),
vec_index
.
begin
()
+
k_num
);
return
index
;
}
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_MATH_H_
...
...
sherpa-onnx/csrc/online-lstm-transducer-model.cc
查看文件 @
5326d0f
...
...
@@ -247,24 +247,6 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
return
{
std
::
move
(
encoder_out
[
0
]),
std
::
move
(
next_states
)};
}
Ort
::
Value
OnlineLstmTransducerModel
::
BuildDecoderInput
(
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
)
{
int32_t
batch_size
=
static_cast
<
int32_t
>
(
results
.
size
());
std
::
array
<
int64_t
,
2
>
shape
{
batch_size
,
context_size_
};
Ort
::
Value
decoder_input
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
allocator_
,
shape
.
data
(),
shape
.
size
());
int64_t
*
p
=
decoder_input
.
GetTensorMutableData
<
int64_t
>
();
for
(
const
auto
&
r
:
results
)
{
const
int64_t
*
begin
=
r
.
tokens
.
data
()
+
r
.
tokens
.
size
()
-
context_size_
;
const
int64_t
*
end
=
r
.
tokens
.
data
()
+
r
.
tokens
.
size
();
std
::
copy
(
begin
,
end
,
p
);
p
+=
context_size_
;
}
return
decoder_input
;
}
Ort
::
Value
OnlineLstmTransducerModel
::
RunDecoder
(
Ort
::
Value
decoder_input
)
{
auto
decoder_out
=
decoder_sess_
->
Run
(
{},
decoder_input_names_ptr_
.
data
(),
&
decoder_input
,
1
,
...
...
sherpa-onnx/csrc/online-lstm-transducer-model.h
查看文件 @
5326d0f
...
...
@@ -40,9 +40,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
RunEncoder
(
Ort
::
Value
features
,
std
::
vector
<
Ort
::
Value
>
states
)
override
;
Ort
::
Value
BuildDecoderInput
(
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
)
override
;
Ort
::
Value
RunDecoder
(
Ort
::
Value
decoder_input
)
override
;
Ort
::
Value
RunJoiner
(
Ort
::
Value
encoder_out
,
Ort
::
Value
decoder_out
)
override
;
...
...
sherpa-onnx/csrc/online-recognizer.cc
查看文件 @
5326d0f
// sherpa-onnx/csrc/online-recognizer.cc
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2023 Pingfeng Luo
#include "sherpa-onnx/csrc/online-recognizer.h"
...
...
@@ -16,6 +17,7 @@
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace
sherpa_onnx
{
...
...
@@ -39,6 +41,11 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
po
->
Register
(
"enable-endpoint"
,
&
enable_endpoint
,
"True to enable endpoint detection. False to disable it."
);
po
->
Register
(
"max-active-paths"
,
&
max_active_paths
,
"beam size used in modified beam search."
);
po
->
Register
(
"decoding-mothod"
,
&
decoding_method
,
"decoding method,"
"now support greedy_search and modified_beam_search."
);
}
bool
OnlineRecognizerConfig
::
Validate
()
const
{
...
...
@@ -52,7 +59,9 @@ std::string OnlineRecognizerConfig::ToString() const {
os
<<
"feat_config="
<<
feat_config
.
ToString
()
<<
", "
;
os
<<
"model_config="
<<
model_config
.
ToString
()
<<
", "
;
os
<<
"endpoint_config="
<<
endpoint_config
.
ToString
()
<<
", "
;
os
<<
"enable_endpoint="
<<
(
enable_endpoint
?
"True"
:
"False"
)
<<
")"
;
os
<<
"enable_endpoint="
<<
(
enable_endpoint
?
"True"
:
"False"
)
<<
","
;
os
<<
"max_active_paths="
<<
max_active_paths
<<
","
;
os
<<
"decoding_method=
\"
"
<<
decoding_method
<<
"
\"
)"
;
return
os
.
str
();
}
...
...
@@ -64,8 +73,17 @@ class OnlineRecognizer::Impl {
model_
(
OnlineTransducerModel
::
Create
(
config
.
model_config
)),
sym_
(
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
());
if
(
config
.
decoding_method
==
"modified_beam_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
config_
.
max_active_paths
);
}
else
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
());
}
else
{
fprintf
(
stderr
,
"Unsupported decoding method: %s
\n
"
,
config
.
decoding_method
.
c_str
());
exit
(
-
1
);
}
}
#if __ANDROID_API__ >= 9
...
...
@@ -74,8 +92,17 @@ class OnlineRecognizer::Impl {
model_
(
OnlineTransducerModel
::
Create
(
mgr
,
config
.
model_config
)),
sym_
(
mgr
,
config
.
model_config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
());
if
(
config
.
decoding_method
==
"modified_beam_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerModifiedBeamSearchDecoder
>
(
model_
.
get
(),
config_
.
max_active_paths
);
}
else
if
(
config
.
decoding_method
==
"greedy_search"
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
());
}
else
{
fprintf
(
stderr
,
"Unsupported decoding method: %s
\n
"
,
config
.
decoding_method
.
c_str
());
exit
(
-
1
);
}
}
#endif
...
...
sherpa-onnx/csrc/online-recognizer.h
查看文件 @
5326d0f
...
...
@@ -32,7 +32,11 @@ struct OnlineRecognizerConfig {
FeatureExtractorConfig
feat_config
;
OnlineTransducerModelConfig
model_config
;
EndpointConfig
endpoint_config
;
bool
enable_endpoint
;
bool
enable_endpoint
=
true
;
int32_t
max_active_paths
=
4
;
std
::
string
decoding_method
=
"modified_beam_search"
;
// now support modified_beam_search and greedy_search
OnlineRecognizerConfig
()
=
default
;
...
...
sherpa-onnx/csrc/online-transducer-decoder.h
查看文件 @
5326d0f
...
...
@@ -8,6 +8,7 @@
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
namespace
sherpa_onnx
{
...
...
@@ -17,6 +18,9 @@ struct OnlineTransducerDecoderResult {
/// number of trailing blank frames decoded so far
int32_t
num_trailing_blanks
=
0
;
// used only in modified beam_search
Hypotheses
hyps
;
};
class
OnlineTransducerDecoder
{
...
...
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
查看文件 @
5326d0f
...
...
@@ -4,8 +4,6 @@
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include <assert.h>
#include <algorithm>
#include <utility>
#include <vector>
...
...
@@ -15,39 +13,6 @@
namespace
sherpa_onnx
{
static
Ort
::
Value
GetFrame
(
OrtAllocator
*
allocator
,
Ort
::
Value
*
encoder_out
,
int32_t
t
)
{
std
::
vector
<
int64_t
>
encoder_out_shape
=
encoder_out
->
GetTensorTypeAndShapeInfo
().
GetShape
();
auto
batch_size
=
encoder_out_shape
[
0
];
auto
num_frames
=
encoder_out_shape
[
1
];
assert
(
t
<
num_frames
);
auto
encoder_out_dim
=
encoder_out_shape
[
2
];
auto
offset
=
num_frames
*
encoder_out_dim
;
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
array
<
int64_t
,
2
>
shape
{
batch_size
,
encoder_out_dim
};
Ort
::
Value
ans
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
shape
.
data
(),
shape
.
size
());
float
*
dst
=
ans
.
GetTensorMutableData
<
float
>
();
const
float
*
src
=
encoder_out
->
GetTensorData
<
float
>
();
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
std
::
copy
(
src
+
t
*
encoder_out_dim
,
src
+
(
t
+
1
)
*
encoder_out_dim
,
dst
);
src
+=
offset
;
dst
+=
encoder_out_dim
;
}
return
ans
;
}
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder
::
GetEmptyResult
()
const
{
int32_t
context_size
=
model_
->
ContextSize
();
...
...
@@ -90,7 +55,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort
::
Value
decoder_out
=
model_
->
RunDecoder
(
std
::
move
(
decoder_input
));
for
(
int32_t
t
=
0
;
t
!=
num_frames
;
++
t
)
{
Ort
::
Value
cur_encoder_out
=
GetFrame
(
model_
->
Allocator
(),
&
encoder_out
,
t
);
Ort
::
Value
cur_encoder_out
=
GetEncoderOutFrame
(
model_
->
Allocator
(),
&
encoder_out
,
t
);
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
model_
->
Allocator
(),
&
decoder_out
));
...
...
sherpa-onnx/csrc/online-transducer-model.cc
查看文件 @
5326d0f
// sherpa-onnx/csrc/online-transducer-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2023 Pingfeng Luo
#include "sherpa-onnx/csrc/online-transducer-model.h"
#if __ANDROID_API__ >= 9
...
...
@@ -8,6 +9,7 @@
#include "android/asset_manager_jni.h"
#endif
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
...
...
@@ -75,6 +77,40 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
return
nullptr
;
}
Ort
::
Value
OnlineTransducerModel
::
BuildDecoderInput
(
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
)
{
int32_t
batch_size
=
static_cast
<
int32_t
>
(
results
.
size
());
int32_t
context_size
=
ContextSize
();
std
::
array
<
int64_t
,
2
>
shape
{
batch_size
,
context_size
};
Ort
::
Value
decoder_input
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
Allocator
(),
shape
.
data
(),
shape
.
size
());
int64_t
*
p
=
decoder_input
.
GetTensorMutableData
<
int64_t
>
();
for
(
const
auto
&
r
:
results
)
{
const
int64_t
*
begin
=
r
.
tokens
.
data
()
+
r
.
tokens
.
size
()
-
context_size
;
const
int64_t
*
end
=
r
.
tokens
.
data
()
+
r
.
tokens
.
size
();
std
::
copy
(
begin
,
end
,
p
);
p
+=
context_size
;
}
return
decoder_input
;
}
Ort
::
Value
OnlineTransducerModel
::
BuildDecoderInput
(
const
std
::
vector
<
Hypothesis
>
&
hyps
)
{
int32_t
batch_size
=
static_cast
<
int32_t
>
(
hyps
.
size
());
int32_t
context_size
=
ContextSize
();
std
::
array
<
int64_t
,
2
>
shape
{
batch_size
,
context_size
};
Ort
::
Value
decoder_input
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
Allocator
(),
shape
.
data
(),
shape
.
size
());
int64_t
*
p
=
decoder_input
.
GetTensorMutableData
<
int64_t
>
();
for
(
const
auto
&
h
:
hyps
)
{
std
::
copy
(
h
.
ys
.
end
()
-
context_size
,
h
.
ys
.
end
(),
p
);
p
+=
context_size
;
}
return
decoder_input
;
}
#if __ANDROID_API__ >= 9
std
::
unique_ptr
<
OnlineTransducerModel
>
OnlineTransducerModel
::
Create
(
AAssetManager
*
mgr
,
const
OnlineTransducerModelConfig
&
config
)
{
...
...
sherpa-onnx/csrc/online-transducer-model.h
查看文件 @
5326d0f
...
...
@@ -14,6 +14,8 @@
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
namespace
sherpa_onnx
{
...
...
@@ -71,9 +73,6 @@ class OnlineTransducerModel {
Ort
::
Value
features
,
std
::
vector
<
Ort
::
Value
>
states
)
=
0
;
// NOLINT
virtual
Ort
::
Value
BuildDecoderInput
(
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
)
=
0
;
/** Run the decoder network.
*
* Caution: We assume there are no recurrent connections in the decoder and
...
...
@@ -125,7 +124,13 @@ class OnlineTransducerModel {
virtual
int32_t
VocabSize
()
const
=
0
;
virtual
int32_t
SubsamplingFactor
()
const
{
return
4
;
}
virtual
OrtAllocator
*
Allocator
()
=
0
;
Ort
::
Value
BuildDecoderInput
(
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
);
Ort
::
Value
BuildDecoderInput
(
const
std
::
vector
<
Hypothesis
>
&
hyps
);
};
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
0 → 100644
查看文件 @
5326d0f
// sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
//
// Copyright (c) 2023 Pingfeng Luo
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace
sherpa_onnx
{
static
Ort
::
Value
Repeat
(
OrtAllocator
*
allocator
,
Ort
::
Value
*
cur_encoder_out
,
const
std
::
vector
<
int32_t
>
&
hyps_num_split
)
{
std
::
vector
<
int64_t
>
cur_encoder_out_shape
=
cur_encoder_out
->
GetTensorTypeAndShapeInfo
().
GetShape
();
std
::
array
<
int64_t
,
2
>
ans_shape
{
hyps_num_split
.
back
(),
cur_encoder_out_shape
[
1
]};
Ort
::
Value
ans
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
ans_shape
.
data
(),
ans_shape
.
size
());
const
float
*
src
=
cur_encoder_out
->
GetTensorData
<
float
>
();
float
*
dst
=
ans
.
GetTensorMutableData
<
float
>
();
int32_t
batch_size
=
static_cast
<
int32_t
>
(
hyps_num_split
.
size
())
-
1
;
for
(
int32_t
b
=
0
;
b
!=
batch_size
;
++
b
)
{
int32_t
cur_stream_hyps_num
=
hyps_num_split
[
b
+
1
]
-
hyps_num_split
[
b
];
for
(
int32_t
i
=
0
;
i
!=
cur_stream_hyps_num
;
++
i
)
{
std
::
copy
(
src
,
src
+
cur_encoder_out_shape
[
1
],
dst
);
dst
+=
cur_encoder_out_shape
[
1
];
}
src
+=
cur_encoder_out_shape
[
1
];
}
return
ans
;
}
static
void
LogSoftmax
(
float
*
in
,
int32_t
w
,
int32_t
h
)
{
for
(
int32_t
i
=
0
;
i
!=
h
;
++
i
)
{
LogSoftmax
(
in
,
w
);
in
+=
w
;
}
}
OnlineTransducerDecoderResult
OnlineTransducerModifiedBeamSearchDecoder
::
GetEmptyResult
()
const
{
int32_t
context_size
=
model_
->
ContextSize
();
int32_t
blank_id
=
0
;
// always 0
OnlineTransducerDecoderResult
r
;
std
::
vector
<
int32_t
>
blanks
(
context_size
,
blank_id
);
Hypotheses
blank_hyp
({{
blanks
,
0
}});
r
.
hyps
=
std
::
move
(
blank_hyp
);
return
r
;
}
void
OnlineTransducerModifiedBeamSearchDecoder
::
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
r
)
const
{
int32_t
context_size
=
model_
->
ContextSize
();
auto
hyp
=
r
->
hyps
.
GetMostProbable
(
true
);
std
::
vector
<
int64_t
>
tokens
(
hyp
.
ys
.
begin
()
+
context_size
,
hyp
.
ys
.
end
());
r
->
tokens
=
std
::
move
(
tokens
);
r
->
num_trailing_blanks
=
hyp
.
num_trailing_blanks
;
}
void
OnlineTransducerModifiedBeamSearchDecoder
::
Decode
(
Ort
::
Value
encoder_out
,
std
::
vector
<
OnlineTransducerDecoderResult
>
*
result
)
{
std
::
vector
<
int64_t
>
encoder_out_shape
=
encoder_out
.
GetTensorTypeAndShapeInfo
().
GetShape
();
if
(
encoder_out_shape
[
0
]
!=
result
->
size
())
{
fprintf
(
stderr
,
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d
\n
"
,
static_cast
<
int32_t
>
(
encoder_out_shape
[
0
]),
static_cast
<
int32_t
>
(
result
->
size
()));
exit
(
-
1
);
}
int32_t
batch_size
=
static_cast
<
int32_t
>
(
encoder_out_shape
[
0
]);
int32_t
num_frames
=
static_cast
<
int32_t
>
(
encoder_out_shape
[
1
]);
int32_t
vocab_size
=
model_
->
VocabSize
();
std
::
vector
<
Hypotheses
>
cur
;
for
(
auto
&
r
:
*
result
)
{
cur
.
push_back
(
std
::
move
(
r
.
hyps
));
}
std
::
vector
<
Hypothesis
>
prev
;
for
(
int32_t
t
=
0
;
t
!=
num_frames
;
++
t
)
{
// Due to merging paths with identical token sequences,
// not all utterances have "num_active_paths" paths.
int32_t
hyps_num_acc
=
0
;
std
::
vector
<
int32_t
>
hyps_num_split
;
hyps_num_split
.
push_back
(
0
);
prev
.
clear
();
for
(
auto
&
hyps
:
cur
)
{
for
(
auto
&
h
:
hyps
)
{
prev
.
push_back
(
std
::
move
(
h
.
second
));
hyps_num_acc
++
;
}
hyps_num_split
.
push_back
(
hyps_num_acc
);
}
cur
.
clear
();
cur
.
reserve
(
batch_size
);
Ort
::
Value
decoder_input
=
model_
->
BuildDecoderInput
(
prev
);
Ort
::
Value
decoder_out
=
model_
->
RunDecoder
(
std
::
move
(
decoder_input
));
Ort
::
Value
cur_encoder_out
=
GetEncoderOutFrame
(
model_
->
Allocator
(),
&
encoder_out
,
t
);
cur_encoder_out
=
Repeat
(
model_
->
Allocator
(),
&
cur_encoder_out
,
hyps_num_split
);
Ort
::
Value
logit
=
model_
->
RunJoiner
(
std
::
move
(
cur_encoder_out
),
Clone
(
model_
->
Allocator
(),
&
decoder_out
));
float
*
p_logit
=
logit
.
GetTensorMutableData
<
float
>
();
for
(
int32_t
b
=
0
;
b
<
batch_size
;
++
b
)
{
int32_t
start
=
hyps_num_split
[
b
];
int32_t
end
=
hyps_num_split
[
b
+
1
];
LogSoftmax
(
p_logit
,
vocab_size
,
(
end
-
start
));
auto
topk
=
TopkIndex
(
p_logit
,
vocab_size
*
(
end
-
start
),
max_active_paths_
);
Hypotheses
hyps
;
for
(
auto
i
:
topk
)
{
int32_t
hyp_index
=
i
/
vocab_size
+
start
;
int32_t
new_token
=
i
%
vocab_size
;
Hypothesis
new_hyp
=
prev
[
hyp_index
];
if
(
new_token
!=
0
)
{
new_hyp
.
ys
.
push_back
(
new_token
);
new_hyp
.
num_trailing_blanks
=
0
;
}
else
{
++
new_hyp
.
num_trailing_blanks
;
}
new_hyp
.
log_prob
+=
p_logit
[
i
];
hyps
.
Add
(
std
::
move
(
new_hyp
));
}
cur
.
push_back
(
std
::
move
(
hyps
));
p_logit
+=
vocab_size
*
(
end
-
start
);
}
}
for
(
int32_t
b
=
0
;
b
!=
batch_size
;
++
b
)
{
(
*
result
)[
b
].
hyps
=
std
::
move
(
cur
[
b
]);
}
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h
0 → 100644
查看文件 @
5326d0f
// sherpa-onnx/csrc/online-transducer-modified_beam-search-decoder.h
//
// Copyright (c) 2023 Pingfeng Luo
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace
sherpa_onnx
{
class
OnlineTransducerModifiedBeamSearchDecoder
:
public
OnlineTransducerDecoder
{
public
:
OnlineTransducerModifiedBeamSearchDecoder
(
OnlineTransducerModel
*
model
,
int32_t
max_active_paths
)
:
model_
(
model
),
max_active_paths_
(
max_active_paths
)
{}
OnlineTransducerDecoderResult
GetEmptyResult
()
const
override
;
void
StripLeadingBlanks
(
OnlineTransducerDecoderResult
*
r
)
const
override
;
void
Decode
(
Ort
::
Value
encoder_out
,
std
::
vector
<
OnlineTransducerDecoderResult
>
*
result
)
override
;
private
:
OnlineTransducerModel
*
model_
;
// Not owned
int32_t
max_active_paths_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
...
...
sherpa-onnx/csrc/online-zipformer-transducer-model.cc
查看文件 @
5326d0f
...
...
@@ -461,24 +461,6 @@ OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
return
{
std
::
move
(
encoder_out
[
0
]),
std
::
move
(
next_states
)};
}
Ort
::
Value
OnlineZipformerTransducerModel
::
BuildDecoderInput
(
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
)
{
int32_t
batch_size
=
static_cast
<
int32_t
>
(
results
.
size
());
std
::
array
<
int64_t
,
2
>
shape
{
batch_size
,
context_size_
};
Ort
::
Value
decoder_input
=
Ort
::
Value
::
CreateTensor
<
int64_t
>
(
allocator_
,
shape
.
data
(),
shape
.
size
());
int64_t
*
p
=
decoder_input
.
GetTensorMutableData
<
int64_t
>
();
for
(
const
auto
&
r
:
results
)
{
const
int64_t
*
begin
=
r
.
tokens
.
data
()
+
r
.
tokens
.
size
()
-
context_size_
;
const
int64_t
*
end
=
r
.
tokens
.
data
()
+
r
.
tokens
.
size
();
std
::
copy
(
begin
,
end
,
p
);
p
+=
context_size_
;
}
return
decoder_input
;
}
Ort
::
Value
OnlineZipformerTransducerModel
::
RunDecoder
(
Ort
::
Value
decoder_input
)
{
auto
decoder_out
=
decoder_sess_
->
Run
(
...
...
sherpa-onnx/csrc/online-zipformer-transducer-model.h
查看文件 @
5326d0f
...
...
@@ -41,9 +41,6 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
std
::
pair
<
Ort
::
Value
,
std
::
vector
<
Ort
::
Value
>>
RunEncoder
(
Ort
::
Value
features
,
std
::
vector
<
Ort
::
Value
>
states
)
override
;
Ort
::
Value
BuildDecoderInput
(
const
std
::
vector
<
OnlineTransducerDecoderResult
>
&
results
)
override
;
Ort
::
Value
RunDecoder
(
Ort
::
Value
decoder_input
)
override
;
Ort
::
Value
RunJoiner
(
Ort
::
Value
encoder_out
,
Ort
::
Value
decoder_out
)
override
;
...
...
sherpa-onnx/csrc/onnx-utils.cc
查看文件 @
5326d0f
...
...
@@ -44,6 +44,38 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
}
}
Ort
::
Value
GetEncoderOutFrame
(
OrtAllocator
*
allocator
,
Ort
::
Value
*
encoder_out
,
int32_t
t
)
{
std
::
vector
<
int64_t
>
encoder_out_shape
=
encoder_out
->
GetTensorTypeAndShapeInfo
().
GetShape
();
auto
batch_size
=
encoder_out_shape
[
0
];
auto
num_frames
=
encoder_out_shape
[
1
];
assert
(
t
<
num_frames
);
auto
encoder_out_dim
=
encoder_out_shape
[
2
];
auto
offset
=
num_frames
*
encoder_out_dim
;
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
array
<
int64_t
,
2
>
shape
{
batch_size
,
encoder_out_dim
};
Ort
::
Value
ans
=
Ort
::
Value
::
CreateTensor
<
float
>
(
allocator
,
shape
.
data
(),
shape
.
size
());
float
*
dst
=
ans
.
GetTensorMutableData
<
float
>
();
const
float
*
src
=
encoder_out
->
GetTensorData
<
float
>
();
for
(
int32_t
i
=
0
;
i
!=
batch_size
;
++
i
)
{
std
::
copy
(
src
+
t
*
encoder_out_dim
,
src
+
(
t
+
1
)
*
encoder_out_dim
,
dst
);
src
+=
offset
;
dst
+=
encoder_out_dim
;
}
return
ans
;
}
void
PrintModelMetadata
(
std
::
ostream
&
os
,
const
Ort
::
ModelMetadata
&
meta_data
)
{
Ort
::
AllocatorWithDefaultOptions
allocator
;
std
::
vector
<
Ort
::
AllocatedStringPtr
>
v
=
...
...
sherpa-onnx/csrc/onnx-utils.h
查看文件 @
5326d0f
...
...
@@ -10,6 +10,7 @@
#include <locale>
#endif
#include <cassert>
#include <ostream>
#include <string>
#include <vector>
...
...
@@ -57,6 +58,17 @@ void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
void
GetOutputNames
(
Ort
::
Session
*
sess
,
std
::
vector
<
std
::
string
>
*
output_names
,
std
::
vector
<
const
char
*>
*
output_names_ptr
);
/**
* Get the output frame of Encoder
*
* @param allocator allocator of onnxruntime
* @param encoder_out encoder out tensor
* @param t frame_index
*
*/
Ort
::
Value
GetEncoderOutFrame
(
OrtAllocator
*
allocator
,
Ort
::
Value
*
encoder_out
,
int32_t
t
);
void
PrintModelMetadata
(
std
::
ostream
&
os
,
const
Ort
::
ModelMetadata
&
meta_data
);
// NOLINT
...
...
请
注册
或
登录
后发表评论