davidliu
Committed by GitHub

Synchronize only one reconnecting job occurs at a time (#111)

* Use lock to ensure only one reconnection attempt at a time

* fix test
@@ -17,6 +17,7 @@ import io.livekit.android.webrtc.isConnected @@ -17,6 +17,7 @@ import io.livekit.android.webrtc.isConnected
17 import io.livekit.android.webrtc.isDisconnected 17 import io.livekit.android.webrtc.isDisconnected
18 import io.livekit.android.webrtc.toProtoSessionDescription 18 import io.livekit.android.webrtc.toProtoSessionDescription
19 import kotlinx.coroutines.* 19 import kotlinx.coroutines.*
  20 +import kotlinx.coroutines.sync.Mutex
20 import livekit.LivekitModels 21 import livekit.LivekitModels
21 import livekit.LivekitRtc 22 import livekit.LivekitRtc
22 import org.webrtc.* 23 import org.webrtc.*
@@ -32,7 +33,6 @@ import kotlin.coroutines.suspendCoroutine @@ -32,7 +33,6 @@ import kotlin.coroutines.suspendCoroutine
32 /** 33 /**
33 * @suppress 34 * @suppress
34 */ 35 */
35 -@OptIn(ExperimentalCoroutinesApi::class)  
36 @Singleton 36 @Singleton
37 class RTCEngine 37 class RTCEngine
38 @Inject 38 @Inject
@@ -76,6 +76,7 @@ internal constructor( @@ -76,6 +76,7 @@ internal constructor(
76 } 76 }
77 77
78 private var reconnectingJob: Job? = null 78 private var reconnectingJob: Job? = null
  79 + private val reconnectingLock = Mutex()
79 private var fullReconnectOnNext = false 80 private var fullReconnectOnNext = false
80 81
81 private val pendingTrackResolvers: MutableMap<String, Continuation<LivekitModels.TrackInfo>> = 82 private val pendingTrackResolvers: MutableMap<String, Continuation<LivekitModels.TrackInfo>> =
@@ -330,7 +331,8 @@ internal constructor( @@ -330,7 +331,8 @@ internal constructor(
330 * reconnect Signal and PeerConnections 331 * reconnect Signal and PeerConnections
331 */ 332 */
332 internal fun reconnect() { 333 internal fun reconnect() {
333 - if (reconnectingJob != null) { 334 + val didLock = reconnectingLock.tryLock()
  335 + if (!didLock) {
334 return 336 return
335 } 337 }
336 if (this.isClosed) { 338 if (this.isClosed) {
@@ -345,6 +347,7 @@ internal constructor( @@ -345,6 +347,7 @@ internal constructor(
345 val forceFullReconnect = fullReconnectOnNext 347 val forceFullReconnect = fullReconnectOnNext
346 fullReconnectOnNext = false 348 fullReconnectOnNext = false
347 val job = coroutineScope.launch { 349 val job = coroutineScope.launch {
  350 +
348 connectionState = ConnectionState.RECONNECTING 351 connectionState = ConnectionState.RECONNECTING
349 listener?.onEngineReconnecting() 352 listener?.onEngineReconnecting()
350 353
@@ -372,6 +375,7 @@ internal constructor( @@ -372,6 +375,7 @@ internal constructor(
372 continue 375 continue
373 } 376 }
374 } else { 377 } else {
  378 + subscriber.prepareForIceRestart()
375 try { 379 try {
376 client.reconnect(url, token) 380 client.reconnect(url, token)
377 // no join response for regular reconnects 381 // no join response for regular reconnects
@@ -385,7 +389,6 @@ internal constructor( @@ -385,7 +389,6 @@ internal constructor(
385 LKLog.v { "ws reconnected, restarting ICE" } 389 LKLog.v { "ws reconnected, restarting ICE" }
386 listener?.onSignalConnected(!isFullReconnect) 390 listener?.onSignalConnected(!isFullReconnect)
387 391
388 - subscriber.prepareForIceRestart()  
389 // trigger publisher reconnect 392 // trigger publisher reconnect
390 // only restart publisher if it's needed 393 // only restart publisher if it's needed
391 if (hasPublished) { 394 if (hasPublished) {
@@ -435,6 +438,7 @@ internal constructor( @@ -435,6 +438,7 @@ internal constructor(
435 if (reconnectingJob == job) { 438 if (reconnectingJob == job) {
436 reconnectingJob = null 439 reconnectingJob = null
437 } 440 }
  441 + reconnectingLock.unlock()
438 } 442 }
439 } 443 }
440 444
@@ -111,11 +111,11 @@ constructor( @@ -111,11 +111,11 @@ constructor(
111 val request = Request.Builder() 111 val request = Request.Builder()
112 .url(wsUrlString) 112 .url(wsUrlString)
113 .build() 113 .build()
114 - currentWs = websocketFactory.newWebSocket(request, this)  
115 114
116 return suspendCancellableCoroutine { 115 return suspendCancellableCoroutine {
117 // Wait for join response through WebSocketListener 116 // Wait for join response through WebSocketListener
118 joinContinuation = it 117 joinContinuation = it
  118 + currentWs = websocketFactory.newWebSocket(request, this)
119 } 119 }
120 } 120 }
121 121
@@ -61,13 +61,20 @@ abstract class MockE2ETest : BaseTest() { @@ -61,13 +61,20 @@ abstract class MockE2ETest : BaseTest() {
61 job.join() 61 job.join()
62 } 62 }
63 63
64 - suspend fun connectPeerConnection() { 64 + fun connectPeerConnection() {
65 subscriber = component.rtcEngine().subscriber 65 subscriber = component.rtcEngine().subscriber
66 simulateMessageFromServer(SignalClientTest.OFFER) 66 simulateMessageFromServer(SignalClientTest.OFFER)
67 val subPeerConnection = subscriber.peerConnection as MockPeerConnection 67 val subPeerConnection = subscriber.peerConnection as MockPeerConnection
68 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.CONNECTED) 68 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.CONNECTED)
69 } 69 }
70 70
  71 + fun disconnectPeerConnection() {
  72 + subscriber = component.rtcEngine().subscriber
  73 + simulateMessageFromServer(SignalClientTest.OFFER)
  74 + val subPeerConnection = subscriber.peerConnection as MockPeerConnection
  75 + subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
  76 + }
  77 +
71 fun createOpenResponse(request: Request): Response { 78 fun createOpenResponse(request: Request): Response {
72 return Response.Builder() 79 return Response.Builder()
73 .request(request) 80 .request(request)
@@ -2,9 +2,14 @@ package io.livekit.android.mock @@ -2,9 +2,14 @@ package io.livekit.android.mock
2 2
3 import okhttp3.Request 3 import okhttp3.Request
4 import okhttp3.WebSocket 4 import okhttp3.WebSocket
  5 +import okhttp3.WebSocketListener
5 import okio.ByteString 6 import okio.ByteString
  7 +import okio.IOException
6 8
7 -class MockWebSocket(private val request: Request) : WebSocket { 9 +class MockWebSocket(
  10 + private val request: Request,
  11 + private val listener: WebSocketListener
  12 +) : WebSocket {
8 13
9 var isClosed = false 14 var isClosed = false
10 private set 15 private set
@@ -15,11 +20,17 @@ class MockWebSocket(private val request: Request) : WebSocket { @@ -15,11 +20,17 @@ class MockWebSocket(private val request: Request) : WebSocket {
15 20
16 override fun cancel() { 21 override fun cancel() {
17 isClosed = true 22 isClosed = true
  23 + listener.onFailure(this, IOException("cancelled"), null)
18 } 24 }
19 25
20 override fun close(code: Int, reason: String?): Boolean { 26 override fun close(code: Int, reason: String?): Boolean {
21 val willClose = !isClosed 27 val willClose = !isClosed
  28 + if (!willClose) {
  29 + return false
  30 + }
22 isClosed = true 31 isClosed = true
  32 + listener.onClosing(this, code, reason ?: "")
  33 + listener.onClosed(this, code, reason ?: "")
23 return willClose 34 return willClose
24 } 35 }
25 36
@@ -30,9 +41,10 @@ class MockWebSocket(private val request: Request) : WebSocket { @@ -30,9 +41,10 @@ class MockWebSocket(private val request: Request) : WebSocket {
30 override fun send(text: String): Boolean = !isClosed 41 override fun send(text: String): Boolean = !isClosed
31 42
32 override fun send(bytes: ByteString): Boolean { 43 override fun send(bytes: ByteString): Boolean {
  44 + if (isClosed) {
  45 + return false
  46 + }
33 mutableSentRequests.add(bytes) 47 mutableSentRequests.add(bytes)
34 return !isClosed 48 return !isClosed
35 } 49 }
36 -  
37 -  
38 } 50 }
@@ -20,10 +20,13 @@ class MockWebSocketFactory : WebSocket.Factory { @@ -20,10 +20,13 @@ class MockWebSocketFactory : WebSocket.Factory {
20 */ 20 */
21 lateinit var listener: WebSocketListener 21 lateinit var listener: WebSocketListener
22 override fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket { 22 override fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket {
23 - this.ws = MockWebSocket(request)  
24 - 23 + this.ws = MockWebSocket(request, listener)
25 this.listener = listener 24 this.listener = listener
26 this.request = request 25 this.request = request
  26 +
  27 + onOpen?.invoke(this)
27 return ws 28 return ws
28 } 29 }
  30 +
  31 + var onOpen: ((MockWebSocketFactory) -> Unit)? = null
29 } 32 }
@@ -15,6 +15,7 @@ import io.livekit.android.room.track.Track @@ -15,6 +15,7 @@ import io.livekit.android.room.track.Track
15 import io.livekit.android.util.flow 15 import io.livekit.android.util.flow
16 import io.livekit.android.util.toOkioByteString 16 import io.livekit.android.util.toOkioByteString
17 import junit.framework.Assert.assertEquals 17 import junit.framework.Assert.assertEquals
  18 +import junit.framework.Assert.assertTrue
18 import kotlinx.coroutines.ExperimentalCoroutinesApi 19 import kotlinx.coroutines.ExperimentalCoroutinesApi
19 import kotlinx.coroutines.launch 20 import kotlinx.coroutines.launch
20 import org.junit.Assert 21 import org.junit.Assert
@@ -305,10 +306,27 @@ class RoomMockE2ETest : MockE2ETest() { @@ -305,10 +306,27 @@ class RoomMockE2ETest : MockE2ETest() {
305 } 306 }
306 307
307 @Test 308 @Test
308 - fun reconnectAfterDisconnect() = runTest { 309 + fun connectAfterDisconnect() = runTest {
309 connect() 310 connect()
310 room.disconnect() 311 room.disconnect()
311 connect() 312 connect()
312 Assert.assertEquals(room.state, Room.State.CONNECTED) 313 Assert.assertEquals(room.state, Room.State.CONNECTED)
313 } 314 }
  315 +
  316 + @Test
  317 + fun reconnectFromPeerConnectionDisconnect() = runTest {
  318 + connect()
  319 +
  320 + val eventCollector = EventCollector(room.events, coroutineRule.scope)
  321 + wsFactory.onOpen = {
  322 + wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
  323 + connectPeerConnection()
  324 + }
  325 + disconnectPeerConnection()
  326 + val events = eventCollector.stopCollecting()
  327 +
  328 + assertEquals(2, events.size)
  329 + assertTrue(events[0] is RoomEvent.Reconnecting)
  330 + assertTrue(events[1] is RoomEvent.Reconnected)
  331 + }
314 } 332 }