davidliu
Committed by GitHub

full reconnect support (#45)

* properly clean up texture view renderer when done

* tests

* full reconnect if regular reconnect doesn't work

* small tweak

* proper join message handling

* fix republishing
... ... @@ -28,6 +28,7 @@ import kotlin.coroutines.suspendCoroutine
/**
* @suppress
*/
@OptIn(ExperimentalCoroutinesApi::class)
@Singleton
class RTCEngine
@Inject
... ... @@ -75,6 +76,7 @@ internal constructor(
mutableMapOf()
private var sessionUrl: String? = null
private var sessionToken: String? = null
private var connectOptions: ConnectOptions? = null
private val publisherObserver = PublisherTransportObserver(this, client)
private val subscriberObserver = SubscriberTransportObserver(this, client)
... ... @@ -113,9 +115,14 @@ internal constructor(
coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
sessionUrl = url
sessionToken = token
return joinImpl(url, token, options)
}
suspend fun joinImpl(url: String, token: String, options: ConnectOptions): LivekitRtc.JoinResponse {
val joinResponse = client.join(url, token, options)
listener?.onJoinResponse(joinResponse)
isClosed = false
listener?.onSignalConnected()
listener?.onSignalConnected(false)
isSubscriberPrimary = joinResponse.subscriberPrimary
... ... @@ -245,6 +252,7 @@ internal constructor(
throw TrackException.DuplicateTrackException("Track with same ID $cid has already been published!")
}
// Suspend until signal client receives message confirming track publication.
return suspendCoroutine { cont ->
pendingTrackResolvers[cid] = cont
client.sendAddTrack(cid, name, kind, builder)
... ... @@ -267,11 +275,31 @@ internal constructor(
return
}
isClosed = true
hasPublished = false
sessionUrl = null
sessionToken = null
connectOptions = null
reconnectingJob?.cancel()
reconnectingJob = null
coroutineScope.close()
closeResources()
}
private fun closeResources() {
connectionState = ConnectionState.DISCONNECTED
_publisher?.close()
_publisher = null
_subscriber?.close()
_subscriber = null
reliableDataChannel?.close()
reliableDataChannel = null
reliableDataChannelSub?.close()
reliableDataChannelSub = null
lossyDataChannel?.close()
lossyDataChannel = null
lossyDataChannelSub?.close()
lossyDataChannelSub = null
isSubscriberPrimary = false
client.close()
}
... ... @@ -293,6 +321,7 @@ internal constructor(
}
val job = coroutineScope.launch {
connectionState = ConnectionState.RECONNECTING
listener?.onEngineReconnecting()
for (wsRetries in 0 until MAX_SIGNAL_RETRIES) {
... ... @@ -302,28 +331,44 @@ internal constructor(
}
LKLog.i { "Reconnecting to signal, attempt ${wsRetries + 1}" }
delay(startDelay)
try {
client.reconnect(url, token)
} catch (e: Exception) {
// ws reconnect failed, retry.
continue
}
LKLog.v { "ws reconnected, restarting ICE" }
listener?.onSignalConnected()
// full reconnect after first try.
val isFullReconnect = true
if (isFullReconnect) {
try {
closeResources()
listener?.onFullReconnecting()
joinImpl(url, token, connectOptions ?: ConnectOptions())
} catch (e: Exception) {
LKLog.w(e) { "Error during reconnection." }
// reconnect failed, retry.
continue
}
} else {
try {
client.reconnect(url, token)
// no join response for regular reconnects
client.onReady()
} catch (e: Exception) {
LKLog.w(e) { "Error during reconnection." }
// ws reconnect failed, retry.
continue
}
subscriber.prepareForIceRestart()
connectionState = ConnectionState.RECONNECTING
// trigger publisher reconnect
// only restart publisher if it's needed
if (hasPublished) {
negotiate()
}
LKLog.v { "ws reconnected, restarting ICE" }
listener?.onSignalConnected(!isFullReconnect)
subscriber.prepareForIceRestart()
// trigger publisher reconnect
// only restart publisher if it's needed
if (hasPublished) {
negotiate()
}
}
// wait until ICE connected
val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS;
val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS
while (SystemClock.elapsedRealtime() < endTime) {
if (connectionState == ConnectionState.CONNECTED) {
LKLog.v { "reconnected to ICE" }
... ... @@ -333,11 +378,13 @@ internal constructor(
}
if (connectionState == ConnectionState.CONNECTED) {
if (isFullReconnect) {
listener?.onFullReconnect()
}
return@launch
}
}
close()
listener?.onEngineDisconnected("failed reconnecting.")
}
... ... @@ -389,7 +436,7 @@ internal constructor(
publisher.peerConnection.iceConnectionState() != PeerConnection.IceConnectionState.CHECKING
) {
// start negotiation
this.negotiate();
this.negotiate()
}
... ... @@ -399,7 +446,7 @@ internal constructor(
}
// wait until publisher ICE connected
val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS;
val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS
while (SystemClock.elapsedRealtime() < endTime) {
if (this.publisher.peerConnection.isConnected() && targetChannel.state() == DataChannel.State.OPEN) {
return
... ... @@ -450,6 +497,7 @@ internal constructor(
fun onEngineReconnecting()
fun onEngineDisconnected(reason: String)
fun onFailToConnect(error: Throwable)
fun onJoinResponse(response: LivekitRtc.JoinResponse)
fun onAddTrack(track: MediaStreamTrack, streams: Array<out MediaStream>)
fun onUpdateParticipants(updates: List<LivekitModels.ParticipantInfo>)
fun onActiveSpeakersUpdate(speakers: List<LivekitModels.SpeakerInfo>)
... ... @@ -461,7 +509,9 @@ internal constructor(
fun onStreamStateUpdate(streamStates: List<LivekitRtc.StreamStateInfo>)
fun onSubscribedQualityUpdate(subscribedQualityUpdate: LivekitRtc.SubscribedQualityUpdate)
fun onSubscriptionPermissionUpdate(subscriptionPermissionUpdate: LivekitRtc.SubscriptionPermissionUpdate)
fun onSignalConnected()
fun onSignalConnected(isReconnect: Boolean)
fun onFullReconnecting()
suspend fun onFullReconnect()
}
companion object {
... ...
@file:Suppress("unused")
package io.livekit.android.room
import android.content.Context
... ... @@ -117,8 +119,12 @@ constructor(
*/
var videoTrackPublishDefaults: VideoTrackPublishDefaults by defaultsManager::videoTrackPublishDefaults
lateinit var localParticipant: LocalParticipant
private set
var _localParticipant: LocalParticipant? = null
val localParticipant: LocalParticipant
get() {
return _localParticipant
?: throw UninitializedPropertyAccessException("localParticipant has not been initialized yet.")
}
private var mutableRemoteParticipants by flowDelegate(emptyMap<String, RemoteParticipant>())
... ... @@ -143,25 +149,8 @@ constructor(
coroutineScope = CoroutineScope(defaultDispatcher + SupervisorJob())
state = State.CONNECTING
connectOptions = options
val response = engine.join(url, token, options)
LKLog.i { "Connected to server, server version: ${response.serverVersion}, client version: ${Version.CLIENT_VERSION}" }
sid = Sid(response.room.sid)
name = response.room.name
engine.join(url, token, options)
if (!response.hasParticipant()) {
listener?.onFailedToConnect(this, RoomException.ConnectException("server didn't return any participants"))
return
}
val lp = localParticipantFactory.create(response.participant, dynacast)
lp.internalListener = this
localParticipant = lp
if (response.otherParticipantsList.isNotEmpty()) {
response.otherParticipantsList.forEach {
getOrCreateRemoteParticipant(it.sid, it)
}
}
val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
val networkRequest = NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
... ... @@ -186,11 +175,38 @@ constructor(
handleDisconnect()
}
override fun onJoinResponse(response: LivekitRtc.JoinResponse) {
LKLog.i { "Connected to server, server version: ${response.serverVersion}, client version: ${Version.CLIENT_VERSION}" }
sid = Sid(response.room.sid)
name = response.room.name
if (!response.hasParticipant()) {
listener?.onFailedToConnect(this, RoomException.ConnectException("server didn't return any participants"))
return
}
if (_localParticipant == null) {
val lp = localParticipantFactory.create(response.participant, dynacast)
lp.internalListener = this
_localParticipant = lp
} else {
localParticipant.updateFromInfo(response.participant)
}
if (response.otherParticipantsList.isNotEmpty()) {
response.otherParticipantsList.forEach {
getOrCreateRemoteParticipant(it.sid, it)
}
}
}
private fun handleParticipantDisconnect(sid: String) {
val newParticipants = mutableRemoteParticipants.toMutableMap()
val removedParticipant = newParticipants.remove(sid) ?: return
removedParticipant.tracks.values.toList().forEach { publication ->
removedParticipant.unpublishTrack(publication.sid)
removedParticipant.unpublishTrack(publication.sid, true)
}
mutableRemoteParticipants = newParticipants
... ... @@ -316,6 +332,15 @@ constructor(
engine.reconnect()
}
/**
* Removes all participants and tracks from the room.
*/
private fun cleanupRoom() {
localParticipant.cleanup()
remoteParticipants.keys.toMutableSet() // copy keys to avoid concurrent modifications.
.forEach { sid -> handleParticipantDisconnect(sid) }
}
private fun handleDisconnect() {
if (state == State.DISCONNECTED) {
return
... ... @@ -328,19 +353,14 @@ constructor(
// do nothing, may happen on older versions if attempting to unregister twice.
}
for (pub in localParticipant.tracks.values) {
pub.track?.stop()
}
// stop remote tracks too
for (p in remoteParticipants.values) {
for (pub in p.tracks.values) {
pub.track?.stop()
}
}
cleanupRoom()
engine.close()
state = State.DISCONNECTED
listener?.onDisconnect(this, null)
listener = null
_localParticipant?.dispose()
_localParticipant = null
// Ensure all observers see the disconnected before closing scope.
runBlocking {
... ... @@ -560,13 +580,21 @@ constructor(
eventBus.tryPostEvent(RoomEvent.FailedToConnect(this, error))
}
override fun onSignalConnected() {
if (state == State.RECONNECTING) {
override fun onSignalConnected(isReconnect: Boolean) {
if (state == State.RECONNECTING && isReconnect) {
// during reconnection, need to send sync state upon signal connection.
sendSyncState()
}
}
override fun onFullReconnecting() {
localParticipant.prepareForFullReconnect()
}
override suspend fun onFullReconnect() {
localParticipant.republishTracks()
}
//------------------------------- ParticipantListener --------------------------------//
/**
* This is called for both Local and Remote participants
... ...
... ... @@ -14,7 +14,6 @@ import io.livekit.android.util.safe
import io.livekit.android.webrtc.toProtoSessionDescription
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.collect
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
... ... @@ -126,7 +125,13 @@ constructor(
}
}
@ExperimentalCoroutinesApi
/**
* Notifies that the downstream consumers of SignalClient are ready to consume messages.
* Until this method is called, any messages received through the websocket are buffered.
*
* Should be called after resolving the join message.
*/
@OptIn(ExperimentalCoroutinesApi::class)
fun onReady() {
coroutineScope.launch {
responseFlow.collect {
... ... @@ -483,6 +488,11 @@ constructor(
}.safe()
}
/**
* Closes out any existing websocket connection, and cleans up used resources.
*
* Can be reused afterwards.
*/
fun close(code: Int = 1000, reason: String = "Normal Closure") {
isConnected = false
if(::coroutineScope.isInitialized) {
... ...
... ... @@ -9,11 +9,13 @@ import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.events.ParticipantEvent
import io.livekit.android.room.ConnectionState
import io.livekit.android.room.DefaultsManager
import io.livekit.android.room.RTCEngine
import io.livekit.android.room.track.*
import io.livekit.android.util.LKLog
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.cancel
import livekit.LivekitModels
import livekit.LivekitRtc
import org.webrtc.EglBase
... ... @@ -58,6 +60,13 @@ internal constructor(
.mapNotNull { it as? LocalTrackPublication }
.toList()
private var isReconnecting = false
/**
* Holds on to publishes that need to be republished after a full reconnect.
*/
private var publishes = mutableMapOf<Track, TrackPublishOptions>()
/**
* Creates an audio track, recording audio through the microphone with the given [options].
*
... ... @@ -189,7 +198,7 @@ internal constructor(
),
publishListener: PublishListener? = null
) {
publishTrackImpl(
val published = publishTrackImpl(
track,
requestConfig = {
disableDtx = !options.dtx
... ... @@ -197,6 +206,10 @@ internal constructor(
},
publishListener = publishListener,
)
if (published) {
publishes[track] = options
}
}
suspend fun publishVideoTrack(
... ... @@ -208,7 +221,7 @@ internal constructor(
val encodings = computeVideoEncodings(track.dimensions, options)
val videoLayers = videoLayersFromEncodings(track.dimensions.width, track.dimensions.height, encodings)
publishTrackImpl(
val published = publishTrackImpl(
track,
requestConfig = {
width = track.dimensions.width
... ... @@ -223,18 +236,25 @@ internal constructor(
encodings = encodings,
publishListener = publishListener
)
if (published) {
publishes[track] = options
}
}
/**
* @return true if the track publish was successful.
*/
private suspend fun publishTrackImpl(
track: Track,
requestConfig: LivekitRtc.AddTrackRequest.Builder.() -> Unit,
encodings: List<RtpParameters.Encoding> = emptyList(),
publishListener: PublishListener? = null
) {
): Boolean {
if (localTrackPublications.any { it.track == track }) {
publishListener?.onPublishFailure(TrackException.PublishException("Track has already been published"))
return
return false
}
val cid = track.rtcTrack.id()
... ... @@ -264,16 +284,19 @@ internal constructor(
if (transceiver == null) {
publishListener?.onPublishFailure(TrackException.PublishException("null sender returned from peer connection"))
return
return false
}
// TODO: enable setting preferred codec
val publication = LocalTrackPublication(trackInfo, track, this)
addTrackPublication(publication)
publishListener?.onPublishSuccess(publication)
internalListener?.onTrackPublished(publication, this)
eventBus.postEvent(ParticipantEvent.LocalTrackPublished(this, publication), scope)
return true
}
private fun computeVideoEncodings(
... ... @@ -451,14 +474,19 @@ internal constructor(
LKLog.d { "this track was never published." }
return
}
publishes.remove(track)
val sid = publication.sid
tracks = tracks.toMutableMap().apply { remove(sid) }
val senders = engine.publisher.peerConnection.senders ?: return
for (sender in senders) {
val t = sender.track() ?: continue
if (t.id() == track.rtcTrack.id()) {
engine.publisher.peerConnection.removeTrack(sender)
if (engine.connectionState == ConnectionState.CONNECTED) {
val senders = engine.publisher.peerConnection.senders
for (sender in senders) {
val t = sender.track() ?: continue
if (t.id() == track.rtcTrack.id()) {
engine.publisher.peerConnection.removeTrack(sender)
}
}
}
track.stop()
... ... @@ -555,6 +583,41 @@ internal constructor(
}
}
fun prepareForFullReconnect() {
val pubs = localTrackPublications // creates a copy, so is safe from the following removal.
tracks = tracks.toMutableMap().apply { clear() }
for (publication in pubs) {
internalListener?.onTrackUnpublished(publication, this)
eventBus.postEvent(ParticipantEvent.LocalTrackUnpublished(this, publication), scope)
}
}
suspend fun republishTracks() {
for ((track, options) in publishes) {
when (track) {
is LocalAudioTrack -> publishAudioTrack(track, options as AudioTrackPublishOptions, null)
is LocalVideoTrack -> publishVideoTrack(track, options as VideoTrackPublishOptions, null)
else -> throw IllegalStateException("LocalParticipant has a non local track publish?")
}
}
}
fun cleanup() {
for (pub in tracks.values) {
val track = pub.track
if (track != null) {
track.stop()
unpublishTrack(track)
}
}
}
fun dispose() {
cleanup()
scope.cancel()
}
interface PublishListener {
fun onPublishSuccess(publication: TrackPublication) {}
... ... @@ -678,4 +741,16 @@ data class ParticipantTrackPermission(
.addAllTrackSids(allowedTrackSids)
.build()
}
}
sealed class PublishRecord() {
data class AudioTrackPublishRecord(
val track: LocalAudioTrack,
val options: AudioTrackPublishOptions
)
data class VideoTrackPublishRecord(
val track: LocalVideoTrack,
val options: VideoTrackPublishOptions
)
}
\ No newline at end of file
... ...
... ... @@ -32,4 +32,10 @@ class MockDataChannel(private val label: String?) : DataChannel(1L) {
override fun send(buffer: Buffer?): Boolean {
return true
}
override fun close() {
}
override fun dispose() {
}
}
\ No newline at end of file
... ...
... ... @@ -74,16 +74,16 @@ class MockPeerConnection(
return super.createSender(kind, stream_id)
}
override fun getSenders(): MutableList<RtpSender> {
return super.getSenders()
override fun getSenders(): List<RtpSender> {
return emptyList()
}
override fun getReceivers(): MutableList<RtpReceiver> {
return super.getReceivers()
override fun getReceivers(): List<RtpReceiver> {
return emptyList()
}
override fun getTransceivers(): MutableList<RtpTransceiver> {
return super.getTransceivers()
override fun getTransceivers(): List<RtpTransceiver> {
return emptyList()
}
override fun addTrack(track: MediaStreamTrack?): RtpSender {
... ... @@ -103,10 +103,10 @@ class MockPeerConnection(
}
override fun addTransceiver(
track: MediaStreamTrack?,
track: MediaStreamTrack,
init: RtpTransceiver.RtpTransceiverInit?
): RtpTransceiver {
return super.addTransceiver(track, init)
return MockRtpTransceiver.create(track, init ?: RtpTransceiver.RtpTransceiverInit())
}
override fun addTransceiver(mediaType: MediaStreamTrack.MediaType?): RtpTransceiver {
... ...
package io.livekit.android.mock
import org.mockito.Mockito
import org.webrtc.MediaStreamTrack
import org.webrtc.RtpTransceiver
object MockRtpTransceiver {
fun create(
track: MediaStreamTrack,
init: RtpTransceiver.RtpTransceiverInit = RtpTransceiver.RtpTransceiverInit()
): RtpTransceiver {
val mock = Mockito.mock(RtpTransceiver::class.java)
Mockito.`when`(mock.mediaType).then {
return@then when (track.kind()) {
MediaStreamTrack.AUDIO_TRACK_KIND -> MediaStreamTrack.MediaType.MEDIA_TYPE_AUDIO
MediaStreamTrack.VIDEO_TRACK_KIND -> MediaStreamTrack.MediaType.MEDIA_TYPE_VIDEO
else -> throw IllegalStateException("illegal kind: ${track.kind()}")
}
}
return mock
}
}
\ No newline at end of file
... ...
... ... @@ -4,6 +4,12 @@ import livekit.LivekitModels
object TestData {
val LOCAL_AUDIO_TRACK = with(LivekitModels.TrackInfo.newBuilder()) {
sid = "local_audio_track_sid"
type = LivekitModels.TrackType.AUDIO
build()
}
val REMOTE_AUDIO_TRACK = with(LivekitModels.TrackInfo.newBuilder()) {
sid = "remote_audio_track_sid"
type = LivekitModels.TrackType.AUDIO
... ...
... ... @@ -10,8 +10,8 @@ import io.livekit.android.mock.MockMediaStream
import io.livekit.android.mock.TestData
import io.livekit.android.mock.createMediaStreamId
import io.livekit.android.room.participant.ConnectionQuality
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.room.track.Track
import io.livekit.android.util.delegate
import io.livekit.android.util.flow
import io.livekit.android.util.toOkioByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
... ... @@ -257,6 +257,30 @@ class RoomMockE2ETest : MockE2ETest() {
}
@Test
fun disconnectCleansLocalParticipant() = runTest {
connect()
val publishJob = launch {
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
)
)
}
wsFactory.listener.onMessage(wsFactory.ws, SignalClientTest.LOCAL_TRACK_PUBLISHED.toOkioByteString())
publishJob.join()
val eventCollector = EventCollector(room.events, coroutineRule.scope)
room.disconnect()
val events = eventCollector.stopCollecting()
Assert.assertEquals(2, events.size)
Assert.assertEquals(true, events[0] is RoomEvent.TrackUnpublished)
Assert.assertEquals(true, events[1] is RoomEvent.Disconnected)
}
@Test
fun reconnectAfterDisconnect() = runTest {
connect()
room.disconnect()
... ...
... ... @@ -6,7 +6,7 @@ import androidx.test.core.app.ApplicationProvider
import io.livekit.android.coroutines.TestCoroutineRule
import io.livekit.android.events.EventCollector
import io.livekit.android.events.RoomEvent
import io.livekit.android.mock.MockEglBase
import io.livekit.android.mock.*
import io.livekit.android.room.participant.LocalParticipant
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runTest
... ... @@ -107,4 +107,32 @@ class RoomTest {
Assert.assertEquals(1, events.size)
Assert.assertEquals(true, events[0] is RoomEvent.Disconnected)
}
@Test
fun disconnectCleansUpParticipants() = runTest {
connect()
room.onUpdateParticipants(SignalClientTest.PARTICIPANT_JOIN.update.participantsList)
room.onAddTrack(
MockAudioStreamTrack(),
arrayOf(
MockMediaStream(
id = createMediaStreamId(
TestData.REMOTE_PARTICIPANT.sid,
TestData.REMOTE_AUDIO_TRACK.sid
)
)
)
)
val eventCollector = EventCollector(room.events, coroutineRule.scope)
room.onEngineDisconnected("")
val events = eventCollector.stopCollecting()
Assert.assertEquals(4, events.size)
Assert.assertEquals(true, events[0] is RoomEvent.TrackUnsubscribed)
Assert.assertEquals(true, events[1] is RoomEvent.TrackUnpublished)
Assert.assertEquals(true, events[2] is RoomEvent.ParticipantDisconnected)
Assert.assertEquals(true, events[3] is RoomEvent.Disconnected)
}
}
\ No newline at end of file
... ...
... ... @@ -192,9 +192,10 @@ class SignalClientTest : BaseTest() {
build()
}
val TRACK_PUBLISHED = with(LivekitRtc.SignalResponse.newBuilder()) {
val LOCAL_TRACK_PUBLISHED = with(LivekitRtc.SignalResponse.newBuilder()) {
trackPublished = with(trackPublishedBuilder) {
track = TestData.REMOTE_AUDIO_TRACK
cid = "local_cid"
track = TestData.LOCAL_AUDIO_TRACK
build()
}
build()
... ...
package io.livekit.android.composesample
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.runtime.*
import androidx.compose.ui.Modifier
import androidx.compose.ui.layout.onGloballyPositioned
... ... @@ -55,6 +54,12 @@ fun VideoItem(
}
}
DisposableEffect(currentCompositeKeyHash.toString()) {
onDispose {
view?.release()
}
}
AndroidView(
factory = { context ->
TextureViewRenderer(context).apply {
... ...
<resources>
<string name="app_name">Sample Compose</string>
<string name="app_name">Livekit Compose Sample</string>
</resources>
\ No newline at end of file
... ...