Fangjun Kuang
Committed by GitHub

Support reading multi-channel wave files with 8/16/32-bit encoded samples (#1258)

@@ -38,14 +38,28 @@ done @@ -38,14 +38,28 @@ done
38 38
39 39
40 # test wav reader for non-standard wav files 40 # test wav reader for non-standard wav files
41 -curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/naudio.wav  
42 -curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/junk-padding.wav 41 +waves=(
  42 + naudio.wav
  43 + junk-padding.wav
  44 + int8-1-channel-zh.wav
  45 + int8-2-channel-zh.wav
  46 + int8-4-channel-zh.wav
  47 + int16-1-channel-zh.wav
  48 + int16-2-channel-zh.wav
  49 + int32-1-channel-zh.wav
  50 + int32-2-channel-zh.wav
  51 + float32-1-channel-zh.wav
  52 + float32-2-channel-zh.wav
  53 +)
  54 +for w in ${waves[@]}; do
  55 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/$w
43 56
44 -time $EXE \  
45 - --tokens=$repo/tokens.txt \  
46 - --sense-voice-model=$repo/model.int8.onnx \  
47 - ./naudio.wav \  
48 - ./junk-padding.wav 57 + time $EXE \
  58 + --tokens=$repo/tokens.txt \
  59 + --sense-voice-model=$repo/model.int8.onnx \
  60 + $w
  61 + rm -v $w
  62 +done
49 63
50 rm -rf $repo 64 rm -rf $repo
51 65
@@ -143,35 +143,34 @@ jobs: @@ -143,35 +143,34 @@ jobs:
143 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} 143 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
144 path: install/* 144 path: install/*
145 145
146 - - name: Test online punctuation 146 + - name: Test offline CTC
147 shell: bash 147 shell: bash
148 run: | 148 run: |
149 du -h -d1 . 149 du -h -d1 .
150 export PATH=$PWD/build/bin:$PATH 150 export PATH=$PWD/build/bin:$PATH
151 - export EXE=sherpa-onnx-online-punctuation 151 + export EXE=sherpa-onnx-offline
152 152
153 - .github/scripts/test-online-punctuation.sh 153 + .github/scripts/test-offline-ctc.sh
154 du -h -d1 . 154 du -h -d1 .
155 155
156 - - name: Test offline transducer 156 + - name: Test online punctuation
157 shell: bash 157 shell: bash
158 run: | 158 run: |
159 du -h -d1 . 159 du -h -d1 .
160 export PATH=$PWD/build/bin:$PATH 160 export PATH=$PWD/build/bin:$PATH
161 - export EXE=sherpa-onnx-offline 161 + export EXE=sherpa-onnx-online-punctuation
162 162
163 - .github/scripts/test-offline-transducer.sh 163 + .github/scripts/test-online-punctuation.sh
164 du -h -d1 . 164 du -h -d1 .
165 165
166 -  
167 - - name: Test offline CTC 166 + - name: Test offline transducer
168 shell: bash 167 shell: bash
169 run: | 168 run: |
170 du -h -d1 . 169 du -h -d1 .
171 export PATH=$PWD/build/bin:$PATH 170 export PATH=$PWD/build/bin:$PATH
172 export EXE=sherpa-onnx-offline 171 export EXE=sherpa-onnx-offline
173 172
174 - .github/scripts/test-offline-ctc.sh 173 + .github/scripts/test-offline-transducer.sh
175 du -h -d1 . 174 du -h -d1 .
176 175
177 - name: Test online transducer 176 - name: Test online transducer
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 #define SHERPA_ONNX_CSRC_OFFLINE_TTS_FRONTEND_H_ 6 #define SHERPA_ONNX_CSRC_OFFLINE_TTS_FRONTEND_H_
7 #include <cstdint> 7 #include <cstdint>
8 #include <string> 8 #include <string>
  9 +#include <utility>
9 #include <vector> 10 #include <vector>
10 11
11 #include "sherpa-onnx/csrc/macros.h" 12 #include "sherpa-onnx/csrc/macros.h"
@@ -50,6 +50,16 @@ struct WaveHeader { @@ -50,6 +50,16 @@ struct WaveHeader {
50 }; 50 };
51 static_assert(sizeof(WaveHeader) == 44); 51 static_assert(sizeof(WaveHeader) == 44);
52 52
  53 +/*
  54 +sox int16-1-channel-zh.wav -b 8 int8-1-channel-zh.wav
  55 +
  56 +sox int16-1-channel-zh.wav -c 2 int16-2-channel-zh.wav
  57 +
  58 +we use audacity to generate int32-1-channel-zh.wav and float32-1-channel-zh.wav
  59 +because sox uses WAVE_FORMAT_EXTENSIBLE, which is not easy to support
  60 +in sherpa-onnx.
  61 + */
  62 +
