Fangjun Kuang
Committed by GitHub

Add transpose (#82)

@@ -18,6 +18,7 @@ function(download_asio) @@ -18,6 +18,7 @@ function(download_asio)
18 foreach(f IN LISTS possible_file_locations) 18 foreach(f IN LISTS possible_file_locations)
19 if(EXISTS ${f}) 19 if(EXISTS ${f})
20 set(asio_URL "file://${f}") 20 set(asio_URL "file://${f}")
  21 + set(asio_URL2)
21 break() 22 break()
22 endif() 23 endif()
23 endforeach() 24 endforeach()
@@ -18,6 +18,7 @@ function(download_googltest) @@ -18,6 +18,7 @@ function(download_googltest)
18 foreach(f IN LISTS possible_file_locations) 18 foreach(f IN LISTS possible_file_locations)
19 if(EXISTS ${f}) 19 if(EXISTS ${f})
20 set(googletest_URL "file://${f}") 20 set(googletest_URL "file://${f}")
  21 + set(googletest_URL2)
21 break() 22 break()
22 endif() 23 endif()
23 endforeach() 24 endforeach()
@@ -22,6 +22,7 @@ function(download_kaldi_native_fbank) @@ -22,6 +22,7 @@ function(download_kaldi_native_fbank)
22 foreach(f IN LISTS possible_file_locations) 22 foreach(f IN LISTS possible_file_locations)
23 if(EXISTS ${f}) 23 if(EXISTS ${f})
24 set(kaldi_native_fbank_URL "file://${f}") 24 set(kaldi_native_fbank_URL "file://${f}")
  25 + set(kaldi_native_fbank_URL2 )
25 break() 26 break()
26 endif() 27 endif()
27 endforeach() 28 endforeach()
@@ -78,6 +78,7 @@ function(download_onnxruntime) @@ -78,6 +78,7 @@ function(download_onnxruntime)
78 foreach(f IN LISTS possible_file_locations) 78 foreach(f IN LISTS possible_file_locations)
79 if(EXISTS ${f}) 79 if(EXISTS ${f})
80 set(onnxruntime_URL "file://${f}") 80 set(onnxruntime_URL "file://${f}")
  81 + set(onnxruntime_URL2)
81 break() 82 break()
82 endif() 83 endif()
83 endforeach() 84 endforeach()
@@ -19,6 +19,7 @@ function(download_portaudio) @@ -19,6 +19,7 @@ function(download_portaudio)
19 foreach(f IN LISTS possible_file_locations) 19 foreach(f IN LISTS possible_file_locations)
20 if(EXISTS ${f}) 20 if(EXISTS ${f})
21 set(portaudio_URL "file://${f}") 21 set(portaudio_URL "file://${f}")
  22 + set(portaudio_URL2)
22 break() 23 break()
23 endif() 24 endif()
24 endforeach() 25 endforeach()
@@ -18,6 +18,7 @@ function(download_pybind11) @@ -18,6 +18,7 @@ function(download_pybind11)
18 foreach(f IN LISTS possible_file_locations) 18 foreach(f IN LISTS possible_file_locations)
19 if(EXISTS ${f}) 19 if(EXISTS ${f})
20 set(pybind11_URL "file://${f}") 20 set(pybind11_URL "file://${f}")
  21 + set(pybind11_URL2)
21 break() 22 break()
22 endif() 23 endif()
23 endforeach() 24 endforeach()
@@ -19,6 +19,7 @@ function(download_websocketpp) @@ -19,6 +19,7 @@ function(download_websocketpp)
19 foreach(f IN LISTS possible_file_locations) 19 foreach(f IN LISTS possible_file_locations)
20 if(EXISTS ${f}) 20 if(EXISTS ${f})
21 set(websocketpp_URL "file://${f}") 21 set(websocketpp_URL "file://${f}")
  22 + set(websocketpp_URL2)
22 break() 23 break()
23 endif() 24 endif()
24 endforeach() 25 endforeach()
@@ -20,6 +20,7 @@ set(sources @@ -20,6 +20,7 @@ set(sources
20 resample.cc 20 resample.cc
21 symbol-table.cc 21 symbol-table.cc
22 text-utils.cc 22 text-utils.cc
  23 + transpose.cc
23 unbind.cc 24 unbind.cc
24 wave-reader.cc 25 wave-reader.cc
25 ) 26 )
@@ -120,6 +121,7 @@ endif() @@ -120,6 +121,7 @@ endif()
120 if(SHERPA_ONNX_ENABLE_TESTS) 121 if(SHERPA_ONNX_ENABLE_TESTS)
121 set(sherpa_onnx_test_srcs 122 set(sherpa_onnx_test_srcs
122 cat-test.cc 123 cat-test.cc
  124 + transpose-test.cc
123 unbind-test.cc 125 unbind-test.cc
124 ) 126 )
125 127
  1 +// sherpa-onnx/csrc/transpose-test.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/transpose.h"
  6 +
  7 +#include <numeric>
  8 +
  9 +#include "gtest/gtest.h"
  10 +#include "sherpa-onnx/csrc/onnx-utils.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +TEST(Tranpose, Tranpose01) {
  15 + Ort::AllocatorWithDefaultOptions allocator;
  16 + std::array<int64_t, 3> shape{3, 2, 5};
  17 + Ort::Value v =
  18 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  19 + float *p = v.GetTensorMutableData<float>();
  20 +
  21 + std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
  22 +
  23 + auto ans = Transpose01(allocator, &v);
  24 + auto v2 = Transpose01(allocator, &ans);
  25 +
  26 + Print3D(&v);
  27 + Print3D(&ans);
  28 + Print3D(&v2);
  29 +
  30 + const float *q = v2.GetTensorData<float>();
  31 +
  32 + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
  33 + ++i) {
  34 + EXPECT_EQ(p[i], q[i]);
  35 + }
  36 +}
  37 +
  38 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/transpose.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/transpose.h"
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <algorithm>
  10 +#include <vector>
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +template <typename T /*=float*/>
  15 +Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
  16 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  17 + assert(shape.size() == 3);
  18 +
  19 + std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]};
  20 + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
  21 + ans_shape.size());
  22 +
  23 + T *dst = ans.GetTensorMutableData<T>();
  24 + auto plane_offset = shape[1] * shape[2];
  25 +
  26 + for (int64_t i = 0; i != ans_shape[0]; ++i) {
  27 + const T *src = v->GetTensorData<T>() + i * shape[2];
  28 + for (int64_t k = 0; k != ans_shape[1]; ++k) {
  29 + std::copy(src, src + shape[2], dst);
  30 + src += plane_offset;
  31 + dst += shape[2];
  32 + }
  33 + }
  34 +
  35 + return ans;
  36 +}
  37 +
  38 +template Ort::Value Transpose01<float>(OrtAllocator *allocator,
  39 + const Ort::Value *v);
  40 +
  41 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/transpose.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_TRANSPOSE_H_
  5 +#define SHERPA_ONNX_CSRC_TRANSPOSE_H_
  6 +
  7 +#include "onnxruntime_cxx_api.h" // NOLINT
  8 +
  9 +namespace sherpa_onnx {
  10 +/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
  11 + *
  12 + * @param allocator
  13 + * @param v A 3-D tensor of shape (B, T, C). Its dataype is T.
  14 + *
  15 + * @return Return a 3-D tensor of shape (T, B, C). Its datatype is T.
  16 + */
  17 +template <typename T = float>
  18 +Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
  19 +
  20 +} // namespace sherpa_onnx
  21 +
  22 +#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_