正在显示
4 个修改的文件
包含
129 行增加
和
0 行删除
| @@ -16,6 +16,7 @@ set(sources | @@ -16,6 +16,7 @@ set(sources | ||
| 16 | online-transducer-modified-beam-search-decoder.cc | 16 | online-transducer-modified-beam-search-decoder.cc |
| 17 | online-zipformer-transducer-model.cc | 17 | online-zipformer-transducer-model.cc |
| 18 | onnx-utils.cc | 18 | onnx-utils.cc |
| 19 | + pad-sequence.cc | ||
| 19 | parse-options.cc | 20 | parse-options.cc |
| 20 | resample.cc | 21 | resample.cc |
| 21 | slice.cc | 22 | slice.cc |
| @@ -122,6 +123,7 @@ endif() | @@ -122,6 +123,7 @@ endif() | ||
| 122 | if(SHERPA_ONNX_ENABLE_TESTS) | 123 | if(SHERPA_ONNX_ENABLE_TESTS) |
| 123 | set(sherpa_onnx_test_srcs | 124 | set(sherpa_onnx_test_srcs |
| 124 | cat-test.cc | 125 | cat-test.cc |
| 126 | + pad-sequence-test.cc | ||
| 125 | slice-test.cc | 127 | slice-test.cc |
| 126 | transpose-test.cc | 128 | transpose-test.cc |
| 127 | unbind-test.cc | 129 | unbind-test.cc |
sherpa-onnx/csrc/pad-sequence-test.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/pad-sequence-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 6 | + | ||
| 7 | +#include <numeric> | ||
| 8 | + | ||
| 9 | +#include "gtest/gtest.h" | ||
| 10 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +TEST(PadSequence, ThreeTensors) { | ||
| 15 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 16 | + | ||
| 17 | + std::array<int64_t, 2> shape1{3, 5}; | ||
| 18 | + Ort::Value v1 = | ||
| 19 | + Ort::Value::CreateTensor<float>(allocator, shape1.data(), shape1.size()); | ||
| 20 | + float *p1 = v1.GetTensorMutableData<float>(); | ||
| 21 | + std::iota(p1, p1 + shape1[0] * shape1[1], 0); | ||
| 22 | + | ||
| 23 | + std::array<int64_t, 2> shape2{4, 5}; | ||
| 24 | + Ort::Value v2 = | ||
| 25 | + Ort::Value::CreateTensor<float>(allocator, shape2.data(), shape2.size()); | ||
| 26 | + float *p2 = v2.GetTensorMutableData<float>(); | ||
| 27 | + std::iota(p2, p2 + shape2[0] * shape2[1], 0); | ||
| 28 | + | ||
| 29 | + std::array<int64_t, 2> shape3{2, 5}; | ||
| 30 | + Ort::Value v3 = | ||
| 31 | + Ort::Value::CreateTensor<float>(allocator, shape3.data(), shape3.size()); | ||
| 32 | + float *p3 = v3.GetTensorMutableData<float>(); | ||
| 33 | + std::iota(p3, p3 + shape3[0] * shape3[1], 0); | ||
| 34 | + | ||
| 35 | + auto ans = PadSequence(allocator, {&v1, &v2, &v3}, -1); | ||
| 36 | + | ||
| 37 | + Print2D(&v1); | ||
| 38 | + Print2D(&v2); | ||
| 39 | + Print2D(&v3); | ||
| 40 | + Print3D(&ans); | ||
| 41 | +} | ||
| 42 | + | ||
| 43 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/pad-sequence.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/pad-sequence.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <algorithm> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +Ort::Value PadSequence(OrtAllocator *allocator, | ||
| 15 | + const std::vector<const Ort::Value *> &values, | ||
| 16 | + float padding_value) { | ||
| 17 | + int32_t batch_size = static_cast<int32_t>(values.size()); | ||
| 18 | + | ||
| 19 | + std::vector<int64_t> shape0 = | ||
| 20 | + values[0]->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 21 | + assert(shape0.size() == 2); | ||
| 22 | + | ||
| 23 | + auto feature_dim = shape0[1]; | ||
| 24 | + auto max_T = shape0[0]; | ||
| 25 | + | ||
| 26 | + for (int32_t i = 1; i != batch_size; ++i) { | ||
| 27 | + auto shape = values[i]->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 28 | + | ||
| 29 | + assert(shape.size() == 2); | ||
| 30 | + assert(shape[1] == feature_dim); | ||
| 31 | + | ||
| 32 | + max_T = std::max(max_T, shape[0]); | ||
| 33 | + } | ||
| 34 | + std::array<int64_t, 3> ans_shape{batch_size, max_T, feature_dim}; | ||
| 35 | + | ||
| 36 | + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(), | ||
| 37 | + ans_shape.size()); | ||
| 38 | + float *dst = ans.GetTensorMutableData<float>(); | ||
| 39 | + std::fill(dst, dst + batch_size * max_T * feature_dim, padding_value); | ||
| 40 | + | ||
| 41 | + for (const auto *v : values) { | ||
| 42 | + const float *src = v->GetTensorData<float>(); | ||
| 43 | + auto shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 44 | + std::copy(src, src + shape[0] * shape[1], dst); | ||
| 45 | + dst += max_T * feature_dim; | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + return ans; | ||
| 49 | + | ||
| 50 | + // TODO(fangjun): Check that the returned value is correct. | ||
| 51 | +} | ||
| 52 | + | ||
| 53 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/pad-sequence.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/pad-sequence.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_ | ||
| 6 | + | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only | ||
| 14 | + * batch_first=true. | ||
| 15 | + * | ||
| 16 | + * @param allocator | ||
| 17 | + * @param values A list of 2-D tensors. Each tensor's second dimension | ||
| 18 | + * must be the same and the data type of each tensor should | ||
| 19 | + * be float. | ||
| 20 | + * @param padding_value Value used for padding. For log-fbank, you usually use | ||
| 21 | + * -23.025850929940457f as the padding value. | ||
| 22 | + * | ||
| 23 | + * @return Return a 3-D tensor of shape (B, max_T, C). | ||
| 24 | + */ | ||
| 25 | +Ort::Value PadSequence(OrtAllocator *allocator, | ||
| 26 | + const std::vector<const Ort::Value *> &values, | ||
| 27 | + float padding_value); | ||
| 28 | + | ||
| 29 | +} // namespace sherpa_onnx | ||
| 30 | + | ||
| 31 | +#endif // SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_ |
-
请 注册 或 登录 后发表评论