speaker-embedding-extractor.cc
4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// sherpa-onnx/jni/speaker-embedding-extractor.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig(
JNIEnv *env, jobject config) {
SpeakerEmbeddingExtractorConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str());
auto extractor = new sherpa_onnx::SpeakerEmbeddingExtractor(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)extractor;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto extractor = new sherpa_onnx::SpeakerEmbeddingExtractor(config);
return (jlong)extractor;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv * /*env*/,
jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
std::unique_ptr<sherpa_onnx::OnlineStream> s =
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr)
->CreateStream();
// The user is responsible to free the returned pointer.
//
// See Java_com_k2fsa_sherpa_onnx_OnlineStream_delete() from
// ./online-stream.cc
sherpa_onnx::OnlineStream *p = s.release();
return (jlong)p;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv * /*env*/,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
return extractor->IsReady(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jfloatArray JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
std::vector<float> embedding = extractor->Compute(stream);
jfloatArray embedding_arr = env->NewFloatArray(embedding.size());
env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(),
embedding.data());
return embedding_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_dim(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
return extractor->Dim();
}