Fangjun Kuang
Committed by GitHub

add pad_sequence (#84)

@@ -16,6 +16,7 @@ set(sources @@ -16,6 +16,7 @@ set(sources
16 online-transducer-modified-beam-search-decoder.cc 16 online-transducer-modified-beam-search-decoder.cc
17 online-zipformer-transducer-model.cc 17 online-zipformer-transducer-model.cc
18 onnx-utils.cc 18 onnx-utils.cc
  19 + pad-sequence.cc
19 parse-options.cc 20 parse-options.cc
20 resample.cc 21 resample.cc
21 slice.cc 22 slice.cc
@@ -122,6 +123,7 @@ endif() @@ -122,6 +123,7 @@ endif()
122 if(SHERPA_ONNX_ENABLE_TESTS) 123 if(SHERPA_ONNX_ENABLE_TESTS)
123 set(sherpa_onnx_test_srcs 124 set(sherpa_onnx_test_srcs
124 cat-test.cc 125 cat-test.cc
  126 + pad-sequence-test.cc
125 slice-test.cc 127 slice-test.cc
126 transpose-test.cc 128 transpose-test.cc
127 unbind-test.cc 129 unbind-test.cc
  1 +// sherpa-onnx/csrc/pad-sequence-test.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/pad-sequence.h"
  6 +
  7 +#include <numeric>
  8 +
  9 +#include "gtest/gtest.h"
  10 +#include "sherpa-onnx/csrc/onnx-utils.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +TEST(PadSequence, ThreeTensors) {
  15 + Ort::AllocatorWithDefaultOptions allocator;
  16 +
  17 + std::array<int64_t, 2> shape1{3, 5};
  18 + Ort::Value v1 =
  19 + Ort::Value::CreateTensor<float>(allocator, shape1.data(), shape1.size());
  20 + float *p1 = v1.GetTensorMutableData<float>();
  21 + std::iota(p1, p1 + shape1[0] * shape1[1], 0);
  22 +
  23 + std::array<int64_t, 2> shape2{4, 5};
  24 + Ort::Value v2 =
  25 + Ort::Value::CreateTensor<float>(allocator, shape2.data(), shape2.size());
  26 + float *p2 = v2.GetTensorMutableData<float>();
  27 + std::iota(p2, p2 + shape2[0] * shape2[1], 0);
  28 +
  29 + std::array<int64_t, 2> shape3{2, 5};
  30 + Ort::Value v3 =
  31 + Ort::Value::CreateTensor<float>(allocator, shape3.data(), shape3.size());
  32 + float *p3 = v3.GetTensorMutableData<float>();
  33 + std::iota(p3, p3 + shape3[0] * shape3[1], 0);
  34 +
  35 + auto ans = PadSequence(allocator, {&v1, &v2, &v3}, -1);
  36 +
  37 + Print2D(&v1);
  38 + Print2D(&v2);
  39 + Print2D(&v3);
  40 + Print3D(&ans);
  41 +}
  42 +
  43 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/pad-sequence.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/pad-sequence.h"
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <algorithm>
  10 +#include <vector>
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +Ort::Value PadSequence(OrtAllocator *allocator,
  15 + const std::vector<const Ort::Value *> &values,
  16 + float padding_value) {
  17 + int32_t batch_size = static_cast<int32_t>(values.size());
  18 +
  19 + std::vector<int64_t> shape0 =
  20 + values[0]->GetTensorTypeAndShapeInfo().GetShape();
  21 + assert(shape0.size() == 2);
  22 +
  23 + auto feature_dim = shape0[1];
  24 + auto max_T = shape0[0];
  25 +
  26 + for (int32_t i = 1; i != batch_size; ++i) {
  27 + auto shape = values[i]->GetTensorTypeAndShapeInfo().GetShape();
  28 +
  29 + assert(shape.size() == 2);
  30 + assert(shape[1] == feature_dim);
  31 +
  32 + max_T = std::max(max_T, shape[0]);
  33 + }
  34 + std::array<int64_t, 3> ans_shape{batch_size, max_T, feature_dim};
  35 +
  36 + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
  37 + ans_shape.size());
  38 + float *dst = ans.GetTensorMutableData<float>();
  39 + std::fill(dst, dst + batch_size * max_T * feature_dim, padding_value);
  40 +
  41 + for (const auto *v : values) {
  42 + const float *src = v->GetTensorData<float>();
  43 + auto shape = v->GetTensorTypeAndShapeInfo().GetShape();
  44 + std::copy(src, src + shape[0] * shape[1], dst);
  45 + dst += max_T * feature_dim;
  46 + }
  47 +
  48 + return ans;
  49 +
  50 + // TODO(fangjun): Check that the returned value is correct.
  51 +}
  52 +
  53 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/pad-sequence.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_
  5 +#define SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_
  6 +
  7 +#include <vector>
  8 +
  9 +#include "onnxruntime_cxx_api.h" // NOLINT
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only
  14 + * batch_first=true.
  15 + *
  16 + * @param allocator
  17 + * @param values A list of 2-D tensors. Each tensor's second dimension
  18 + * must be the same and the data type of each tensor should
  19 + * be float.
  20 + * @param padding_value Value used for padding. For log-fbank, you usually use
  21 + * -23.025850929940457f as the padding value.
  22 + *
  23 + * @return Return a 3-D tensor of shape (B, max_T, C).
  24 + */
  25 +Ort::Value PadSequence(OrtAllocator *allocator,
  26 + const std::vector<const Ort::Value *> &values,
  27 + float padding_value);
  28 +
  29 +} // namespace sherpa_onnx
  30 +
  31 +#endif // SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_