正在显示
7 个修改的文件
包含
236 行增加
和
21 行删除
| @@ -16,6 +16,7 @@ set(sources | @@ -16,6 +16,7 @@ set(sources | ||
| 16 | online-transducer-modified-beam-search-decoder.cc | 16 | online-transducer-modified-beam-search-decoder.cc |
| 17 | online-zipformer-transducer-model.cc | 17 | online-zipformer-transducer-model.cc |
| 18 | onnx-utils.cc | 18 | onnx-utils.cc |
| 19 | + packed-sequence.cc | ||
| 19 | pad-sequence.cc | 20 | pad-sequence.cc |
| 20 | parse-options.cc | 21 | parse-options.cc |
| 21 | resample.cc | 22 | resample.cc |
| @@ -123,6 +124,7 @@ endif() | @@ -123,6 +124,7 @@ endif() | ||
| 123 | if(SHERPA_ONNX_ENABLE_TESTS) | 124 | if(SHERPA_ONNX_ENABLE_TESTS) |
| 124 | set(sherpa_onnx_test_srcs | 125 | set(sherpa_onnx_test_srcs |
| 125 | cat-test.cc | 126 | cat-test.cc |
| 127 | + packed-sequence-test.cc | ||
| 126 | pad-sequence-test.cc | 128 | pad-sequence-test.cc |
| 127 | slice-test.cc | 129 | slice-test.cc |
| 128 | transpose-test.cc | 130 | transpose-test.cc |
sherpa-onnx/csrc/packed-sequence-test.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/packed-sequence-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/packed-sequence.h" | ||
| 6 | + | ||
| 7 | +#include <numeric> | ||
| 8 | + | ||
| 9 | +#include "gtest/gtest.h" | ||
| 10 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +TEST(PackedSequence, Case1) { | ||
| 15 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 16 | + std::array<int64_t, 3> shape{5, 5, 4}; | ||
| 17 | + Ort::Value v = | ||
| 18 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 19 | + float *p = v.GetTensorMutableData<float>(); | ||
| 20 | + | ||
| 21 | + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); | ||
| 22 | + | ||
| 23 | + Ort::Value length = | ||
| 24 | + Ort::Value::CreateTensor<int64_t>(allocator, shape.data(), 1); | ||
| 25 | + int64_t *p_length = length.GetTensorMutableData<int64_t>(); | ||
| 26 | + p_length[0] = 1; | ||
| 27 | + p_length[1] = 2; | ||
| 28 | + p_length[2] = 3; | ||
| 29 | + p_length[3] = 5; | ||
| 30 | + p_length[4] = 2; | ||
| 31 | + | ||
| 32 | + auto packed_seq = PackPaddedSequence(allocator, &v, &length); | ||
| 33 | + fprintf(stderr, "sorted indexes: "); | ||
| 34 | + for (auto i : packed_seq.sorted_indexes) { | ||
| 35 | + fprintf(stderr, "%d ", static_cast<int32_t>(i)); | ||
| 36 | + } | ||
| 37 | + fprintf(stderr, "\n"); | ||
| 38 | + // output index: 0 1 2 3 4 | ||
| 39 | + // sorted indexes: 3 2 1 4 0 | ||
| 40 | + // length: 5 3 2 2 1 | ||
| 41 | + Print3D(&v); | ||
| 42 | + Print2D(&packed_seq.data); | ||
| 43 | + fprintf(stderr, "batch sizes per time step: "); | ||
| 44 | + for (auto i : packed_seq.batch_sizes) { | ||
| 45 | + fprintf(stderr, "%d ", static_cast<int32_t>(i)); | ||
| 46 | + } | ||
| 47 | + fprintf(stderr, "\n"); | ||
| 48 | + | ||
| 49 | + // TODO(fangjun): Check that the return value is correct | ||
| 50 | +} | ||
| 51 | + | ||
| 52 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/packed-sequence.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/packed-sequence.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/packed-sequence.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <algorithm> | ||
| 10 | +#include <numeric> | ||
| 11 | +#include <utility> | ||
| 12 | + | ||
| 13 | +#include "sherpa-onnx/csrc/slice.h" | ||
| 14 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +static Ort::Value IndexSelect(OrtAllocator *allocator, const Ort::Value *value, | ||
| 19 | + const std::vector<int32_t> &sorted_indexes) { | ||
| 20 | + auto shape = value->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 21 | + assert(shape.size() == 3); | ||
| 22 | + std::array<int64_t, 3> ans_shape{static_cast<int64_t>(sorted_indexes.size()), | ||
| 23 | + shape[1], shape[2]}; | ||
| 24 | + | ||
| 25 | + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(), | ||
| 26 | + ans_shape.size()); | ||
| 27 | + float *dst = ans.GetTensorMutableData<float>(); | ||
| 28 | + const float *src = value->GetTensorData<float>(); | ||
| 29 | + | ||
| 30 | + for (auto i : sorted_indexes) { | ||
| 31 | + const float *start = src + i * shape[1] * shape[2]; | ||
| 32 | + std::copy(start, start + shape[1] * shape[2], dst); | ||
| 33 | + dst += shape[1] * shape[2]; | ||
| 34 | + } | ||
| 35 | + return ans; | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +PackedSequence PackPaddedSequence(OrtAllocator *allocator, | ||
| 39 | + const Ort::Value *value, Ort::Value *length) { | ||
| 40 | + std::vector<int64_t> v_shape = value->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 41 | + std::vector<int64_t> l_shape = length->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 42 | + | ||
| 43 | + assert(v_shape.size() == 3); | ||
| 44 | + assert(l_shape.size() == 3); | ||
| 45 | + assert(v_shape[0] == l_shape[0]); | ||
| 46 | + | ||
| 47 | + std::vector<int32_t> indexes(v_shape[0]); | ||
| 48 | + std::iota(indexes.begin(), indexes.end(), 0); | ||
| 49 | + | ||
| 50 | + const int64_t *p_length = length->GetTensorData<int64_t>(); | ||
| 51 | + // sort in descending order | ||
| 52 | + std::sort(indexes.begin(), indexes.end(), [p_length](int32_t i, int32_t j) { | ||
| 53 | + return p_length[i] > p_length[j]; | ||
| 54 | + }); | ||
| 55 | + | ||
| 56 | + int32_t n = static_cast<int32_t>(v_shape[0]); | ||
| 57 | + | ||
| 58 | + int64_t max_T = p_length[indexes[0]]; | ||
| 59 | + | ||
| 60 | + int32_t sum_T = std::accumulate(p_length, p_length + n, 0); | ||
| 61 | + | ||
| 62 | + std::array<int64_t, 2> data_shape{sum_T, v_shape[2]}; | ||
| 63 | + | ||
| 64 | + Ort::Value data = Ort::Value::CreateTensor<float>( | ||
| 65 | + allocator, data_shape.data(), data_shape.size()); | ||
| 66 | + float *dst = data.GetTensorMutableData<float>(); | ||
| 67 | + | ||
| 68 | + Ort::Value tensor = IndexSelect(allocator, value, indexes); | ||
| 69 | + tensor = Transpose01(allocator, &tensor); | ||
| 70 | + | ||
| 71 | + // batch size at each time step | ||
| 72 | + std::vector<int32_t> batch_sizes; | ||
| 73 | + batch_sizes.reserve(max_T); | ||
| 74 | + | ||
| 75 | + int64_t prev_l = 0; | ||
| 76 | + for (int32_t i = 0; i != n; ++i) { | ||
| 77 | + auto cur_l = p_length[indexes[n - 1 - i]]; | ||
| 78 | + assert(cur_l >= prev_l); | ||
| 79 | + if (cur_l == prev_l) { | ||
| 80 | + continue; | ||
| 81 | + } | ||
| 82 | + | ||
| 83 | + auto cur_batch_size = n - i; | ||
| 84 | + | ||
| 85 | + Ort::Value cur_batch = | ||
| 86 | + Slice(allocator, &tensor, prev_l, cur_l, 0, cur_batch_size); | ||
| 87 | + auto count = cur_batch.GetTensorTypeAndShapeInfo().GetElementCount(); | ||
| 88 | + const float *src = cur_batch.GetTensorData<float>(); | ||
| 89 | + std::copy(src, src + count, dst); | ||
| 90 | + dst += count; | ||
| 91 | + | ||
| 92 | + for (int32_t j = prev_l; j < cur_l; ++j) { | ||
| 93 | + batch_sizes.push_back(cur_batch_size); | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + prev_l = cur_l; | ||
| 97 | + } | ||
| 98 | + | ||
| 99 | + PackedSequence packed_seq; | ||
| 100 | + packed_seq.sorted_indexes = std::move(indexes); | ||
| 101 | + packed_seq.data = std::move(data); | ||
| 102 | + packed_seq.batch_sizes = std::move(batch_sizes); | ||
| 103 | + | ||
| 104 | + return packed_seq; | ||
| 105 | +} | ||
| 106 | + | ||
| 107 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/packed-sequence.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/packed-sequence.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_ | ||
| 6 | + | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct PackedSequence { | ||
| 14 | + std::vector<int32_t> sorted_indexes; | ||
| 15 | + std::vector<int32_t> batch_sizes; | ||
| 16 | + Ort::Value data{nullptr}; | ||
| 17 | +}; | ||
| 18 | + | ||
| 19 | +/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only | ||
| 20 | + * batch_first=true. | ||
| 21 | + * | ||
| 22 | + * @param allocator | ||
| 23 | + * @param value A 3-D tensor of shape (B, T, C). Its dtype is float. | ||
| 24 | + * @param length A 1-D tensor of shape (B,). Its dtype is int64_t. Each | ||
| 25 | + * element in it specifies the valid length of the corresponding | ||
| 26 | + * entry in value before padding. | ||
| 27 | + */ | ||
| 28 | +PackedSequence PackPaddedSequence(OrtAllocator *allocator, | ||
| 29 | + const Ort::Value *value, Ort::Value *length); | ||
| 30 | + | ||
| 31 | +} // namespace sherpa_onnx | ||
| 32 | + | ||
| 33 | +#endif // SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_ |
| @@ -13,19 +13,19 @@ namespace sherpa_onnx { | @@ -13,19 +13,19 @@ namespace sherpa_onnx { | ||
| 13 | 13 | ||
| 14 | TEST(Slice, Slice3D) { | 14 | TEST(Slice, Slice3D) { |
| 15 | Ort::AllocatorWithDefaultOptions allocator; | 15 | Ort::AllocatorWithDefaultOptions allocator; |
| 16 | - std::array<int64_t, 3> shape{3, 5, 4}; | 16 | + std::array<int64_t, 3> shape{5, 5, 4}; |
| 17 | Ort::Value v = | 17 | Ort::Value v = |
| 18 | Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | 18 | Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); |
| 19 | float *p = v.GetTensorMutableData<float>(); | 19 | float *p = v.GetTensorMutableData<float>(); |
| 20 | 20 | ||
| 21 | std::iota(p, p + shape[0] * shape[1] * shape[2], 0); | 21 | std::iota(p, p + shape[0] * shape[1] * shape[2], 0); |
| 22 | 22 | ||
| 23 | - auto v1 = Slice(&v, 0, 2, 5); | ||
| 24 | - auto v2 = Slice(&v, 1, 2, 4); | 23 | + auto v1 = Slice(allocator, &v, 2, 4, 0, 2); |
| 24 | + auto v2 = Slice(allocator, &v, 1, 3, 1, 3); | ||
| 25 | 25 | ||
| 26 | Print3D(&v); | 26 | Print3D(&v); |
| 27 | - Print2D(&v1); | ||
| 28 | - Print2D(&v2); | 27 | + Print3D(&v1); |
| 28 | + Print3D(&v2); | ||
| 29 | 29 | ||
| 30 | // TODO(fangjun): Check that the results are correct | 30 | // TODO(fangjun): Check that the results are correct |
| 31 | } | 31 | } |
| @@ -6,29 +6,48 @@ | @@ -6,29 +6,48 @@ | ||
| 6 | 6 | ||
| 7 | #include <assert.h> | 7 | #include <assert.h> |
| 8 | 8 | ||
| 9 | +#include <algorithm> | ||
| 9 | #include <vector> | 10 | #include <vector> |
| 10 | 11 | ||
| 11 | namespace sherpa_onnx { | 12 | namespace sherpa_onnx { |
| 12 | 13 | ||
| 13 | template <typename T /*=float*/> | 14 | template <typename T /*=float*/> |
| 14 | -Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start, | 15 | +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, |
| 16 | + int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, | ||
| 15 | int32_t dim1_end) { | 17 | int32_t dim1_end) { |
| 16 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | 18 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); |
| 17 | assert(shape.size() == 3); | 19 | assert(shape.size() == 3); |
| 18 | 20 | ||
| 19 | - auto memory_info = | ||
| 20 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 21 | + assert(0 <= dim0_start); |
| 22 | + assert(dim0_start < dim0_end); | ||
| 23 | + assert(dim0_end <= shape[0]); | ||
| 24 | + | ||
| 25 | + assert(0 <= dim1_start); | ||
| 26 | + assert(dim1_start < dim1_end); | ||
| 27 | + assert(dim1_end < shape[1]); | ||
| 21 | 28 | ||
| 22 | - std::array<int64_t, 2> ans_shape{dim1_end - dim1_start, shape[2]}; | ||
| 23 | const T *src = v->GetTensorData<T>(); | 29 | const T *src = v->GetTensorData<T>(); |
| 24 | - src += dim0 * shape[1] * shape[2] + dim1_start * shape[2]; | ||
| 25 | 30 | ||
| 26 | - return Ort::Value::CreateTensor(memory_info, const_cast<T *>(src), | ||
| 27 | - ans_shape[0] * ans_shape[1], ans_shape.data(), | ||
| 28 | - ans_shape.size()); | 31 | + std::array<int64_t, 3> ans_shape{dim0_end - dim0_start, dim1_end - dim1_start, |
| 32 | + shape[2]}; | ||
| 33 | + | ||
| 34 | + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(), | ||
| 35 | + ans_shape.size()); | ||
| 36 | + T *dst = ans.GetTensorMutableData<T>(); | ||
| 37 | + for (int32_t i = dim0_start; i != dim0_end; ++i) { | ||
| 38 | + const T *src = v->GetTensorData<T>() + i * shape[1] * shape[2]; | ||
| 39 | + const T *start = src + dim1_start * shape[2]; | ||
| 40 | + const T *end = src + dim1_end * shape[2]; | ||
| 41 | + | ||
| 42 | + std::copy(start, end, dst); | ||
| 43 | + dst += ans_shape[1] * ans_shape[2]; | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + return ans; | ||
| 29 | } | 47 | } |
| 30 | 48 | ||
| 31 | -template Ort::Value Slice<float>(const Ort::Value *v, int32_t dim0, | 49 | +template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v, |
| 50 | + int32_t dim0_start, int32_t dim0_end, | ||
| 32 | int32_t dim1_start, int32_t dim1_end); | 51 | int32_t dim1_start, int32_t dim1_end); |
| 33 | 52 | ||
| 34 | } // namespace sherpa_onnx | 53 | } // namespace sherpa_onnx |
| @@ -8,21 +8,23 @@ | @@ -8,21 +8,23 @@ | ||
| 8 | 8 | ||
| 9 | namespace sherpa_onnx { | 9 | namespace sherpa_onnx { |
| 10 | 10 | ||
| 11 | -/** Get a shallow copy by slicing v. | 11 | +/** Get a deep copy by slicing v. |
| 12 | * | 12 | * |
| 13 | - * It returns v[dim0, dim1_start:dim1_end] | 13 | + * It returns v[dim0_start:dim0_end, dim1_start:dim1_end] |
| 14 | * | 14 | * |
| 15 | + * @param allocator | ||
| 15 | * @param v A 3-D tensor. Its data type is T. | 16 | * @param v A 3-D tensor. Its data type is T. |
| 16 | - * @param dim0 Start index of the first dimension.. | 17 | + * @param dim0_start Start index of the first dimension.. |
| 18 | + * @param dim0_end End index of the first dimension.. | ||
| 17 | * @param dim1_start Start index of the second dimension. | 19 | * @param dim1_start Start index of the second dimension. |
| 18 | * @param dim1_end End index of the second dimension. | 20 | * @param dim1_end End index of the second dimension. |
| 19 | * | 21 | * |
| 20 | - * @return Return a 2-D tensor of shape (dim1_end-dim1_start, v.shape[2]) | ||
| 21 | - * | ||
| 22 | - * @caution: The returned tensor is a shallow copy of `v`! | 22 | + * @return Return a 3-D tensor of shape |
| 23 | + * (dim0_end-dim0_start, dim1_end-dim1_start, v.shape[2]) | ||
| 23 | */ | 24 | */ |
| 24 | template <typename T = float> | 25 | template <typename T = float> |
| 25 | -Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start, | 26 | +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, |
| 27 | + int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, | ||
| 26 | int32_t dim1_end); | 28 | int32_t dim1_end); |
| 27 | } // namespace sherpa_onnx | 29 | } // namespace sherpa_onnx |
| 28 | 30 |
-
请 注册 或 登录 后发表评论