davidliu
Committed by GitHub

Fix local participant not republishing tracks upon full reconnect (#139)

* fix reconnect not republishing tracks

* apply delay fix to other reconnect tests

* fix tests
@@ -43,6 +43,9 @@ class LiveKit { @@ -43,6 +43,9 @@ class LiveKit {
43 @JvmStatic 43 @JvmStatic
44 var enableWebRTCLogging: Boolean = false 44 var enableWebRTCLogging: Boolean = false
45 45
  46 + /**
  47 + * Create a Room object.
  48 + */
46 fun create( 49 fun create(
47 appContext: Context, 50 appContext: Context,
48 options: RoomOptions = RoomOptions(), 51 options: RoomOptions = RoomOptions(),
@@ -75,6 +75,7 @@ internal constructor( @@ -75,6 +75,7 @@ internal constructor(
75 } 75 }
76 } 76 }
77 77
  78 + internal var reconnectType: ReconnectType = ReconnectType.DEFAULT
78 private var reconnectingJob: Job? = null 79 private var reconnectingJob: Job? = null
79 private val reconnectingLock = Mutex() 80 private val reconnectingLock = Mutex()
80 private var fullReconnectOnNext = false 81 private var fullReconnectOnNext = false
@@ -353,7 +354,7 @@ internal constructor( @@ -353,7 +354,7 @@ internal constructor(
353 354
354 val reconnectStartTime = SystemClock.elapsedRealtime() 355 val reconnectStartTime = SystemClock.elapsedRealtime()
355 for (retries in 0 until MAX_RECONNECT_RETRIES) { 356 for (retries in 0 until MAX_RECONNECT_RETRIES) {
356 - var startDelay = retries.toLong() * retries * 500 357 + var startDelay = 100 + retries.toLong() * retries * 500
357 if (startDelay > 5000) { 358 if (startDelay > 5000) {
358 startDelay = 5000 359 startDelay = 5000
359 } 360 }
@@ -361,10 +362,15 @@ internal constructor( @@ -361,10 +362,15 @@ internal constructor(
361 LKLog.i { "Reconnecting to signal, attempt ${retries + 1}" } 362 LKLog.i { "Reconnecting to signal, attempt ${retries + 1}" }
362 delay(startDelay) 363 delay(startDelay)
363 364
364 - // full reconnect after first try.  
365 - val isFullReconnect = retries != 0 || forceFullReconnect 365 + val isFullReconnect = when (reconnectType) {
  366 + // full reconnect after first try.
  367 + ReconnectType.DEFAULT -> retries != 0 || forceFullReconnect
  368 + ReconnectType.FORCE_SOFT_RECONNECT -> false
  369 + ReconnectType.FORCE_FULL_RECONNECT -> true
  370 + }
366 371
367 if (isFullReconnect) { 372 if (isFullReconnect) {
  373 + LKLog.v { "Attempting full reconnect." }
368 try { 374 try {
369 closeResources() 375 closeResources()
370 listener?.onFullReconnecting() 376 listener?.onFullReconnecting()
@@ -375,6 +381,7 @@ internal constructor( @@ -375,6 +381,7 @@ internal constructor(
375 continue 381 continue
376 } 382 }
377 } else { 383 } else {
  384 + LKLog.v { "Attempting soft reconnect." }
378 subscriber.prepareForIceRestart() 385 subscriber.prepareForIceRestart()
379 try { 386 try {
380 client.reconnect(url, token) 387 client.reconnect(url, token)
@@ -783,4 +790,13 @@ internal constructor( @@ -783,4 +790,13 @@ internal constructor(
783 790
784 client.sendSyncState(syncState) 791 client.sendSyncState(syncState)
785 } 792 }
  793 +}
  794 +
  795 +/**
  796 + * @suppress
  797 + */
  798 +enum class ReconnectType {
  799 + DEFAULT,
  800 + FORCE_SOFT_RECONNECT,
  801 + FORCE_FULL_RECONNECT;
786 } 802 }
@@ -7,6 +7,7 @@ import android.net.ConnectivityManager @@ -7,6 +7,7 @@ import android.net.ConnectivityManager
7 import android.net.Network 7 import android.net.Network
8 import android.net.NetworkCapabilities 8 import android.net.NetworkCapabilities
9 import android.net.NetworkRequest 9 import android.net.NetworkRequest
  10 +import androidx.annotation.VisibleForTesting
10 import dagger.assisted.Assisted 11 import dagger.assisted.Assisted
11 import dagger.assisted.AssistedFactory 12 import dagger.assisted.AssistedFactory
12 import dagger.assisted.AssistedInject 13 import dagger.assisted.AssistedInject
@@ -872,6 +873,16 @@ constructor( @@ -872,6 +873,16 @@ constructor(
872 eventBus.postEvent(event) 873 eventBus.postEvent(event)
873 } 874 }
874 } 875 }
  876 +
  877 + // Debug options
  878 +
  879 + /**
  880 + * @suppress
  881 + */
  882 + @VisibleForTesting
  883 + fun setReconnectionType(reconnectType: ReconnectType) {
  884 + engine.reconnectType = reconnectType
  885 + }
875 } 886 }
876 887
877 /** 888 /**
@@ -43,6 +43,7 @@ internal constructor( @@ -43,6 +43,7 @@ internal constructor(
43 var videoTrackCaptureDefaults: LocalVideoTrackOptions by defaultsManager::videoTrackCaptureDefaults 43 var videoTrackCaptureDefaults: LocalVideoTrackOptions by defaultsManager::videoTrackCaptureDefaults
44 var videoTrackPublishDefaults: VideoTrackPublishDefaults by defaultsManager::videoTrackPublishDefaults 44 var videoTrackPublishDefaults: VideoTrackPublishDefaults by defaultsManager::videoTrackPublishDefaults
45 45
  46 + var republishes = emptyList<LocalTrackPublication>()
46 private val localTrackPublications 47 private val localTrackPublications
47 get() = tracks.values 48 get() = tracks.values
48 .mapNotNull { it as? LocalTrackPublication } 49 .mapNotNull { it as? LocalTrackPublication }
@@ -522,7 +523,8 @@ internal constructor( @@ -522,7 +523,8 @@ internal constructor(
522 } 523 }
523 524
524 fun prepareForFullReconnect() { 525 fun prepareForFullReconnect() {
525 - val pubs = localTrackPublications // creates a copy, so is safe from the following removal. 526 + val pubs = localTrackPublications.toList() // creates a copy, so is safe from the following removal.
  527 + republishes = pubs
526 tracks = tracks.toMutableMap().apply { clear() } 528 tracks = tracks.toMutableMap().apply { clear() }
527 529
528 for (publication in pubs) { 530 for (publication in pubs) {
@@ -532,9 +534,9 @@ internal constructor( @@ -532,9 +534,9 @@ internal constructor(
532 } 534 }
533 535
534 suspend fun republishTracks() { 536 suspend fun republishTracks() {
535 - val republishes = localTrackPublications  
536 -  
537 - for (pub in republishes) { 537 + val publish = republishes.toList()
  538 + republishes = emptyList()
  539 + for (pub in publish) {
538 val track = pub.track ?: continue 540 val track = pub.track ?: continue
539 unpublishTrack(track, false) 541 unpublishTrack(track, false)
540 // Cannot publish muted tracks. 542 // Cannot publish muted tracks.
1 package io.livekit.android 1 package io.livekit.android
2 2
3 import io.livekit.android.coroutines.TestCoroutineRule 3 import io.livekit.android.coroutines.TestCoroutineRule
  4 +import io.livekit.android.util.LoggingRule
4 import kotlinx.coroutines.ExperimentalCoroutinesApi 5 import kotlinx.coroutines.ExperimentalCoroutinesApi
5 import kotlinx.coroutines.test.TestScope 6 import kotlinx.coroutines.test.TestScope
6 import kotlinx.coroutines.test.runTest 7 import kotlinx.coroutines.test.runTest
@@ -10,8 +11,8 @@ import org.mockito.junit.MockitoJUnit @@ -10,8 +11,8 @@ import org.mockito.junit.MockitoJUnit
10 @OptIn(ExperimentalCoroutinesApi::class) 11 @OptIn(ExperimentalCoroutinesApi::class)
11 abstract class BaseTest { 12 abstract class BaseTest {
12 // Uncomment to enable logging in tests. 13 // Uncomment to enable logging in tests.
13 - //@get:Rule  
14 - //var loggingRule = LoggingRule() 14 + @get:Rule
  15 + var loggingRule = LoggingRule()
15 16
16 @get:Rule 17 @get:Rule
17 var mockitoRule = MockitoJUnit.rule() 18 var mockitoRule = MockitoJUnit.rule()
@@ -70,7 +70,6 @@ abstract class MockE2ETest : BaseTest() { @@ -70,7 +70,6 @@ abstract class MockE2ETest : BaseTest() {
70 70
71 fun disconnectPeerConnection() { 71 fun disconnectPeerConnection() {
72 subscriber = component.rtcEngine().subscriber 72 subscriber = component.rtcEngine().subscriber
73 - simulateMessageFromServer(SignalClientTest.OFFER)  
74 val subPeerConnection = subscriber.peerConnection as MockPeerConnection 73 val subPeerConnection = subscriber.peerConnection as MockPeerConnection
75 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED) 74 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
76 } 75 }
  1 +package io.livekit.android.assert
  2 +
  3 +import org.junit.Assert
  4 +
  5 +fun assertIsClass(expectedClass: Class<*>, actual: Any) {
  6 + val klazz = actual::class.java
  7 +
  8 + Assert.assertEquals(expectedClass, klazz)
  9 +}
  10 +
  11 +fun assertIsClassList(expectedClasses: List<Class<*>>, actual: List<*>) {
  12 + val klazzes = actual.map {
  13 + if (it == null) {
  14 + Nothing::class.java
  15 + } else {
  16 + it::class.java
  17 + }
  18 + }
  19 +
  20 + Assert.assertEquals(expectedClasses, klazzes)
  21 +}
@@ -51,7 +51,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -51,7 +51,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
51 connect() 51 connect()
52 val oldWs = wsFactory.ws 52 val oldWs = wsFactory.ws
53 wsFactory.listener.onFailure(oldWs, Exception(), null) 53 wsFactory.listener.onFailure(oldWs, Exception(), null)
54 - 54 + testScheduler.advanceTimeBy(1000)
55 val newWs = wsFactory.ws 55 val newWs = wsFactory.ws
56 Assert.assertNotEquals(oldWs, newWs) 56 Assert.assertNotEquals(oldWs, newWs)
57 } 57 }
@@ -63,6 +63,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -63,6 +63,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
63 63
64 val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection 64 val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection
65 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED) 65 subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
  66 + testScheduler.advanceTimeBy(1000)
66 67
67 val newWs = wsFactory.ws 68 val newWs = wsFactory.ws
68 Assert.assertNotEquals(oldWs, newWs) 69 Assert.assertNotEquals(oldWs, newWs)
@@ -75,6 +76,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -75,6 +76,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
75 76
76 val pubPeerConnection = rtcEngine.publisher.peerConnection as MockPeerConnection 77 val pubPeerConnection = rtcEngine.publisher.peerConnection as MockPeerConnection
77 pubPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED) 78 pubPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
  79 + testScheduler.advanceTimeBy(1000)
78 80
79 val newWs = wsFactory.ws 81 val newWs = wsFactory.ws
80 Assert.assertNotEquals(oldWs, newWs) 82 Assert.assertNotEquals(oldWs, newWs)
@@ -88,6 +90,7 @@ class RTCEngineMockE2ETest : MockE2ETest() { @@ -88,6 +90,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
88 wsFactory.listener.onMessage(wsFactory.ws, SignalClientTest.REFRESH_TOKEN.toOkioByteString()) 90 wsFactory.listener.onMessage(wsFactory.ws, SignalClientTest.REFRESH_TOKEN.toOkioByteString())
89 wsFactory.listener.onFailure(wsFactory.ws, Exception(), null) 91 wsFactory.listener.onFailure(wsFactory.ws, Exception(), null)
90 92
  93 + testScheduler.advanceUntilIdle()
91 val newToken = wsFactory.request.url.queryParameter(SignalClient.CONNECT_QUERY_TOKEN) 94 val newToken = wsFactory.request.url.queryParameter(SignalClient.CONNECT_QUERY_TOKEN)
92 Assert.assertNotEquals(oldToken, newToken) 95 Assert.assertNotEquals(oldToken, newToken)
93 Assert.assertEquals(SignalClientTest.REFRESH_TOKEN.refreshToken, newToken) 96 Assert.assertEquals(SignalClientTest.REFRESH_TOKEN.refreshToken, newToken)
@@ -18,7 +18,6 @@ import io.livekit.android.room.track.Track @@ -18,7 +18,6 @@ import io.livekit.android.room.track.Track
18 import io.livekit.android.util.flow 18 import io.livekit.android.util.flow
19 import io.livekit.android.util.toOkioByteString 19 import io.livekit.android.util.toOkioByteString
20 import junit.framework.Assert.assertEquals 20 import junit.framework.Assert.assertEquals
21 -import junit.framework.Assert.assertTrue  
22 import kotlinx.coroutines.ExperimentalCoroutinesApi 21 import kotlinx.coroutines.ExperimentalCoroutinesApi
23 import kotlinx.coroutines.launch 22 import kotlinx.coroutines.launch
24 import org.junit.Assert 23 import org.junit.Assert
@@ -328,38 +327,4 @@ class RoomMockE2ETest : MockE2ETest() { @@ -328,38 +327,4 @@ class RoomMockE2ETest : MockE2ETest() {
328 connect() 327 connect()
329 Assert.assertEquals(room.state, Room.State.CONNECTED) 328 Assert.assertEquals(room.state, Room.State.CONNECTED)
330 } 329 }
331 -  
332 - @Test  
333 - fun reconnectFromPeerConnectionDisconnect() = runTest {  
334 - connect()  
335 -  
336 - val eventCollector = EventCollector(room.events, coroutineRule.scope)  
337 - wsFactory.onOpen = {  
338 - wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))  
339 - connectPeerConnection()  
340 - }  
341 - disconnectPeerConnection()  
342 - val events = eventCollector.stopCollecting()  
343 -  
344 - assertEquals(2, events.size)  
345 - assertTrue(events[0] is RoomEvent.Reconnecting)  
346 - assertTrue(events[1] is RoomEvent.Reconnected)  
347 - }  
348 -  
349 - @Test  
350 - fun reconnectFromWebSocketFailure() = runTest {  
351 - connect()  
352 -  
353 - val eventCollector = EventCollector(room.events, coroutineRule.scope)  
354 - wsFactory.onOpen = {  
355 - wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))  
356 - connectPeerConnection()  
357 - }  
358 - wsFactory.ws.cancel()  
359 - val events = eventCollector.stopCollecting()  
360 -  
361 - assertEquals(2, events.size)  
362 - assertTrue(events[0] is RoomEvent.Reconnecting)  
363 - assertTrue(events[1] is RoomEvent.Reconnected)  
364 - }  
365 } 330 }
  1 +package io.livekit.android.room
  2 +
  3 +import io.livekit.android.MockE2ETest
  4 +import io.livekit.android.assert.assertIsClassList
  5 +import io.livekit.android.events.EventCollector
  6 +import io.livekit.android.events.FlowCollector
  7 +import io.livekit.android.events.RoomEvent
  8 +import io.livekit.android.mock.MockAudioStreamTrack
  9 +import io.livekit.android.room.track.LocalAudioTrack
  10 +import io.livekit.android.util.flow
  11 +import io.livekit.android.util.toPBByteString
  12 +import junit.framework.Assert.assertEquals
  13 +import kotlinx.coroutines.ExperimentalCoroutinesApi
  14 +import kotlinx.coroutines.launch
  15 +import livekit.LivekitRtc
  16 +import org.junit.Assert
  17 +import org.junit.Test
  18 +import org.junit.runner.RunWith
  19 +import org.robolectric.RobolectricTestRunner
  20 +
  21 +@ExperimentalCoroutinesApi
  22 +@RunWith(RobolectricTestRunner::class)
  23 +class RoomReconnectionMockE2ETest : MockE2ETest() {
  24 +
  25 + private fun prepareForReconnect(softReconnect: Boolean = false) {
  26 + wsFactory.onOpen = {
  27 + wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
  28 + if (!softReconnect) {
  29 + simulateMessageFromServer(SignalClientTest.JOIN)
  30 + }
  31 + }
  32 + }
  33 +
  34 + @Test
  35 + fun reconnectFromPeerConnectionDisconnect() = runTest {
  36 + connect()
  37 +
  38 + val eventCollector = EventCollector(room.events, coroutineRule.scope)
  39 + val stateCollector = FlowCollector(room::state.flow, coroutineRule.scope)
  40 + prepareForReconnect()
  41 + disconnectPeerConnection()
  42 + // Wait so that the reconnect job properly starts first.
  43 + testScheduler.advanceTimeBy(1000)
  44 + connectPeerConnection()
  45 +
  46 + testScheduler.advanceUntilIdle()
  47 + val events = eventCollector.stopCollecting()
  48 + val states = stateCollector.stopCollecting()
  49 +
  50 + assertIsClassList(
  51 + listOf(
  52 + RoomEvent.Reconnecting::class.java,
  53 + RoomEvent.Reconnected::class.java,
  54 + ),
  55 + events
  56 + )
  57 +
  58 + assertEquals(
  59 + listOf(
  60 + Room.State.CONNECTED,
  61 + Room.State.RECONNECTING,
  62 + Room.State.CONNECTED,
  63 + ),
  64 + states
  65 + )
  66 + }
  67 +
  68 + @Test
  69 + fun reconnectFromWebSocketFailure() = runTest {
  70 + connect()
  71 +
  72 + val eventCollector = EventCollector(room.events, coroutineRule.scope)
  73 + val stateCollector = FlowCollector(room::state.flow, coroutineRule.scope)
  74 + prepareForReconnect()
  75 + wsFactory.ws.cancel()
  76 + // Wait so that the reconnect job properly starts first.
  77 + testScheduler.advanceTimeBy(1000)
  78 + connectPeerConnection()
  79 +
  80 + testScheduler.advanceUntilIdle()
  81 + val events = eventCollector.stopCollecting()
  82 + val states = stateCollector.stopCollecting()
  83 +
  84 + assertIsClassList(
  85 + listOf(
  86 + RoomEvent.Reconnecting::class.java,
  87 + RoomEvent.Reconnected::class.java,
  88 + ),
  89 + events
  90 + )
  91 +
  92 + assertEquals(
  93 + listOf(
  94 + Room.State.CONNECTED,
  95 + Room.State.RECONNECTING,
  96 + Room.State.CONNECTED,
  97 + ),
  98 + states
  99 + )
  100 + }
  101 +
  102 + @Test
  103 + fun softReconnectSendsSyncState() = runTest {
  104 + room.setReconnectionType(ReconnectType.FORCE_SOFT_RECONNECT)
  105 +
  106 + connect()
  107 + prepareForReconnect()
  108 + disconnectPeerConnection()
  109 + // Wait so that the reconnect job properly starts first.
  110 + testScheduler.advanceTimeBy(1000)
  111 + connectPeerConnection()
  112 +
  113 + testScheduler.advanceUntilIdle()
  114 + val sentRequests = wsFactory.ws.sentRequests
  115 + val sentSyncState = sentRequests.any { requestString ->
  116 + val sentRequest = LivekitRtc.SignalRequest.newBuilder()
  117 + .mergeFrom(requestString.toPBByteString())
  118 + .build()
  119 +
  120 + return@any sentRequest.hasSyncState()
  121 + }
  122 +
  123 + Assert.assertTrue(sentSyncState)
  124 + }
  125 +
  126 + @Test
  127 + fun fullReconnectRepublishesTracks() = runTest {
  128 + room.setReconnectionType(ReconnectType.FORCE_FULL_RECONNECT)
  129 + connect()
  130 +
  131 + // publish track
  132 + val publishJob = launch {
  133 + room.localParticipant.publishAudioTrack(
  134 + LocalAudioTrack(
  135 + "",
  136 + MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
  137 + )
  138 + )
  139 + }
  140 + simulateMessageFromServer(SignalClientTest.LOCAL_TRACK_PUBLISHED)
  141 + publishJob.join()
  142 +
  143 + prepareForReconnect()
  144 + disconnectPeerConnection()
  145 + // Wait so that the reconnect job properly starts first.
  146 + testScheduler.advanceTimeBy(1000)
  147 + connectPeerConnection()
  148 +
  149 + testScheduler.advanceUntilIdle()
  150 + val sentRequests = wsFactory.ws.sentRequests
  151 + val sentAddTrack = sentRequests.any { requestString ->
  152 + val sentRequest = LivekitRtc.SignalRequest.newBuilder()
  153 + .mergeFrom(requestString.toPBByteString())
  154 + .build()
  155 +
  156 + return@any sentRequest.hasAddTrack()
  157 + }
  158 +
  159 + println(sentRequests)
  160 + Assert.assertTrue(sentAddTrack)
  161 + }
  162 +}