Fangjun Kuang
Committed by GitHub

Fix c api (#2545)

This PR fixes the C API by adding proper support for durations in offline recognition results. The issue addresses problems introduced in a previous PR where the durations field was added to the C API struct but not properly handled across all language bindings.

Key changes:

- Adds durations field handling across multiple language bindings (Swift, Kotlin, Java, C#)
- Fixes field ordering in C API struct to ensure ABI compatibility
- Updates JNI implementation to properly extract and pass durations data
... ... @@ -40,7 +40,15 @@ function(download_cppinyin)
if(NOT cppinyin_POPULATED)
message(STATUS "Downloading cppinyin ${cppinyin_URL}")
FetchContent_Populate(cppinyin)
file(REMOVE ${cppinyin_SOURCE_DIR}/CMakeLists.txt)
configure_file(
${CMAKE_SOURCE_DIR}/cmake/cppinyin.patch
${cppinyin_SOURCE_DIR}/CMakeLists.txt
COPYONLY
)
endif()
message(STATUS "cppinyin is downloaded to ${cppinyin_SOURCE_DIR}")
if(BUILD_SHARED_LIBS)
... ...
cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
project(cppinyin)
set(CPPINYIN_VERSION "0.10")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
set(CMAKE_SKIP_BUILD_RPATH FALSE)
set(BUILD_RPATH_USE_ORIGIN TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
if(NOT APPLE)
set(CPPINYIN_RPATH_ORIGIN "$ORIGIN")
else()
set(CPPINYIN_RPATH_ORIGIN "@loader_path")
endif()
set(CMAKE_INSTALL_RPATH ${CPPINYIN_RPATH_ORIGIN})
set(CMAKE_BUILD_RPATH ${CPPINYIN_RPATH_ORIGIN})
option(CPPINYIN_ENABLE_TESTS "Whether to build tests" OFF)
option(CPPINYIN_BUILD_PYTHON "Whether to build Python" OFF)
option(BUILD_SHARED_LIBS "Whether to build shared libraries" ON)
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "No CMAKE_BUILD_TYPE given, default to Release")
set(CMAKE_BUILD_TYPE Release)
endif()
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include(CheckCXXCompilerFlag)
if(NOT WIN32)
check_cxx_compiler_flag("-std=c++14" CPPINYIN_COMPILER_SUPPORTS_CXX14)
else()
# windows x86 or x86_64
check_cxx_compiler_flag("/std:c++14" CPPINYIN_COMPILER_SUPPORTS_CXX14)
endif()
if(NOT CPPINYIN_COMPILER_SUPPORTS_CXX14)
message(FATAL_ERROR "
cppinyin requires a compiler supporting at least C++14.
If you are using GCC, please upgrade it to at least version 7.0.
If you are using Clang, please upgrade it to at least version 3.4.")
endif()
if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
endif()
set(CMAKE_CXX_EXTENSIONS OFF)
message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}")
if(CPPINYIN_BUILD_PYTHON)
include(pybind11)
endif()
include_directories(${CMAKE_SOURCE_DIR})
if(WIN32)
# disable various warnings for MSVC
# 4244: 'initializing': conversion from 'float' to 'int32_t',
# 4267: 'argument': conversion from 'size_t' to 'uint32_t', possible loss of data
set(disabled_warnings
/wd4244
/wd4267
)
message(STATUS "Disabled warnings: ${disabled_warnings}")
foreach(w IN LISTS disabled_warnings)
string(APPEND CMAKE_CXX_FLAGS " ${w} ")
endforeach()
endif()
if(CPPINYIN_ENABLE_TESTS)
include(googletest)
enable_testing()
endif()
add_subdirectory(cppinyin)
... ...
... ... @@ -52,7 +52,7 @@ function process_windows_x64() {
cd t
curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/$windows_x64_wheel_filename
unzip $windows_x64_wheel_filename
cp -v sherpa_onnx-${SHERPA_ONNX_VERSION}.data/data/bin/*.dll ../windows
cp -v sherpa_onnx/lib/*.dll ../windows
cd ..
rm -rf t
}
... ...
... ... @@ -77,6 +77,21 @@ namespace SherpaOnnx
}
}
unsafe
{
if (impl.Durations != IntPtr.Zero)
{
float *d = (float*)impl.Durations;
_durations = new float[impl.Count];
fixed (float* f = _durations)
{
for (int k = 0; k < impl.Count; k++)
{
f[k] = d[k];
}
}
}
}
}
[StructLayout(LayoutKind.Sequential)]
... ... @@ -86,6 +101,7 @@ namespace SherpaOnnx
public IntPtr Timestamps;
public int Count;
public IntPtr Tokens;
public IntPtr Durations;
}
private String _text;
... ... @@ -96,5 +112,8 @@ namespace SherpaOnnx
private float[] _timestamps;
public float[] Timestamps => _timestamps;
private float[] _durations;
public float[] Durations => _durations;
}
}
... ...
... ... @@ -134,7 +134,7 @@ if [ ! -f $src_dir/windows-x64/sherpa-onnx-c-api.dll ]; then
curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/$windows_x64_wheel_filename
fi
unzip $windows_x64_wheel_filename
cp -v sherpa_onnx-${SHERPA_ONNX_VERSION}.data/data/bin/*.dll ../
cp -v sherpa_onnx/lib/*.dll ../
cd ..
rm -rf wheel
... ...
... ... @@ -141,7 +141,7 @@ function windows() {
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win_amd64.whl
unzip ./sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win_amd64.whl
cp -v sherpa_onnx_core-${SHERPA_ONNX_VERSION}.data/data/Scripts/*.dll $dst
cp -v sherpa_onnx/lib/*.dll $dst
cd ..
rm -rf t
... ... @@ -153,7 +153,7 @@ function windows() {
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-wheels/resolve/main/cpu/$SHERPA_ONNX_VERSION/sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win32.whl
unzip ./sherpa_onnx_core-${SHERPA_ONNX_VERSION}-py3-none-win32.whl
cp -v sherpa_onnx_core-${SHERPA_ONNX_VERSION}.data/data/Scripts/*.dll $dst
cp -v sherpa_onnx/lib/*.dll $dst
cd ..
rm -rf t
... ...
... ... @@ -614,10 +614,6 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
// It is NULL if the model does not support timestamps
float *timestamps;
// Pointer to continuous memory which holds durations (in seconds) for each token
// It is NULL if the model does not support durations
float *durations;
// number of entries in timestamps
int32_t count;
... ... @@ -651,6 +647,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
// return event.
const char *event;
// Pointer to continuous memory which holds durations (in seconds) for each
// token It is NULL if the model does not support durations
float *durations;
} SherpaOnnxOfflineRecognizerResult;
/// Get the result of the offline stream.
... ...
... ... @@ -356,22 +356,28 @@ OfflineRecognizerResult OfflineRecognizer::GetResult(
ans.event = r->event ? r->event : "";
}
if (r->durations) {
ans.durations.resize(r->count);
std::copy(r->durations, r->durations + r->count, ans.durations.data());
}
SherpaOnnxDestroyOfflineRecognizerResult(r);
return ans;
}
std::shared_ptr<OfflineRecognizerResult> OfflineRecognizer::GetResultPtr(const OfflineStream *s) const
{
std::shared_ptr<OfflineRecognizerResult> OfflineRecognizer::GetResultPtr(
const OfflineStream *s) const {
auto r = SherpaOnnxGetOfflineStreamResult(s->Get());
OfflineRecognizerResult* ans = new OfflineRecognizerResult;
OfflineRecognizerResult *ans = new OfflineRecognizerResult;
if (r) {
ans->text = r->text;
if (r->timestamps) {
ans->timestamps.resize(r->count);
std::copy(r->timestamps, r->timestamps + r->count, ans->timestamps.data());
std::copy(r->timestamps, r->timestamps + r->count,
ans->timestamps.data());
}
ans->tokens.resize(r->count);
... ... @@ -484,8 +490,7 @@ std::shared_ptr<GeneratedAudio> OfflineTts::Generate2(
ans->samples = std::move(audio.samples);
ans->sample_rate = audio.sample_rate;
return std::shared_ptr<GeneratedAudio>(ans,
[](GeneratedAudio *p) { delete p; });
return std::shared_ptr<GeneratedAudio>(ans);
}
KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) {
... ... @@ -594,7 +599,6 @@ KeywordResult KeywordSpotter::GetResult(const OnlineStream *s) const {
return ans;
}
void KeywordSpotter::Reset(const OnlineStream *s) const {
SherpaOnnxResetKeywordStream(p_, s->Get());
}
... ... @@ -753,11 +757,10 @@ SpeechSegment VoiceActivityDetector::Front() const {
return segment;
}
std::shared_ptr<SpeechSegment> VoiceActivityDetector::FrontPtr() const
{
std::shared_ptr<SpeechSegment> VoiceActivityDetector::FrontPtr() const {
auto f = SherpaOnnxVoiceActivityDetectorFront(p_);
SpeechSegment* segment = new SpeechSegment;
SpeechSegment *segment = new SpeechSegment;
segment->start = f->start;
segment->samples = std::vector<float>{f->samples, f->samples + f->n};
... ... @@ -824,7 +827,8 @@ bool FileExists(const std::string &filename) {
// ============================================================
// For Offline Punctuation
// ============================================================
OfflinePunctuation OfflinePunctuation::Create(const OfflinePunctuationConfig &config) {
OfflinePunctuation OfflinePunctuation::Create(
const OfflinePunctuationConfig &config) {
struct SherpaOnnxOfflinePunctuationConfig c;
memset(&c, 0, sizeof(c));
c.model.ct_transformer = config.model.ct_transformer.c_str();
... ... @@ -832,12 +836,13 @@ OfflinePunctuation OfflinePunctuation::Create(const OfflinePunctuationConfig &co
c.model.debug = config.model.debug;
c.model.provider = config.model.provider.c_str();
const SherpaOnnxOfflinePunctuation *punct = SherpaOnnxCreateOfflinePunctuation(&c);
const SherpaOnnxOfflinePunctuation *punct =
SherpaOnnxCreateOfflinePunctuation(&c);
return OfflinePunctuation(punct);
}
OfflinePunctuation::OfflinePunctuation(const SherpaOnnxOfflinePunctuation *p)
: MoveOnly<OfflinePunctuation, SherpaOnnxOfflinePunctuation>(p) {}
: MoveOnly<OfflinePunctuation, SherpaOnnxOfflinePunctuation>(p) {}
void OfflinePunctuation::Destroy(const SherpaOnnxOfflinePunctuation *p) const {
SherpaOnnxDestroyOfflinePunctuation(p);
... ...
... ... @@ -319,6 +319,9 @@ struct SHERPA_ONNX_API OfflineRecognizerResult {
std::string lang;
std::string emotion;
std::string event;
// non-empty only for TDT models
std::vector<float> durations;
};
class SHERPA_ONNX_API OfflineStream
... ... @@ -349,8 +352,9 @@ class SHERPA_ONNX_API OfflineRecognizer
OfflineRecognizerResult GetResult(const OfflineStream *s) const;
std::shared_ptr<OfflineRecognizerResult> GetResultPtr(const OfflineStream *s) const;
std::shared_ptr<OfflineRecognizerResult> GetResultPtr(
const OfflineStream *s) const;
void SetConfig(const OfflineRecognizerConfig &config) const;
private:
... ...
... ... @@ -61,7 +61,8 @@ public class OfflineRecognizer {
String lang = (String) arr[3];
String emotion = (String) arr[4];
String event = (String) arr[5];
return new OfflineRecognizerResult(text, tokens, timestamps, lang, emotion, event);
float[] durations = (float[]) arr[6];
return new OfflineRecognizerResult(text, tokens, timestamps, lang, emotion, event, durations);
}
private native void delete(long ptr);
... ...
... ... @@ -9,14 +9,16 @@ public class OfflineRecognizerResult {
private final String lang;
private final String emotion;
private final String event;
private final float[] durations;
public OfflineRecognizerResult(String text, String[] tokens, float[] timestamps, String lang, String emotion, String event) {
public OfflineRecognizerResult(String text, String[] tokens, float[] timestamps, String lang, String emotion, String event, float[] durations) {
this.text = text;
this.tokens = tokens;
this.timestamps = timestamps;
this.lang = lang;
this.emotion = emotion;
this.event = event;
this.durations = durations;
}
public String getText() {
... ... @@ -42,4 +44,8 @@ public class OfflineRecognizerResult {
public String getEvent() {
return event;
}
public float[] getDurations() {
return durations;
}
}
... ...
... ... @@ -509,8 +509,9 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env,
// [3]: lang, jstring
// [4]: emotion, jstring
// [5]: event, jstring
// [6]: durations, array of float
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
6, env->FindClass("java/lang/Object"), nullptr);
7, env->FindClass("java/lang/Object"), nullptr);
jstring text = env->NewStringUTF(result.text.c_str());
env->SetObjectArrayElement(obj_arr, 0, text);
... ... @@ -543,5 +544,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env,
env->SetObjectArrayElement(obj_arr, 5,
env->NewStringUTF(result.event.c_str()));
// [6]: durations, array of float
jfloatArray durations_arr = env->NewFloatArray(result.durations.size());
env->SetFloatArrayRegion(durations_arr, 0, result.durations.size(),
result.durations.data());
env->SetObjectArrayElement(obj_arr, 6, durations_arr);
return obj_arr;
}
... ...
... ... @@ -9,6 +9,9 @@ data class OfflineRecognizerResult(
val lang: String,
val emotion: String,
val event: String,
// valid only for TDT models
val durations: FloatArray,
)
data class OfflineTransducerModelConfig(
... ... @@ -139,13 +142,15 @@ class OfflineRecognizer(
val lang = objArray[3] as String
val emotion = objArray[4] as String
val event = objArray[5] as String
val durations = objArray[6] as FloatArray
return OfflineRecognizerResult(
text = text,
tokens = tokens,
timestamps = timestamps,
lang = lang,
emotion = emotion,
event = event
event = event,
durations = durations,
)
}
... ...
... ... @@ -539,6 +539,11 @@ class SherpaOnnxOfflineRecongitionResult {
return (0..<result.pointee.count).map { p[Int($0)] }
}()
private lazy var _durations: [Float] = {
guard let p = result.pointee.durations else { return [] }
return (0..<result.pointee.count).map { p[Int($0)] }
}()
private lazy var _lang: String = {
guard let cstr = result.pointee.lang else { return "" }
return String(cString: cstr)
... ... @@ -561,6 +566,9 @@ class SherpaOnnxOfflineRecongitionResult {
var count: Int { Int(result.pointee.count) }
var timestamps: [Float] { _timestamps }
// Non-empty for TDT models. Empty for all other non-TDT models
var durations: [Float] { _durations }
// For SenseVoice models, it can be zh, en, ja, yue, ko
// where zh is for Chinese
// en is for English
... ...