transpose.h 970 字节
// sherpa-onnx/csrc/transpose.h
//
// Copyright (c)  2023  Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_TRANSPOSE_H_
#define SHERPA_ONNX_CSRC_TRANSPOSE_H_

#include "onnxruntime_cxx_api.h"  // NOLINT

namespace sherpa_onnx {
/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
 *
 * @param allocator
 * @param v A 3-D tensor of shape (B, T, C). Its data type is type.
 *
 * @return Return a 3-D tensor of shape (T, B, C). Its data type is type.
 */
template <typename type = float>
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);

/** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T).
 *
 * @param allocator
 * @param v A 3-D tensor of shape (B, T, C). Its data type is type.
 *
 * @return Return a 3-D tensor of shape (B, C, T). Its data type is type.
 */
template <typename type = float>
Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v);

}  // namespace sherpa_onnx

#endif  // SHERPA_ONNX_CSRC_TRANSPOSE_H_