davidliu
Committed by GitHub

Queue requests when reconnecting (#59)

* Queue requests when reconnecting

* more tests
@@ -5,7 +5,9 @@ import io.livekit.android.ConnectOptions @@ -5,7 +5,9 @@ import io.livekit.android.ConnectOptions
5 import io.livekit.android.dagger.InjectionNames 5 import io.livekit.android.dagger.InjectionNames
6 import io.livekit.android.room.participant.ParticipantTrackPermission 6 import io.livekit.android.room.participant.ParticipantTrackPermission
7 import io.livekit.android.room.track.TrackException 7 import io.livekit.android.room.track.TrackException
8 -import io.livekit.android.room.util.* 8 +import io.livekit.android.room.util.MediaConstraintKeys
  9 +import io.livekit.android.room.util.createAnswer
  10 +import io.livekit.android.room.util.setLocalDescription
9 import io.livekit.android.util.CloseableCoroutineScope 11 import io.livekit.android.util.CloseableCoroutineScope
10 import io.livekit.android.util.Either 12 import io.livekit.android.util.Either
11 import io.livekit.android.util.LKLog 13 import io.livekit.android.util.LKLog
@@ -132,7 +134,7 @@ internal constructor( @@ -132,7 +134,7 @@ internal constructor(
132 if (!this.isSubscriberPrimary) { 134 if (!this.isSubscriberPrimary) {
133 negotiate() 135 negotiate()
134 } 136 }
135 - client.onReady() 137 + client.onReadyForResponses()
136 return joinResponse 138 return joinResponse
137 } 139 }
138 140
@@ -350,7 +352,7 @@ internal constructor( @@ -350,7 +352,7 @@ internal constructor(
350 try { 352 try {
351 client.reconnect(url, token) 353 client.reconnect(url, token)
352 // no join response for regular reconnects 354 // no join response for regular reconnects
353 - client.onReady() 355 + client.onReadyForResponses()
354 } catch (e: Exception) { 356 } catch (e: Exception) {
355 LKLog.w(e) { "Error during reconnection." } 357 LKLog.w(e) { "Error during reconnection." }
356 // ws reconnect failed, retry. 358 // ws reconnect failed, retry.
@@ -378,6 +380,7 @@ internal constructor( @@ -378,6 +380,7 @@ internal constructor(
378 } 380 }
379 381
380 if (connectionState == ConnectionState.CONNECTED) { 382 if (connectionState == ConnectionState.CONNECTED) {
  383 + client.onPCConnected()
381 listener?.onPostReconnect(isFullReconnect) 384 listener?.onPostReconnect(isFullReconnect)
382 return@launch 385 return@launch
383 } 386 }
1 package io.livekit.android.room 1 package io.livekit.android.room
2 2
3 -import android.net.Uri  
4 import com.google.protobuf.util.JsonFormat 3 import com.google.protobuf.util.JsonFormat
5 import com.vdurmont.semver4j.Semver 4 import com.vdurmont.semver4j.Semver
6 import io.livekit.android.ConnectOptions 5 import io.livekit.android.ConnectOptions
7 -import io.livekit.android.Version  
8 import io.livekit.android.dagger.InjectionNames 6 import io.livekit.android.dagger.InjectionNames
9 import io.livekit.android.room.participant.ParticipantTrackPermission 7 import io.livekit.android.room.participant.ParticipantTrackPermission
10 import io.livekit.android.room.track.Track 8 import io.livekit.android.room.track.Track
@@ -60,6 +58,10 @@ constructor( @@ -60,6 +58,10 @@ constructor(
60 private var joinContinuation: CancellableContinuation<Either<LivekitRtc.JoinResponse, Unit>>? = null 58 private var joinContinuation: CancellableContinuation<Either<LivekitRtc.JoinResponse, Unit>>? = null
61 private lateinit var coroutineScope: CloseableCoroutineScope 59 private lateinit var coroutineScope: CloseableCoroutineScope
62 60
  61 + private val requestFlowJobLock = Object()
  62 + private var requestFlowJob: Job? = null
  63 + private val requestFlow = MutableSharedFlow<LivekitRtc.SignalRequest>(Int.MAX_VALUE)
  64 +
63 private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE) 65 private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE)
64 66
65 /** 67 /**
@@ -150,7 +152,7 @@ constructor( @@ -150,7 +152,7 @@ constructor(
150 * Should be called after resolving the join message. 152 * Should be called after resolving the join message.
151 */ 153 */
152 @OptIn(ExperimentalCoroutinesApi::class) 154 @OptIn(ExperimentalCoroutinesApi::class)
153 - fun onReady() { 155 + fun onReadyForResponses() {
154 coroutineScope.launch { 156 coroutineScope.launch {
155 responseFlow.collect { 157 responseFlow.collect {
156 responseFlow.resetReplayCache() 158 responseFlow.resetReplayCache()
@@ -159,6 +161,31 @@ constructor( @@ -159,6 +161,31 @@ constructor(
159 } 161 }
160 } 162 }
161 163
  164 + @OptIn(ExperimentalCoroutinesApi::class)
  165 + private fun startRequestQueue() {
  166 + if (requestFlowJob != null) {
  167 + return
  168 + }
  169 + synchronized(requestFlowJobLock) {
  170 + if (requestFlowJob == null) {
  171 + requestFlowJob = coroutineScope.launch {
  172 + requestFlow.collect {
  173 + requestFlow.resetReplayCache()
  174 + sendRequestImpl(it)
  175 + }
  176 + }
  177 + }
  178 + }
  179 + }
  180 +
  181 + /**
  182 + * On reconnection, SignalClient waits until the peer connection is established to send messages.
  183 + * Call this method when it is connected.
  184 + */
  185 + fun onPCConnected() {
  186 + startRequestQueue()
  187 + }
  188 +
162 //--------------------------------- WebSocket Listener --------------------------------------// 189 //--------------------------------- WebSocket Listener --------------------------------------//
163 override fun onOpen(webSocket: WebSocket, response: Response) { 190 override fun onOpen(webSocket: WebSocket, response: Response) {
164 if (isReconnecting) { 191 if (isReconnecting) {
@@ -402,6 +429,16 @@ constructor( @@ -402,6 +429,16 @@ constructor(
402 } 429 }
403 430
404 private fun sendRequest(request: LivekitRtc.SignalRequest) { 431 private fun sendRequest(request: LivekitRtc.SignalRequest) {
  432 + val skipQueue = skipQueueTypes.contains(request.messageCase)
  433 +
  434 + if (skipQueue) {
  435 + sendRequestImpl(request)
  436 + } else {
  437 + requestFlow.tryEmit(request)
  438 + }
  439 + }
  440 +
  441 + private fun sendRequestImpl(request: LivekitRtc.SignalRequest) {
405 LKLog.v { "sending request: $request" } 442 LKLog.v { "sending request: $request" }
406 if (!isConnected || currentWs == null) { 443 if (!isConnected || currentWs == null) {
407 LKLog.w { "not connected, could not send request $request" } 444 LKLog.w { "not connected, could not send request $request" }
@@ -428,6 +465,7 @@ constructor( @@ -428,6 +465,7 @@ constructor(
428 // Only handle joins if not connected. 465 // Only handle joins if not connected.
429 if (response.hasJoin()) { 466 if (response.hasJoin()) {
430 isConnected = true 467 isConnected = true
  468 + startRequestQueue()
431 try { 469 try {
432 serverVersion = Semver(response.join.serverVersion) 470 serverVersion = Semver(response.join.serverVersion)
433 } catch (t: Throwable) { 471 } catch (t: Throwable) {
@@ -439,9 +477,7 @@ constructor( @@ -439,9 +477,7 @@ constructor(
439 } 477 }
440 return 478 return
441 } 479 }
442 - coroutineScope.launch {  
443 - responseFlow.tryEmit(response)  
444 - } 480 + responseFlow.tryEmit(response)
445 } 481 }
446 482
447 private fun handleSignalResponseImpl(response: LivekitRtc.SignalResponse) { 483 private fun handleSignalResponseImpl(response: LivekitRtc.SignalResponse) {
@@ -519,6 +555,7 @@ constructor( @@ -519,6 +555,7 @@ constructor(
519 fun close(code: Int = 1000, reason: String = "Normal Closure") { 555 fun close(code: Int = 1000, reason: String = "Normal Closure") {
520 isConnected = false 556 isConnected = false
521 isReconnecting = false 557 isReconnecting = false
  558 + requestFlowJob = null
522 if (::coroutineScope.isInitialized) { 559 if (::coroutineScope.isInitialized) {
523 coroutineScope.close() 560 coroutineScope.close()
524 } 561 }
@@ -564,6 +601,14 @@ constructor( @@ -564,6 +601,14 @@ constructor(
564 const val PROTOCOL_VERSION = 6 601 const val PROTOCOL_VERSION = 6
565 const val SDK_TYPE = "android" 602 const val SDK_TYPE = "android"
566 603
  604 + private val skipQueueTypes = listOf(
  605 + LivekitRtc.SignalRequest.MessageCase.SYNC_STATE,
  606 + LivekitRtc.SignalRequest.MessageCase.TRICKLE,
  607 + LivekitRtc.SignalRequest.MessageCase.OFFER,
  608 + LivekitRtc.SignalRequest.MessageCase.ANSWER,
  609 + LivekitRtc.SignalRequest.MessageCase.SIMULATE
  610 + )
  611 +
567 private fun iceServer(url: String) = 612 private fun iceServer(url: String) =
568 PeerConnection.IceServer.builder(url).createIceServer() 613 PeerConnection.IceServer.builder(url).createIceServer()
569 614
@@ -8,7 +8,7 @@ class MockWebSocketFactory : WebSocket.Factory { @@ -8,7 +8,7 @@ class MockWebSocketFactory : WebSocket.Factory {
8 /** 8 /**
9 * The most recently created [WebSocket]. 9 * The most recently created [WebSocket].
10 */ 10 */
11 - lateinit var ws: WebSocket 11 + lateinit var ws: MockWebSocket
12 12
13 /** 13 /**
14 * The request used to create [ws] 14 * The request used to create [ws]
@@ -33,7 +33,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -33,7 +33,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
33 rtcEngine.subscriber.peerConnection.remoteDescription.description 33 rtcEngine.subscriber.peerConnection.remoteDescription.description
34 ) 34 )
35 35
36 - val ws = wsFactory.ws as MockWebSocket 36 + val ws = wsFactory.ws
37 val sentRequest = LivekitRtc.SignalRequest.newBuilder() 37 val sentRequest = LivekitRtc.SignalRequest.newBuilder()
38 .mergeFrom(ws.sentRequests[0].toPBByteString()) 38 .mergeFrom(ws.sentRequests[0].toPBByteString())
39 .build() 39 .build()
@@ -5,6 +5,7 @@ import io.livekit.android.BaseTest @@ -5,6 +5,7 @@ import io.livekit.android.BaseTest
5 import io.livekit.android.mock.MockWebSocketFactory 5 import io.livekit.android.mock.MockWebSocketFactory
6 import io.livekit.android.mock.TestData 6 import io.livekit.android.mock.TestData
7 import io.livekit.android.util.toOkioByteString 7 import io.livekit.android.util.toOkioByteString
  8 +import io.livekit.android.util.toPBByteString
8 import kotlinx.coroutines.ExperimentalCoroutinesApi 9 import kotlinx.coroutines.ExperimentalCoroutinesApi
9 import kotlinx.coroutines.async 10 import kotlinx.coroutines.async
10 import kotlinx.serialization.json.Json 11 import kotlinx.serialization.json.Json
@@ -130,7 +131,7 @@ class SignalClientTest : BaseTest() { @@ -130,7 +131,7 @@ class SignalClientTest : BaseTest() {
130 client.onMessage(wsFactory.ws, OFFER.toOkioByteString()) 131 client.onMessage(wsFactory.ws, OFFER.toOkioByteString())
131 132
132 job.await() 133 job.await()
133 - client.onReady() 134 + client.onReadyForResponses()
134 Mockito.verify(listener) 135 Mockito.verify(listener)
135 .onOffer(argThat { type == SessionDescription.Type.OFFER && description == OFFER.offer.sdp }) 136 .onOffer(argThat { type == SessionDescription.Type.OFFER && description == OFFER.offer.sdp })
136 } 137 }
@@ -153,6 +154,68 @@ class SignalClientTest : BaseTest() { @@ -153,6 +154,68 @@ class SignalClientTest : BaseTest() {
153 .onClose(any(), any()) 154 .onClose(any(), any())
154 } 155 }
155 156
  157 + @Test
  158 + fun sendRequest() = runTest {
  159 + val job = async { client.join(EXAMPLE_URL, "") }
  160 + connectWebsocketAndJoin()
  161 + job.await()
  162 +
  163 + client.sendMuteTrack("sid", true)
  164 +
  165 + val ws = wsFactory.ws
  166 +
  167 + Assert.assertEquals(1, ws.sentRequests.size)
  168 + val sentRequest = LivekitRtc.SignalRequest.newBuilder()
  169 + .mergeFrom(ws.sentRequests[0].toPBByteString())
  170 + .build()
  171 +
  172 + Assert.assertTrue(sentRequest.hasMute())
  173 + }
  174 +
  175 + @Test
  176 + fun queuedRequests() = runTest {
  177 + client.sendMuteTrack("sid", true)
  178 + client.sendMuteTrack("sid", true)
  179 + client.sendMuteTrack("sid", true)
  180 +
  181 + val job = async { client.join(EXAMPLE_URL, "") }
  182 + connectWebsocketAndJoin()
  183 + job.await()
  184 +
  185 + val ws = wsFactory.ws
  186 + Assert.assertEquals(3, ws.sentRequests.size)
  187 + val sentRequest = LivekitRtc.SignalRequest.newBuilder()
  188 + .mergeFrom(ws.sentRequests[0].toPBByteString())
  189 + .build()
  190 +
  191 + Assert.assertTrue(sentRequest.hasMute())
  192 + }
  193 +
  194 + @Test
  195 + fun queuedRequestsWhileReconnecting() = runTest {
  196 + client.sendMuteTrack("sid", true)
  197 + client.sendMuteTrack("sid", true)
  198 + client.sendMuteTrack("sid", true)
  199 +
  200 + val job = async { client.reconnect(EXAMPLE_URL, "") }
  201 + client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
  202 + job.await()
  203 +
  204 + val ws = wsFactory.ws
  205 +
  206 + // Wait until peer connection is connected to send requests.
  207 + Assert.assertEquals(0, ws.sentRequests.size)
  208 +
  209 + client.onPCConnected()
  210 +
  211 + Assert.assertEquals(3, ws.sentRequests.size)
  212 + val sentRequest = LivekitRtc.SignalRequest.newBuilder()
  213 + .mergeFrom(ws.sentRequests[0].toPBByteString())
  214 + .build()
  215 +
  216 + Assert.assertTrue(sentRequest.hasMute())
  217 + }
  218 +
156 // mock data 219 // mock data
157 companion object { 220 companion object {
158 const val EXAMPLE_URL = "ws://www.example.com" 221 const val EXAMPLE_URL = "ws://www.example.com"