Fangjun Kuang
Committed by GitHub

Add PackPaddedSequence (#85)

@@ -16,6 +16,7 @@ set(sources @@ -16,6 +16,7 @@ set(sources
16 online-transducer-modified-beam-search-decoder.cc 16 online-transducer-modified-beam-search-decoder.cc
17 online-zipformer-transducer-model.cc 17 online-zipformer-transducer-model.cc
18 onnx-utils.cc 18 onnx-utils.cc
  19 + packed-sequence.cc
19 pad-sequence.cc 20 pad-sequence.cc
20 parse-options.cc 21 parse-options.cc
21 resample.cc 22 resample.cc
@@ -123,6 +124,7 @@ endif() @@ -123,6 +124,7 @@ endif()
123 if(SHERPA_ONNX_ENABLE_TESTS) 124 if(SHERPA_ONNX_ENABLE_TESTS)
124 set(sherpa_onnx_test_srcs 125 set(sherpa_onnx_test_srcs
125 cat-test.cc 126 cat-test.cc
  127 + packed-sequence-test.cc
126 pad-sequence-test.cc 128 pad-sequence-test.cc
127 slice-test.cc 129 slice-test.cc
128 transpose-test.cc 130 transpose-test.cc
  1 +// sherpa-onnx/csrc/packed-sequence-test.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/packed-sequence.h"
  6 +
  7 +#include <numeric>
  8 +
  9 +#include "gtest/gtest.h"
  10 +#include "sherpa-onnx/csrc/onnx-utils.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +TEST(PackedSequence, Case1) {
  15 + Ort::AllocatorWithDefaultOptions allocator;
  16 + std::array<int64_t, 3> shape{5, 5, 4};
  17 + Ort::Value v =
  18 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  19 + float *p = v.GetTensorMutableData<float>();
  20 +
  21 + std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
  22 +
  23 + Ort::Value length =
  24 + Ort::Value::CreateTensor<int64_t>(allocator, shape.data(), 1);
  25 + int64_t *p_length = length.GetTensorMutableData<int64_t>();
  26 + p_length[0] = 1;
  27 + p_length[1] = 2;
  28 + p_length[2] = 3;
  29 + p_length[3] = 5;
  30 + p_length[4] = 2;
  31 +
  32 + auto packed_seq = PackPaddedSequence(allocator, &v, &length);
  33 + fprintf(stderr, "sorted indexes: ");
  34 + for (auto i : packed_seq.sorted_indexes) {
  35 + fprintf(stderr, "%d ", static_cast<int32_t>(i));
  36 + }
  37 + fprintf(stderr, "\n");
  38 + // output index: 0 1 2 3 4
  39 + // sorted indexes: 3 2 1 4 0
  40 + // length: 5 3 2 2 1
  41 + Print3D(&v);
  42 + Print2D(&packed_seq.data);
  43 + fprintf(stderr, "batch sizes per time step: ");
  44 + for (auto i : packed_seq.batch_sizes) {
  45 + fprintf(stderr, "%d ", static_cast<int32_t>(i));
  46 + }
  47 + fprintf(stderr, "\n");
  48 +
  49 + // TODO(fangjun): Check that the return value is correct
  50 +}
  51 +
  52 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/packed-sequence.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/packed-sequence.h"
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <algorithm>
  10 +#include <numeric>
  11 +#include <utility>
  12 +
  13 +#include "sherpa-onnx/csrc/slice.h"
  14 +#include "sherpa-onnx/csrc/transpose.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +static Ort::Value IndexSelect(OrtAllocator *allocator, const Ort::Value *value,
  19 + const std::vector<int32_t> &sorted_indexes) {
  20 + auto shape = value->GetTensorTypeAndShapeInfo().GetShape();
  21 + assert(shape.size() == 3);
  22 + std::array<int64_t, 3> ans_shape{static_cast<int64_t>(sorted_indexes.size()),
  23 + shape[1], shape[2]};
  24 +
  25 + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
  26 + ans_shape.size());
  27 + float *dst = ans.GetTensorMutableData<float>();
  28 + const float *src = value->GetTensorData<float>();
  29 +
  30 + for (auto i : sorted_indexes) {
  31 + const float *start = src + i * shape[1] * shape[2];
  32 + std::copy(start, start + shape[1] * shape[2], dst);
  33 + dst += shape[1] * shape[2];
  34 + }
  35 + return ans;
  36 +}
  37 +
  38 +PackedSequence PackPaddedSequence(OrtAllocator *allocator,
  39 + const Ort::Value *value, Ort::Value *length) {
  40 + std::vector<int64_t> v_shape = value->GetTensorTypeAndShapeInfo().GetShape();
  41 + std::vector<int64_t> l_shape = length->GetTensorTypeAndShapeInfo().GetShape();
  42 +
  43 + assert(v_shape.size() == 3);
  44 + assert(l_shape.size() == 3);
  45 + assert(v_shape[0] == l_shape[0]);
  46 +
  47 + std::vector<int32_t> indexes(v_shape[0]);
  48 + std::iota(indexes.begin(), indexes.end(), 0);
  49 +
  50 + const int64_t *p_length = length->GetTensorData<int64_t>();
  51 + // sort in descending order
  52 + std::sort(indexes.begin(), indexes.end(), [p_length](int32_t i, int32_t j) {
  53 + return p_length[i] > p_length[j];
  54 + });
  55 +
  56 + int32_t n = static_cast<int32_t>(v_shape[0]);
  57 +
  58 + int64_t max_T = p_length[indexes[0]];
  59 +
  60 + int32_t sum_T = std::accumulate(p_length, p_length + n, 0);
  61 +
  62 + std::array<int64_t, 2> data_shape{sum_T, v_shape[2]};
  63 +
  64 + Ort::Value data = Ort::Value::CreateTensor<float>(
  65 + allocator, data_shape.data(), data_shape.size());
  66 + float *dst = data.GetTensorMutableData<float>();
  67 +
  68 + Ort::Value tensor = IndexSelect(allocator, value, indexes);
  69 + tensor = Transpose01(allocator, &tensor);
  70 +
  71 + // batch size at each time step
  72 + std::vector<int32_t> batch_sizes;
  73 + batch_sizes.reserve(max_T);
  74 +
  75 + int64_t prev_l = 0;
  76 + for (int32_t i = 0; i != n; ++i) {
  77 + auto cur_l = p_length[indexes[n - 1 - i]];
  78 + assert(cur_l >= prev_l);
  79 + if (cur_l == prev_l) {
  80 + continue;
  81 + }
  82 +
  83 + auto cur_batch_size = n - i;
  84 +
  85 + Ort::Value cur_batch =
  86 + Slice(allocator, &tensor, prev_l, cur_l, 0, cur_batch_size);
  87 + auto count = cur_batch.GetTensorTypeAndShapeInfo().GetElementCount();
  88 + const float *src = cur_batch.GetTensorData<float>();
  89 + std::copy(src, src + count, dst);
  90 + dst += count;
  91 +
  92 + for (int32_t j = prev_l; j < cur_l; ++j) {
  93 + batch_sizes.push_back(cur_batch_size);
  94 + }
  95 +
  96 + prev_l = cur_l;
  97 + }
  98 +
  99 + PackedSequence packed_seq;
  100 + packed_seq.sorted_indexes = std::move(indexes);
  101 + packed_seq.data = std::move(data);
  102 + packed_seq.batch_sizes = std::move(batch_sizes);
  103 +
  104 + return packed_seq;
  105 +}
  106 +
  107 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/packed-sequence.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_
  5 +#define SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_
  6 +
  7 +#include <vector>
  8 +
  9 +#include "onnxruntime_cxx_api.h" // NOLINT
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct PackedSequence {
  14 + std::vector<int32_t> sorted_indexes;
  15 + std::vector<int32_t> batch_sizes;
  16 + Ort::Value data{nullptr};
  17 +};
  18 +
  19 +/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only
  20 + * batch_first=true.
  21 + *
  22 + * @param allocator
  23 + * @param value A 3-D tensor of shape (B, T, C). Its dtype is float.
  24 + * @param length A 1-D tensor of shape (B,). Its dtype is int64_t. Each
  25 + * element in it specifies the valid length of the corresponding
  26 + * entry in value before padding.
  27 + */
  28 +PackedSequence PackPaddedSequence(OrtAllocator *allocator,
  29 + const Ort::Value *value, Ort::Value *length);
  30 +
  31 +} // namespace sherpa_onnx
  32 +
  33 +#endif // SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_
@@ -13,19 +13,19 @@ namespace sherpa_onnx { @@ -13,19 +13,19 @@ namespace sherpa_onnx {
13 13
14 TEST(Slice, Slice3D) { 14 TEST(Slice, Slice3D) {
15 Ort::AllocatorWithDefaultOptions allocator; 15 Ort::AllocatorWithDefaultOptions allocator;
16 - std::array<int64_t, 3> shape{3, 5, 4}; 16 + std::array<int64_t, 3> shape{5, 5, 4};
17 Ort::Value v = 17 Ort::Value v =
18 Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); 18 Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
19 float *p = v.GetTensorMutableData<float>(); 19 float *p = v.GetTensorMutableData<float>();
20 20
21 std::iota(p, p + shape[0] * shape[1] * shape[2], 0); 21 std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
22 22
23 - auto v1 = Slice(&v, 0, 2, 5);  
24 - auto v2 = Slice(&v, 1, 2, 4); 23 + auto v1 = Slice(allocator, &v, 2, 4, 0, 2);
  24 + auto v2 = Slice(allocator, &v, 1, 3, 1, 3);
25 25
26 Print3D(&v); 26 Print3D(&v);
27 - Print2D(&v1);  
28 - Print2D(&v2); 27 + Print3D(&v1);
  28 + Print3D(&v2);
29 29
30 // TODO(fangjun): Check that the results are correct 30 // TODO(fangjun): Check that the results are correct
31 } 31 }
@@ -6,29 +6,48 @@ @@ -6,29 +6,48 @@
6 6
7 #include <assert.h> 7 #include <assert.h>
8 8
  9 +#include <algorithm>
9 #include <vector> 10 #include <vector>
10 11
11 namespace sherpa_onnx { 12 namespace sherpa_onnx {
12 13
13 template <typename T /*=float*/> 14 template <typename T /*=float*/>
14 -Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start, 15 +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
  16 + int32_t dim0_start, int32_t dim0_end, int32_t dim1_start,
15 int32_t dim1_end) { 17 int32_t dim1_end) {
16 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); 18 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
17 assert(shape.size() == 3); 19 assert(shape.size() == 3);
18 20
19 - auto memory_info =  
20 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 21 + assert(0 <= dim0_start);
  22 + assert(dim0_start < dim0_end);
  23 + assert(dim0_end <= shape[0]);
  24 +
  25 + assert(0 <= dim1_start);
  26 + assert(dim1_start < dim1_end);
  27 + assert(dim1_end < shape[1]);
21 28
22 - std::array<int64_t, 2> ans_shape{dim1_end - dim1_start, shape[2]};  
23 const T *src = v->GetTensorData<T>(); 29 const T *src = v->GetTensorData<T>();
24 - src += dim0 * shape[1] * shape[2] + dim1_start * shape[2];  
25 30
26 - return Ort::Value::CreateTensor(memory_info, const_cast<T *>(src),  
27 - ans_shape[0] * ans_shape[1], ans_shape.data(), 31 + std::array<int64_t, 3> ans_shape{dim0_end - dim0_start, dim1_end - dim1_start,
  32 + shape[2]};
  33 +
  34 + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
28 ans_shape.size()); 35 ans_shape.size());
  36 + T *dst = ans.GetTensorMutableData<T>();
  37 + for (int32_t i = dim0_start; i != dim0_end; ++i) {
  38 + const T *src = v->GetTensorData<T>() + i * shape[1] * shape[2];
  39 + const T *start = src + dim1_start * shape[2];
  40 + const T *end = src + dim1_end * shape[2];
  41 +
  42 + std::copy(start, end, dst);
  43 + dst += ans_shape[1] * ans_shape[2];
  44 + }
  45 +
  46 + return ans;
29 } 47 }
30 48
31 -template Ort::Value Slice<float>(const Ort::Value *v, int32_t dim0, 49 +template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v,
  50 + int32_t dim0_start, int32_t dim0_end,
32 int32_t dim1_start, int32_t dim1_end); 51 int32_t dim1_start, int32_t dim1_end);
33 52
34 } // namespace sherpa_onnx 53 } // namespace sherpa_onnx
@@ -8,21 +8,23 @@ @@ -8,21 +8,23 @@
8 8
9 namespace sherpa_onnx { 9 namespace sherpa_onnx {
10 10
11 -/** Get a shallow copy by slicing v. 11 +/** Get a deep copy by slicing v.
12 * 12 *
13 - * It returns v[dim0, dim1_start:dim1_end] 13 + * It returns v[dim0_start:dim0_end, dim1_start:dim1_end]
14 * 14 *
  15 + * @param allocator
15 * @param v A 3-D tensor. Its data type is T. 16 * @param v A 3-D tensor. Its data type is T.
16 - * @param dim0 Start index of the first dimension.. 17 + * @param dim0_start Start index of the first dimension..
  18 + * @param dim0_end End index of the first dimension..
17 * @param dim1_start Start index of the second dimension. 19 * @param dim1_start Start index of the second dimension.
18 * @param dim1_end End index of the second dimension. 20 * @param dim1_end End index of the second dimension.
19 * 21 *
20 - * @return Return a 2-D tensor of shape (dim1_end-dim1_start, v.shape[2])  
21 - *  
22 - * @caution: The returned tensor is a shallow copy of `v`! 22 + * @return Return a 3-D tensor of shape
  23 + * (dim0_end-dim0_start, dim1_end-dim1_start, v.shape[2])
23 */ 24 */
24 template <typename T = float> 25 template <typename T = float>
25 -Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start, 26 +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
  27 + int32_t dim0_start, int32_t dim0_end, int32_t dim1_start,
26 int32_t dim1_end); 28 int32_t dim1_end);
27 } // namespace sherpa_onnx 29 } // namespace sherpa_onnx
28 30