davidliu
Committed by GitHub

full reconnect support (#45)

* properly clean up texture view renderer when done

* tests

* full reconnect if regular reconnect doesn't work

* small tweak

* proper join message handling

* fix republishing
@@ -28,6 +28,7 @@ import kotlin.coroutines.suspendCoroutine @@ -28,6 +28,7 @@ import kotlin.coroutines.suspendCoroutine
28 /** 28 /**
29 * @suppress 29 * @suppress
30 */ 30 */
  31 +@OptIn(ExperimentalCoroutinesApi::class)
31 @Singleton 32 @Singleton
32 class RTCEngine 33 class RTCEngine
33 @Inject 34 @Inject
@@ -75,6 +76,7 @@ internal constructor( @@ -75,6 +76,7 @@ internal constructor(
75 mutableMapOf() 76 mutableMapOf()
76 private var sessionUrl: String? = null 77 private var sessionUrl: String? = null
77 private var sessionToken: String? = null 78 private var sessionToken: String? = null
  79 + private var connectOptions: ConnectOptions? = null
78 80
79 private val publisherObserver = PublisherTransportObserver(this, client) 81 private val publisherObserver = PublisherTransportObserver(this, client)
80 private val subscriberObserver = SubscriberTransportObserver(this, client) 82 private val subscriberObserver = SubscriberTransportObserver(this, client)
@@ -113,9 +115,14 @@ internal constructor( @@ -113,9 +115,14 @@ internal constructor(
113 coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher) 115 coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
114 sessionUrl = url 116 sessionUrl = url
115 sessionToken = token 117 sessionToken = token
  118 + return joinImpl(url, token, options)
  119 + }
  120 +
  121 + suspend fun joinImpl(url: String, token: String, options: ConnectOptions): LivekitRtc.JoinResponse {
116 val joinResponse = client.join(url, token, options) 122 val joinResponse = client.join(url, token, options)
  123 + listener?.onJoinResponse(joinResponse)
117 isClosed = false 124 isClosed = false
118 - listener?.onSignalConnected() 125 + listener?.onSignalConnected(false)
119 126
120 isSubscriberPrimary = joinResponse.subscriberPrimary 127 isSubscriberPrimary = joinResponse.subscriberPrimary
121 128
@@ -245,6 +252,7 @@ internal constructor( @@ -245,6 +252,7 @@ internal constructor(
245 throw TrackException.DuplicateTrackException("Track with same ID $cid has already been published!") 252 throw TrackException.DuplicateTrackException("Track with same ID $cid has already been published!")
246 } 253 }
247 254
  255 + // Suspend until signal client receives message confirming track publication.
248 return suspendCoroutine { cont -> 256 return suspendCoroutine { cont ->
249 pendingTrackResolvers[cid] = cont 257 pendingTrackResolvers[cid] = cont
250 client.sendAddTrack(cid, name, kind, builder) 258 client.sendAddTrack(cid, name, kind, builder)
@@ -267,11 +275,31 @@ internal constructor( @@ -267,11 +275,31 @@ internal constructor(
267 return 275 return
268 } 276 }
269 isClosed = true 277 isClosed = true
  278 + hasPublished = false
  279 + sessionUrl = null
  280 + sessionToken = null
  281 + connectOptions = null
  282 + reconnectingJob?.cancel()
  283 + reconnectingJob = null
270 coroutineScope.close() 284 coroutineScope.close()
  285 + closeResources()
  286 + }
  287 +
  288 + private fun closeResources() {
  289 + connectionState = ConnectionState.DISCONNECTED
271 _publisher?.close() 290 _publisher?.close()
272 _publisher = null 291 _publisher = null
273 _subscriber?.close() 292 _subscriber?.close()
274 _subscriber = null 293 _subscriber = null
  294 + reliableDataChannel?.close()
  295 + reliableDataChannel = null
  296 + reliableDataChannelSub?.close()
  297 + reliableDataChannelSub = null
  298 + lossyDataChannel?.close()
  299 + lossyDataChannel = null
  300 + lossyDataChannelSub?.close()
  301 + lossyDataChannelSub = null
  302 + isSubscriberPrimary = false
275 client.close() 303 client.close()
276 } 304 }
277 305
@@ -293,6 +321,7 @@ internal constructor( @@ -293,6 +321,7 @@ internal constructor(
293 } 321 }
294 322
295 val job = coroutineScope.launch { 323 val job = coroutineScope.launch {
  324 + connectionState = ConnectionState.RECONNECTING
296 listener?.onEngineReconnecting() 325 listener?.onEngineReconnecting()
297 326
298 for (wsRetries in 0 until MAX_SIGNAL_RETRIES) { 327 for (wsRetries in 0 until MAX_SIGNAL_RETRIES) {
@@ -302,28 +331,44 @@ internal constructor( @@ -302,28 +331,44 @@ internal constructor(
302 } 331 }
303 332
304 LKLog.i { "Reconnecting to signal, attempt ${wsRetries + 1}" } 333 LKLog.i { "Reconnecting to signal, attempt ${wsRetries + 1}" }
305 -  
306 delay(startDelay) 334 delay(startDelay)
307 - try {  
308 - client.reconnect(url, token)  
309 - } catch (e: Exception) {  
310 - // ws reconnect failed, retry.  
311 - continue  
312 - }  
313 335
314 - LKLog.v { "ws reconnected, restarting ICE" }  
315 - listener?.onSignalConnected() 336 + // full reconnect after first try.
  337 + val isFullReconnect = true
  338 +
  339 + if (isFullReconnect) {
  340 + try {
  341 + closeResources()
  342 + listener?.onFullReconnecting()
  343 + joinImpl(url, token, connectOptions ?: ConnectOptions())
  344 + } catch (e: Exception) {
  345 + LKLog.w(e) { "Error during reconnection." }
  346 + // reconnect failed, retry.
  347 + continue
  348 + }
  349 + } else {
  350 + try {
  351 + client.reconnect(url, token)
  352 + // no join response for regular reconnects
  353 + client.onReady()
  354 + } catch (e: Exception) {
  355 + LKLog.w(e) { "Error during reconnection." }
  356 + // ws reconnect failed, retry.
  357 + continue
  358 + }
316 359
317 - subscriber.prepareForIceRestart()  
318 - connectionState = ConnectionState.RECONNECTING  
319 - // trigger publisher reconnect  
320 - // only restart publisher if it's needed  
321 - if (hasPublished) {  
322 - negotiate()  
323 - } 360 + LKLog.v { "ws reconnected, restarting ICE" }
  361 + listener?.onSignalConnected(!isFullReconnect)
324 362
  363 + subscriber.prepareForIceRestart()
  364 + // trigger publisher reconnect
  365 + // only restart publisher if it's needed
  366 + if (hasPublished) {
  367 + negotiate()
  368 + }
  369 + }
325 // wait until ICE connected 370 // wait until ICE connected
326 - val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS; 371 + val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS
327 while (SystemClock.elapsedRealtime() < endTime) { 372 while (SystemClock.elapsedRealtime() < endTime) {
328 if (connectionState == ConnectionState.CONNECTED) { 373 if (connectionState == ConnectionState.CONNECTED) {
329 LKLog.v { "reconnected to ICE" } 374 LKLog.v { "reconnected to ICE" }
@@ -333,11 +378,13 @@ internal constructor( @@ -333,11 +378,13 @@ internal constructor(
333 } 378 }
334 379
335 if (connectionState == ConnectionState.CONNECTED) { 380 if (connectionState == ConnectionState.CONNECTED) {
  381 + if (isFullReconnect) {
  382 + listener?.onFullReconnect()
  383 + }
336 return@launch 384 return@launch
337 } 385 }
338 } 386 }
339 387
340 -  
341 close() 388 close()
342 listener?.onEngineDisconnected("failed reconnecting.") 389 listener?.onEngineDisconnected("failed reconnecting.")
343 } 390 }
@@ -389,7 +436,7 @@ internal constructor( @@ -389,7 +436,7 @@ internal constructor(
389 publisher.peerConnection.iceConnectionState() != PeerConnection.IceConnectionState.CHECKING 436 publisher.peerConnection.iceConnectionState() != PeerConnection.IceConnectionState.CHECKING
390 ) { 437 ) {
391 // start negotiation 438 // start negotiation
392 - this.negotiate(); 439 + this.negotiate()
393 } 440 }
394 441
395 442
@@ -399,7 +446,7 @@ internal constructor( @@ -399,7 +446,7 @@ internal constructor(
399 } 446 }
400 447
401 // wait until publisher ICE connected 448 // wait until publisher ICE connected
402 - val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS; 449 + val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS
403 while (SystemClock.elapsedRealtime() < endTime) { 450 while (SystemClock.elapsedRealtime() < endTime) {
404 if (this.publisher.peerConnection.isConnected() && targetChannel.state() == DataChannel.State.OPEN) { 451 if (this.publisher.peerConnection.isConnected() && targetChannel.state() == DataChannel.State.OPEN) {
405 return 452 return
@@ -450,6 +497,7 @@ internal constructor( @@ -450,6 +497,7 @@ internal constructor(
450 fun onEngineReconnecting() 497 fun onEngineReconnecting()
451 fun onEngineDisconnected(reason: String) 498 fun onEngineDisconnected(reason: String)
452 fun onFailToConnect(error: Throwable) 499 fun onFailToConnect(error: Throwable)
  500 + fun onJoinResponse(response: LivekitRtc.JoinResponse)
453 fun onAddTrack(track: MediaStreamTrack, streams: Array<out MediaStream>) 501 fun onAddTrack(track: MediaStreamTrack, streams: Array<out MediaStream>)
454 fun onUpdateParticipants(updates: List<LivekitModels.ParticipantInfo>) 502 fun onUpdateParticipants(updates: List<LivekitModels.ParticipantInfo>)
455 fun onActiveSpeakersUpdate(speakers: List<LivekitModels.SpeakerInfo>) 503 fun onActiveSpeakersUpdate(speakers: List<LivekitModels.SpeakerInfo>)
@@ -461,7 +509,9 @@ internal constructor( @@ -461,7 +509,9 @@ internal constructor(
461 fun onStreamStateUpdate(streamStates: List<LivekitRtc.StreamStateInfo>) 509 fun onStreamStateUpdate(streamStates: List<LivekitRtc.StreamStateInfo>)
462 fun onSubscribedQualityUpdate(subscribedQualityUpdate: LivekitRtc.SubscribedQualityUpdate) 510 fun onSubscribedQualityUpdate(subscribedQualityUpdate: LivekitRtc.SubscribedQualityUpdate)
463 fun onSubscriptionPermissionUpdate(subscriptionPermissionUpdate: LivekitRtc.SubscriptionPermissionUpdate) 511 fun onSubscriptionPermissionUpdate(subscriptionPermissionUpdate: LivekitRtc.SubscriptionPermissionUpdate)
464 - fun onSignalConnected() 512 + fun onSignalConnected(isReconnect: Boolean)
  513 + fun onFullReconnecting()
  514 + suspend fun onFullReconnect()
465 } 515 }
466 516
467 companion object { 517 companion object {
  1 +@file:Suppress("unused")
  2 +
1 package io.livekit.android.room 3 package io.livekit.android.room
2 4
3 import android.content.Context 5 import android.content.Context
@@ -117,8 +119,12 @@ constructor( @@ -117,8 +119,12 @@ constructor(
117 */ 119 */
118 var videoTrackPublishDefaults: VideoTrackPublishDefaults by defaultsManager::videoTrackPublishDefaults 120 var videoTrackPublishDefaults: VideoTrackPublishDefaults by defaultsManager::videoTrackPublishDefaults
119 121
120 - lateinit var localParticipant: LocalParticipant  
121 - private set 122 + var _localParticipant: LocalParticipant? = null
  123 + val localParticipant: LocalParticipant
  124 + get() {
  125 + return _localParticipant
  126 + ?: throw UninitializedPropertyAccessException("localParticipant has not been initialized yet.")
  127 + }
122 128
123 private var mutableRemoteParticipants by flowDelegate(emptyMap<String, RemoteParticipant>()) 129 private var mutableRemoteParticipants by flowDelegate(emptyMap<String, RemoteParticipant>())
124 130
@@ -143,25 +149,8 @@ constructor( @@ -143,25 +149,8 @@ constructor(
143 coroutineScope = CoroutineScope(defaultDispatcher + SupervisorJob()) 149 coroutineScope = CoroutineScope(defaultDispatcher + SupervisorJob())
144 state = State.CONNECTING 150 state = State.CONNECTING
145 connectOptions = options 151 connectOptions = options
146 - val response = engine.join(url, token, options)  
147 - LKLog.i { "Connected to server, server version: ${response.serverVersion}, client version: ${Version.CLIENT_VERSION}" }  
148 -  
149 - sid = Sid(response.room.sid)  
150 - name = response.room.name 152 + engine.join(url, token, options)
151 153
152 - if (!response.hasParticipant()) {  
153 - listener?.onFailedToConnect(this, RoomException.ConnectException("server didn't return any participants"))  
154 - return  
155 - }  
156 -  
157 - val lp = localParticipantFactory.create(response.participant, dynacast)  
158 - lp.internalListener = this  
159 - localParticipant = lp  
160 - if (response.otherParticipantsList.isNotEmpty()) {  
161 - response.otherParticipantsList.forEach {  
162 - getOrCreateRemoteParticipant(it.sid, it)  
163 - }  
164 - }  
165 val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager 154 val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
166 val networkRequest = NetworkRequest.Builder() 155 val networkRequest = NetworkRequest.Builder()
167 .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) 156 .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
@@ -186,11 +175,38 @@ constructor( @@ -186,11 +175,38 @@ constructor(
186 handleDisconnect() 175 handleDisconnect()
187 } 176 }
188 177
  178 + override fun onJoinResponse(response: LivekitRtc.JoinResponse) {
  179 +
  180 + LKLog.i { "Connected to server, server version: ${response.serverVersion}, client version: ${Version.CLIENT_VERSION}" }
  181 +
  182 + sid = Sid(response.room.sid)
  183 + name = response.room.name
  184 +
  185 + if (!response.hasParticipant()) {
  186 + listener?.onFailedToConnect(this, RoomException.ConnectException("server didn't return any participants"))
  187 + return
  188 + }
  189 +
  190 + if (_localParticipant == null) {
  191 + val lp = localParticipantFactory.create(response.participant, dynacast)
  192 + lp.internalListener = this
  193 + _localParticipant = lp
  194 + } else {
  195 + localParticipant.updateFromInfo(response.participant)
  196 + }
  197 +
  198 + if (response.otherParticipantsList.isNotEmpty()) {
  199 + response.otherParticipantsList.forEach {
  200 + getOrCreateRemoteParticipant(it.sid, it)
  201 + }
  202 + }
  203 + }
  204 +
189 private fun handleParticipantDisconnect(sid: String) { 205 private fun handleParticipantDisconnect(sid: String) {
190 val newParticipants = mutableRemoteParticipants.toMutableMap() 206 val newParticipants = mutableRemoteParticipants.toMutableMap()
191 val removedParticipant = newParticipants.remove(sid) ?: return 207 val removedParticipant = newParticipants.remove(sid) ?: return
192 removedParticipant.tracks.values.toList().forEach { publication -> 208 removedParticipant.tracks.values.toList().forEach { publication ->
193 - removedParticipant.unpublishTrack(publication.sid) 209 + removedParticipant.unpublishTrack(publication.sid, true)
194 } 210 }
195 211
196 mutableRemoteParticipants = newParticipants 212 mutableRemoteParticipants = newParticipants
@@ -316,6 +332,15 @@ constructor( @@ -316,6 +332,15 @@ constructor(
316 engine.reconnect() 332 engine.reconnect()
317 } 333 }
318 334
  335 + /**
  336 + * Removes all participants and tracks from the room.
  337 + */
  338 + private fun cleanupRoom() {
  339 + localParticipant.cleanup()
  340 + remoteParticipants.keys.toMutableSet() // copy keys to avoid concurrent modifications.
  341 + .forEach { sid -> handleParticipantDisconnect(sid) }
  342 + }
  343 +
319 private fun handleDisconnect() { 344 private fun handleDisconnect() {
320 if (state == State.DISCONNECTED) { 345 if (state == State.DISCONNECTED) {
321 return 346 return
@@ -328,19 +353,14 @@ constructor( @@ -328,19 +353,14 @@ constructor(
328 // do nothing, may happen on older versions if attempting to unregister twice. 353 // do nothing, may happen on older versions if attempting to unregister twice.
329 } 354 }
330 355
331 - for (pub in localParticipant.tracks.values) {  
332 - pub.track?.stop()  
333 - }  
334 - // stop remote tracks too  
335 - for (p in remoteParticipants.values) {  
336 - for (pub in p.tracks.values) {  
337 - pub.track?.stop()  
338 - }  
339 - } 356 + cleanupRoom()
  357 +
340 engine.close() 358 engine.close()
341 state = State.DISCONNECTED 359 state = State.DISCONNECTED
342 listener?.onDisconnect(this, null) 360 listener?.onDisconnect(this, null)
343 listener = null 361 listener = null
  362 + _localParticipant?.dispose()
  363 + _localParticipant = null
344 364
345 // Ensure all observers see the disconnected before closing scope. 365 // Ensure all observers see the disconnected before closing scope.
346 runBlocking { 366 runBlocking {
@@ -560,13 +580,21 @@ constructor( @@ -560,13 +580,21 @@ constructor(
560 eventBus.tryPostEvent(RoomEvent.FailedToConnect(this, error)) 580 eventBus.tryPostEvent(RoomEvent.FailedToConnect(this, error))
561 } 581 }
562 582
563 - override fun onSignalConnected() {  
564 - if (state == State.RECONNECTING) { 583 + override fun onSignalConnected(isReconnect: Boolean) {
  584 + if (state == State.RECONNECTING && isReconnect) {
565 // during reconnection, need to send sync state upon signal connection. 585 // during reconnection, need to send sync state upon signal connection.
566 sendSyncState() 586 sendSyncState()
567 } 587 }
568 } 588 }
569 589
  590 + override fun onFullReconnecting() {
  591 + localParticipant.prepareForFullReconnect()
  592 + }
  593 +
  594 + override suspend fun onFullReconnect() {
  595 + localParticipant.republishTracks()
  596 + }
  597 +
570 //------------------------------- ParticipantListener --------------------------------// 598 //------------------------------- ParticipantListener --------------------------------//
571 /** 599 /**
572 * This is called for both Local and Remote participants 600 * This is called for both Local and Remote participants
@@ -14,7 +14,6 @@ import io.livekit.android.util.safe @@ -14,7 +14,6 @@ import io.livekit.android.util.safe
14 import io.livekit.android.webrtc.toProtoSessionDescription 14 import io.livekit.android.webrtc.toProtoSessionDescription
15 import kotlinx.coroutines.* 15 import kotlinx.coroutines.*
16 import kotlinx.coroutines.flow.MutableSharedFlow 16 import kotlinx.coroutines.flow.MutableSharedFlow
17 -import kotlinx.coroutines.flow.collect  
18 import kotlinx.serialization.decodeFromString 17 import kotlinx.serialization.decodeFromString
19 import kotlinx.serialization.encodeToString 18 import kotlinx.serialization.encodeToString
20 import kotlinx.serialization.json.Json 19 import kotlinx.serialization.json.Json
@@ -126,7 +125,13 @@ constructor( @@ -126,7 +125,13 @@ constructor(
126 } 125 }
127 } 126 }
128 127
129 - @ExperimentalCoroutinesApi 128 + /**
  129 + * Notifies that the downstream consumers of SignalClient are ready to consume messages.
  130 + * Until this method is called, any messages received through the websocket are buffered.
  131 + *
  132 + * Should be called after resolving the join message.
  133 + */
  134 + @OptIn(ExperimentalCoroutinesApi::class)
130 fun onReady() { 135 fun onReady() {
131 coroutineScope.launch { 136 coroutineScope.launch {
132 responseFlow.collect { 137 responseFlow.collect {
@@ -483,6 +488,11 @@ constructor( @@ -483,6 +488,11 @@ constructor(
483 }.safe() 488 }.safe()
484 } 489 }
485 490
  491 + /**
  492 + * Closes out any existing websocket connection, and cleans up used resources.
  493 + *
  494 + * Can be reused afterwards.
  495 + */
486 fun close(code: Int = 1000, reason: String = "Normal Closure") { 496 fun close(code: Int = 1000, reason: String = "Normal Closure") {
487 isConnected = false 497 isConnected = false
488 if(::coroutineScope.isInitialized) { 498 if(::coroutineScope.isInitialized) {
@@ -9,11 +9,13 @@ import dagger.assisted.AssistedFactory @@ -9,11 +9,13 @@ import dagger.assisted.AssistedFactory
9 import dagger.assisted.AssistedInject 9 import dagger.assisted.AssistedInject
10 import io.livekit.android.dagger.InjectionNames 10 import io.livekit.android.dagger.InjectionNames
11 import io.livekit.android.events.ParticipantEvent 11 import io.livekit.android.events.ParticipantEvent
  12 +import io.livekit.android.room.ConnectionState
12 import io.livekit.android.room.DefaultsManager 13 import io.livekit.android.room.DefaultsManager
13 import io.livekit.android.room.RTCEngine 14 import io.livekit.android.room.RTCEngine
14 import io.livekit.android.room.track.* 15 import io.livekit.android.room.track.*
15 import io.livekit.android.util.LKLog 16 import io.livekit.android.util.LKLog
16 import kotlinx.coroutines.CoroutineDispatcher 17 import kotlinx.coroutines.CoroutineDispatcher
  18 +import kotlinx.coroutines.cancel
17 import livekit.LivekitModels 19 import livekit.LivekitModels
18 import livekit.LivekitRtc 20 import livekit.LivekitRtc
19 import org.webrtc.EglBase 21 import org.webrtc.EglBase
@@ -58,6 +60,13 @@ internal constructor( @@ -58,6 +60,13 @@ internal constructor(
58 .mapNotNull { it as? LocalTrackPublication } 60 .mapNotNull { it as? LocalTrackPublication }
59 .toList() 61 .toList()
60 62
  63 + private var isReconnecting = false
  64 +
  65 + /**
  66 + * Holds on to publishes that need to be republished after a full reconnect.
  67 + */
  68 + private var publishes = mutableMapOf<Track, TrackPublishOptions>()
  69 +
61 /** 70 /**
62 * Creates an audio track, recording audio through the microphone with the given [options]. 71 * Creates an audio track, recording audio through the microphone with the given [options].
63 * 72 *
@@ -189,7 +198,7 @@ internal constructor( @@ -189,7 +198,7 @@ internal constructor(
189 ), 198 ),
190 publishListener: PublishListener? = null 199 publishListener: PublishListener? = null
191 ) { 200 ) {
192 - publishTrackImpl( 201 + val published = publishTrackImpl(
193 track, 202 track,
194 requestConfig = { 203 requestConfig = {
195 disableDtx = !options.dtx 204 disableDtx = !options.dtx
@@ -197,6 +206,10 @@ internal constructor( @@ -197,6 +206,10 @@ internal constructor(
197 }, 206 },
198 publishListener = publishListener, 207 publishListener = publishListener,
199 ) 208 )
  209 +
  210 + if (published) {
  211 + publishes[track] = options
  212 + }
200 } 213 }
201 214
202 suspend fun publishVideoTrack( 215 suspend fun publishVideoTrack(
@@ -208,7 +221,7 @@ internal constructor( @@ -208,7 +221,7 @@ internal constructor(
208 val encodings = computeVideoEncodings(track.dimensions, options) 221 val encodings = computeVideoEncodings(track.dimensions, options)
209 val videoLayers = videoLayersFromEncodings(track.dimensions.width, track.dimensions.height, encodings) 222 val videoLayers = videoLayersFromEncodings(track.dimensions.width, track.dimensions.height, encodings)
210 223
211 - publishTrackImpl( 224 + val published = publishTrackImpl(
212 track, 225 track,
213 requestConfig = { 226 requestConfig = {
214 width = track.dimensions.width 227 width = track.dimensions.width
@@ -223,18 +236,25 @@ internal constructor( @@ -223,18 +236,25 @@ internal constructor(
223 encodings = encodings, 236 encodings = encodings,
224 publishListener = publishListener 237 publishListener = publishListener
225 ) 238 )
  239 +
  240 + if (published) {
  241 + publishes[track] = options
  242 + }
226 } 243 }
227 244
228 245
  246 + /**
  247 + * @return true if the track publish was successful.
  248 + */
229 private suspend fun publishTrackImpl( 249 private suspend fun publishTrackImpl(
230 track: Track, 250 track: Track,
231 requestConfig: LivekitRtc.AddTrackRequest.Builder.() -> Unit, 251 requestConfig: LivekitRtc.AddTrackRequest.Builder.() -> Unit,
232 encodings: List<RtpParameters.Encoding> = emptyList(), 252 encodings: List<RtpParameters.Encoding> = emptyList(),
233 publishListener: PublishListener? = null 253 publishListener: PublishListener? = null
234 - ) { 254 + ): Boolean {
235 if (localTrackPublications.any { it.track == track }) { 255 if (localTrackPublications.any { it.track == track }) {
236 publishListener?.onPublishFailure(TrackException.PublishException("Track has already been published")) 256 publishListener?.onPublishFailure(TrackException.PublishException("Track has already been published"))
237 - return 257 + return false
238 } 258 }
239 259
240 val cid = track.rtcTrack.id() 260 val cid = track.rtcTrack.id()
@@ -264,16 +284,19 @@ internal constructor( @@ -264,16 +284,19 @@ internal constructor(
264 284
265 if (transceiver == null) { 285 if (transceiver == null) {
266 publishListener?.onPublishFailure(TrackException.PublishException("null sender returned from peer connection")) 286 publishListener?.onPublishFailure(TrackException.PublishException("null sender returned from peer connection"))
267 - return 287 + return false
268 } 288 }
269 289
270 // TODO: enable setting preferred codec 290 // TODO: enable setting preferred codec
271 291
272 val publication = LocalTrackPublication(trackInfo, track, this) 292 val publication = LocalTrackPublication(trackInfo, track, this)
273 addTrackPublication(publication) 293 addTrackPublication(publication)
  294 +
274 publishListener?.onPublishSuccess(publication) 295 publishListener?.onPublishSuccess(publication)
275 internalListener?.onTrackPublished(publication, this) 296 internalListener?.onTrackPublished(publication, this)
276 eventBus.postEvent(ParticipantEvent.LocalTrackPublished(this, publication), scope) 297 eventBus.postEvent(ParticipantEvent.LocalTrackPublished(this, publication), scope)
  298 +
  299 + return true
277 } 300 }
278 301
279 private fun computeVideoEncodings( 302 private fun computeVideoEncodings(
@@ -451,14 +474,19 @@ internal constructor( @@ -451,14 +474,19 @@ internal constructor(
451 LKLog.d { "this track was never published." } 474 LKLog.d { "this track was never published." }
452 return 475 return
453 } 476 }
  477 +
  478 + publishes.remove(track)
  479 +
454 val sid = publication.sid 480 val sid = publication.sid
455 tracks = tracks.toMutableMap().apply { remove(sid) } 481 tracks = tracks.toMutableMap().apply { remove(sid) }
456 482
457 - val senders = engine.publisher.peerConnection.senders ?: return  
458 - for (sender in senders) {  
459 - val t = sender.track() ?: continue  
460 - if (t.id() == track.rtcTrack.id()) {  
461 - engine.publisher.peerConnection.removeTrack(sender) 483 + if (engine.connectionState == ConnectionState.CONNECTED) {
  484 + val senders = engine.publisher.peerConnection.senders
  485 + for (sender in senders) {
  486 + val t = sender.track() ?: continue
  487 + if (t.id() == track.rtcTrack.id()) {
  488 + engine.publisher.peerConnection.removeTrack(sender)
  489 + }
462 } 490 }
463 } 491 }
464 track.stop() 492 track.stop()
@@ -555,6 +583,41 @@ internal constructor( @@ -555,6 +583,41 @@ internal constructor(
555 } 583 }
556 } 584 }
557 585
  586 + fun prepareForFullReconnect() {
  587 + val pubs = localTrackPublications // creates a copy, so is safe from the following removal.
  588 + tracks = tracks.toMutableMap().apply { clear() }
  589 +
  590 + for (publication in pubs) {
  591 + internalListener?.onTrackUnpublished(publication, this)
  592 + eventBus.postEvent(ParticipantEvent.LocalTrackUnpublished(this, publication), scope)
  593 + }
  594 + }
  595 +
  596 + suspend fun republishTracks() {
  597 + for ((track, options) in publishes) {
  598 + when (track) {
  599 + is LocalAudioTrack -> publishAudioTrack(track, options as AudioTrackPublishOptions, null)
  600 + is LocalVideoTrack -> publishVideoTrack(track, options as VideoTrackPublishOptions, null)
  601 + else -> throw IllegalStateException("LocalParticipant has a non local track publish?")
  602 + }
  603 + }
  604 + }
  605 +
  606 + fun cleanup() {
  607 + for (pub in tracks.values) {
  608 + val track = pub.track
  609 +
  610 + if (track != null) {
  611 + track.stop()
  612 + unpublishTrack(track)
  613 + }
  614 + }
  615 + }
  616 +
  617 + fun dispose() {
  618 + cleanup()
  619 + scope.cancel()
  620 + }
558 621
559 interface PublishListener { 622 interface PublishListener {
560 fun onPublishSuccess(publication: TrackPublication) {} 623 fun onPublishSuccess(publication: TrackPublication) {}
@@ -678,4 +741,16 @@ data class ParticipantTrackPermission( @@ -678,4 +741,16 @@ data class ParticipantTrackPermission(
678 .addAllTrackSids(allowedTrackSids) 741 .addAllTrackSids(allowedTrackSids)
679 .build() 742 .build()
680 } 743 }
  744 +}
  745 +
  746 +sealed class PublishRecord() {
  747 + data class AudioTrackPublishRecord(
  748 + val track: LocalAudioTrack,
  749 + val options: AudioTrackPublishOptions
  750 + )
  751 +
  752 + data class VideoTrackPublishRecord(
  753 + val track: LocalVideoTrack,
  754 + val options: VideoTrackPublishOptions
  755 + )
681 } 756 }
@@ -32,4 +32,10 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) { @@ -32,4 +32,10 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) {
32 override fun send(buffer: Buffer?): Boolean { 32 override fun send(buffer: Buffer?): Boolean {
33 return true 33 return true
34 } 34 }
  35 +
  36 + override fun close() {
  37 + }
  38 +
  39 + override fun dispose() {
  40 + }
35 } 41 }
@@ -74,16 +74,16 @@ class MockPeerConnection( @@ -74,16 +74,16 @@ class MockPeerConnection(
74 return super.createSender(kind, stream_id) 74 return super.createSender(kind, stream_id)
75 } 75 }
76 76
77 - override fun getSenders(): MutableList<RtpSender> {  
78 - return super.getSenders() 77 + override fun getSenders(): List<RtpSender> {
  78 + return emptyList()
79 } 79 }
80 80
81 - override fun getReceivers(): MutableList<RtpReceiver> {  
82 - return super.getReceivers() 81 + override fun getReceivers(): List<RtpReceiver> {
  82 + return emptyList()
83 } 83 }
84 84
85 - override fun getTransceivers(): MutableList<RtpTransceiver> {  
86 - return super.getTransceivers() 85 + override fun getTransceivers(): List<RtpTransceiver> {
  86 + return emptyList()
87 } 87 }
88 88
89 override fun addTrack(track: MediaStreamTrack?): RtpSender { 89 override fun addTrack(track: MediaStreamTrack?): RtpSender {
@@ -103,10 +103,10 @@ class MockPeerConnection( @@ -103,10 +103,10 @@ class MockPeerConnection(
103 } 103 }
104 104
105 override fun addTransceiver( 105 override fun addTransceiver(
106 - track: MediaStreamTrack?, 106 + track: MediaStreamTrack,
107 init: RtpTransceiver.RtpTransceiverInit? 107 init: RtpTransceiver.RtpTransceiverInit?
108 ): RtpTransceiver { 108 ): RtpTransceiver {
109 - return super.addTransceiver(track, init) 109 + return MockRtpTransceiver.create(track, init ?: RtpTransceiver.RtpTransceiverInit())
110 } 110 }
111 111
112 override fun addTransceiver(mediaType: MediaStreamTrack.MediaType?): RtpTransceiver { 112 override fun addTransceiver(mediaType: MediaStreamTrack.MediaType?): RtpTransceiver {
  1 +package io.livekit.android.mock
  2 +
  3 +import org.mockito.Mockito
  4 +import org.webrtc.MediaStreamTrack
  5 +import org.webrtc.RtpTransceiver
  6 +
  7 +object MockRtpTransceiver {
  8 + fun create(
  9 + track: MediaStreamTrack,
  10 + init: RtpTransceiver.RtpTransceiverInit = RtpTransceiver.RtpTransceiverInit()
  11 + ): RtpTransceiver {
  12 + val mock = Mockito.mock(RtpTransceiver::class.java)
  13 +
  14 + Mockito.`when`(mock.mediaType).then {
  15 + return@then when (track.kind()) {
  16 + MediaStreamTrack.AUDIO_TRACK_KIND -> MediaStreamTrack.MediaType.MEDIA_TYPE_AUDIO
  17 + MediaStreamTrack.VIDEO_TRACK_KIND -> MediaStreamTrack.MediaType.MEDIA_TYPE_VIDEO
  18 + else -> throw IllegalStateException("illegal kind: ${track.kind()}")
  19 + }
  20 + }
  21 +
  22 + return mock
  23 + }
  24 +}
@@ -4,6 +4,12 @@ import livekit.LivekitModels @@ -4,6 +4,12 @@ import livekit.LivekitModels
4 4
5 object TestData { 5 object TestData {
6 6
  7 + val LOCAL_AUDIO_TRACK = with(LivekitModels.TrackInfo.newBuilder()) {
  8 + sid = "local_audio_track_sid"
  9 + type = LivekitModels.TrackType.AUDIO
  10 + build()
  11 + }
  12 +
7 val REMOTE_AUDIO_TRACK = with(LivekitModels.TrackInfo.newBuilder()) { 13 val REMOTE_AUDIO_TRACK = with(LivekitModels.TrackInfo.newBuilder()) {
8 sid = "remote_audio_track_sid" 14 sid = "remote_audio_track_sid"
9 type = LivekitModels.TrackType.AUDIO 15 type = LivekitModels.TrackType.AUDIO
@@ -10,8 +10,8 @@ import io.livekit.android.mock.MockMediaStream @@ -10,8 +10,8 @@ import io.livekit.android.mock.MockMediaStream
10 import io.livekit.android.mock.TestData 10 import io.livekit.android.mock.TestData
11 import io.livekit.android.mock.createMediaStreamId 11 import io.livekit.android.mock.createMediaStreamId
12 import io.livekit.android.room.participant.ConnectionQuality 12 import io.livekit.android.room.participant.ConnectionQuality
  13 +import io.livekit.android.room.track.LocalAudioTrack
13 import io.livekit.android.room.track.Track 14 import io.livekit.android.room.track.Track
14 -import io.livekit.android.util.delegate  
15 import io.livekit.android.util.flow 15 import io.livekit.android.util.flow
16 import io.livekit.android.util.toOkioByteString 16 import io.livekit.android.util.toOkioByteString
17 import kotlinx.coroutines.ExperimentalCoroutinesApi 17 import kotlinx.coroutines.ExperimentalCoroutinesApi
@@ -257,6 +257,30 @@ class RoomMockE2ETest : MockE2ETest() { @@ -257,6 +257,30 @@ class RoomMockE2ETest : MockE2ETest() {
257 } 257 }
258 258
259 @Test 259 @Test
  260 + fun disconnectCleansLocalParticipant() = runTest {
  261 + connect()
  262 +
  263 + val publishJob = launch {
  264 + room.localParticipant.publishAudioTrack(
  265 + LocalAudioTrack(
  266 + "",
  267 + MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
  268 + )
  269 + )
  270 + }
  271 + wsFactory.listener.onMessage(wsFactory.ws, SignalClientTest.LOCAL_TRACK_PUBLISHED.toOkioByteString())
  272 + publishJob.join()
  273 +
  274 + val eventCollector = EventCollector(room.events, coroutineRule.scope)
  275 + room.disconnect()
  276 + val events = eventCollector.stopCollecting()
  277 +
  278 + Assert.assertEquals(2, events.size)
  279 + Assert.assertEquals(true, events[0] is RoomEvent.TrackUnpublished)
  280 + Assert.assertEquals(true, events[1] is RoomEvent.Disconnected)
  281 + }
  282 +
  283 + @Test
260 fun reconnectAfterDisconnect() = runTest { 284 fun reconnectAfterDisconnect() = runTest {
261 connect() 285 connect()
262 room.disconnect() 286 room.disconnect()
@@ -6,7 +6,7 @@ import androidx.test.core.app.ApplicationProvider @@ -6,7 +6,7 @@ import androidx.test.core.app.ApplicationProvider
6 import io.livekit.android.coroutines.TestCoroutineRule 6 import io.livekit.android.coroutines.TestCoroutineRule
7 import io.livekit.android.events.EventCollector 7 import io.livekit.android.events.EventCollector
8 import io.livekit.android.events.RoomEvent 8 import io.livekit.android.events.RoomEvent
9 -import io.livekit.android.mock.MockEglBase 9 +import io.livekit.android.mock.*
10 import io.livekit.android.room.participant.LocalParticipant 10 import io.livekit.android.room.participant.LocalParticipant
11 import kotlinx.coroutines.ExperimentalCoroutinesApi 11 import kotlinx.coroutines.ExperimentalCoroutinesApi
12 import kotlinx.coroutines.test.runTest 12 import kotlinx.coroutines.test.runTest
@@ -107,4 +107,32 @@ class RoomTest { @@ -107,4 +107,32 @@ class RoomTest {
107 Assert.assertEquals(1, events.size) 107 Assert.assertEquals(1, events.size)
108 Assert.assertEquals(true, events[0] is RoomEvent.Disconnected) 108 Assert.assertEquals(true, events[0] is RoomEvent.Disconnected)
109 } 109 }
  110 +
  111 + @Test
  112 + fun disconnectCleansUpParticipants() = runTest {
  113 + connect()
  114 +
  115 + room.onUpdateParticipants(SignalClientTest.PARTICIPANT_JOIN.update.participantsList)
  116 + room.onAddTrack(
  117 + MockAudioStreamTrack(),
  118 + arrayOf(
  119 + MockMediaStream(
  120 + id = createMediaStreamId(
  121 + TestData.REMOTE_PARTICIPANT.sid,
  122 + TestData.REMOTE_AUDIO_TRACK.sid
  123 + )
  124 + )
  125 + )
  126 + )
  127 +
  128 + val eventCollector = EventCollector(room.events, coroutineRule.scope)
  129 + room.onEngineDisconnected("")
  130 + val events = eventCollector.stopCollecting()
  131 +
  132 + Assert.assertEquals(4, events.size)
  133 + Assert.assertEquals(true, events[0] is RoomEvent.TrackUnsubscribed)
  134 + Assert.assertEquals(true, events[1] is RoomEvent.TrackUnpublished)
  135 + Assert.assertEquals(true, events[2] is RoomEvent.ParticipantDisconnected)
  136 + Assert.assertEquals(true, events[3] is RoomEvent.Disconnected)
  137 + }
110 } 138 }
@@ -192,9 +192,10 @@ class SignalClientTest : BaseTest() { @@ -192,9 +192,10 @@ class SignalClientTest : BaseTest() {
192 build() 192 build()
193 } 193 }
194 194
195 - val TRACK_PUBLISHED = with(LivekitRtc.SignalResponse.newBuilder()) { 195 + val LOCAL_TRACK_PUBLISHED = with(LivekitRtc.SignalResponse.newBuilder()) {
196 trackPublished = with(trackPublishedBuilder) { 196 trackPublished = with(trackPublishedBuilder) {
197 - track = TestData.REMOTE_AUDIO_TRACK 197 + cid = "local_cid"
  198 + track = TestData.LOCAL_AUDIO_TRACK
198 build() 199 build()
199 } 200 }
200 build() 201 build()
1 package io.livekit.android.composesample 1 package io.livekit.android.composesample
2 2
3 -import androidx.compose.foundation.layout.fillMaxSize  
4 import androidx.compose.runtime.* 3 import androidx.compose.runtime.*
5 import androidx.compose.ui.Modifier 4 import androidx.compose.ui.Modifier
6 import androidx.compose.ui.layout.onGloballyPositioned 5 import androidx.compose.ui.layout.onGloballyPositioned
@@ -55,6 +54,12 @@ fun VideoItem( @@ -55,6 +54,12 @@ fun VideoItem(
55 } 54 }
56 } 55 }
57 56
  57 + DisposableEffect(currentCompositeKeyHash.toString()) {
  58 + onDispose {
  59 + view?.release()
  60 + }
  61 + }
  62 +
58 AndroidView( 63 AndroidView(
59 factory = { context -> 64 factory = { context ->
60 TextureViewRenderer(context).apply { 65 TextureViewRenderer(context).apply {
1 <resources> 1 <resources>
2 - <string name="app_name">Sample Compose</string> 2 + <string name="app_name">Livekit Compose Sample</string>
3 </resources> 3 </resources>