davidliu
Committed by GitHub

Fix memory leak caused by disconnecting before connect finished (#386)

* State locking for Room and RTC engine around critical spots

* Cancel connect job if invoking coroutine is cancelled

* cleanup

* Clean up test logs

* revert stress test changes to sample apps
... ... @@ -144,7 +144,7 @@ constructor(
restartingIce = true
}
if (this.peerConnection.signalingState() == SignalingState.HAVE_LOCAL_OFFER) {
if (peerConnection.signalingState() == SignalingState.HAVE_LOCAL_OFFER) {
// we're waiting for the peer to accept our offer, so we'll just wait
// the only exception to this is when ICE restart is needed
val curSd = peerConnection.remoteDescription
... ... @@ -313,7 +313,7 @@ constructor(
}
@OptIn(ExperimentalContracts::class)
private suspend inline fun <T> launchRTCIfNotClosed(noinline action: suspend () -> T): T? {
private suspend inline fun <T> launchRTCIfNotClosed(noinline action: suspend CoroutineScope.() -> T): T? {
contract { callsInPlace(action, InvocationKind.AT_MOST_ONCE) }
if (isClosed()) {
return null
... ...
... ... @@ -35,13 +35,17 @@ import io.livekit.android.util.FlowObservable
import io.livekit.android.util.LKLog
import io.livekit.android.util.flowDelegate
import io.livekit.android.util.nullSafe
import io.livekit.android.util.withCheckLock
import io.livekit.android.webrtc.RTCStatsGetter
import io.livekit.android.webrtc.copy
import io.livekit.android.webrtc.isConnected
import io.livekit.android.webrtc.isDisconnected
import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
import io.livekit.android.webrtc.peerconnection.launchBlockingOnRTCThread
import io.livekit.android.webrtc.toProtoSessionDescription
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import livekit.LivekitModels
import livekit.LivekitRtc
import livekit.LivekitRtc.JoinResponse
... ... @@ -134,6 +138,12 @@ internal constructor(
private var coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
/**
* Note: If this lock is ever used in conjunction with the RTC thread,
* this must be grabbed on the RTC thread to prevent deadlocks.
*/
private var configurationLock = Mutex()
init {
client.listener = this
}
... ... @@ -158,8 +168,10 @@ internal constructor(
token: String,
options: ConnectOptions,
roomOptions: RoomOptions,
): JoinResponse {
): JoinResponse = coroutineScope {
val joinResponse = client.join(url, token, options, roomOptions)
ensureActive()
listener?.onJoinResponse(joinResponse)
isClosed = false
listener?.onSignalConnected(false)
... ... @@ -169,19 +181,25 @@ internal constructor(
configure(joinResponse, options)
// create offer
if (!this.isSubscriberPrimary) {
if (!isSubscriberPrimary) {
negotiatePublisher()
}
client.onReadyForResponses()
return joinResponse
return@coroutineScope joinResponse
}
private suspend fun configure(joinResponse: JoinResponse, connectOptions: ConnectOptions) {
launchBlockingOnRTCThread {
configurationLock.withCheckLock(
{
ensureActive()
if (publisher != null && subscriber != null) {
// already configured
return
return@launchBlockingOnRTCThread
}
},
) {
participantSid = if (joinResponse.hasParticipant()) {
joinResponse.participant.sid
} else {
... ... @@ -235,6 +253,7 @@ internal constructor(
publisherObserver.connectionChangeListener = connectionStateListener
}
ensureActive()
// data channels
val reliableInit = DataChannel.Init()
reliableInit.ordered = true
... ... @@ -247,6 +266,7 @@ internal constructor(
}
}
ensureActive()
val lossyInit = DataChannel.Init()
lossyInit.ordered = true
lossyInit.maxRetransmits = 0
... ... @@ -259,6 +279,8 @@ internal constructor(
}
}
}
}
}
/**
* @param builder an optional builder to include other parameters related to the track
... ... @@ -327,6 +349,8 @@ internal constructor(
private fun closeResources(reason: String) {
executeBlockingOnRTCThread {
runBlocking {
configurationLock.withLock {
publisherObserver.connectionChangeListener = null
subscriberObserver.connectionChangeListener = null
publisher?.closeBlocking()
... ... @@ -339,6 +363,7 @@ internal constructor(
this?.close()
this?.dispose()
}
reliableDataChannel?.completeDispose()
reliableDataChannel = null
reliableDataChannelSub?.completeDispose()
... ... @@ -349,6 +374,8 @@ internal constructor(
lossyDataChannelSub = null
isSubscriberPrimary = false
}
}
}
client.close(reason = reason)
}
... ...
... ... @@ -49,6 +49,8 @@ import io.livekit.android.webrtc.getFilteredStats
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.Serializable
import livekit.LivekitModels
import livekit.LivekitRtc
... ... @@ -243,6 +245,8 @@ constructor(
private var hasLostConnectivity: Boolean = false
private var connectOptions: ConnectOptions = ConnectOptions()
private var stateLock = Mutex()
private fun getCurrentRoomOptions(): RoomOptions =
RoomOptions(
adaptiveStream = adaptiveStream,
... ... @@ -260,15 +264,32 @@ constructor(
* @param url
* @param token
* @param options
*
* @throws IllegalStateException when connect is attempted while the room is not disconnected.
* @throws Exception when connection fails
*/
@Throws(Exception::class)
suspend fun connect(url: String, token: String, options: ConnectOptions = ConnectOptions()) {
if (this::coroutineScope.isInitialized) {
suspend fun connect(url: String, token: String, options: ConnectOptions = ConnectOptions()) = coroutineScope {
if (state != State.DISCONNECTED) {
throw IllegalStateException("Room.connect attempted while room is not disconnected!")
}
val roomOptions: RoomOptions
stateLock.withLock {
if (state != State.DISCONNECTED) {
throw IllegalStateException("Room.connect attempted while room is not disconnected!")
}
if (::coroutineScope.isInitialized) {
val job = coroutineScope.coroutineContext.job
coroutineScope.cancel()
job.join()
}
state = State.CONNECTING
connectOptions = options
coroutineScope = CoroutineScope(defaultDispatcher + SupervisorJob())
val roomOptions = getCurrentRoomOptions()
roomOptions = getCurrentRoomOptions()
// Setup local participant.
localParticipant.reinitialize()
... ... @@ -319,9 +340,6 @@ constructor(
}
}
state = State.CONNECTING
connectOptions = options
if (roomOptions.e2eeOptions != null) {
e2eeManager = e2EEManagerFactory.create(roomOptions.e2eeOptions.keyProvider).apply {
setup(this@Room) { event ->
... ... @@ -331,7 +349,14 @@ constructor(
}
}
}
}
// Use an empty coroutineExceptionHandler since we want to
// rethrow all throwables from the connect job.
val emptyCoroutineExceptionHandler = CoroutineExceptionHandler { _, _ -> }
val connectJob = coroutineScope.launch(
ioDispatcher + emptyCoroutineExceptionHandler,
) {
engine.join(url, token, options, roomOptions)
val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
val networkRequest = NetworkRequest.Builder()
... ... @@ -339,16 +364,35 @@ constructor(
.build()
cm.registerNetworkCallback(networkRequest, networkCallback)
ensureActive()
if (options.audio) {
val audioTrack = localParticipant.createAudioTrack()
localParticipant.publishAudioTrack(audioTrack)
}
ensureActive()
if (options.video) {
val videoTrack = localParticipant.createVideoTrack()
localParticipant.publishVideoTrack(videoTrack)
}
}
val outerHandler = coroutineContext.job.invokeOnCompletion { cause ->
// Cancel connect job if invoking coroutine is cancelled.
if (cause is CancellationException) {
connectJob.cancel(cause)
}
}
var error: Throwable? = null
connectJob.invokeOnCompletion { cause ->
outerHandler.dispose()
error = cause
}
connectJob.join()
error?.let { throw it }
}
/**
* Disconnect from the room.
*/
... ... @@ -592,28 +636,15 @@ constructor(
engine.reconnect()
}
/**
* Removes all participants and tracks from the room.
*/
private fun cleanupRoom() {
e2eeManager?.cleanUp()
e2eeManager = null
localParticipant.cleanup()
remoteParticipants.keys.toMutableSet() // copy keys to avoid concurrent modifications.
.forEach { sid -> handleParticipantDisconnect(sid) }
sid = null
metadata = null
name = null
isRecording = false
sidToIdentity.clear()
}
private fun handleDisconnect(reason: DisconnectReason) {
if (state == State.DISCONNECTED) {
return
}
runBlocking {
stateLock.withLock {
if (state == State.DISCONNECTED) {
return@runBlocking
}
try {
val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
cm.unregisterNetworkCallback(networkCallback)
... ... @@ -628,11 +659,28 @@ constructor(
localParticipant.dispose()
// Ensure all observers see the disconnected before closing scope.
runBlocking {
eventBus.postEvent(RoomEvent.Disconnected(this@Room, null, reason), coroutineScope).join()
}
coroutineScope.cancel()
}
}
}
/**
* Removes all participants and tracks from the room.
*/
private fun cleanupRoom() {
e2eeManager?.cleanUp()
e2eeManager = null
localParticipant.cleanup()
remoteParticipants.keys.toMutableSet() // copy keys to avoid concurrent modifications.
.forEach { sid -> handleParticipantDisconnect(sid) }
sid = null
metadata = null
name = null
isRecording = false
sidToIdentity.clear()
}
private fun sendSyncState() {
// Whether we're sending subscribed tracks or tracks to unsubscribe.
... ...
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.util
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
/**
* Applies a double-checked lock before running [action].
*/
suspend inline fun <T> Mutex.withCheckLock(check: () -> Unit, action: () -> T): T {
check()
return withLock {
check()
action()
}
}
... ...
/*
* Copyright 2023 LiveKit, Inc.
* Copyright 2023-2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -18,6 +18,7 @@ package io.livekit.android.webrtc.peerconnection
import androidx.annotation.VisibleForTesting
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
... ... @@ -41,7 +42,7 @@ private val threadFactory = object : ThreadFactory {
}
}
// var only for testing purposes, do not alter!
// var only for testing purposes, do not alter in production!
private var executor = Executors.newSingleThreadExecutor(threadFactory)
private var rtcDispatcher: CoroutineDispatcher = executor.asCoroutineDispatcher()
... ... @@ -82,12 +83,12 @@ fun <T> executeBlockingOnRTCThread(action: () -> T): T {
* is generally not thread safe, so all actions relating to
* peer connection objects should go through the RTC thread.
*/
suspend fun <T> launchBlockingOnRTCThread(action: suspend () -> T): T = coroutineScope {
suspend fun <T> launchBlockingOnRTCThread(action: suspend CoroutineScope.() -> T): T = coroutineScope {
return@coroutineScope if (Thread.currentThread().name.startsWith(EXECUTOR_THREADNAME_PREFIX)) {
action()
this.action()
} else {
async(rtcDispatcher) {
action()
this.action()
}.await()
}
}
... ...