math.h
3.0 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Daniel Povey)
* Copyright (c) 2023 (Pingfeng Luo)
*
*/
// This file is copied from k2/csrc/utils.h
#ifndef SHERPA_ONNX_CSRC_MATH_H_
#define SHERPA_ONNX_CSRC_MATH_H_
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <vector>
namespace sherpa_onnx {
// logf(FLT_EPSILON)
#define SHERPA_ONNX_MIN_LOG_DIFF_FLOAT -15.9423847198486328125f
// log(DBL_EPSILON)
#define SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE \
-36.0436533891171535515240975655615329742431640625
template <typename T>
struct LogAdd;
template <>
struct LogAdd<double> {
double operator()(double x, double y) const {
double diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) {
double res;
res = x + log1p(exp(diff));
return res;
}
return x; // return the larger one.
}
};
template <>
struct LogAdd<float> {
float operator()(float x, float y) const {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) {
float res;
res = x + log1pf(expf(diff));
return res;
}
return x; // return the larger one.
}
};
template <class T>
void LogSoftmax(T *input, int32_t input_len) {
assert(input);
T m = *std::max_element(input, input + input_len);
T sum = 0.0;
for (int32_t i = 0; i < input_len; i++) {
sum += exp(input[i] - m);
}
T offset = m + log(sum);
for (int32_t i = 0; i < input_len; i++) {
input[i] -= offset;
}
}
template <typename T>
void LogSoftmax(T *in, int32_t w, int32_t h) {
for (int32_t i = 0; i != h; ++i) {
LogSoftmax(in, w);
in += w;
}
}
template <typename T>
void SubtractBlank(T *in, int32_t w, int32_t h, int32_t blank_idx,
float blank_penalty) {
for (int32_t i = 0; i != h; ++i) {
in[blank_idx] -= blank_penalty;
in += w;
}
}
template <class T>
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
std::vector<int32_t> vec_index(size);
std::iota(vec_index.begin(), vec_index.end(), 0);
std::partial_sort(vec_index.begin(), vec_index.begin() + topk,
vec_index.end(), [vec](int32_t index_1, int32_t index_2) {
return vec[index_1] > vec[index_2];
});
int32_t k_num = std::min<int32_t>(size, topk);
return {vec_index.begin(), vec_index.begin() + k_num};
}
template <class T>
std::vector<int32_t> TopkIndex(const std::vector<std::vector<T>> &vec,
int32_t topk) {
std::vector<T> flatten;
flatten.reserve(vec.size() * vec[0].size());
for (const auto &v : vec) {
flatten.insert(flatten.end(), v.begin(), v.end());
}
return TopkIndex(flatten.data(), flatten.size(), topk);
}
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_MATH_H_