Committed by
GitHub
Queue requests when reconnecting (#59)
* Queue requests when reconnecting * more tests
正在显示
5 个修改的文件
包含
122 行增加
和
11 行删除
| @@ -5,7 +5,9 @@ import io.livekit.android.ConnectOptions | @@ -5,7 +5,9 @@ import io.livekit.android.ConnectOptions | ||
| 5 | import io.livekit.android.dagger.InjectionNames | 5 | import io.livekit.android.dagger.InjectionNames |
| 6 | import io.livekit.android.room.participant.ParticipantTrackPermission | 6 | import io.livekit.android.room.participant.ParticipantTrackPermission |
| 7 | import io.livekit.android.room.track.TrackException | 7 | import io.livekit.android.room.track.TrackException |
| 8 | -import io.livekit.android.room.util.* | 8 | +import io.livekit.android.room.util.MediaConstraintKeys |
| 9 | +import io.livekit.android.room.util.createAnswer | ||
| 10 | +import io.livekit.android.room.util.setLocalDescription | ||
| 9 | import io.livekit.android.util.CloseableCoroutineScope | 11 | import io.livekit.android.util.CloseableCoroutineScope |
| 10 | import io.livekit.android.util.Either | 12 | import io.livekit.android.util.Either |
| 11 | import io.livekit.android.util.LKLog | 13 | import io.livekit.android.util.LKLog |
| @@ -132,7 +134,7 @@ internal constructor( | @@ -132,7 +134,7 @@ internal constructor( | ||
| 132 | if (!this.isSubscriberPrimary) { | 134 | if (!this.isSubscriberPrimary) { |
| 133 | negotiate() | 135 | negotiate() |
| 134 | } | 136 | } |
| 135 | - client.onReady() | 137 | + client.onReadyForResponses() |
| 136 | return joinResponse | 138 | return joinResponse |
| 137 | } | 139 | } |
| 138 | 140 | ||
| @@ -350,7 +352,7 @@ internal constructor( | @@ -350,7 +352,7 @@ internal constructor( | ||
| 350 | try { | 352 | try { |
| 351 | client.reconnect(url, token) | 353 | client.reconnect(url, token) |
| 352 | // no join response for regular reconnects | 354 | // no join response for regular reconnects |
| 353 | - client.onReady() | 355 | + client.onReadyForResponses() |
| 354 | } catch (e: Exception) { | 356 | } catch (e: Exception) { |
| 355 | LKLog.w(e) { "Error during reconnection." } | 357 | LKLog.w(e) { "Error during reconnection." } |
| 356 | // ws reconnect failed, retry. | 358 | // ws reconnect failed, retry. |
| @@ -378,6 +380,7 @@ internal constructor( | @@ -378,6 +380,7 @@ internal constructor( | ||
| 378 | } | 380 | } |
| 379 | 381 | ||
| 380 | if (connectionState == ConnectionState.CONNECTED) { | 382 | if (connectionState == ConnectionState.CONNECTED) { |
| 383 | + client.onPCConnected() | ||
| 381 | listener?.onPostReconnect(isFullReconnect) | 384 | listener?.onPostReconnect(isFullReconnect) |
| 382 | return@launch | 385 | return@launch |
| 383 | } | 386 | } |
| 1 | package io.livekit.android.room | 1 | package io.livekit.android.room |
| 2 | 2 | ||
| 3 | -import android.net.Uri | ||
| 4 | import com.google.protobuf.util.JsonFormat | 3 | import com.google.protobuf.util.JsonFormat |
| 5 | import com.vdurmont.semver4j.Semver | 4 | import com.vdurmont.semver4j.Semver |
| 6 | import io.livekit.android.ConnectOptions | 5 | import io.livekit.android.ConnectOptions |
| 7 | -import io.livekit.android.Version | ||
| 8 | import io.livekit.android.dagger.InjectionNames | 6 | import io.livekit.android.dagger.InjectionNames |
| 9 | import io.livekit.android.room.participant.ParticipantTrackPermission | 7 | import io.livekit.android.room.participant.ParticipantTrackPermission |
| 10 | import io.livekit.android.room.track.Track | 8 | import io.livekit.android.room.track.Track |
| @@ -60,6 +58,10 @@ constructor( | @@ -60,6 +58,10 @@ constructor( | ||
| 60 | private var joinContinuation: CancellableContinuation<Either<LivekitRtc.JoinResponse, Unit>>? = null | 58 | private var joinContinuation: CancellableContinuation<Either<LivekitRtc.JoinResponse, Unit>>? = null |
| 61 | private lateinit var coroutineScope: CloseableCoroutineScope | 59 | private lateinit var coroutineScope: CloseableCoroutineScope |
| 62 | 60 | ||
| 61 | + private val requestFlowJobLock = Object() | ||
| 62 | + private var requestFlowJob: Job? = null | ||
| 63 | + private val requestFlow = MutableSharedFlow<LivekitRtc.SignalRequest>(Int.MAX_VALUE) | ||
| 64 | + | ||
| 63 | private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE) | 65 | private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE) |
| 64 | 66 | ||
| 65 | /** | 67 | /** |
| @@ -150,7 +152,7 @@ constructor( | @@ -150,7 +152,7 @@ constructor( | ||
| 150 | * Should be called after resolving the join message. | 152 | * Should be called after resolving the join message. |
| 151 | */ | 153 | */ |
| 152 | @OptIn(ExperimentalCoroutinesApi::class) | 154 | @OptIn(ExperimentalCoroutinesApi::class) |
| 153 | - fun onReady() { | 155 | + fun onReadyForResponses() { |
| 154 | coroutineScope.launch { | 156 | coroutineScope.launch { |
| 155 | responseFlow.collect { | 157 | responseFlow.collect { |
| 156 | responseFlow.resetReplayCache() | 158 | responseFlow.resetReplayCache() |
| @@ -159,6 +161,31 @@ constructor( | @@ -159,6 +161,31 @@ constructor( | ||
| 159 | } | 161 | } |
| 160 | } | 162 | } |
| 161 | 163 | ||
| 164 | + @OptIn(ExperimentalCoroutinesApi::class) | ||
| 165 | + private fun startRequestQueue() { | ||
| 166 | + if (requestFlowJob != null) { | ||
| 167 | + return | ||
| 168 | + } | ||
| 169 | + synchronized(requestFlowJobLock) { | ||
| 170 | + if (requestFlowJob == null) { | ||
| 171 | + requestFlowJob = coroutineScope.launch { | ||
| 172 | + requestFlow.collect { | ||
| 173 | + requestFlow.resetReplayCache() | ||
| 174 | + sendRequestImpl(it) | ||
| 175 | + } | ||
| 176 | + } | ||
| 177 | + } | ||
| 178 | + } | ||
| 179 | + } | ||
| 180 | + | ||
| 181 | + /** | ||
| 182 | + * On reconnection, SignalClient waits until the peer connection is established to send messages. | ||
| 183 | + * Call this method when it is connected. | ||
| 184 | + */ | ||
| 185 | + fun onPCConnected() { | ||
| 186 | + startRequestQueue() | ||
| 187 | + } | ||
| 188 | + | ||
| 162 | //--------------------------------- WebSocket Listener --------------------------------------// | 189 | //--------------------------------- WebSocket Listener --------------------------------------// |
| 163 | override fun onOpen(webSocket: WebSocket, response: Response) { | 190 | override fun onOpen(webSocket: WebSocket, response: Response) { |
| 164 | if (isReconnecting) { | 191 | if (isReconnecting) { |
| @@ -402,6 +429,16 @@ constructor( | @@ -402,6 +429,16 @@ constructor( | ||
| 402 | } | 429 | } |
| 403 | 430 | ||
| 404 | private fun sendRequest(request: LivekitRtc.SignalRequest) { | 431 | private fun sendRequest(request: LivekitRtc.SignalRequest) { |
| 432 | + val skipQueue = skipQueueTypes.contains(request.messageCase) | ||
| 433 | + | ||
| 434 | + if (skipQueue) { | ||
| 435 | + sendRequestImpl(request) | ||
| 436 | + } else { | ||
| 437 | + requestFlow.tryEmit(request) | ||
| 438 | + } | ||
| 439 | + } | ||
| 440 | + | ||
| 441 | + private fun sendRequestImpl(request: LivekitRtc.SignalRequest) { | ||
| 405 | LKLog.v { "sending request: $request" } | 442 | LKLog.v { "sending request: $request" } |
| 406 | if (!isConnected || currentWs == null) { | 443 | if (!isConnected || currentWs == null) { |
| 407 | LKLog.w { "not connected, could not send request $request" } | 444 | LKLog.w { "not connected, could not send request $request" } |
| @@ -428,6 +465,7 @@ constructor( | @@ -428,6 +465,7 @@ constructor( | ||
| 428 | // Only handle joins if not connected. | 465 | // Only handle joins if not connected. |
| 429 | if (response.hasJoin()) { | 466 | if (response.hasJoin()) { |
| 430 | isConnected = true | 467 | isConnected = true |
| 468 | + startRequestQueue() | ||
| 431 | try { | 469 | try { |
| 432 | serverVersion = Semver(response.join.serverVersion) | 470 | serverVersion = Semver(response.join.serverVersion) |
| 433 | } catch (t: Throwable) { | 471 | } catch (t: Throwable) { |
| @@ -439,10 +477,8 @@ constructor( | @@ -439,10 +477,8 @@ constructor( | ||
| 439 | } | 477 | } |
| 440 | return | 478 | return |
| 441 | } | 479 | } |
| 442 | - coroutineScope.launch { | ||
| 443 | responseFlow.tryEmit(response) | 480 | responseFlow.tryEmit(response) |
| 444 | } | 481 | } |
| 445 | - } | ||
| 446 | 482 | ||
| 447 | private fun handleSignalResponseImpl(response: LivekitRtc.SignalResponse) { | 483 | private fun handleSignalResponseImpl(response: LivekitRtc.SignalResponse) { |
| 448 | when (response.messageCase) { | 484 | when (response.messageCase) { |
| @@ -519,6 +555,7 @@ constructor( | @@ -519,6 +555,7 @@ constructor( | ||
| 519 | fun close(code: Int = 1000, reason: String = "Normal Closure") { | 555 | fun close(code: Int = 1000, reason: String = "Normal Closure") { |
| 520 | isConnected = false | 556 | isConnected = false |
| 521 | isReconnecting = false | 557 | isReconnecting = false |
| 558 | + requestFlowJob = null | ||
| 522 | if (::coroutineScope.isInitialized) { | 559 | if (::coroutineScope.isInitialized) { |
| 523 | coroutineScope.close() | 560 | coroutineScope.close() |
| 524 | } | 561 | } |
| @@ -564,6 +601,14 @@ constructor( | @@ -564,6 +601,14 @@ constructor( | ||
| 564 | const val PROTOCOL_VERSION = 6 | 601 | const val PROTOCOL_VERSION = 6 |
| 565 | const val SDK_TYPE = "android" | 602 | const val SDK_TYPE = "android" |
| 566 | 603 | ||
| 604 | + private val skipQueueTypes = listOf( | ||
| 605 | + LivekitRtc.SignalRequest.MessageCase.SYNC_STATE, | ||
| 606 | + LivekitRtc.SignalRequest.MessageCase.TRICKLE, | ||
| 607 | + LivekitRtc.SignalRequest.MessageCase.OFFER, | ||
| 608 | + LivekitRtc.SignalRequest.MessageCase.ANSWER, | ||
| 609 | + LivekitRtc.SignalRequest.MessageCase.SIMULATE | ||
| 610 | + ) | ||
| 611 | + | ||
| 567 | private fun iceServer(url: String) = | 612 | private fun iceServer(url: String) = |
| 568 | PeerConnection.IceServer.builder(url).createIceServer() | 613 | PeerConnection.IceServer.builder(url).createIceServer() |
| 569 | 614 |
| @@ -8,7 +8,7 @@ class MockWebSocketFactory : WebSocket.Factory { | @@ -8,7 +8,7 @@ class MockWebSocketFactory : WebSocket.Factory { | ||
| 8 | /** | 8 | /** |
| 9 | * The most recently created [WebSocket]. | 9 | * The most recently created [WebSocket]. |
| 10 | */ | 10 | */ |
| 11 | - lateinit var ws: WebSocket | 11 | + lateinit var ws: MockWebSocket |
| 12 | 12 | ||
| 13 | /** | 13 | /** |
| 14 | * The request used to create [ws] | 14 | * The request used to create [ws] |
| @@ -33,7 +33,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { | @@ -33,7 +33,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { | ||
| 33 | rtcEngine.subscriber.peerConnection.remoteDescription.description | 33 | rtcEngine.subscriber.peerConnection.remoteDescription.description |
| 34 | ) | 34 | ) |
| 35 | 35 | ||
| 36 | - val ws = wsFactory.ws as MockWebSocket | 36 | + val ws = wsFactory.ws |
| 37 | val sentRequest = LivekitRtc.SignalRequest.newBuilder() | 37 | val sentRequest = LivekitRtc.SignalRequest.newBuilder() |
| 38 | .mergeFrom(ws.sentRequests[0].toPBByteString()) | 38 | .mergeFrom(ws.sentRequests[0].toPBByteString()) |
| 39 | .build() | 39 | .build() |
| @@ -5,6 +5,7 @@ import io.livekit.android.BaseTest | @@ -5,6 +5,7 @@ import io.livekit.android.BaseTest | ||
| 5 | import io.livekit.android.mock.MockWebSocketFactory | 5 | import io.livekit.android.mock.MockWebSocketFactory |
| 6 | import io.livekit.android.mock.TestData | 6 | import io.livekit.android.mock.TestData |
| 7 | import io.livekit.android.util.toOkioByteString | 7 | import io.livekit.android.util.toOkioByteString |
| 8 | +import io.livekit.android.util.toPBByteString | ||
| 8 | import kotlinx.coroutines.ExperimentalCoroutinesApi | 9 | import kotlinx.coroutines.ExperimentalCoroutinesApi |
| 9 | import kotlinx.coroutines.async | 10 | import kotlinx.coroutines.async |
| 10 | import kotlinx.serialization.json.Json | 11 | import kotlinx.serialization.json.Json |
| @@ -130,7 +131,7 @@ class SignalClientTest : BaseTest() { | @@ -130,7 +131,7 @@ class SignalClientTest : BaseTest() { | ||
| 130 | client.onMessage(wsFactory.ws, OFFER.toOkioByteString()) | 131 | client.onMessage(wsFactory.ws, OFFER.toOkioByteString()) |
| 131 | 132 | ||
| 132 | job.await() | 133 | job.await() |
| 133 | - client.onReady() | 134 | + client.onReadyForResponses() |
| 134 | Mockito.verify(listener) | 135 | Mockito.verify(listener) |
| 135 | .onOffer(argThat { type == SessionDescription.Type.OFFER && description == OFFER.offer.sdp }) | 136 | .onOffer(argThat { type == SessionDescription.Type.OFFER && description == OFFER.offer.sdp }) |
| 136 | } | 137 | } |
| @@ -153,6 +154,68 @@ class SignalClientTest : BaseTest() { | @@ -153,6 +154,68 @@ class SignalClientTest : BaseTest() { | ||
| 153 | .onClose(any(), any()) | 154 | .onClose(any(), any()) |
| 154 | } | 155 | } |
| 155 | 156 | ||
| 157 | + @Test | ||
| 158 | + fun sendRequest() = runTest { | ||
| 159 | + val job = async { client.join(EXAMPLE_URL, "") } | ||
| 160 | + connectWebsocketAndJoin() | ||
| 161 | + job.await() | ||
| 162 | + | ||
| 163 | + client.sendMuteTrack("sid", true) | ||
| 164 | + | ||
| 165 | + val ws = wsFactory.ws | ||
| 166 | + | ||
| 167 | + Assert.assertEquals(1, ws.sentRequests.size) | ||
| 168 | + val sentRequest = LivekitRtc.SignalRequest.newBuilder() | ||
| 169 | + .mergeFrom(ws.sentRequests[0].toPBByteString()) | ||
| 170 | + .build() | ||
| 171 | + | ||
| 172 | + Assert.assertTrue(sentRequest.hasMute()) | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + @Test | ||
| 176 | + fun queuedRequests() = runTest { | ||
| 177 | + client.sendMuteTrack("sid", true) | ||
| 178 | + client.sendMuteTrack("sid", true) | ||
| 179 | + client.sendMuteTrack("sid", true) | ||
| 180 | + | ||
| 181 | + val job = async { client.join(EXAMPLE_URL, "") } | ||
| 182 | + connectWebsocketAndJoin() | ||
| 183 | + job.await() | ||
| 184 | + | ||
| 185 | + val ws = wsFactory.ws | ||
| 186 | + Assert.assertEquals(3, ws.sentRequests.size) | ||
| 187 | + val sentRequest = LivekitRtc.SignalRequest.newBuilder() | ||
| 188 | + .mergeFrom(ws.sentRequests[0].toPBByteString()) | ||
| 189 | + .build() | ||
| 190 | + | ||
| 191 | + Assert.assertTrue(sentRequest.hasMute()) | ||
| 192 | + } | ||
| 193 | + | ||
| 194 | + @Test | ||
| 195 | + fun queuedRequestsWhileReconnecting() = runTest { | ||
| 196 | + client.sendMuteTrack("sid", true) | ||
| 197 | + client.sendMuteTrack("sid", true) | ||
| 198 | + client.sendMuteTrack("sid", true) | ||
| 199 | + | ||
| 200 | + val job = async { client.reconnect(EXAMPLE_URL, "") } | ||
| 201 | + client.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request)) | ||
| 202 | + job.await() | ||
| 203 | + | ||
| 204 | + val ws = wsFactory.ws | ||
| 205 | + | ||
| 206 | + // Wait until peer connection is connected to send requests. | ||
| 207 | + Assert.assertEquals(0, ws.sentRequests.size) | ||
| 208 | + | ||
| 209 | + client.onPCConnected() | ||
| 210 | + | ||
| 211 | + Assert.assertEquals(3, ws.sentRequests.size) | ||
| 212 | + val sentRequest = LivekitRtc.SignalRequest.newBuilder() | ||
| 213 | + .mergeFrom(ws.sentRequests[0].toPBByteString()) | ||
| 214 | + .build() | ||
| 215 | + | ||
| 216 | + Assert.assertTrue(sentRequest.hasMute()) | ||
| 217 | + } | ||
| 218 | + | ||
| 156 | // mock data | 219 | // mock data |
| 157 | companion object { | 220 | companion object { |
| 158 | const val EXAMPLE_URL = "ws://www.example.com" | 221 | const val EXAMPLE_URL = "ws://www.example.com" |
-
请 注册 或 登录 后发表评论