Fangjun Kuang
Committed by GitHub

Release GIL to support multithreading in websocket servers. (#451)

@@ -414,7 +414,7 @@ def get_args(): @@ -414,7 +414,7 @@ def get_args():
414 parser.add_argument( 414 parser.add_argument(
415 "--max-batch-size", 415 "--max-batch-size",
416 type=int, 416 type=int,
417 - default=25, 417 + default=3,
418 help="""Max batch size for computation. Note if there are not enough 418 help="""Max batch size for computation. Note if there are not enough
419 requests in the queue, it will wait for max_wait_ms time. After that, 419 requests in the queue, it will wait for max_wait_ms time. After that,
420 even if there are not enough requests, it still sends the 420 even if there are not enough requests, it still sends the
@@ -459,7 +459,7 @@ def get_args(): @@ -459,7 +459,7 @@ def get_args():
459 parser.add_argument( 459 parser.add_argument(
460 "--max-active-connections", 460 "--max-active-connections",
461 type=int, 461 type=int,
462 - default=500, 462 + default=200,
463 help="""Maximum number of active connections. The server will refuse 463 help="""Maximum number of active connections. The server will refuse
464 to accept new connections once the current number of active connections 464 to accept new connections once the current number of active connections
465 equals to this limit. 465 equals to this limit.
@@ -533,6 +533,7 @@ class NonStreamingServer: @@ -533,6 +533,7 @@ class NonStreamingServer:
533 self.certificate = certificate 533 self.certificate = certificate
534 self.http_server = HttpServer(doc_root) 534 self.http_server = HttpServer(doc_root)
535 535
  536 + self.nn_pool_size = nn_pool_size
536 self.nn_pool = ThreadPoolExecutor( 537 self.nn_pool = ThreadPoolExecutor(
537 max_workers=nn_pool_size, 538 max_workers=nn_pool_size,
538 thread_name_prefix="nn", 539 thread_name_prefix="nn",
@@ -604,7 +605,9 @@ or <a href="/offline_record.html">/offline_record.html</a> @@ -604,7 +605,9 @@ or <a href="/offline_record.html">/offline_record.html</a>
604 async def run(self, port: int): 605 async def run(self, port: int):
605 logging.info("started") 606 logging.info("started")
606 607
607 - task = asyncio.create_task(self.stream_consumer_task()) 608 + tasks = []
  609 + for i in range(self.nn_pool_size):
  610 + tasks.append(asyncio.create_task(self.stream_consumer_task()))
608 611
609 if self.certificate: 612 if self.certificate:
610 logging.info(f"Using certificate: {self.certificate}") 613 logging.info(f"Using certificate: {self.certificate}")
@@ -636,7 +639,7 @@ or <a href="/offline_record.html">/offline_record.html</a> @@ -636,7 +639,7 @@ or <a href="/offline_record.html">/offline_record.html</a>
636 639
637 await asyncio.Future() # run forever 640 await asyncio.Future() # run forever
638 641
639 - await task # not reachable 642 + await asyncio.gather(*tasks) # not reachable
640 643
641 async def recv_audio_samples( 644 async def recv_audio_samples(
642 self, 645 self,
@@ -722,6 +725,7 @@ or <a href="/offline_record.html">/offline_record.html</a> @@ -722,6 +725,7 @@ or <a href="/offline_record.html">/offline_record.html</a>
722 batch.append(item) 725 batch.append(item)
723 except asyncio.QueueEmpty: 726 except asyncio.QueueEmpty:
724 pass 727 pass
  728 +
725 stream_list = [b[0] for b in batch] 729 stream_list = [b[0] for b in batch]
726 future_list = [b[1] for b in batch] 730 future_list = [b[1] for b in batch]
727 731
@@ -296,7 +296,7 @@ def get_args(): @@ -296,7 +296,7 @@ def get_args():
296 parser.add_argument( 296 parser.add_argument(
297 "--max-batch-size", 297 "--max-batch-size",
298 type=int, 298 type=int,
299 - default=50, 299 + default=3,
300 help="""Max batch size for computation. Note if there are not enough 300 help="""Max batch size for computation. Note if there are not enough
301 requests in the queue, it will wait for max_wait_ms time. After that, 301 requests in the queue, it will wait for max_wait_ms time. After that,
302 even if there are not enough requests, it still sends the 302 even if there are not enough requests, it still sends the
@@ -334,7 +334,7 @@ def get_args(): @@ -334,7 +334,7 @@ def get_args():
334 parser.add_argument( 334 parser.add_argument(
335 "--max-active-connections", 335 "--max-active-connections",
336 type=int, 336 type=int,
337 - default=500, 337 + default=200,
338 help="""Maximum number of active connections. The server will refuse 338 help="""Maximum number of active connections. The server will refuse
339 to accept new connections once the current number of active connections 339 to accept new connections once the current number of active connections
340 equals to this limit. 340 equals to this limit.
@@ -478,6 +478,7 @@ class StreamingServer(object): @@ -478,6 +478,7 @@ class StreamingServer(object):
478 self.certificate = certificate 478 self.certificate = certificate
479 self.http_server = HttpServer(doc_root) 479 self.http_server = HttpServer(doc_root)
480 480
  481 + self.nn_pool_size = nn_pool_size
481 self.nn_pool = ThreadPoolExecutor( 482 self.nn_pool = ThreadPoolExecutor(
482 max_workers=nn_pool_size, 483 max_workers=nn_pool_size,
483 thread_name_prefix="nn", 484 thread_name_prefix="nn",
@@ -591,7 +592,9 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a> @@ -591,7 +592,9 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a>
591 return status, header, response 592 return status, header, response
592 593
593 async def run(self, port: int): 594 async def run(self, port: int):
594 - task = asyncio.create_task(self.stream_consumer_task()) 595 + tasks = []
  596 + for i in range(self.nn_pool_size):
  597 + tasks.append(asyncio.create_task(self.stream_consumer_task()))
595 598
596 if self.certificate: 599 if self.certificate:
597 logging.info(f"Using certificate: {self.certificate}") 600 logging.info(f"Using certificate: {self.certificate}")
@@ -629,7 +632,7 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a> @@ -629,7 +632,7 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a>
629 632
630 await asyncio.Future() # run forever 633 await asyncio.Future() # run forever
631 634
632 - await task # not reachable 635 + await asyncio.gather(*tasks) # not reachable
633 636
634 async def handle_connection( 637 async def handle_connection(
635 self, 638 self,
@@ -19,10 +19,12 @@ void PybindCircularBuffer(py::module *m) { @@ -19,10 +19,12 @@ void PybindCircularBuffer(py::module *m) {
19 [](PyClass &self, const std::vector<float> &samples) { 19 [](PyClass &self, const std::vector<float> &samples) {
20 self.Push(samples.data(), samples.size()); 20 self.Push(samples.data(), samples.size());
21 }, 21 },
22 - py::arg("samples"))  
23 - .def("get", &PyClass::Get, py::arg("start_index"), py::arg("n"))  
24 - .def("pop", &PyClass::Pop, py::arg("n"))  
25 - .def("reset", &PyClass::Reset) 22 + py::arg("samples"), py::call_guard<py::gil_scoped_release>())
  23 + .def("get", &PyClass::Get, py::arg("start_index"), py::arg("n"),
  24 + py::call_guard<py::gil_scoped_release>())
  25 + .def("pop", &PyClass::Pop, py::arg("n"),
  26 + py::call_guard<py::gil_scoped_release>())
  27 + .def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
26 .def_property_readonly("size", &PyClass::Size) 28 .def_property_readonly("size", &PyClass::Size)
27 .def_property_readonly("head", &PyClass::Head) 29 .def_property_readonly("head", &PyClass::Head)
28 .def_property_readonly("tail", &PyClass::Tail); 30 .def_property_readonly("tail", &PyClass::Tail);
@@ -41,19 +41,24 @@ void PybindOfflineRecognizer(py::module *m) { @@ -41,19 +41,24 @@ void PybindOfflineRecognizer(py::module *m) {
41 using PyClass = OfflineRecognizer; 41 using PyClass = OfflineRecognizer;
42 py::class_<PyClass>(*m, "OfflineRecognizer") 42 py::class_<PyClass>(*m, "OfflineRecognizer")
43 .def(py::init<const OfflineRecognizerConfig &>(), py::arg("config")) 43 .def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
44 - .def("create_stream",  
45 - [](const PyClass &self) { return self.CreateStream(); }) 44 + .def(
  45 + "create_stream",
  46 + [](const PyClass &self) { return self.CreateStream(); },
  47 + py::call_guard<py::gil_scoped_release>())
46 .def( 48 .def(
47 "create_stream", 49 "create_stream",
48 [](PyClass &self, const std::string &hotwords) { 50 [](PyClass &self, const std::string &hotwords) {
49 return self.CreateStream(hotwords); 51 return self.CreateStream(hotwords);
50 }, 52 },
51 - py::arg("hotwords"))  
52 - .def("decode_stream", &PyClass::DecodeStream)  
53 - .def("decode_streams",  
54 - [](const PyClass &self, std::vector<OfflineStream *> ss) {  
55 - self.DecodeStreams(ss.data(), ss.size());  
56 - }); 53 + py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
  54 + .def("decode_stream", &PyClass::DecodeStream,
  55 + py::call_guard<py::gil_scoped_release>())
  56 + .def(
  57 + "decode_streams",
  58 + [](const PyClass &self, std::vector<OfflineStream *> ss) {
  59 + self.DecodeStreams(ss.data(), ss.size());
  60 + },
  61 + py::call_guard<py::gil_scoped_release>());
57 } 62 }
58 63
59 } // namespace sherpa_onnx 64 } // namespace sherpa_onnx
@@ -50,9 +50,20 @@ void PybindOfflineStream(py::module *m) { @@ -50,9 +50,20 @@ void PybindOfflineStream(py::module *m) {
50 .def( 50 .def(
51 "accept_waveform", 51 "accept_waveform",
52 [](PyClass &self, float sample_rate, py::array_t<float> waveform) { 52 [](PyClass &self, float sample_rate, py::array_t<float> waveform) {
  53 +#if 0
  54 + auto report_gil_status = []() {
  55 + auto is_gil_held = false;
  56 + if (auto tstate = py::detail::get_thread_state_unchecked())
  57 + is_gil_held = (tstate == PyGILState_GetThisThreadState());
  58 +
  59 + return is_gil_held ? "GIL held" : "GIL released";
  60 + };
  61 + std::cout << report_gil_status() << "\n";
  62 +#endif
53 self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); 63 self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
54 }, 64 },
55 - py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) 65 + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage,
  66 + py::call_guard<py::gil_scoped_release>())
56 .def_property_readonly("result", &PyClass::GetResult); 67 .def_property_readonly("result", &PyClass::GetResult);
57 } 68 }
58 69
@@ -45,7 +45,7 @@ void PybindOfflineTts(py::module *m) { @@ -45,7 +45,7 @@ void PybindOfflineTts(py::module *m) {
45 py::class_<PyClass>(*m, "OfflineTts") 45 py::class_<PyClass>(*m, "OfflineTts")
46 .def(py::init<const OfflineTtsConfig &>(), py::arg("config")) 46 .def(py::init<const OfflineTtsConfig &>(), py::arg("config"))
47 .def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0, 47 .def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0,
48 - py::arg("speed") = 1.0); 48 + py::arg("speed") = 1.0, py::call_guard<py::gil_scoped_release>());
49 } 49 }
50 50
51 } // namespace sherpa_onnx 51 } // namespace sherpa_onnx
@@ -54,23 +54,31 @@ void PybindOnlineRecognizer(py::module *m) { @@ -54,23 +54,31 @@ void PybindOnlineRecognizer(py::module *m) {
54 using PyClass = OnlineRecognizer; 54 using PyClass = OnlineRecognizer;
55 py::class_<PyClass>(*m, "OnlineRecognizer") 55 py::class_<PyClass>(*m, "OnlineRecognizer")
56 .def(py::init<const OnlineRecognizerConfig &>(), py::arg("config")) 56 .def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
57 - .def("create_stream",  
58 - [](const PyClass &self) { return self.CreateStream(); }) 57 + .def(
  58 + "create_stream",
  59 + [](const PyClass &self) { return self.CreateStream(); },
  60 + py::call_guard<py::gil_scoped_release>())
59 .def( 61 .def(
60 "create_stream", 62 "create_stream",
61 [](PyClass &self, const std::string &hotwords) { 63 [](PyClass &self, const std::string &hotwords) {
62 return self.CreateStream(hotwords); 64 return self.CreateStream(hotwords);
63 }, 65 },
64 - py::arg("hotwords"))  
65 - .def("is_ready", &PyClass::IsReady)  
66 - .def("decode_stream", &PyClass::DecodeStream)  
67 - .def("decode_streams",  
68 - [](PyClass &self, std::vector<OnlineStream *> ss) {  
69 - self.DecodeStreams(ss.data(), ss.size());  
70 - })  
71 - .def("get_result", &PyClass::GetResult)  
72 - .def("is_endpoint", &PyClass::IsEndpoint)  
73 - .def("reset", &PyClass::Reset); 66 + py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
  67 + .def("is_ready", &PyClass::IsReady,
  68 + py::call_guard<py::gil_scoped_release>())
  69 + .def("decode_stream", &PyClass::DecodeStream,
  70 + py::call_guard<py::gil_scoped_release>())
  71 + .def(
  72 + "decode_streams",
  73 + [](PyClass &self, std::vector<OnlineStream *> ss) {
  74 + self.DecodeStreams(ss.data(), ss.size());
  75 + },
  76 + py::call_guard<py::gil_scoped_release>())
  77 + .def("get_result", &PyClass::GetResult,
  78 + py::call_guard<py::gil_scoped_release>())
  79 + .def("is_endpoint", &PyClass::IsEndpoint,
  80 + py::call_guard<py::gil_scoped_release>())
  81 + .def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>());
74 } 82 }
75 83
76 } // namespace sherpa_onnx 84 } // namespace sherpa_onnx
@@ -28,8 +28,10 @@ void PybindOnlineStream(py::module *m) { @@ -28,8 +28,10 @@ void PybindOnlineStream(py::module *m) {
28 [](PyClass &self, float sample_rate, py::array_t<float> waveform) { 28 [](PyClass &self, float sample_rate, py::array_t<float> waveform) {
29 self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); 29 self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
30 }, 30 },
31 - py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)  
32 - .def("input_finished", &PyClass::InputFinished); 31 + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage,
  32 + py::call_guard<py::gil_scoped_release>())
  33 + .def("input_finished", &PyClass::InputFinished,
  34 + py::call_guard<py::gil_scoped_release>());
33 } 35 }
34 36
35 } // namespace sherpa_onnx 37 } // namespace sherpa_onnx
@@ -13,17 +13,21 @@ namespace sherpa_onnx { @@ -13,17 +13,21 @@ namespace sherpa_onnx {
13 void PybindVadModel(py::module *m) { 13 void PybindVadModel(py::module *m) {
14 using PyClass = VadModel; 14 using PyClass = VadModel;
15 py::class_<PyClass>(*m, "VadModel") 15 py::class_<PyClass>(*m, "VadModel")
16 - .def_static("create", &PyClass::Create, py::arg("config"))  
17 - .def("reset", &PyClass::Reset) 16 + .def_static("create", &PyClass::Create, py::arg("config"),
  17 + py::call_guard<py::gil_scoped_release>())
  18 + .def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
18 .def( 19 .def(
19 "is_speech", 20 "is_speech",
20 [](PyClass &self, const std::vector<float> &samples) -> bool { 21 [](PyClass &self, const std::vector<float> &samples) -> bool {
21 return self.IsSpeech(samples.data(), samples.size()); 22 return self.IsSpeech(samples.data(), samples.size());
22 }, 23 },
23 - py::arg("samples"))  
24 - .def("window_size", &PyClass::WindowSize)  
25 - .def("min_silence_duration_samples", &PyClass::MinSilenceDurationSamples)  
26 - .def("min_speech_duration_samples", &PyClass::MinSpeechDurationSamples); 24 + py::arg("samples"), py::call_guard<py::gil_scoped_release>())
  25 + .def("window_size", &PyClass::WindowSize,
  26 + py::call_guard<py::gil_scoped_release>())
  27 + .def("min_silence_duration_samples", &PyClass::MinSilenceDurationSamples,
  28 + py::call_guard<py::gil_scoped_release>())
  29 + .def("min_speech_duration_samples", &PyClass::MinSpeechDurationSamples,
  30 + py::call_guard<py::gil_scoped_release>());
27 } 31 }
28 32
29 } // namespace sherpa_onnx 33 } // namespace sherpa_onnx
@@ -30,11 +30,12 @@ void PybindVoiceActivityDetector(py::module *m) { @@ -30,11 +30,12 @@ void PybindVoiceActivityDetector(py::module *m) {
30 [](PyClass &self, const std::vector<float> &samples) { 30 [](PyClass &self, const std::vector<float> &samples) {
31 self.AcceptWaveform(samples.data(), samples.size()); 31 self.AcceptWaveform(samples.data(), samples.size());
32 }, 32 },
33 - py::arg("samples"))  
34 - .def("empty", &PyClass::Empty)  
35 - .def("pop", &PyClass::Pop)  
36 - .def("is_speech_detected", &PyClass::IsSpeechDetected)  
37 - .def("reset", &PyClass::Reset) 33 + py::arg("samples"), py::call_guard<py::gil_scoped_release>())
  34 + .def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>())
  35 + .def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>())
  36 + .def("is_speech_detected", &PyClass::IsSpeechDetected,
  37 + py::call_guard<py::gil_scoped_release>())
  38 + .def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
38 .def_property_readonly("front", &PyClass::Front); 39 .def_property_readonly("front", &PyClass::Front);
39 } 40 }
40 41