davidliu
Committed by GitHub

Fix memory leak caused by disconnecting before connect finished (#386)

* State locking for Room and RTC engine around critical spots

* Cancel connect job if invoking coroutine is cancelled

* cleanup

* Clean up test logs

* revert stress test changes to sample apps
... ... @@ -144,7 +144,7 @@ constructor(
restartingIce = true
}
if (this.peerConnection.signalingState() == SignalingState.HAVE_LOCAL_OFFER) {
if (peerConnection.signalingState() == SignalingState.HAVE_LOCAL_OFFER) {
// we're waiting for the peer to accept our offer, so we'll just wait
// the only exception to this is when ICE restart is needed
val curSd = peerConnection.remoteDescription
... ... @@ -313,7 +313,7 @@ constructor(
}
@OptIn(ExperimentalContracts::class)
private suspend inline fun <T> launchRTCIfNotClosed(noinline action: suspend () -> T): T? {
private suspend inline fun <T> launchRTCIfNotClosed(noinline action: suspend CoroutineScope.() -> T): T? {
contract { callsInPlace(action, InvocationKind.AT_MOST_ONCE) }
if (isClosed()) {
return null
... ...
... ... @@ -35,13 +35,17 @@ import io.livekit.android.util.FlowObservable
import io.livekit.android.util.LKLog
import io.livekit.android.util.flowDelegate
import io.livekit.android.util.nullSafe
import io.livekit.android.util.withCheckLock
import io.livekit.android.webrtc.RTCStatsGetter
import io.livekit.android.webrtc.copy
import io.livekit.android.webrtc.isConnected
import io.livekit.android.webrtc.isDisconnected
import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
import io.livekit.android.webrtc.peerconnection.launchBlockingOnRTCThread
import io.livekit.android.webrtc.toProtoSessionDescription
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import livekit.LivekitModels
import livekit.LivekitRtc
import livekit.LivekitRtc.JoinResponse
... ... @@ -134,6 +138,12 @@ internal constructor(
private var coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
/**
* Note: If this lock is ever used in conjunction with the RTC thread,
* this must be grabbed on the RTC thread to prevent deadlocks.
*/
private var configurationLock = Mutex()
init {
client.listener = this
}
... ... @@ -158,8 +168,10 @@ internal constructor(
token: String,
options: ConnectOptions,
roomOptions: RoomOptions,
): JoinResponse {
): JoinResponse = coroutineScope {
val joinResponse = client.join(url, token, options, roomOptions)
ensureActive()
listener?.onJoinResponse(joinResponse)
isClosed = false
listener?.onSignalConnected(false)
... ... @@ -169,93 +181,103 @@ internal constructor(
configure(joinResponse, options)
// create offer
if (!this.isSubscriberPrimary) {
if (!isSubscriberPrimary) {
negotiatePublisher()
}
client.onReadyForResponses()
return joinResponse
return@coroutineScope joinResponse
}
private suspend fun configure(joinResponse: JoinResponse, connectOptions: ConnectOptions) {
if (publisher != null && subscriber != null) {
// already configured
return
}
launchBlockingOnRTCThread {
configurationLock.withCheckLock(
{
ensureActive()
if (publisher != null && subscriber != null) {
// already configured
return@launchBlockingOnRTCThread
}
},
) {
participantSid = if (joinResponse.hasParticipant()) {
joinResponse.participant.sid
} else {
null
}
participantSid = if (joinResponse.hasParticipant()) {
joinResponse.participant.sid
} else {
null
}
// Setup peer connections
val rtcConfig = makeRTCConfig(Either.Left(joinResponse), connectOptions)
// Setup peer connections
val rtcConfig = makeRTCConfig(Either.Left(joinResponse), connectOptions)
publisher?.close()
publisher = pctFactory.create(
rtcConfig,
publisherObserver,
publisherObserver,
)
subscriber?.close()
subscriber = pctFactory.create(
rtcConfig,
subscriberObserver,
null,
)
publisher?.close()
publisher = pctFactory.create(
rtcConfig,
publisherObserver,
publisherObserver,
)
subscriber?.close()
subscriber = pctFactory.create(
rtcConfig,
subscriberObserver,
null,
)
val connectionStateListener: (PeerConnection.PeerConnectionState) -> Unit = { newState ->
LKLog.v { "onIceConnection new state: $newState" }
if (newState.isConnected()) {
connectionState = ConnectionState.CONNECTED
} else if (newState.isDisconnected()) {
connectionState = ConnectionState.DISCONNECTED
}
}
val connectionStateListener: (PeerConnection.PeerConnectionState) -> Unit = { newState ->
LKLog.v { "onIceConnection new state: $newState" }
if (newState.isConnected()) {
connectionState = ConnectionState.CONNECTED
} else if (newState.isDisconnected()) {
connectionState = ConnectionState.DISCONNECTED
}
}
if (joinResponse.subscriberPrimary) {
// in subscriber primary mode, server side opens sub data channels.
subscriberObserver.dataChannelListener = onDataChannel@{ dataChannel: DataChannel ->
when (dataChannel.label()) {
RELIABLE_DATA_CHANNEL_LABEL -> reliableDataChannelSub = dataChannel
LOSSY_DATA_CHANNEL_LABEL -> lossyDataChannelSub = dataChannel
else -> return@onDataChannel
}
dataChannel.registerObserver(DataChannelObserver(dataChannel))
}
if (joinResponse.subscriberPrimary) {
// in subscriber primary mode, server side opens sub data channels.
subscriberObserver.dataChannelListener = onDataChannel@{ dataChannel: DataChannel ->
when (dataChannel.label()) {
RELIABLE_DATA_CHANNEL_LABEL -> reliableDataChannelSub = dataChannel
LOSSY_DATA_CHANNEL_LABEL -> lossyDataChannelSub = dataChannel
else -> return@onDataChannel
subscriberObserver.connectionChangeListener = connectionStateListener
// Also reconnect on publisher disconnect
publisherObserver.connectionChangeListener = { newState ->
if (newState.isDisconnected()) {
reconnect()
}
}
} else {
publisherObserver.connectionChangeListener = connectionStateListener
}
dataChannel.registerObserver(DataChannelObserver(dataChannel))
}
subscriberObserver.connectionChangeListener = connectionStateListener
// Also reconnect on publisher disconnect
publisherObserver.connectionChangeListener = { newState ->
if (newState.isDisconnected()) {
reconnect()
ensureActive()
// data channels
val reliableInit = DataChannel.Init()
reliableInit.ordered = true
reliableDataChannel = publisher?.withPeerConnection {
createDataChannel(
RELIABLE_DATA_CHANNEL_LABEL,
reliableInit,
).also { dataChannel ->
dataChannel.registerObserver(DataChannelObserver(dataChannel))
}
}
}
} else {
publisherObserver.connectionChangeListener = connectionStateListener
}
// data channels
val reliableInit = DataChannel.Init()
reliableInit.ordered = true
reliableDataChannel = publisher?.withPeerConnection {
createDataChannel(
RELIABLE_DATA_CHANNEL_LABEL,
reliableInit,
).also { dataChannel ->
dataChannel.registerObserver(DataChannelObserver(dataChannel))
}
}
val lossyInit = DataChannel.Init()
lossyInit.ordered = true
lossyInit.maxRetransmits = 0
lossyDataChannel = publisher?.withPeerConnection {
createDataChannel(
LOSSY_DATA_CHANNEL_LABEL,
lossyInit,
).also { dataChannel ->
dataChannel.registerObserver(DataChannelObserver(dataChannel))
ensureActive()
val lossyInit = DataChannel.Init()
lossyInit.ordered = true
lossyInit.maxRetransmits = 0
lossyDataChannel = publisher?.withPeerConnection {
createDataChannel(
LOSSY_DATA_CHANNEL_LABEL,
lossyInit,
).also { dataChannel ->
dataChannel.registerObserver(DataChannelObserver(dataChannel))
}
}
}
}
}
... ... @@ -327,27 +349,32 @@ internal constructor(
private fun closeResources(reason: String) {
executeBlockingOnRTCThread {
publisherObserver.connectionChangeListener = null
subscriberObserver.connectionChangeListener = null
publisher?.closeBlocking()
publisher = null
subscriber?.closeBlocking()
subscriber = null
fun DataChannel?.completeDispose() {
this?.unregisterObserver()
this?.close()
this?.dispose()
runBlocking {
configurationLock.withLock {
publisherObserver.connectionChangeListener = null
subscriberObserver.connectionChangeListener = null
publisher?.closeBlocking()
publisher = null
subscriber?.closeBlocking()
subscriber = null
fun DataChannel?.completeDispose() {
this?.unregisterObserver()
this?.close()
this?.dispose()
}
reliableDataChannel?.completeDispose()
reliableDataChannel = null
reliableDataChannelSub?.completeDispose()
reliableDataChannelSub = null
lossyDataChannel?.completeDispose()
lossyDataChannel = null
lossyDataChannelSub?.completeDispose()
lossyDataChannelSub = null
isSubscriberPrimary = false
}
}
reliableDataChannel?.completeDispose()
reliableDataChannel = null
reliableDataChannelSub?.completeDispose()
reliableDataChannelSub = null
lossyDataChannel?.completeDispose()
lossyDataChannel = null
lossyDataChannelSub?.completeDispose()
lossyDataChannelSub = null
isSubscriberPrimary = false
}
client.close(reason = reason)
}
... ...
... ... @@ -49,6 +49,8 @@ import io.livekit.android.webrtc.getFilteredStats
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.Serializable
import livekit.LivekitModels
import livekit.LivekitRtc
... ... @@ -243,6 +245,8 @@ constructor(
private var hasLostConnectivity: Boolean = false
private var connectOptions: ConnectOptions = ConnectOptions()
private var stateLock = Mutex()
private fun getCurrentRoomOptions(): RoomOptions =
RoomOptions(
adaptiveStream = adaptiveStream,
... ... @@ -260,93 +264,133 @@ constructor(
* @param url
* @param token
* @param options
*
* @throws IllegalStateException when connect is attempted while the room is not disconnected.
* @throws Exception when connection fails
*/
@Throws(Exception::class)
suspend fun connect(url: String, token: String, options: ConnectOptions = ConnectOptions()) {
if (this::coroutineScope.isInitialized) {
coroutineScope.cancel()
suspend fun connect(url: String, token: String, options: ConnectOptions = ConnectOptions()) = coroutineScope {
if (state != State.DISCONNECTED) {
throw IllegalStateException("Room.connect attempted while room is not disconnected!")
}
coroutineScope = CoroutineScope(defaultDispatcher + SupervisorJob())
val roomOptions: RoomOptions
stateLock.withLock {
if (state != State.DISCONNECTED) {
throw IllegalStateException("Room.connect attempted while room is not disconnected!")
}
if (::coroutineScope.isInitialized) {
val job = coroutineScope.coroutineContext.job
coroutineScope.cancel()
job.join()
}
val roomOptions = getCurrentRoomOptions()
state = State.CONNECTING
connectOptions = options
// Setup local participant.
localParticipant.reinitialize()
coroutineScope.launch {
localParticipant.events.collect {
when (it) {
is ParticipantEvent.TrackPublished -> emitWhenConnected(
RoomEvent.TrackPublished(
room = this@Room,
publication = it.publication,
participant = it.participant,
),
)
coroutineScope = CoroutineScope(defaultDispatcher + SupervisorJob())
is ParticipantEvent.ParticipantPermissionsChanged -> emitWhenConnected(
RoomEvent.ParticipantPermissionsChanged(
room = this@Room,
participant = it.participant,
newPermissions = it.newPermissions,
oldPermissions = it.oldPermissions,
),
)
roomOptions = getCurrentRoomOptions()
is ParticipantEvent.MetadataChanged -> {
emitWhenConnected(
RoomEvent.ParticipantMetadataChanged(
this@Room,
it.participant,
it.prevMetadata,
// Setup local participant.
localParticipant.reinitialize()
coroutineScope.launch {
localParticipant.events.collect {
when (it) {
is ParticipantEvent.TrackPublished -> emitWhenConnected(
RoomEvent.TrackPublished(
room = this@Room,
publication = it.publication,
participant = it.participant,
),
)
}
is ParticipantEvent.NameChanged -> {
emitWhenConnected(
RoomEvent.ParticipantNameChanged(
this@Room,
it.participant,
it.name,
is ParticipantEvent.ParticipantPermissionsChanged -> emitWhenConnected(
RoomEvent.ParticipantPermissionsChanged(
room = this@Room,
participant = it.participant,
newPermissions = it.newPermissions,
oldPermissions = it.oldPermissions,
),
)
}
else -> {
// do nothing
is ParticipantEvent.MetadataChanged -> {
emitWhenConnected(
RoomEvent.ParticipantMetadataChanged(
this@Room,
it.participant,
it.prevMetadata,
),
)
}
is ParticipantEvent.NameChanged -> {
emitWhenConnected(
RoomEvent.ParticipantNameChanged(
this@Room,
it.participant,
it.name,
),
)
}
else -> {
// do nothing
}
}
}
}
}
state = State.CONNECTING
connectOptions = options
if (roomOptions.e2eeOptions != null) {
e2eeManager = e2EEManagerFactory.create(roomOptions.e2eeOptions.keyProvider).apply {
setup(this@Room) { event ->
coroutineScope.launch {
emitWhenConnected(event)
if (roomOptions.e2eeOptions != null) {
e2eeManager = e2EEManagerFactory.create(roomOptions.e2eeOptions.keyProvider).apply {
setup(this@Room) { event ->
coroutineScope.launch {
emitWhenConnected(event)
}
}
}
}
}
engine.join(url, token, options, roomOptions)
val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
val networkRequest = NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
.build()
cm.registerNetworkCallback(networkRequest, networkCallback)
// Use an empty coroutineExceptionHandler since we want to
// rethrow all throwables from the connect job.
val emptyCoroutineExceptionHandler = CoroutineExceptionHandler { _, _ -> }
val connectJob = coroutineScope.launch(
ioDispatcher + emptyCoroutineExceptionHandler,
) {
engine.join(url, token, options, roomOptions)
val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
val networkRequest = NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
.build()
cm.registerNetworkCallback(networkRequest, networkCallback)
ensureActive()
if (options.audio) {
val audioTrack = localParticipant.createAudioTrack()
localParticipant.publishAudioTrack(audioTrack)
}
ensureActive()
if (options.video) {
val videoTrack = localParticipant.createVideoTrack()
localParticipant.publishVideoTrack(videoTrack)
}
}
if (options.audio) {
val audioTrack = localParticipant.createAudioTrack()
localParticipant.publishAudioTrack(audioTrack)
val outerHandler = coroutineContext.job.invokeOnCompletion { cause ->
// Cancel connect job if invoking coroutine is cancelled.
if (cause is CancellationException) {
connectJob.cancel(cause)
}
}
if (options.video) {
val videoTrack = localParticipant.createVideoTrack()
localParticipant.publishVideoTrack(videoTrack)
var error: Throwable? = null
connectJob.invokeOnCompletion { cause ->
outerHandler.dispose()
error = cause
}
connectJob.join()
error?.let { throw it }
}
/**
... ... @@ -592,6 +636,35 @@ constructor(
engine.reconnect()
}
private fun handleDisconnect(reason: DisconnectReason) {
if (state == State.DISCONNECTED) {
return
}
runBlocking {
stateLock.withLock {
if (state == State.DISCONNECTED) {
return@runBlocking
}
try {
val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
cm.unregisterNetworkCallback(networkCallback)
} catch (e: IllegalArgumentException) {
// do nothing, may happen on older versions if attempting to unregister twice.
}
state = State.DISCONNECTED
cleanupRoom()
engine.close()
localParticipant.dispose()
// Ensure all observers see the disconnected before closing scope.
eventBus.postEvent(RoomEvent.Disconnected(this@Room, null, reason), coroutineScope).join()
coroutineScope.cancel()
}
}
}
/**
* Removes all participants and tracks from the room.
*/
... ... @@ -609,31 +682,6 @@ constructor(
sidToIdentity.clear()
}
private fun handleDisconnect(reason: DisconnectReason) {
if (state == State.DISCONNECTED) {
return
}
try {
val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
cm.unregisterNetworkCallback(networkCallback)
} catch (e: IllegalArgumentException) {
// do nothing, may happen on older versions if attempting to unregister twice.
}
state = State.DISCONNECTED
cleanupRoom()
engine.close()
localParticipant.dispose()
// Ensure all observers see the disconnected before closing scope.
runBlocking {
eventBus.postEvent(RoomEvent.Disconnected(this@Room, null, reason), coroutineScope).join()
}
coroutineScope.cancel()
}
private fun sendSyncState() {
// Whether we're sending subscribed tracks or tracks to unsubscribe.
val sendUnsub = connectOptions.autoSubscribe
... ...
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.util
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
/**
* Applies a double-checked lock before running [action].
*/
suspend inline fun <T> Mutex.withCheckLock(check: () -> Unit, action: () -> T): T {
check()
return withLock {
check()
action()
}
}
... ...
/*
* Copyright 2023 LiveKit, Inc.
* Copyright 2023-2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -18,6 +18,7 @@ package io.livekit.android.webrtc.peerconnection
import androidx.annotation.VisibleForTesting
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
... ... @@ -41,7 +42,7 @@ private val threadFactory = object : ThreadFactory {
}
}
// var only for testing purposes, do not alter!
// var only for testing purposes, do not alter in production!
private var executor = Executors.newSingleThreadExecutor(threadFactory)
private var rtcDispatcher: CoroutineDispatcher = executor.asCoroutineDispatcher()
... ... @@ -82,12 +83,12 @@ fun <T> executeBlockingOnRTCThread(action: () -> T): T {
* is generally not thread safe, so all actions relating to
* peer connection objects should go through the RTC thread.
*/
suspend fun <T> launchBlockingOnRTCThread(action: suspend () -> T): T = coroutineScope {
suspend fun <T> launchBlockingOnRTCThread(action: suspend CoroutineScope.() -> T): T = coroutineScope {
return@coroutineScope if (Thread.currentThread().name.startsWith(EXECUTOR_THREADNAME_PREFIX)) {
action()
this.action()
} else {
async(rtcDispatcher) {
action()
this.action()
}.await()
}
}
... ...