Fangjun Kuang
Committed by GitHub

Add transpose (#82)

... ... @@ -18,6 +18,7 @@ function(download_asio)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(asio_URL "file://${f}")
set(asio_URL2)
break()
endif()
endforeach()
... ...
... ... @@ -18,6 +18,7 @@ function(download_googltest)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(googletest_URL "file://${f}")
set(googletest_URL2)
break()
endif()
endforeach()
... ...
... ... @@ -22,6 +22,7 @@ function(download_kaldi_native_fbank)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(kaldi_native_fbank_URL "file://${f}")
set(kaldi_native_fbank_URL2 )
break()
endif()
endforeach()
... ...
... ... @@ -78,6 +78,7 @@ function(download_onnxruntime)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(onnxruntime_URL "file://${f}")
set(onnxruntime_URL2)
break()
endif()
endforeach()
... ...
... ... @@ -19,6 +19,7 @@ function(download_portaudio)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(portaudio_URL "file://${f}")
set(portaudio_URL2)
break()
endif()
endforeach()
... ...
... ... @@ -18,6 +18,7 @@ function(download_pybind11)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(pybind11_URL "file://${f}")
set(pybind11_URL2)
break()
endif()
endforeach()
... ...
... ... @@ -19,6 +19,7 @@ function(download_websocketpp)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(websocketpp_URL "file://${f}")
set(websocketpp_URL2)
break()
endif()
endforeach()
... ...
... ... @@ -20,6 +20,7 @@ set(sources
resample.cc
symbol-table.cc
text-utils.cc
transpose.cc
unbind.cc
wave-reader.cc
)
... ... @@ -120,6 +121,7 @@ endif()
if(SHERPA_ONNX_ENABLE_TESTS)
set(sherpa_onnx_test_srcs
cat-test.cc
transpose-test.cc
unbind-test.cc
)
... ...
// sherpa-onnx/csrc/transpose-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/transpose.h"
#include <numeric>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
TEST(Tranpose, Tranpose01) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> shape{3, 2, 5};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
auto ans = Transpose01(allocator, &v);
auto v2 = Transpose01(allocator, &ans);
Print3D(&v);
Print3D(&ans);
Print3D(&v2);
const float *q = v2.GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
++i) {
EXPECT_EQ(p[i], q[i]);
}
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/transpose.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/transpose.h"
#include <assert.h>
#include <algorithm>
#include <vector>
namespace sherpa_onnx {
template <typename T /*=float*/>
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
assert(shape.size() == 3);
std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]};
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
ans_shape.size());
T *dst = ans.GetTensorMutableData<T>();
auto plane_offset = shape[1] * shape[2];
for (int64_t i = 0; i != ans_shape[0]; ++i) {
const T *src = v->GetTensorData<T>() + i * shape[2];
for (int64_t k = 0; k != ans_shape[1]; ++k) {
std::copy(src, src + shape[2], dst);
src += plane_offset;
dst += shape[2];
}
}
return ans;
}
template Ort::Value Transpose01<float>(OrtAllocator *allocator,
const Ort::Value *v);
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/transpose.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_TRANSPOSE_H_
#define SHERPA_ONNX_CSRC_TRANSPOSE_H_
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
*
* @param allocator
* @param v A 3-D tensor of shape (B, T, C). Its dataype is T.
*
* @return Return a 3-D tensor of shape (T, B, C). Its datatype is T.
*/
template <typename T = float>
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_
... ...