53 // Read a wave file of mono-channel. 63 // Read a wave file of mono-channel.
54 // Return its samples normalized to the range [-1, 1). 64 // Return its samples normalized to the range [-1, 1).
55 std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, 65 std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
@@ -114,9 +124,18 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -114,9 +124,18 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
114 is.read(reinterpret_cast<char *>(&header.audio_format), 124 is.read(reinterpret_cast<char *>(&header.audio_format),
115 sizeof(header.audio_format)); 125 sizeof(header.audio_format));
116 126
117 - if (header.audio_format != 1) { // 1 for PCM 127 + if (header.audio_format != 1 && header.audio_format != 3) {
  128 + // 1 for integer PCM
  129 + // 3 for floating point PCM
  130 + // see https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
  131 + // and https://github.com/microsoft/DirectXTK/wiki/Wave-Formats
118 SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n", 132 SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n",
119 header.audio_format); 133 header.audio_format);
  134 +
  135 + if (header.audio_format == static_cast<int16_t>(0xfffe)) {
  136 + SHERPA_ONNX_LOGE("We don't support WAVE_FORMAT_EXTENSIBLE files.");
  137 + }
  138 +
120 *is_ok = false; 139 *is_ok = false;
121 return {}; 140 return {};
122 } 141 }
@@ -125,10 +144,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -125,10 +144,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
125 sizeof(header.num_channels)); 144 sizeof(header.num_channels));
126 145
127 if (header.num_channels != 1) { // we support only single channel for now 146 if (header.num_channels != 1) { // we support only single channel for now
128 - SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n",  
129 - header.num_channels);  
130 - *is_ok = false;  
131 - return {}; 147 + SHERPA_ONNX_LOGE(
  148 + "Warning: %d channels are found. We only use the first channel.\n",
  149 + header.num_channels);
