正在显示
11 个修改的文件
包含
110 行增加
和
0 行删除
| @@ -18,6 +18,7 @@ function(download_asio) | @@ -18,6 +18,7 @@ function(download_asio) | ||
| 18 | foreach(f IN LISTS possible_file_locations) | 18 | foreach(f IN LISTS possible_file_locations) |
| 19 | if(EXISTS ${f}) | 19 | if(EXISTS ${f}) |
| 20 | set(asio_URL "file://${f}") | 20 | set(asio_URL "file://${f}") |
| 21 | + set(asio_URL2) | ||
| 21 | break() | 22 | break() |
| 22 | endif() | 23 | endif() |
| 23 | endforeach() | 24 | endforeach() |
| @@ -18,6 +18,7 @@ function(download_googltest) | @@ -18,6 +18,7 @@ function(download_googltest) | ||
| 18 | foreach(f IN LISTS possible_file_locations) | 18 | foreach(f IN LISTS possible_file_locations) |
| 19 | if(EXISTS ${f}) | 19 | if(EXISTS ${f}) |
| 20 | set(googletest_URL "file://${f}") | 20 | set(googletest_URL "file://${f}") |
| 21 | + set(googletest_URL2) | ||
| 21 | break() | 22 | break() |
| 22 | endif() | 23 | endif() |
| 23 | endforeach() | 24 | endforeach() |
| @@ -22,6 +22,7 @@ function(download_kaldi_native_fbank) | @@ -22,6 +22,7 @@ function(download_kaldi_native_fbank) | ||
| 22 | foreach(f IN LISTS possible_file_locations) | 22 | foreach(f IN LISTS possible_file_locations) |
| 23 | if(EXISTS ${f}) | 23 | if(EXISTS ${f}) |
| 24 | set(kaldi_native_fbank_URL "file://${f}") | 24 | set(kaldi_native_fbank_URL "file://${f}") |
| 25 | + set(kaldi_native_fbank_URL2 ) | ||
| 25 | break() | 26 | break() |
| 26 | endif() | 27 | endif() |
| 27 | endforeach() | 28 | endforeach() |
| @@ -78,6 +78,7 @@ function(download_onnxruntime) | @@ -78,6 +78,7 @@ function(download_onnxruntime) | ||
| 78 | foreach(f IN LISTS possible_file_locations) | 78 | foreach(f IN LISTS possible_file_locations) |
| 79 | if(EXISTS ${f}) | 79 | if(EXISTS ${f}) |
| 80 | set(onnxruntime_URL "file://${f}") | 80 | set(onnxruntime_URL "file://${f}") |
| 81 | + set(onnxruntime_URL2) | ||
| 81 | break() | 82 | break() |
| 82 | endif() | 83 | endif() |
| 83 | endforeach() | 84 | endforeach() |
| @@ -19,6 +19,7 @@ function(download_portaudio) | @@ -19,6 +19,7 @@ function(download_portaudio) | ||
| 19 | foreach(f IN LISTS possible_file_locations) | 19 | foreach(f IN LISTS possible_file_locations) |
| 20 | if(EXISTS ${f}) | 20 | if(EXISTS ${f}) |
| 21 | set(portaudio_URL "file://${f}") | 21 | set(portaudio_URL "file://${f}") |
| 22 | + set(portaudio_URL2) | ||
| 22 | break() | 23 | break() |
| 23 | endif() | 24 | endif() |
| 24 | endforeach() | 25 | endforeach() |
| @@ -18,6 +18,7 @@ function(download_pybind11) | @@ -18,6 +18,7 @@ function(download_pybind11) | ||
| 18 | foreach(f IN LISTS possible_file_locations) | 18 | foreach(f IN LISTS possible_file_locations) |
| 19 | if(EXISTS ${f}) | 19 | if(EXISTS ${f}) |
| 20 | set(pybind11_URL "file://${f}") | 20 | set(pybind11_URL "file://${f}") |
| 21 | + set(pybind11_URL2) | ||
| 21 | break() | 22 | break() |
| 22 | endif() | 23 | endif() |
| 23 | endforeach() | 24 | endforeach() |
| @@ -19,6 +19,7 @@ function(download_websocketpp) | @@ -19,6 +19,7 @@ function(download_websocketpp) | ||
| 19 | foreach(f IN LISTS possible_file_locations) | 19 | foreach(f IN LISTS possible_file_locations) |
| 20 | if(EXISTS ${f}) | 20 | if(EXISTS ${f}) |
| 21 | set(websocketpp_URL "file://${f}") | 21 | set(websocketpp_URL "file://${f}") |
| 22 | + set(websocketpp_URL2) | ||
| 22 | break() | 23 | break() |
| 23 | endif() | 24 | endif() |
| 24 | endforeach() | 25 | endforeach() |
| @@ -20,6 +20,7 @@ set(sources | @@ -20,6 +20,7 @@ set(sources | ||
| 20 | resample.cc | 20 | resample.cc |
| 21 | symbol-table.cc | 21 | symbol-table.cc |
| 22 | text-utils.cc | 22 | text-utils.cc |
| 23 | + transpose.cc | ||
| 23 | unbind.cc | 24 | unbind.cc |
| 24 | wave-reader.cc | 25 | wave-reader.cc |
| 25 | ) | 26 | ) |
| @@ -120,6 +121,7 @@ endif() | @@ -120,6 +121,7 @@ endif() | ||
| 120 | if(SHERPA_ONNX_ENABLE_TESTS) | 121 | if(SHERPA_ONNX_ENABLE_TESTS) |
| 121 | set(sherpa_onnx_test_srcs | 122 | set(sherpa_onnx_test_srcs |
| 122 | cat-test.cc | 123 | cat-test.cc |
| 124 | + transpose-test.cc | ||
| 123 | unbind-test.cc | 125 | unbind-test.cc |
| 124 | ) | 126 | ) |
| 125 | 127 |
sherpa-onnx/csrc/transpose-test.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/transpose-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 6 | + | ||
| 7 | +#include <numeric> | ||
| 8 | + | ||
| 9 | +#include "gtest/gtest.h" | ||
| 10 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +TEST(Tranpose, Tranpose01) { | ||
| 15 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 16 | + std::array<int64_t, 3> shape{3, 2, 5}; | ||
| 17 | + Ort::Value v = | ||
| 18 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 19 | + float *p = v.GetTensorMutableData<float>(); | ||
| 20 | + | ||
| 21 | + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); | ||
| 22 | + | ||
| 23 | + auto ans = Transpose01(allocator, &v); | ||
| 24 | + auto v2 = Transpose01(allocator, &ans); | ||
| 25 | + | ||
| 26 | + Print3D(&v); | ||
| 27 | + Print3D(&ans); | ||
| 28 | + Print3D(&v2); | ||
| 29 | + | ||
| 30 | + const float *q = v2.GetTensorData<float>(); | ||
| 31 | + | ||
| 32 | + for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]); | ||
| 33 | + ++i) { | ||
| 34 | + EXPECT_EQ(p[i], q[i]); | ||
| 35 | + } | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/transpose.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/transpose.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <algorithm> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +template <typename T /*=float*/> | ||
| 15 | +Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) { | ||
| 16 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 17 | + assert(shape.size() == 3); | ||
| 18 | + | ||
| 19 | + std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]}; | ||
| 20 | + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(), | ||
| 21 | + ans_shape.size()); | ||
| 22 | + | ||
| 23 | + T *dst = ans.GetTensorMutableData<T>(); | ||
| 24 | + auto plane_offset = shape[1] * shape[2]; | ||
| 25 | + | ||
| 26 | + for (int64_t i = 0; i != ans_shape[0]; ++i) { | ||
| 27 | + const T *src = v->GetTensorData<T>() + i * shape[2]; | ||
| 28 | + for (int64_t k = 0; k != ans_shape[1]; ++k) { | ||
| 29 | + std::copy(src, src + shape[2], dst); | ||
| 30 | + src += plane_offset; | ||
| 31 | + dst += shape[2]; | ||
| 32 | + } | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + return ans; | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +template Ort::Value Transpose01<float>(OrtAllocator *allocator, | ||
| 39 | + const Ort::Value *v); | ||
| 40 | + | ||
| 41 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/transpose.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/transpose.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_TRANSPOSE_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_TRANSPOSE_H_ | ||
| 6 | + | ||
| 7 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | +/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C). | ||
| 11 | + * | ||
| 12 | + * @param allocator | ||
| 13 | + * @param v A 3-D tensor of shape (B, T, C). Its dataype is T. | ||
| 14 | + * | ||
| 15 | + * @return Return a 3-D tensor of shape (T, B, C). Its datatype is T. | ||
| 16 | + */ | ||
| 17 | +template <typename T = float> | ||
| 18 | +Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); | ||
| 19 | + | ||
| 20 | +} // namespace sherpa_onnx | ||
| 21 | + | ||
| 22 | +#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_ |
-
请 注册 或 登录 后发表评论