Committed by
GitHub
Support slicing a shallow copy of a 3-d tensor (#83)
正在显示
5 个修改的文件
包含
98 行增加
和
6 行删除
| @@ -18,6 +18,7 @@ set(sources | @@ -18,6 +18,7 @@ set(sources | ||
| 18 | onnx-utils.cc | 18 | onnx-utils.cc |
| 19 | parse-options.cc | 19 | parse-options.cc |
| 20 | resample.cc | 20 | resample.cc |
| 21 | + slice.cc | ||
| 21 | symbol-table.cc | 22 | symbol-table.cc |
| 22 | text-utils.cc | 23 | text-utils.cc |
| 23 | transpose.cc | 24 | transpose.cc |
| @@ -121,6 +122,7 @@ endif() | @@ -121,6 +122,7 @@ endif() | ||
| 121 | if(SHERPA_ONNX_ENABLE_TESTS) | 122 | if(SHERPA_ONNX_ENABLE_TESTS) |
| 122 | set(sherpa_onnx_test_srcs | 123 | set(sherpa_onnx_test_srcs |
| 123 | cat-test.cc | 124 | cat-test.cc |
| 125 | + slice-test.cc | ||
| 124 | transpose-test.cc | 126 | transpose-test.cc |
| 125 | unbind-test.cc | 127 | unbind-test.cc |
| 126 | ) | 128 | ) |
| @@ -57,9 +57,6 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, | @@ -57,9 +57,6 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, | ||
| 57 | 57 | ||
| 58 | auto offset = num_frames * encoder_out_dim; | 58 | auto offset = num_frames * encoder_out_dim; |
| 59 | 59 | ||
| 60 | - auto memory_info = | ||
| 61 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 62 | - | ||
| 63 | std::array<int64_t, 2> shape{batch_size, encoder_out_dim}; | 60 | std::array<int64_t, 2> shape{batch_size, encoder_out_dim}; |
| 64 | 61 | ||
| 65 | Ort::Value ans = | 62 | Ort::Value ans = |
| @@ -90,9 +87,6 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { | @@ -90,9 +87,6 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { | ||
| 90 | auto type_and_shape = v->GetTensorTypeAndShapeInfo(); | 87 | auto type_and_shape = v->GetTensorTypeAndShapeInfo(); |
| 91 | std::vector<int64_t> shape = type_and_shape.GetShape(); | 88 | std::vector<int64_t> shape = type_and_shape.GetShape(); |
| 92 | 89 | ||
| 93 | - auto memory_info = | ||
| 94 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 95 | - | ||
| 96 | switch (type_and_shape.GetElementType()) { | 90 | switch (type_and_shape.GetElementType()) { |
| 97 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { | 91 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { |
| 98 | Ort::Value ans = Ort::Value::CreateTensor<int32_t>( | 92 | Ort::Value ans = Ort::Value::CreateTensor<int32_t>( |
sherpa-onnx/csrc/slice-test.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/slice-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/slice.h" | ||
| 6 | + | ||
| 7 | +#include <numeric> | ||
| 8 | + | ||
| 9 | +#include "gtest/gtest.h" | ||
| 10 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +TEST(Slice, Slice3D) { | ||
| 15 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 16 | + std::array<int64_t, 3> shape{3, 5, 4}; | ||
| 17 | + Ort::Value v = | ||
| 18 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 19 | + float *p = v.GetTensorMutableData<float>(); | ||
| 20 | + | ||
| 21 | + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); | ||
| 22 | + | ||
| 23 | + auto v1 = Slice(&v, 0, 2, 5); | ||
| 24 | + auto v2 = Slice(&v, 1, 2, 4); | ||
| 25 | + | ||
| 26 | + Print3D(&v); | ||
| 27 | + Print2D(&v1); | ||
| 28 | + Print2D(&v2); | ||
| 29 | + | ||
| 30 | + // TODO(fangjun): Check that the results are correct | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/slice.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/slice.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/slice.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +template <typename T /*=float*/> | ||
| 14 | +Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start, | ||
| 15 | + int32_t dim1_end) { | ||
| 16 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 17 | + assert(shape.size() == 3); | ||
| 18 | + | ||
| 19 | + auto memory_info = | ||
| 20 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 21 | + | ||
| 22 | + std::array<int64_t, 2> ans_shape{dim1_end - dim1_start, shape[2]}; | ||
| 23 | + const T *src = v->GetTensorData<T>(); | ||
| 24 | + src += dim0 * shape[1] * shape[2] + dim1_start * shape[2]; | ||
| 25 | + | ||
| 26 | + return Ort::Value::CreateTensor(memory_info, const_cast<T *>(src), | ||
| 27 | + ans_shape[0] * ans_shape[1], ans_shape.data(), | ||
| 28 | + ans_shape.size()); | ||
| 29 | +} | ||
| 30 | + | ||
| 31 | +template Ort::Value Slice<float>(const Ort::Value *v, int32_t dim0, | ||
| 32 | + int32_t dim1_start, int32_t dim1_end); | ||
| 33 | + | ||
| 34 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/slice.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/slice.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_SLICE_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_SLICE_H_ | ||
| 6 | + | ||
| 7 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +/** Get a shallow copy by slicing v. | ||
| 12 | + * | ||
| 13 | + * It returns v[dim0, dim1_start:dim1_end] | ||
| 14 | + * | ||
| 15 | + * @param v A 3-D tensor. Its data type is T. | ||
| 16 | + * @param dim0 Start index of the first dimension.. | ||
| 17 | + * @param dim1_start Start index of the second dimension. | ||
| 18 | + * @param dim1_end End index of the second dimension. | ||
| 19 | + * | ||
| 20 | + * @return Return a 2-D tensor of shape (dim1_end-dim1_start, v.shape[2]) | ||
| 21 | + * | ||
| 22 | + * @caution: The returned tensor is a shallow copy of `v`! | ||
| 23 | + */ | ||
| 24 | +template <typename T = float> | ||
| 25 | +Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start, | ||
| 26 | + int32_t dim1_end); | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | + | ||
| 29 | +#endif // SHERPA_ONNX_CSRC_SLICE_H_ |
-
请 注册 或 登录 后发表评论