Fangjun Kuang
Committed by GitHub

Support slicing a shallow copy of a 3-d tensor (#83)

@@ -18,6 +18,7 @@ set(sources @@ -18,6 +18,7 @@ set(sources
18 onnx-utils.cc 18 onnx-utils.cc
19 parse-options.cc 19 parse-options.cc
20 resample.cc 20 resample.cc
  21 + slice.cc
21 symbol-table.cc 22 symbol-table.cc
22 text-utils.cc 23 text-utils.cc
23 transpose.cc 24 transpose.cc
@@ -121,6 +122,7 @@ endif() @@ -121,6 +122,7 @@ endif()
121 if(SHERPA_ONNX_ENABLE_TESTS) 122 if(SHERPA_ONNX_ENABLE_TESTS)
122 set(sherpa_onnx_test_srcs 123 set(sherpa_onnx_test_srcs
123 cat-test.cc 124 cat-test.cc
  125 + slice-test.cc
124 transpose-test.cc 126 transpose-test.cc
125 unbind-test.cc 127 unbind-test.cc
126 ) 128 )
@@ -57,9 +57,6 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, @@ -57,9 +57,6 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
57 57
58 auto offset = num_frames * encoder_out_dim; 58 auto offset = num_frames * encoder_out_dim;
59 59
60 - auto memory_info =  
61 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
62 -  
63 std::array<int64_t, 2> shape{batch_size, encoder_out_dim}; 60 std::array<int64_t, 2> shape{batch_size, encoder_out_dim};
64 61
65 Ort::Value ans = 62 Ort::Value ans =
@@ -90,9 +87,6 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { @@ -90,9 +87,6 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
90 auto type_and_shape = v->GetTensorTypeAndShapeInfo(); 87 auto type_and_shape = v->GetTensorTypeAndShapeInfo();
91 std::vector<int64_t> shape = type_and_shape.GetShape(); 88 std::vector<int64_t> shape = type_and_shape.GetShape();
92 89
93 - auto memory_info =  
94 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
95 -  
96 switch (type_and_shape.GetElementType()) { 90 switch (type_and_shape.GetElementType()) {
97 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { 91 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
98 Ort::Value ans = Ort::Value::CreateTensor<int32_t>( 92 Ort::Value ans = Ort::Value::CreateTensor<int32_t>(
  1 +// sherpa-onnx/csrc/slice-test.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/slice.h"
  6 +
  7 +#include <numeric>
  8 +
  9 +#include "gtest/gtest.h"
  10 +#include "sherpa-onnx/csrc/onnx-utils.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +TEST(Slice, Slice3D) {
  15 + Ort::AllocatorWithDefaultOptions allocator;
  16 + std::array<int64_t, 3> shape{3, 5, 4};
  17 + Ort::Value v =
  18 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  19 + float *p = v.GetTensorMutableData<float>();
  20 +
  21 + std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
  22 +
  23 + auto v1 = Slice(&v, 0, 2, 5);
  24 + auto v2 = Slice(&v, 1, 2, 4);
  25 +
  26 + Print3D(&v);
  27 + Print2D(&v1);
  28 + Print2D(&v2);
  29 +
  30 + // TODO(fangjun): Check that the results are correct
  31 +}
  32 +
  33 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/slice.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/slice.h"
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <vector>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +template <typename T /*=float*/>
  14 +Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start,
  15 + int32_t dim1_end) {
  16 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  17 + assert(shape.size() == 3);
  18 +
  19 + auto memory_info =
  20 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  21 +
  22 + std::array<int64_t, 2> ans_shape{dim1_end - dim1_start, shape[2]};
  23 + const T *src = v->GetTensorData<T>();
  24 + src += dim0 * shape[1] * shape[2] + dim1_start * shape[2];
  25 +
  26 + return Ort::Value::CreateTensor(memory_info, const_cast<T *>(src),
  27 + ans_shape[0] * ans_shape[1], ans_shape.data(),
  28 + ans_shape.size());
  29 +}
  30 +
  31 +template Ort::Value Slice<float>(const Ort::Value *v, int32_t dim0,
  32 + int32_t dim1_start, int32_t dim1_end);
  33 +
  34 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/slice.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_SLICE_H_
  5 +#define SHERPA_ONNX_CSRC_SLICE_H_
  6 +
  7 +#include "onnxruntime_cxx_api.h" // NOLINT
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +/** Get a shallow copy by slicing v.
  12 + *
  13 + * It returns v[dim0, dim1_start:dim1_end]
  14 + *
  15 + * @param v A 3-D tensor. Its data type is T.
  16 + * @param dim0 Start index of the first dimension..
  17 + * @param dim1_start Start index of the second dimension.
  18 + * @param dim1_end End index of the second dimension.
  19 + *
  20 + * @return Return a 2-D tensor of shape (dim1_end-dim1_start, v.shape[2])
  21 + *
  22 + * @caution: The returned tensor is a shallow copy of `v`!
  23 + */
  24 +template <typename T = float>
  25 +Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start,
  26 + int32_t dim1_end);
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_SLICE_H_