Fangjun Kuang
Committed by GitHub

Fix modified beam search for iOS and android (#76)

* Use Int type for sampling rate

* Fix swift

* Fix iOS
Makefile
*.jar
hs_err_pid*.log
... ...
... ... @@ -4,7 +4,7 @@ import android.content.res.AssetManager
fun main() {
var featConfig = FeatureConfig(
sampleRate = 16000.0f,
sampleRate = 16000,
featureDim = 80,
)
... ... @@ -13,7 +13,7 @@ fun main() {
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 4,
numThreads = 1,
debug = false,
)
... ... @@ -24,22 +24,31 @@ fun main() {
featConfig = featConfig,
endpointConfig = endpointConfig,
enableEndpoint = true,
decodingMethod = "greedy_search",
maxActivePaths = 4,
)
var model = SherpaOnnx(
assetManager = AssetManager(),
config = config,
)
var samples = WaveReader.readWave(
assetManager = AssetManager(),
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
)
model.decodeSamples(samples!!)
model.acceptWaveform(samples!!, sampleRate=16000)
while (model.isReady()) {
model.decode()
}
var tail_paddings = FloatArray(8000) // 0.5 seconds
model.decodeSamples(tail_paddings)
model.acceptWaveform(tail_paddings, sampleRate=16000)
model.inputFinished()
while (model.isReady()) {
model.decode()
}
println("results: ${model.text}")
}
... ...
... ... @@ -38,3 +38,4 @@ log.txt
tags
run-decode-file-python.sh
android/SherpaOnnx/app/src/main/assets/
*.ncnn.*
... ...
... ... @@ -121,7 +121,10 @@ class MainActivity : AppCompatActivity() {
val ret = audioRecord?.read(buffer, 0, buffer.size)
if (ret != null && ret > 0) {
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
model.decodeSamples(samples)
model.acceptWaveform(samples, sampleRate=16000)
while (model.isReady()) {
model.decode()
}
runOnUiThread {
val isEndpoint = model.isEndpoint()
val text = model.text
... ... @@ -177,33 +180,17 @@ class MainActivity : AppCompatActivity() {
val type = 0
println("Select model type ${type}")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80),
featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
modelConfig = getModelConfig(type = type)!!,
endpointConfig = getEndpointConfig(),
enableEndpoint = true
enableEndpoint = true,
decodingMethod = "greedy_search",
maxActivePaths = 4,
)
model = SherpaOnnx(
assetManager = application.assets,
config = config,
)
/*
println("reading samples")
val samples = WaveReader.readWave(
assetManager = application.assets,
// filename = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav",
filename = "sherpa-onnx-lstm-zh-2023-02-20/test_wavs/0.wav",
// filename="sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav"
)
println("samples read done!")
model.decodeSamples(samples!!)
val tailPaddings = FloatArray(8000) // 0.5 seconds
model.decodeSamples(tailPaddings)
println("result is: ${model.text}")
model.reset()
*/
}
}
... ...
... ... @@ -24,7 +24,7 @@ data class OnlineTransducerModelConfig(
)
data class FeatureConfig(
var sampleRate: Float = 16000.0f,
var sampleRate: Int = 16000,
var featureDim: Int = 80,
)
... ... @@ -32,7 +32,9 @@ data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineTransducerModelConfig,
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean,
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
var maxActivePaths: Int = 4,
)
class SherpaOnnx(
... ... @@ -49,12 +51,14 @@ class SherpaOnnx(
}
fun decodeSamples(samples: FloatArray) =
decodeSamples(ptr, samples, sampleRate = config.featConfig.sampleRate)
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
acceptWaveform(ptr, samples, sampleRate)
fun inputFinished() = inputFinished(ptr)
fun reset() = reset(ptr)
fun decode() = decode(ptr)
fun isEndpoint(): Boolean = isEndpoint(ptr)
fun isReady(): Boolean = isReady(ptr)
val text: String
get() = getText(ptr)
... ... @@ -66,11 +70,13 @@ class SherpaOnnx(
config: OnlineRecognizerConfig,
): Long
private external fun decodeSamples(ptr: Long, samples: FloatArray, sampleRate: Float)
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun getText(ptr: Long): String
private external fun reset(ptr: Long)
private external fun decode(ptr: Long)
private external fun isEndpoint(ptr: Long): Boolean
private external fun isReady(ptr: Long): Boolean
companion object {
init {
... ... @@ -79,7 +85,7 @@ class SherpaOnnx(
}
}
fun getFeatureConfig(sampleRate: Float, featureDim: Int): FeatureConfig {
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim)
}
... ...
... ... @@ -23,10 +23,10 @@ extension AVAudioPCMBuffer {
class ViewController: UIViewController {
@IBOutlet weak var resultLabel: UILabel!
@IBOutlet weak var recordBtn: UIButton!
var audioEngine: AVAudioEngine? = nil
var recognizer: SherpaOnnxRecognizer! = nil
/// It saves the decoded results so far
var sentences: [String] = [] {
didSet {
... ... @@ -42,7 +42,7 @@ class ViewController: UIViewController {
if sentences.isEmpty {
return "0: \(lastSentence.lowercased())"
}
let start = max(sentences.count - maxSentence, 0)
if lastSentence.isEmpty {
return sentences.enumerated().map { (index, s) in "\(index): \(s.lowercased())" }[start...]
... ... @@ -52,23 +52,23 @@ class ViewController: UIViewController {
.joined(separator: "\n") + "\n\(sentences.count): \(lastSentence.lowercased())"
}
}
func updateLabel() {
DispatchQueue.main.async {
self.resultLabel.text = self.results
}
}
override func viewDidLoad() {
super.viewDidLoad()
// Do any additional setup after loading the view.
resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!"
recordBtn.setTitle("Start", for: .normal)
initRecognizer()
initRecorder()
}
@IBAction func onRecordBtnClick(_ sender: UIButton) {
if recordBtn.currentTitle == "Start" {
startRecorder()
... ... @@ -78,30 +78,32 @@ class ViewController: UIViewController {
recordBtn.setTitle("Start", for: .normal)
}
}
func initRecognizer() {
// Please select one model that is best suitable for you.
//
// You can also modify Model.swift to add new pre-trained models from
// https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
let modelConfig = getBilingualStreamZhEnZipformer20230220()
let featConfig = sherpaOnnxFeatureConfig(
sampleRate: 16000,
featureDim: 80)
var config = sherpaOnnxOnlineRecognizerConfig(
featConfig: featConfig,
modelConfig: modelConfig,
enableEndpoint: true,
rule1MinTrailingSilence: 2.4,
rule2MinTrailingSilence: 0.8,
rule3MinUtteranceLength: 30
rule3MinUtteranceLength: 30,
decodingMethod: "greedy_search",
maxActivePaths: 4
)
recognizer = SherpaOnnxRecognizer(config: &config)
}
func initRecorder() {
print("init recorder")
audioEngine = AVAudioEngine()
... ... @@ -112,9 +114,9 @@ class ViewController: UIViewController {
commonFormat: .pcmFormatFloat32,
sampleRate: 16000, channels: 1,
interleaved: false)!
let converter = AVAudioConverter(from: inputFormat!, to: outputFormat)!
inputNode!.installTap(
onBus: bus,
bufferSize: 1024,
... ... @@ -122,34 +124,34 @@ class ViewController: UIViewController {
) {
(buffer: AVAudioPCMBuffer, when: AVAudioTime) in
var newBufferAvailable = true
let inputCallback: AVAudioConverterInputBlock = {
inNumPackets, outStatus in
if newBufferAvailable {
outStatus.pointee = .haveData
newBufferAvailable = false
return buffer
} else {
outStatus.pointee = .noDataNow
return nil
}
}
let convertedBuffer = AVAudioPCMBuffer(
pcmFormat: outputFormat,
frameCapacity:
AVAudioFrameCount(outputFormat.sampleRate)
* buffer.frameLength
/ AVAudioFrameCount(buffer.format.sampleRate))!
var error: NSError?
let _ = converter.convert(
to: convertedBuffer,
error: &error, withInputFrom: inputCallback)
// TODO(fangjun): Handle status != haveData
let array = convertedBuffer.array()
if !array.isEmpty {
self.recognizer.acceptWaveform(samples: array)
... ... @@ -158,13 +160,13 @@ class ViewController: UIViewController {
}
let isEndpoint = self.recognizer.isEndpoint()
let text = self.recognizer.getResult().text
if !text.isEmpty && self.lastSentence != text {
self.lastSentence = text
self.updateLabel()
print(text)
}
if isEndpoint {
if !text.isEmpty {
let tmp = self.lastSentence
... ... @@ -175,13 +177,13 @@ class ViewController: UIViewController {
}
}
}
}
func startRecorder() {
lastSentence = ""
sentences = []
do {
try self.audioEngine?.start()
} catch let error as NSError {
... ... @@ -189,7 +191,7 @@ class ViewController: UIViewController {
}
print("started")
}
func stopRecorder() {
audioEngine?.stop()
print("stopped")
... ...
... ... @@ -76,7 +76,7 @@ SherpaOnnxOnlineStream *CreateOnlineStream(
void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; }
void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate,
void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
const float *samples, int32_t n) {
stream->impl->AcceptWaveform(sample_rate, samples, n);
}
... ...
... ... @@ -120,7 +120,7 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream);
/// @param samples A pointer to a 1-D array containing audio samples.
/// The range of samples has to be normalized to [-1, 1].
/// @param n Number of elements in the samples array.
void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate,
void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
const float *samples, int32_t n);
/// Return 1 if there are enough number of feature frames for decoding.
... ...
... ... @@ -48,7 +48,7 @@ class FeatureExtractor::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
std::lock_guard<std::mutex> lock(mutex_);
fbank_->AcceptWaveform(sampling_rate, waveform, n);
}
... ... @@ -107,7 +107,7 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
FeatureExtractor::~FeatureExtractor() = default;
void FeatureExtractor::AcceptWaveform(float sampling_rate,
void FeatureExtractor::AcceptWaveform(int32_t sampling_rate,
const float *waveform, int32_t n) {
impl_->AcceptWaveform(sampling_rate, waveform, n);
}
... ...
... ... @@ -14,7 +14,7 @@
namespace sherpa_onnx {
struct FeatureExtractorConfig {
float sampling_rate = 16000;
int32_t sampling_rate = 16000;
int32_t feature_dim = 80;
int32_t max_feature_vectors = -1;
... ... @@ -34,7 +34,7 @@ class FeatureExtractor {
@param waveform Pointer to a 1-D array of size n
@param n Number of entries in waveform
*/
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
/**
* InputFinished() tells the class you won't be providing any
... ...
... ... @@ -112,7 +112,7 @@ for a list of pre-trained models to download.
param.suggestedLatency = info->defaultLowInputLatency;
param.hostApiSpecificStreamInfo = nullptr;
const float sample_rate = 16000;
float sample_rate = 16000;
PaStream *stream;
PaError err =
... ...
... ... @@ -61,7 +61,7 @@ for a list of pre-trained models to download.
sherpa_onnx::OnlineRecognizer recognizer(config);
float expected_sampling_rate = config.feat_config.sampling_rate;
int32_t expected_sampling_rate = config.feat_config.sampling_rate;
bool is_ok = false;
std::vector<float> samples =
... ... @@ -72,7 +72,7 @@ for a list of pre-trained models to download.
return -1;
}
float duration = samples.size() / expected_sampling_rate;
float duration = samples.size() / static_cast<float>(expected_sampling_rate);
fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
fprintf(stderr, "wav duration (s): %.3f\n", duration);
... ...
... ... @@ -40,19 +40,18 @@ class SherpaOnnx {
mgr,
#endif
config),
stream_(recognizer_.CreateStream()),
tail_padding_(16000 * 0.32, 0) {
stream_(recognizer_.CreateStream()) {
}
void DecodeSamples(float sample_rate, const float *samples, int32_t n) const {
void AcceptWaveform(int32_t sample_rate, const float *samples,
int32_t n) const {
stream_->AcceptWaveform(sample_rate, samples, n);
Decode();
}
void InputFinished() const {
stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size());
std::vector<float> tail_padding(16000 * 0.32, 0);
stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size());
stream_->InputFinished();
Decode();
}
const std::string GetText() const {
... ... @@ -62,19 +61,15 @@ class SherpaOnnx {
bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
void Reset() const { return recognizer_.Reset(stream_.get()); }
private:
void Decode() const {
while (recognizer_.IsReady(stream_.get())) {
recognizer_.DecodeStream(stream_.get());
}
}
void Decode() const { recognizer_.DecodeStream(stream_.get()); }
private:
sherpa_onnx::OnlineRecognizer recognizer_;
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
std::vector<float> tail_padding_;
};
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
... ... @@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- decoding ----------
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.decoding_method = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "F");
ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
... ... @@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(model_config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.encoder_filename = p;
env->ReleaseStringUTFChars(s, p);
... ... @@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
#endif
auto config = sherpa_onnx::GetConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnx(
#if __ANDROID_API__ >= 9
mgr,
... ... @@ -221,6 +227,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
return model->IsReady();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
... ... @@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples(
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
model->Decode();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jfloat sample_rate) {
jint sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->DecodeSamples(sample_rate, p, n);
model->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
... ...
... ... @@ -62,11 +62,15 @@ func sherpaOnnxOnlineRecognizerConfig(
enableEndpoint: Bool = false,
rule1MinTrailingSilence: Float = 2.4,
rule2MinTrailingSilence: Float = 1.2,
rule3MinUtteranceLength: Float = 30
rule3MinUtteranceLength: Float = 30,
decodingMethod: String = "greedy_search",
maxActivePaths: Int = 4
) -> SherpaOnnxOnlineRecognizerConfig{
return SherpaOnnxOnlineRecognizerConfig(
feat_config: featConfig,
model_config: modelConfig,
decoding_method: toCPointer(decodingMethod),
max_active_paths: Int32(maxActivePaths),
enable_endpoint: enableEndpoint ? 1 : 0,
rule1_min_trailing_silence: rule1MinTrailingSilence,
rule2_min_trailing_silence: rule2MinTrailingSilence,
... ... @@ -128,12 +132,12 @@ class SherpaOnnxRecognizer {
/// Decode wave samples.
///
/// - Parameters:
/// - samples: Audio samples normalzed to the range [-1, 1]
/// - samples: Audio samples normalized to the range [-1, 1]
/// - sampleRate: Sample rate of the input audio samples. Must match
/// the one expected by the model. It must be 16000 for
/// models from icefall.
func acceptWaveform(samples: [Float], sampleRate: Float = 16000) {
AcceptWaveform(stream, sampleRate, samples, Int32(samples.count))
func acceptWaveform(samples: [Float], sampleRate: Int = 16000) {
AcceptWaveform(stream, Int32(sampleRate), samples, Int32(samples.count))
}
func isReady() -> Bool {
... ...
... ... @@ -32,7 +32,9 @@ func run() {
var config = sherpaOnnxOnlineRecognizerConfig(
featConfig: featConfig,
modelConfig: modelConfig,
enableEndpoint: false
enableEndpoint: false,
decodingMethod: "modified_beam_search",
maxActivePaths: 4
)
... ...