132 } 150 }
133 151
134 is.read(reinterpret_cast<char *>(&header.sample_rate), 152 is.read(reinterpret_cast<char *>(&header.sample_rate),
@@ -161,8 +179,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -161,8 +179,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
161 return {}; 179 return {};
162 } 180 }
163 181
164 - if (header.bits_per_sample != 16) { // we support only 16 bits per sample  
165 - SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n", 182 + if (header.bits_per_sample != 8 && header.bits_per_sample != 16 &&
  183 + header.bits_per_sample != 32) {
  184 + SHERPA_ONNX_LOGE("Expected bits_per_sample 8, 16 or 32. Given: %d\n",
166 header.bits_per_sample); 185 header.bits_per_sample);
167 *is_ok = false; 186 *is_ok = false;
168 return {}; 187 return {};
@@ -199,21 +218,95 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -199,21 +218,95 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
199 218
200 *sampling_rate = header.sample_rate; 219 *sampling_rate = header.sample_rate;
201 220
202 - // header.subchunk2_size contains the number of bytes in the data.  
203 - // As we assume each sample contains two bytes, so it is divided by 2 here  
204 - std::vector<int16_t> samples(header.subchunk2_size / 2); 221 + std::vector<float> ans;
205 222
206 - is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);  
207 - if (!is) { 223 + if (header.bits_per_sample == 16 && header.audio_format == 1) {
  224 + // header.subchunk2_size contains the number of bytes in the data.
  225 + // As we assume each sample contains two bytes, so it is divided by 2 here
  226 + std::vector<int16_t> samples(header.subchunk2_size / 2);
  227 + SHERPA_ONNX_LOGE("%d samples, bytes: %d", (int)samples.size(),
  228 + header.subchunk2_size);
  229 +
  230 + is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
  231 + if (!is) {
  232 + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
  233 + *is_ok = false;
  234 + return {};
  235 + }
  236 +
  237 + ans.resize(samples.size() / header.num_channels);
  238 +
  239 + // samples are interleaved
  240 + for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
  241 + ans[i] = samples[i * header.num_channels] / 32768.;
  242 + }
  243 + } else if (header.bits_per_sample == 8 && header.audio_format == 1) {
  244 + // number of samples == number of bytes for 8-bit encoded samples
  245 + //
  246 + // For 8-bit encoded samples, they are unsigned!
  247 + std::vector<uint8_t> samples(header.subchunk2_size);
  248 +
  249 + is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
  250 + if (!is) {
  251 + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
  252 + *is_ok = false;
  253 + return {};
  254 + }
  255 +
  256 + ans.resize(samples.size() / header.num_channels);
  257 + for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
  258 + // Note(fangjun): We want to normalize each sample into the range [-1, 1]
  259 + // Since each original sample is in the range [0, 256], dividing
  260 + // them by 128 converts them to the range [0, 2];
  261 + // so after subtracting 1, we get the range [-1, 1]
  262 + //
  263 + ans[i] = samples[i * header.num_channels] / 128. - 1;
  264 + }
  265 + } else if (header.bits_per_sample == 32 && header.audio_format == 1) {
  266 + // 32 here is for int32
  267 + //
  268 + // header.subchunk2_size contains the number of bytes in the data.
  269 + // As we assume each sample contains 4 bytes, so it is divided by 4 here
  270 + std::vector<int32_t> samples(header.subchunk2_size / 4);
  271 +
  272 + is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
  273 + if (!is) {
  274 + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
  275 + *is_ok = false;
  276 + return {};
  277 + }
  278 +
  279 + ans.resize(samples.size() / header.num_channels);
  280 + for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
  281 + ans[i] = static_cast<float>(samples[i * header.num_channels]) / (1 << 31);
  282 + }
  283 + } else if (header.bits_per_sample == 32 && header.audio_format == 3) {
  284 + // 32 here is for float32
  285 + //
  286 + // header.subchunk2_size contains the number of bytes in the data.
  287 + // As we assume each sample contains 4 bytes, so it is divided by 4 here
  288 + std::vector<float> samples(header.subchunk2_size / 4);
  289 +
  290 + is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
  291 + if (!is) {
  292 + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
  293 + *is_ok = false;
  294 + return {};
  295 + }
  296 +
  297 + ans.resize(samples.size() / header.num_channels);
  298 + for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
  299 + ans[i] = samples[i * header.num_channels];
  300 + }
  301 + } else {
  302 + SHERPA_ONNX_LOGE(
  303 + "Unsupported %d bits per sample and audio format: %d. Supported values "
  304 + "are: 8, 16, 32.",
  305 + header.bits_per_sample, header.audio_format);
208 *is_ok = false; 306 *is_ok = false;
209 return {}; 307 return {};
210 } 308 }
211 309
212 - std::vector<float> ans(samples.size());  
213 - for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {  
214 - ans[i] = samples[i] / 32768.;  
215 - }  
216 -  
217 *is_ok = true; 310 *is_ok = true;
218 return ans; 311 return ans;
219 } 312 }
@@ -264,13 +264,9 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromFile(JNIEnv *env, @@ -264,13 +264,9 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromFile(JNIEnv *env,
264 return (jlong)model; 264 return (jlong)model;
265 } 265 }
266 266
267 -  
268 SHERPA_ONNX_EXTERN_C 267 SHERPA_ONNX_EXTERN_C
269 -JNIEXPORT void JNICALL  
270 -Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_setConfig(JNIEnv *env,  
271 - jobject /*obj*/,  
272 - jlong ptr,  
273 - jobject _config) { 268 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_setConfig(
  269 + JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) {
274 auto config = sherpa_onnx::GetOfflineConfig(env, _config); 270 auto config = sherpa_onnx::GetOfflineConfig(env, _config);
275 SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); 271 SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
276 272
@@ -350,9 +346,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env, @@ -350,9 +346,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env,
350 // [3]: lang, jstring 346 // [3]: lang, jstring
351 // [4]: emotion, jstring 347 // [4]: emotion, jstring
352 // [5]: event, jstring 348 // [5]: event, jstring
353 - env->SetObjectArrayElement(obj_arr, 3, env->NewStringUTF(result.lang.c_str()));  
354 - env->SetObjectArrayElement(obj_arr, 4, env->NewStringUTF(result.emotion.c_str()));  
355 - env->SetObjectArrayElement(obj_arr, 5, env->NewStringUTF(result.event.c_str())); 349 + env->SetObjectArrayElement(obj_arr, 3,
  350 + env->NewStringUTF(result.lang.c_str()));
  351 + env->SetObjectArrayElement(obj_arr, 4,
  352 + env->NewStringUTF(result.emotion.c_str()));
  353 + env->SetObjectArrayElement(obj_arr, 5,
  354 + env->NewStringUTF(result.event.c_str()));
356 355
357 return obj_arr; 356 return obj_arr;
358 } 357 }