packed-sequence-test.cc
1.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// sherpa-onnx/csrc/packed-sequence-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/packed-sequence.h"
#include <numeric>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
TEST(PackedSequence, Case1) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> shape{5, 5, 4};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
Ort::Value length =
Ort::Value::CreateTensor<int64_t>(allocator, shape.data(), 1);
int64_t *p_length = length.GetTensorMutableData<int64_t>();
p_length[0] = 1;
p_length[1] = 2;
p_length[2] = 3;
p_length[3] = 5;
p_length[4] = 2;
auto packed_seq = PackPaddedSequence(allocator, &v, &length);
fprintf(stderr, "sorted indexes: ");
for (auto i : packed_seq.sorted_indexes) {
fprintf(stderr, "%d ", static_cast<int32_t>(i));
}
fprintf(stderr, "\n");
// output index: 0 1 2 3 4
// sorted indexes: 3 2 1 4 0
// length: 5 3 2 2 1
Print3D(&v);
Print2D(&packed_seq.data);
fprintf(stderr, "batch sizes per time step: ");
for (auto i : packed_seq.batch_sizes) {
fprintf(stderr, "%d ", static_cast<int32_t>(i));
}
fprintf(stderr, "\n");
// TODO(fangjun): Check that the return value is correct
}
} // namespace sherpa_onnx