Fangjun Kuang
Committed by GitHub

Fix reading non-standard wav files. (#1199)

@@ -36,6 +36,17 @@ for m in model.onnx model.int8.onnx; do @@ -36,6 +36,17 @@ for m in model.onnx model.int8.onnx; do
36 done 36 done
37 done 37 done
38 38
  39 +
  40 +# test wav reader for non-standard wav files
  41 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/naudio.wav
  42 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/junk-padding.wav
  43 +
  44 +time $EXE \
  45 + --tokens=$repo/tokens.txt \
  46 + --sense-voice-model=$repo/model.int8.onnx \
  47 + ./naudio.wav \
  48 + ./junk-padding.wav
  49 +
39 rm -rf $repo 50 rm -rf $repo
40 51
41 if true; then 52 if true; then
@@ -18,58 +18,6 @@ namespace { @@ -18,58 +18,6 @@ namespace {
18 // Note: We assume little endian here 18 // Note: We assume little endian here
19 // TODO(fangjun): Support big endian 19 // TODO(fangjun): Support big endian
20 struct WaveHeader { 20 struct WaveHeader {
21 - bool Validate() const {  
22 - // F F I R  
23 - if (chunk_id != 0x46464952) {  
24 - SHERPA_ONNX_LOGE("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id);  
25 - return false;  
26 - }  
27 - // E V A W  
28 - if (format != 0x45564157) {  
29 - SHERPA_ONNX_LOGE("Expected format WAVE. Given: 0x%08x\n", format);  
30 - return false;  
31 - }  
32 -  
33 - if (subchunk1_id != 0x20746d66) {  
34 - SHERPA_ONNX_LOGE("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n",  
35 - subchunk1_id);  
36 - return false;  
37 - }  
38 -  
39 - // NAudio uses 18  
40 - // See https://github.com/naudio/NAudio/issues/1132  
41 - if (subchunk1_size != 16 && subchunk1_size != 18) { // 16 for PCM  
42 - SHERPA_ONNX_LOGE("Expected subchunk1_size 16. Given: %d\n",  
43 - subchunk1_size);  
44 - return false;  
45 - }  
46 -  
47 - if (audio_format != 1) { // 1 for PCM  
48 - SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n", audio_format);  
49 - return false;  
50 - }  
51 -  
52 - if (num_channels != 1) { // we support only single channel for now  
53 - SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n", num_channels);  
54 - return false;  
55 - }  
56 - if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {  
57 - return false;  
58 - }  
59 -  
60 - if (block_align != (num_channels * bits_per_sample / 8)) {  
61 - return false;  
62 - }  
63 -  
64 - if (bits_per_sample != 16) { // we support only 16 bits per sample  
65 - SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n",  
66 - bits_per_sample);  
67 - return false;  
68 - }  
69 -  
70 - return true;  
71 - }  
72 -  
73 // See 21 // See
74 // https://en.wikipedia.org/wiki/WAV#Metadata 22 // https://en.wikipedia.org/wiki/WAV#Metadata
75 // and 23 // and
@@ -107,13 +55,115 @@ static_assert(sizeof(WaveHeader) == 44); @@ -107,13 +55,115 @@ static_assert(sizeof(WaveHeader) == 44);
107 std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, 55 std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
108 bool *is_ok) { 56 bool *is_ok) {
109 WaveHeader header{}; 57 WaveHeader header{};
110 - is.read(reinterpret_cast<char *>(&header), sizeof(header));  
111 - if (!is) { 58 + is.read(reinterpret_cast<char *>(&header.chunk_id), sizeof(header.chunk_id));
  59 +
  60 + // F F I R
  61 + if (header.chunk_id != 0x46464952) {
  62 + SHERPA_ONNX_LOGE("Expected chunk_id RIFF. Given: 0x%08x\n",
  63 + header.chunk_id);
  64 + *is_ok = false;
  65 + return {};
  66 + }
  67 +
  68 + is.read(reinterpret_cast<char *>(&header.chunk_size),
  69 + sizeof(header.chunk_size));
  70 +
  71 + is.read(reinterpret_cast<char *>(&header.format), sizeof(header.format));
  72 +
  73 + // E V A W
  74 + if (header.format != 0x45564157) {
  75 + SHERPA_ONNX_LOGE("Expected format WAVE. Given: 0x%08x\n", header.format);
  76 + *is_ok = false;
  77 + return {};
  78 + }
  79 +
  80 + is.read(reinterpret_cast<char *>(&header.subchunk1_id),
  81 + sizeof(header.subchunk1_id));
  82 +
  83 + is.read(reinterpret_cast<char *>(&header.subchunk1_size),
  84 + sizeof(header.subchunk1_size));
  85 +
  86 + if (header.subchunk1_id == 0x4b4e554a) {
  87 + // skip junk padding
  88 + is.seekg(header.subchunk1_size, std::istream::cur);
  89 +
  90 + is.read(reinterpret_cast<char *>(&header.subchunk1_id),
  91 + sizeof(header.subchunk1_id));
  92 +
  93 + is.read(reinterpret_cast<char *>(&header.subchunk1_size),
  94 + sizeof(header.subchunk1_size));
  95 + }
  96 +
  97 + if (header.subchunk1_id != 0x20746d66) {
  98 + SHERPA_ONNX_LOGE("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n",
  99 + header.subchunk1_id);
112 *is_ok = false; 100 *is_ok = false;
113 return {}; 101 return {};
114 } 102 }
115 103
116 - if (!header.Validate()) { 104 + // NAudio uses 18
  105 + // See https://github.com/naudio/NAudio/issues/1132
  106 + if (header.subchunk1_size != 16 &&
  107 + header.subchunk1_size != 18) { // 16 for PCM
  108 + SHERPA_ONNX_LOGE("Expected subchunk1_size 16. Given: %d\n",
  109 + header.subchunk1_size);
  110 + *is_ok = false;
  111 + return {};
  112 + }
  113 +
  114 + is.read(reinterpret_cast<char *>(&header.audio_format),
  115 + sizeof(header.audio_format));
  116 +
  117 + if (header.audio_format != 1) { // 1 for PCM
  118 + SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n",
  119 + header.audio_format);
  120 + *is_ok = false;
  121 + return {};
  122 + }
  123 +
  124 + is.read(reinterpret_cast<char *>(&header.num_channels),
  125 + sizeof(header.num_channels));
  126 +
  127 + if (header.num_channels != 1) { // we support only single channel for now
  128 + SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n",
  129 + header.num_channels);
  130 + *is_ok = false;
  131 + return {};
  132 + }
  133 +
  134 + is.read(reinterpret_cast<char *>(&header.sample_rate),
  135 + sizeof(header.sample_rate));
  136 +
  137 + is.read(reinterpret_cast<char *>(&header.byte_rate),
  138 + sizeof(header.byte_rate));
  139 +
  140 + is.read(reinterpret_cast<char *>(&header.block_align),
  141 + sizeof(header.block_align));
  142 +
  143 + is.read(reinterpret_cast<char *>(&header.bits_per_sample),
  144 + sizeof(header.bits_per_sample));
  145 +
  146 + if (header.byte_rate !=
  147 + (header.sample_rate * header.num_channels * header.bits_per_sample / 8)) {
  148 + SHERPA_ONNX_LOGE("Incorrect byte rate: %d. Expected: %d", header.byte_rate,
  149 + (header.sample_rate * header.num_channels *
  150 + header.bits_per_sample / 8));
  151 + *is_ok = false;
  152 + return {};
  153 + }
  154 +
  155 + if (header.block_align !=
  156 + (header.num_channels * header.bits_per_sample / 8)) {
  157 + SHERPA_ONNX_LOGE("Incorrect block align: %d. Expected: %d\n",
  158 + header.block_align,
  159 + (header.num_channels * header.bits_per_sample / 8));
  160 + *is_ok = false;
  161 + return {};
  162 + }
  163 +
  164 + if (header.bits_per_sample != 16) { // we support only 16 bits per sample
  165 + SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n",
  166 + header.bits_per_sample);
117 *is_ok = false; 167 *is_ok = false;
118 return {}; 168 return {};
119 } 169 }
@@ -123,8 +173,6 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -123,8 +173,6 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
123 // See 173 // See
124 // https://github.com/naudio/NAudio/blob/master/NAudio.Core/Wave/WaveFormats/WaveFormat.cs#L223 174 // https://github.com/naudio/NAudio/blob/master/NAudio.Core/Wave/WaveFormats/WaveFormat.cs#L223
125 175
126 - is.seekg(36, std::istream::beg);  
127 -  
128 int16_t extra_size = -1; 176 int16_t extra_size = -1;
129 is.read(reinterpret_cast<char *>(&extra_size), sizeof(int16_t)); 177 is.read(reinterpret_cast<char *>(&extra_size), sizeof(int16_t));
130 if (extra_size != 0) { 178 if (extra_size != 0) {
@@ -135,13 +183,14 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -135,13 +183,14 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
135 *is_ok = false; 183 *is_ok = false;
136 return {}; 184 return {};
137 } 185 }
138 -  
139 - is.read(reinterpret_cast<char *>(&header.subchunk2_id),  
140 - sizeof(header.subchunk2_id));  
141 - is.read(reinterpret_cast<char *>(&header.subchunk2_size),  
142 - sizeof(header.subchunk2_size));  
143 } 186 }
144 187
  188 + is.read(reinterpret_cast<char *>(&header.subchunk2_id),
  189 + sizeof(header.subchunk2_id));
  190 +
  191 + is.read(reinterpret_cast<char *>(&header.subchunk2_size),
  192 + sizeof(header.subchunk2_size));
  193 +
145 header.SeekToDataChunk(is); 194 header.SeekToDataChunk(is);
146 if (!is) { 195 if (!is) {
147 *is_ok = false; 196 *is_ok = false;