davidliu
Committed by GitHub

Synchronize only one reconnecting job occurs at a time (#111)

* Use lock to ensure only one reconnection attempt at a time

* fix test
... ... @@ -17,6 +17,7 @@ import io.livekit.android.webrtc.isConnected
import io.livekit.android.webrtc.isDisconnected
import io.livekit.android.webrtc.toProtoSessionDescription
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import livekit.LivekitModels
import livekit.LivekitRtc
import org.webrtc.*
... ... @@ -32,7 +33,6 @@ import kotlin.coroutines.suspendCoroutine
/**
* @suppress
*/
@OptIn(ExperimentalCoroutinesApi::class)
@Singleton
class RTCEngine
@Inject
... ... @@ -76,6 +76,7 @@ internal constructor(
}
private var reconnectingJob: Job? = null
private val reconnectingLock = Mutex()
private var fullReconnectOnNext = false
private val pendingTrackResolvers: MutableMap<String, Continuation<LivekitModels.TrackInfo>> =
... ... @@ -330,7 +331,8 @@ internal constructor(
* reconnect Signal and PeerConnections
*/
internal fun reconnect() {
if (reconnectingJob != null) {
val didLock = reconnectingLock.tryLock()
if (!didLock) {
return
}
if (this.isClosed) {
... ... @@ -345,6 +347,7 @@ internal constructor(
val forceFullReconnect = fullReconnectOnNext
fullReconnectOnNext = false
val job = coroutineScope.launch {
connectionState = ConnectionState.RECONNECTING
listener?.onEngineReconnecting()
... ... @@ -372,6 +375,7 @@ internal constructor(
continue
}
} else {
subscriber.prepareForIceRestart()
try {
client.reconnect(url, token)
// no join response for regular reconnects
... ... @@ -385,7 +389,6 @@ internal constructor(
LKLog.v { "ws reconnected, restarting ICE" }
listener?.onSignalConnected(!isFullReconnect)
subscriber.prepareForIceRestart()
// trigger publisher reconnect
// only restart publisher if it's needed
if (hasPublished) {
... ... @@ -435,6 +438,7 @@ internal constructor(
if (reconnectingJob == job) {
reconnectingJob = null
}
reconnectingLock.unlock()
}
}
... ...
... ... @@ -111,11 +111,11 @@ constructor(
val request = Request.Builder()
.url(wsUrlString)
.build()
currentWs = websocketFactory.newWebSocket(request, this)
return suspendCancellableCoroutine {
// Wait for join response through WebSocketListener
joinContinuation = it
currentWs = websocketFactory.newWebSocket(request, this)
}
}
... ...
... ... @@ -61,13 +61,20 @@ abstract class MockE2ETest : BaseTest() {
job.join()
}
suspend fun connectPeerConnection() {
fun connectPeerConnection() {
subscriber = component.rtcEngine().subscriber
simulateMessageFromServer(SignalClientTest.OFFER)
val subPeerConnection = subscriber.peerConnection as MockPeerConnection
subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.CONNECTED)
}
fun disconnectPeerConnection() {
subscriber = component.rtcEngine().subscriber
simulateMessageFromServer(SignalClientTest.OFFER)
val subPeerConnection = subscriber.peerConnection as MockPeerConnection
subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
}
fun createOpenResponse(request: Request): Response {
return Response.Builder()
.request(request)
... ...
... ... @@ -2,9 +2,14 @@ package io.livekit.android.mock
import okhttp3.Request
import okhttp3.WebSocket
import okhttp3.WebSocketListener
import okio.ByteString
import okio.IOException
class MockWebSocket(private val request: Request) : WebSocket {
class MockWebSocket(
private val request: Request,
private val listener: WebSocketListener
) : WebSocket {
var isClosed = false
private set
... ... @@ -15,11 +20,17 @@ class MockWebSocket(private val request: Request) : WebSocket {
override fun cancel() {
isClosed = true
listener.onFailure(this, IOException("cancelled"), null)
}
override fun close(code: Int, reason: String?): Boolean {
val willClose = !isClosed
if (!willClose) {
return false
}
isClosed = true
listener.onClosing(this, code, reason ?: "")
listener.onClosed(this, code, reason ?: "")
return willClose
}
... ... @@ -30,9 +41,10 @@ class MockWebSocket(private val request: Request) : WebSocket {
override fun send(text: String): Boolean = !isClosed
override fun send(bytes: ByteString): Boolean {
if (isClosed) {
return false
}
mutableSentRequests.add(bytes)
return !isClosed
}
}
\ No newline at end of file
... ...
... ... @@ -20,10 +20,13 @@ class MockWebSocketFactory : WebSocket.Factory {
*/
lateinit var listener: WebSocketListener
override fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket {
this.ws = MockWebSocket(request)
this.ws = MockWebSocket(request, listener)
this.listener = listener
this.request = request
onOpen?.invoke(this)
return ws
}
var onOpen: ((MockWebSocketFactory) -> Unit)? = null
}
\ No newline at end of file
... ...
... ... @@ -15,6 +15,7 @@ import io.livekit.android.room.track.Track
import io.livekit.android.util.flow
import io.livekit.android.util.toOkioByteString
import junit.framework.Assert.assertEquals
import junit.framework.Assert.assertTrue
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import org.junit.Assert
... ... @@ -305,10 +306,27 @@ class RoomMockE2ETest : MockE2ETest() {
}
@Test
fun reconnectAfterDisconnect() = runTest {
fun connectAfterDisconnect() = runTest {
connect()
room.disconnect()
connect()
Assert.assertEquals(room.state, Room.State.CONNECTED)
}
@Test
fun reconnectFromPeerConnectionDisconnect() = runTest {
connect()
val eventCollector = EventCollector(room.events, coroutineRule.scope)
wsFactory.onOpen = {
wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
connectPeerConnection()
}
disconnectPeerConnection()
val events = eventCollector.stopCollecting()
assertEquals(2, events.size)
assertTrue(events[0] is RoomEvent.Reconnecting)
assertTrue(events[1] is RoomEvent.Reconnected)
}
}
\ No newline at end of file
... ...