davidliu
Committed by GitHub

Fix local participant not republishing tracks upon full reconnect (#139)

* fix reconnect not republishing tracks

* apply delay fix to other reconnect tests

* fix tests
... ... @@ -43,6 +43,9 @@ class LiveKit {
@JvmStatic
var enableWebRTCLogging: Boolean = false
/**
* Create a Room object.
*/
fun create(
appContext: Context,
options: RoomOptions = RoomOptions(),
... ...
... ... @@ -75,6 +75,7 @@ internal constructor(
}
}
internal var reconnectType: ReconnectType = ReconnectType.DEFAULT
private var reconnectingJob: Job? = null
private val reconnectingLock = Mutex()
private var fullReconnectOnNext = false
... ... @@ -353,7 +354,7 @@ internal constructor(
val reconnectStartTime = SystemClock.elapsedRealtime()
for (retries in 0 until MAX_RECONNECT_RETRIES) {
var startDelay = retries.toLong() * retries * 500
var startDelay = 100 + retries.toLong() * retries * 500
if (startDelay > 5000) {
startDelay = 5000
}
... ... @@ -361,10 +362,15 @@ internal constructor(
LKLog.i { "Reconnecting to signal, attempt ${retries + 1}" }
delay(startDelay)
// full reconnect after first try.
val isFullReconnect = retries != 0 || forceFullReconnect
val isFullReconnect = when (reconnectType) {
// full reconnect after first try.
ReconnectType.DEFAULT -> retries != 0 || forceFullReconnect
ReconnectType.FORCE_SOFT_RECONNECT -> false
ReconnectType.FORCE_FULL_RECONNECT -> true
}
if (isFullReconnect) {
LKLog.v { "Attempting full reconnect." }
try {
closeResources()
listener?.onFullReconnecting()
... ... @@ -375,6 +381,7 @@ internal constructor(
continue
}
} else {
LKLog.v { "Attempting soft reconnect." }
subscriber.prepareForIceRestart()
try {
client.reconnect(url, token)
... ... @@ -783,4 +790,13 @@ internal constructor(
client.sendSyncState(syncState)
}
}
/**
* @suppress
*/
enum class ReconnectType {
DEFAULT,
FORCE_SOFT_RECONNECT,
FORCE_FULL_RECONNECT;
}
\ No newline at end of file
... ...
... ... @@ -7,6 +7,7 @@ import android.net.ConnectivityManager
import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkRequest
import androidx.annotation.VisibleForTesting
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
... ... @@ -872,6 +873,16 @@ constructor(
eventBus.postEvent(event)
}
}
// Debug options
/**
* @suppress
*/
@VisibleForTesting
fun setReconnectionType(reconnectType: ReconnectType) {
engine.reconnectType = reconnectType
}
}
/**
... ...
... ... @@ -43,6 +43,7 @@ internal constructor(
var videoTrackCaptureDefaults: LocalVideoTrackOptions by defaultsManager::videoTrackCaptureDefaults
var videoTrackPublishDefaults: VideoTrackPublishDefaults by defaultsManager::videoTrackPublishDefaults
var republishes = emptyList<LocalTrackPublication>()
private val localTrackPublications
get() = tracks.values
.mapNotNull { it as? LocalTrackPublication }
... ... @@ -522,7 +523,8 @@ internal constructor(
}
fun prepareForFullReconnect() {
val pubs = localTrackPublications // creates a copy, so is safe from the following removal.
val pubs = localTrackPublications.toList() // creates a copy, so is safe from the following removal.
republishes = pubs
tracks = tracks.toMutableMap().apply { clear() }
for (publication in pubs) {
... ... @@ -532,9 +534,9 @@ internal constructor(
}
suspend fun republishTracks() {
val republishes = localTrackPublications
for (pub in republishes) {
val publish = republishes.toList()
republishes = emptyList()
for (pub in publish) {
val track = pub.track ?: continue
unpublishTrack(track, false)
// Cannot publish muted tracks.
... ...
package io.livekit.android
import io.livekit.android.coroutines.TestCoroutineRule
import io.livekit.android.util.LoggingRule
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.runTest
... ... @@ -10,8 +11,8 @@ import org.mockito.junit.MockitoJUnit
@OptIn(ExperimentalCoroutinesApi::class)
abstract class BaseTest {
// Uncomment to enable logging in tests.
//@get:Rule
//var loggingRule = LoggingRule()
@get:Rule
var loggingRule = LoggingRule()
@get:Rule
var mockitoRule = MockitoJUnit.rule()
... ...
... ... @@ -70,7 +70,6 @@ abstract class MockE2ETest : BaseTest() {
fun disconnectPeerConnection() {
subscriber = component.rtcEngine().subscriber
simulateMessageFromServer(SignalClientTest.OFFER)
val subPeerConnection = subscriber.peerConnection as MockPeerConnection
subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
}
... ...
package io.livekit.android.assert
import org.junit.Assert
fun assertIsClass(expectedClass: Class<*>, actual: Any) {
val klazz = actual::class.java
Assert.assertEquals(expectedClass, klazz)
}
fun assertIsClassList(expectedClasses: List<Class<*>>, actual: List<*>) {
val klazzes = actual.map {
if (it == null) {
Nothing::class.java
} else {
it::class.java
}
}
Assert.assertEquals(expectedClasses, klazzes)
}
\ No newline at end of file
... ...
... ... @@ -51,7 +51,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
connect()
val oldWs = wsFactory.ws
wsFactory.listener.onFailure(oldWs, Exception(), null)
testScheduler.advanceTimeBy(1000)
val newWs = wsFactory.ws
Assert.assertNotEquals(oldWs, newWs)
}
... ... @@ -63,6 +63,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection
subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
testScheduler.advanceTimeBy(1000)
val newWs = wsFactory.ws
Assert.assertNotEquals(oldWs, newWs)
... ... @@ -75,6 +76,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
val pubPeerConnection = rtcEngine.publisher.peerConnection as MockPeerConnection
pubPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
testScheduler.advanceTimeBy(1000)
val newWs = wsFactory.ws
Assert.assertNotEquals(oldWs, newWs)
... ... @@ -88,6 +90,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
wsFactory.listener.onMessage(wsFactory.ws, SignalClientTest.REFRESH_TOKEN.toOkioByteString())
wsFactory.listener.onFailure(wsFactory.ws, Exception(), null)
testScheduler.advanceUntilIdle()
val newToken = wsFactory.request.url.queryParameter(SignalClient.CONNECT_QUERY_TOKEN)
Assert.assertNotEquals(oldToken, newToken)
Assert.assertEquals(SignalClientTest.REFRESH_TOKEN.refreshToken, newToken)
... ...
... ... @@ -18,7 +18,6 @@ import io.livekit.android.room.track.Track
import io.livekit.android.util.flow
import io.livekit.android.util.toOkioByteString
import junit.framework.Assert.assertEquals
import junit.framework.Assert.assertTrue
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import org.junit.Assert
... ... @@ -328,38 +327,4 @@ class RoomMockE2ETest : MockE2ETest() {
connect()
Assert.assertEquals(room.state, Room.State.CONNECTED)
}
@Test
fun reconnectFromPeerConnectionDisconnect() = runTest {
connect()
val eventCollector = EventCollector(room.events, coroutineRule.scope)
wsFactory.onOpen = {
wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
connectPeerConnection()
}
disconnectPeerConnection()
val events = eventCollector.stopCollecting()
assertEquals(2, events.size)
assertTrue(events[0] is RoomEvent.Reconnecting)
assertTrue(events[1] is RoomEvent.Reconnected)
}
@Test
fun reconnectFromWebSocketFailure() = runTest {
connect()
val eventCollector = EventCollector(room.events, coroutineRule.scope)
wsFactory.onOpen = {
wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
connectPeerConnection()
}
wsFactory.ws.cancel()
val events = eventCollector.stopCollecting()
assertEquals(2, events.size)
assertTrue(events[0] is RoomEvent.Reconnecting)
assertTrue(events[1] is RoomEvent.Reconnected)
}
}
\ No newline at end of file
... ...
package io.livekit.android.room
import io.livekit.android.MockE2ETest
import io.livekit.android.assert.assertIsClassList
import io.livekit.android.events.EventCollector
import io.livekit.android.events.FlowCollector
import io.livekit.android.events.RoomEvent
import io.livekit.android.mock.MockAudioStreamTrack
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.util.flow
import io.livekit.android.util.toPBByteString
import junit.framework.Assert.assertEquals
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import livekit.LivekitRtc
import org.junit.Assert
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
@ExperimentalCoroutinesApi
@RunWith(RobolectricTestRunner::class)
class RoomReconnectionMockE2ETest : MockE2ETest() {
private fun prepareForReconnect(softReconnect: Boolean = false) {
wsFactory.onOpen = {
wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
if (!softReconnect) {
simulateMessageFromServer(SignalClientTest.JOIN)
}
}
}
@Test
fun reconnectFromPeerConnectionDisconnect() = runTest {
connect()
val eventCollector = EventCollector(room.events, coroutineRule.scope)
val stateCollector = FlowCollector(room::state.flow, coroutineRule.scope)
prepareForReconnect()
disconnectPeerConnection()
// Wait so that the reconnect job properly starts first.
testScheduler.advanceTimeBy(1000)
connectPeerConnection()
testScheduler.advanceUntilIdle()
val events = eventCollector.stopCollecting()
val states = stateCollector.stopCollecting()
assertIsClassList(
listOf(
RoomEvent.Reconnecting::class.java,
RoomEvent.Reconnected::class.java,
),
events
)
assertEquals(
listOf(
Room.State.CONNECTED,
Room.State.RECONNECTING,
Room.State.CONNECTED,
),
states
)
}
@Test
fun reconnectFromWebSocketFailure() = runTest {
connect()
val eventCollector = EventCollector(room.events, coroutineRule.scope)
val stateCollector = FlowCollector(room::state.flow, coroutineRule.scope)
prepareForReconnect()
wsFactory.ws.cancel()
// Wait so that the reconnect job properly starts first.
testScheduler.advanceTimeBy(1000)
connectPeerConnection()
testScheduler.advanceUntilIdle()
val events = eventCollector.stopCollecting()
val states = stateCollector.stopCollecting()
assertIsClassList(
listOf(
RoomEvent.Reconnecting::class.java,
RoomEvent.Reconnected::class.java,
),
events
)
assertEquals(
listOf(
Room.State.CONNECTED,
Room.State.RECONNECTING,
Room.State.CONNECTED,
),
states
)
}
@Test
fun softReconnectSendsSyncState() = runTest {
room.setReconnectionType(ReconnectType.FORCE_SOFT_RECONNECT)
connect()
prepareForReconnect()
disconnectPeerConnection()
// Wait so that the reconnect job properly starts first.
testScheduler.advanceTimeBy(1000)
connectPeerConnection()
testScheduler.advanceUntilIdle()
val sentRequests = wsFactory.ws.sentRequests
val sentSyncState = sentRequests.any { requestString ->
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(requestString.toPBByteString())
.build()
return@any sentRequest.hasSyncState()
}
Assert.assertTrue(sentSyncState)
}
@Test
fun fullReconnectRepublishesTracks() = runTest {
room.setReconnectionType(ReconnectType.FORCE_FULL_RECONNECT)
connect()
// publish track
val publishJob = launch {
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
)
)
}
simulateMessageFromServer(SignalClientTest.LOCAL_TRACK_PUBLISHED)
publishJob.join()
prepareForReconnect()
disconnectPeerConnection()
// Wait so that the reconnect job properly starts first.
testScheduler.advanceTimeBy(1000)
connectPeerConnection()
testScheduler.advanceUntilIdle()
val sentRequests = wsFactory.ws.sentRequests
val sentAddTrack = sentRequests.any { requestString ->
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(requestString.toPBByteString())
.build()
return@any sentRequest.hasAddTrack()
}
println(sentRequests)
Assert.assertTrue(sentAddTrack)
}
}
\ No newline at end of file
... ...