davidliu
Committed by GitHub

Force webrtc method calls onto a single thread (#342)

* Force all rtc related calls onto a single dedicated thread

* Fix test issues

* clean up long lock

* Move any callbacks from peerconnection api into local RTC thread

* Spotless
正在显示 18 个修改的文件 包含 405 行增加164 行删除
... ... @@ -18,10 +18,10 @@ package io.livekit.android.room
import android.javax.sdp.MediaDescription
import android.javax.sdp.SdpFactory
import androidx.annotation.VisibleForTesting
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
import io.livekit.android.coroutines.withReentrantLock
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.util.*
import io.livekit.android.util.Either
... ... @@ -34,11 +34,12 @@ import io.livekit.android.webrtc.getFmtps
import io.livekit.android.webrtc.getMsid
import io.livekit.android.webrtc.getRtps
import io.livekit.android.webrtc.isConnected
import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
import io.livekit.android.webrtc.peerconnection.launchBlockingOnRTCThread
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex
import org.webrtc.*
import org.webrtc.PeerConnection.RTCConfiguration
import org.webrtc.PeerConnection.SignalingState
... ... @@ -64,7 +65,9 @@ constructor(
private val sdpFactory: SdpFactory,
) {
private val coroutineScope = CoroutineScope(ioDispatcher + SupervisorJob())
private val peerConnection: PeerConnection = connectionFactory.createPeerConnection(
@VisibleForTesting
internal val peerConnection: PeerConnection = connectionFactory.createPeerConnection(
config,
pcObserver,
) ?: throw IllegalStateException("peer connection creation failed?")
... ... @@ -73,8 +76,6 @@ constructor(
private var renegotiate = false
private val mutex = Mutex()
private var trackBitrates = mutableMapOf<TrackBitrateInfoKey, TrackBitrateInfo>()
private var isClosed = AtomicBoolean(false)
... ... @@ -83,25 +84,23 @@ constructor(
}
fun addIceCandidate(candidate: IceCandidate) {
runBlocking {
withNotClosedLock {
if (peerConnection.remoteDescription != null && !restartingIce) {
peerConnection.addIceCandidate(candidate)
} else {
pendingCandidates.add(candidate)
}
executeRTCIfNotClosed {
if (peerConnection.remoteDescription != null && !restartingIce) {
peerConnection.addIceCandidate(candidate)
} else {
pendingCandidates.add(candidate)
}
}
}
suspend fun <T> withPeerConnection(action: suspend PeerConnection.() -> T): T? {
return withNotClosedLock {
return launchRTCIfNotClosed {
action(peerConnection)
}
}
suspend fun setRemoteDescription(sd: SessionDescription): Either<Unit, String?> {
val result = withNotClosedLock {
val result = launchRTCIfNotClosed {
val result = peerConnection.setRemoteDescription(sd)
if (result is Either.Left) {
pendingCandidates.forEach { pending ->
... ... @@ -137,7 +136,7 @@ constructor(
var finalSdp: SessionDescription? = null
// TODO: This is a potentially long lock hold. May need to break up.
withNotClosedLock {
launchRTCIfNotClosed {
val iceRestart =
constraints.findConstraint(MediaConstraintKeys.ICE_RESTART) == MediaConstraintKeys.TRUE
if (iceRestart) {
... ... @@ -155,7 +154,7 @@ constructor(
peerConnection.setRemoteDescription(curSd)
} else {
renegotiate = true
return@withNotClosedLock
return@launchRTCIfNotClosed
}
}
... ... @@ -164,10 +163,13 @@ constructor(
is Either.Left -> outcome.value
is Either.Right -> {
LKLog.d { "error creating offer: ${outcome.value}" }
return@withNotClosedLock
return@launchRTCIfNotClosed
}
}
if (isClosed()) {
return@launchRTCIfNotClosed
}
// munge sdp
val sdpDescription = sdpFactory.createSessionDescription(sdpOffer.description)
... ... @@ -195,11 +197,14 @@ constructor(
LKLog.v { "sdp type: ${sdp.type}\ndescription:\n${sdp.description}" }
LKLog.v { "munged sdp type: ${mungedSdp.type}\ndescription:\n${mungedSdp.description}" }
val mungedResult = if (remote) {
peerConnection.setRemoteDescription(mungedSdp)
} else {
peerConnection.setLocalDescription(mungedSdp)
}
val mungedResult = launchRTCIfNotClosed {
if (remote) {
peerConnection.setRemoteDescription(mungedSdp)
} else {
peerConnection.setLocalDescription(mungedSdp)
}
} ?: Either.Right("PCT closed")
val mungedErrorMessage = when (mungedResult) {
is Either.Left -> {
... ... @@ -224,11 +229,13 @@ constructor(
}
LKLog.w { "error: $mungedErrorMessage" }
val result = if (remote) {
peerConnection.setRemoteDescription(sdp)
} else {
peerConnection.setLocalDescription(sdp)
}
val result = launchRTCIfNotClosed {
if (remote) {
peerConnection.setRemoteDescription(sdp)
} else {
peerConnection.setLocalDescription(sdp)
}
} ?: Either.Right("PCT closed")
if (result is Either.Right) {
val errorMessage = if (result.value.isNullOrBlank()) {
... ... @@ -261,19 +268,15 @@ constructor(
}
suspend fun close() {
withNotClosedLock {
launchRTCIfNotClosed {
isClosed.set(true)
peerConnection.close()
// TODO: properly dispose of peer connection
peerConnection.dispose()
}
}
fun updateRTCConfig(config: RTCConfiguration) {
runBlocking {
withNotClosedLock {
peerConnection.setConfiguration(config)
}
executeRTCIfNotClosed {
peerConnection.setConfiguration(config)
}
}
... ... @@ -286,40 +289,56 @@ constructor(
}
suspend fun isConnected(): Boolean {
return withNotClosedLock {
return launchRTCIfNotClosed {
peerConnection.isConnected()
} ?: false
}
suspend fun iceConnectionState(): PeerConnection.IceConnectionState {
return withNotClosedLock {
return launchRTCIfNotClosed {
peerConnection.iceConnectionState()
} ?: PeerConnection.IceConnectionState.CLOSED
}
suspend fun connectionState(): PeerConnection.PeerConnectionState {
return withNotClosedLock {
return launchRTCIfNotClosed {
peerConnection.connectionState()
} ?: PeerConnection.PeerConnectionState.CLOSED
}
suspend fun signalingState(): SignalingState {
return withNotClosedLock {
return launchRTCIfNotClosed {
peerConnection.signalingState()
} ?: SignalingState.CLOSED
}
@OptIn(ExperimentalContracts::class)
private suspend inline fun <T> withNotClosedLock(crossinline action: suspend () -> T): T? {
private suspend inline fun <T> launchRTCIfNotClosed(noinline action: suspend () -> T): T? {
contract { callsInPlace(action, InvocationKind.AT_MOST_ONCE) }
if (isClosed()) {
return null
}
return mutex.withReentrantLock {
if (isClosed()) {
return@withReentrantLock null
return launchBlockingOnRTCThread {
return@launchBlockingOnRTCThread if (isClosed()) {
null
} else {
action()
}
}
}
@OptIn(ExperimentalContracts::class)
private fun <T> executeRTCIfNotClosed(action: () -> T): T? {
contract { callsInPlace(action, InvocationKind.AT_MOST_ONCE) }
if (isClosed()) {
return null
}
return executeBlockingOnRTCThread {
return@executeBlockingOnRTCThread if (isClosed()) {
null
} else {
action()
}
return@withReentrantLock action()
}
}
... ...
... ... @@ -17,8 +17,16 @@
package io.livekit.android.room
import io.livekit.android.util.LKLog
import io.livekit.android.webrtc.peerconnection.executeOnRTCThread
import livekit.LivekitRtc
import org.webrtc.*
import org.webrtc.CandidatePairChangeEvent
import org.webrtc.DataChannel
import org.webrtc.IceCandidate
import org.webrtc.MediaStream
import org.webrtc.PeerConnection
import org.webrtc.RtpReceiver
import org.webrtc.RtpTransceiver
import org.webrtc.SessionDescription
/**
* @suppress
... ... @@ -31,13 +39,17 @@ class PublisherTransportObserver(
var connectionChangeListener: ((newState: PeerConnection.PeerConnectionState) -> Unit)? = null
override fun onIceCandidate(iceCandidate: IceCandidate?) {
val candidate = iceCandidate ?: return
LKLog.v { "onIceCandidate: $candidate" }
client.sendCandidate(candidate, target = LivekitRtc.SignalTarget.PUBLISHER)
executeOnRTCThread {
val candidate = iceCandidate ?: return@executeOnRTCThread
LKLog.v { "onIceCandidate: $candidate" }
client.sendCandidate(candidate, target = LivekitRtc.SignalTarget.PUBLISHER)
}
}
override fun onRenegotiationNeeded() {
engine.negotiatePublisher()
executeOnRTCThread {
engine.negotiatePublisher()
}
}
override fun onIceConnectionChange(newState: PeerConnection.IceConnectionState?) {
... ... @@ -45,15 +57,19 @@ class PublisherTransportObserver(
}
override fun onOffer(sd: SessionDescription) {
client.sendOffer(sd)
executeOnRTCThread {
client.sendOffer(sd)
}
}
override fun onStandardizedIceConnectionChange(newState: PeerConnection.IceConnectionState?) {
}
override fun onConnectionChange(newState: PeerConnection.PeerConnectionState) {
LKLog.v { "onConnection new state: $newState" }
connectionChangeListener?.invoke(newState)
executeOnRTCThread {
LKLog.v { "onConnection new state: $newState" }
connectionChangeListener?.invoke(newState)
}
}
override fun onSelectedCandidatePairChanged(event: CandidatePairChangeEvent?) {
... ...
... ... @@ -32,10 +32,12 @@ import io.livekit.android.room.util.setLocalDescription
import io.livekit.android.util.CloseableCoroutineScope
import io.livekit.android.util.Either
import io.livekit.android.util.LKLog
import io.livekit.android.util.nullSafe
import io.livekit.android.webrtc.RTCStatsGetter
import io.livekit.android.webrtc.copy
import io.livekit.android.webrtc.isConnected
import io.livekit.android.webrtc.isDisconnected
import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
import io.livekit.android.webrtc.toProtoSessionDescription
import kotlinx.coroutines.*
import livekit.LivekitModels
... ... @@ -316,27 +318,29 @@ internal constructor(
}
private fun closeResources(reason: String) {
publisherObserver.connectionChangeListener = null
subscriberObserver.connectionChangeListener = null
publisher?.closeBlocking()
publisher = null
subscriber?.closeBlocking()
subscriber = null
fun DataChannel?.completeDispose() {
this?.unregisterObserver()
this?.close()
this?.dispose()
executeBlockingOnRTCThread {
publisherObserver.connectionChangeListener = null
subscriberObserver.connectionChangeListener = null
publisher?.closeBlocking()
publisher = null
subscriber?.closeBlocking()
subscriber = null
fun DataChannel?.completeDispose() {
this?.unregisterObserver()
this?.close()
this?.dispose()
}
reliableDataChannel?.completeDispose()
reliableDataChannel = null
reliableDataChannelSub?.completeDispose()
reliableDataChannelSub = null
lossyDataChannel?.completeDispose()
lossyDataChannel = null
lossyDataChannelSub?.completeDispose()
lossyDataChannelSub = null
isSubscriberPrimary = false
}
reliableDataChannel?.completeDispose()
reliableDataChannel = null
reliableDataChannelSub?.completeDispose()
reliableDataChannelSub = null
lossyDataChannel?.completeDispose()
lossyDataChannel = null
lossyDataChannelSub?.completeDispose()
lossyDataChannelSub = null
isSubscriberPrimary = false
client.close(reason = reason)
}
... ... @@ -712,7 +716,7 @@ internal constructor(
LKLog.v { "received server answer: ${sessionDescription.type}, $signalingState" }
coroutineScope.launch {
LKLog.i { sessionDescription.toString() }
when (val outcome = publisher?.setRemoteDescription(sessionDescription)) {
when (val outcome = publisher?.setRemoteDescription(sessionDescription).nullSafe()) {
is Either.Left -> {
// do nothing.
}
... ... @@ -720,10 +724,6 @@ internal constructor(
is Either.Right -> {
LKLog.e { "error setting remote description for answer: ${outcome.value} " }
}
else -> {
LKLog.w { "publisher is null, can't set remote description." }
}
}
}
}
... ... @@ -732,59 +732,49 @@ internal constructor(
val signalingState = runBlocking { publisher?.signalingState() }
LKLog.v { "received server offer: ${sessionDescription.type}, $signalingState" }
coroutineScope.launch {
// TODO: This is a potentially very long lock hold. May need to break up.
val answer = subscriber?.withPeerConnection {
run {
when (
val outcome =
subscriber?.setRemoteDescription(sessionDescription)
) {
is Either.Right -> {
LKLog.e { "error setting remote description for answer: ${outcome.value} " }
return@withPeerConnection null
}
else -> {}
run {
when (val outcome = subscriber?.setRemoteDescription(sessionDescription).nullSafe()) {
is Either.Right -> {
LKLog.e { "error setting remote description for answer: ${outcome.value} " }
return@launch
}
}
if (isClosed) {
return@withPeerConnection null
else -> {}
}
}
val answer = run {
when (val outcome = createAnswer(MediaConstraints())) {
is Either.Left -> outcome.value
is Either.Right -> {
LKLog.e { "error creating answer: ${outcome.value}" }
return@withPeerConnection null
}
}
}
if (isClosed) {
return@launch
}
if (isClosed) {
return@withPeerConnection null
val answer = run {
when (val outcome = subscriber?.withPeerConnection { createAnswer(MediaConstraints()) }.nullSafe()) {
is Either.Left -> outcome.value
is Either.Right -> {
LKLog.e { "error creating answer: ${outcome.value}" }
return@launch
}
}
}
run<Unit> {
when (val outcome = setLocalDescription(answer)) {
is Either.Right -> {
LKLog.e { "error setting local description for answer: ${outcome.value}" }
return@withPeerConnection null
}
if (isClosed) {
return@launch
}
else -> {}
run<Unit> {
when (val outcome = subscriber?.withPeerConnection { setLocalDescription(answer) }.nullSafe()) {
is Either.Left -> Unit
is Either.Right -> {
LKLog.e { "error setting local description for answer: ${outcome.value}" }
return@launch
}
}
if (isClosed) {
return@withPeerConnection null
}
return@withPeerConnection answer
}
answer?.let {
client.sendAnswer(it)
if (isClosed) {
return@launch
}
client.sendAnswer(answer)
}
}
... ... @@ -1018,12 +1008,12 @@ internal constructor(
}
@VisibleForTesting
internal suspend fun getPublisherPeerConnection() =
publisher?.withPeerConnection { this }!!
internal fun getPublisherPeerConnection() =
publisher!!.peerConnection
@VisibleForTesting
internal suspend fun getSubscriberPeerConnection() =
subscriber?.withPeerConnection { this }!!
internal fun getSubscriberPeerConnection() =
subscriber!!.peerConnection
}
/**
... ...
... ... @@ -560,8 +560,8 @@ constructor(
}
state = State.DISCONNECTED
engine.close()
cleanupRoom()
engine.close()
listener?.onDisconnect(this, null)
listener = null
... ...
... ... @@ -17,6 +17,7 @@
package io.livekit.android.room
import io.livekit.android.util.LKLog
import io.livekit.android.webrtc.peerconnection.executeOnRTCThread
import livekit.LivekitRtc
import org.webrtc.CandidatePairChangeEvent
import org.webrtc.DataChannel
... ... @@ -39,14 +40,18 @@ class SubscriberTransportObserver(
var connectionChangeListener: ((PeerConnection.PeerConnectionState) -> Unit)? = null
override fun onIceCandidate(candidate: IceCandidate) {
LKLog.v { "onIceCandidate: $candidate" }
client.sendCandidate(candidate, LivekitRtc.SignalTarget.SUBSCRIBER)
executeOnRTCThread {
LKLog.v { "onIceCandidate: $candidate" }
client.sendCandidate(candidate, LivekitRtc.SignalTarget.SUBSCRIBER)
}
}
override fun onAddTrack(receiver: RtpReceiver, streams: Array<out MediaStream>) {
val track = receiver.track() ?: return
LKLog.v { "onAddTrack: ${track.kind()}, ${track.id()}, ${streams.fold("") { sum, it -> "$sum, $it" }}" }
engine.listener?.onAddTrack(receiver, track, streams)
executeOnRTCThread {
val track = receiver.track() ?: return@executeOnRTCThread
LKLog.v { "onAddTrack: ${track.kind()}, ${track.id()}, ${streams.fold("") { sum, it -> "$sum, $it" }}" }
engine.listener?.onAddTrack(receiver, track, streams)
}
}
override fun onTrack(transceiver: RtpTransceiver) {
... ... @@ -58,15 +63,19 @@ class SubscriberTransportObserver(
}
override fun onDataChannel(channel: DataChannel) {
dataChannelListener?.invoke(channel)
executeOnRTCThread {
dataChannelListener?.invoke(channel)
}
}
override fun onStandardizedIceConnectionChange(newState: PeerConnection.IceConnectionState?) {
}
override fun onConnectionChange(newState: PeerConnection.PeerConnectionState) {
LKLog.v { "onConnectionChange new state: $newState" }
connectionChangeListener?.invoke(newState)
executeOnRTCThread {
LKLog.v { "onConnectionChange new state: $newState" }
connectionChangeListener?.invoke(newState)
}
}
override fun onSelectedCandidatePairChanged(event: CandidatePairChangeEvent?) {
... ...
... ... @@ -182,7 +182,7 @@ class RemoteParticipant(
if (track != null) {
try {
track.stop()
} catch (e: IllegalStateException) {
} catch (e: Exception) {
// track may already be disposed, ignore.
}
internalListener?.onTrackUnsubscribed(track, publication, this)
... ...
... ... @@ -20,6 +20,7 @@ import android.Manifest
import android.content.Context
import android.content.pm.PackageManager
import androidx.core.content.ContextCompat
import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
import org.webrtc.MediaConstraints
import org.webrtc.PeerConnectionFactory
import org.webrtc.RtpSender
... ... @@ -36,9 +37,9 @@ class LocalAudioTrack(
mediaTrack: org.webrtc.AudioTrack
) : AudioTrack(name, mediaTrack) {
var enabled: Boolean
get() = rtcTrack.enabled()
get() = executeBlockingOnRTCThread { rtcTrack.enabled() }
set(value) {
rtcTrack.setEnabled(value)
executeBlockingOnRTCThread { rtcTrack.setEnabled(value) }
}
internal var transceiver: RtpTransceiver? = null
... ...
... ... @@ -16,6 +16,7 @@
package io.livekit.android.room.track
import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
import org.webrtc.AudioTrack
import org.webrtc.AudioTrackSink
import org.webrtc.RtpReceiver
... ... @@ -23,7 +24,7 @@ import org.webrtc.RtpReceiver
class RemoteAudioTrack(
name: String,
rtcTrack: AudioTrack,
internal val receiver: RtpReceiver
internal val receiver: RtpReceiver,
) : io.livekit.android.room.track.AudioTrack(name, rtcTrack) {
/**
... ... @@ -35,13 +36,17 @@ class RemoteAudioTrack(
* to use the data after this function returns.
*/
fun addSink(sink: AudioTrackSink) {
rtcTrack.addSink(sink)
executeBlockingOnRTCThread {
rtcTrack.addSink(sink)
}
}
/**
* Removes a previously added sink.
*/
fun removeSink(sink: AudioTrackSink) {
rtcTrack.removeSink(sink)
executeBlockingOnRTCThread {
rtcTrack.removeSink(sink)
}
}
}
... ...
... ... @@ -21,6 +21,7 @@ import io.livekit.android.events.TrackEvent
import io.livekit.android.util.flowDelegate
import io.livekit.android.webrtc.RTCStatsGetter
import io.livekit.android.webrtc.getStats
import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
import livekit.LivekitModels
import livekit.LivekitRtc
import org.webrtc.MediaStreamTrack
... ... @@ -149,15 +150,21 @@ abstract class Track(
data class Dimensions(val width: Int, val height: Int)
open fun start() {
rtcTrack.setEnabled(true)
executeBlockingOnRTCThread {
rtcTrack.setEnabled(true)
}
}
open fun stop() {
rtcTrack.setEnabled(false)
executeBlockingOnRTCThread {
rtcTrack.setEnabled(false)
}
}
open fun dispose() {
rtcTrack.dispose()
executeBlockingOnRTCThread {
rtcTrack.dispose()
}
}
}
... ...
... ... @@ -16,6 +16,7 @@
package io.livekit.android.room.track
import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
import org.webrtc.VideoSink
import org.webrtc.VideoTrack
... ... @@ -30,20 +31,26 @@ abstract class VideoTrack(name: String, override val rtcTrack: VideoTrack) :
}
open fun addRenderer(renderer: VideoSink) {
sinks.add(renderer)
rtcTrack.addSink(renderer)
executeBlockingOnRTCThread {
sinks.add(renderer)
rtcTrack.addSink(renderer)
}
}
open fun removeRenderer(renderer: VideoSink) {
rtcTrack.removeSink(renderer)
sinks.remove(renderer)
executeBlockingOnRTCThread {
rtcTrack.removeSink(renderer)
sinks.remove(renderer)
}
}
override fun stop() {
for (sink in sinks) {
rtcTrack.removeSink(sink)
executeBlockingOnRTCThread {
for (sink in sinks) {
rtcTrack.removeSink(sink)
sinks.clear()
}
}
sinks.clear()
super.stop()
}
}
... ...
... ... @@ -20,3 +20,7 @@ sealed class Either<out A, out B> {
class Left<out A>(val value: A) : Either<A, Nothing>()
class Right<out B>(val value: B) : Either<Nothing, B>()
}
fun <A> Either<A, String?>?.nullSafe(): Either<A, String?> {
return this ?: Either.Right("null")
}
... ...
/*
* Copyright 2023 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.webrtc.peerconnection
import org.webrtc.PeerConnection
import org.webrtc.RtpReceiver
import org.webrtc.RtpSender
import org.webrtc.RtpTransceiver
/**
* Objects obtained through [PeerConnection] are transient,
* and should not be kept in memory. Calls to these methods
* dispose all existing objects in the tree and refresh with
* new updated objects:
*
* * [PeerConnection.getTransceivers]
* * [PeerConnection.getReceivers]
* * [PeerConnection.getSenders]
*
* For this reason, any object gotten through the PeerConnection
* should instead be looked up through the PeerConnection as needed.
*/
internal abstract class PeerConnectionResource<T>(val parentPeerConnection: PeerConnection) {
abstract fun get(): T?
}
internal class RtpTransceiverResource(parentPeerConnection: PeerConnection, private val senderId: String) : PeerConnectionResource<RtpTransceiver>(parentPeerConnection) {
override fun get() = executeBlockingOnRTCThread {
parentPeerConnection.transceivers.firstOrNull { t -> t.sender.id() == senderId }
}
}
internal class RtpReceiverResource(parentPeerConnection: PeerConnection, private val receiverId: String) : PeerConnectionResource<RtpReceiver>(parentPeerConnection) {
override fun get() = executeBlockingOnRTCThread {
parentPeerConnection.receivers.firstOrNull { r -> r.id() == receiverId }
}
}
internal class RtpSenderResource(parentPeerConnection: PeerConnection, private val senderId: String) : PeerConnectionResource<RtpSender>(parentPeerConnection) {
override fun get() = executeBlockingOnRTCThread {
parentPeerConnection.senders.firstOrNull { s -> s.id() == senderId }
}
}
... ...
/*
* Copyright 2023 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.webrtc.peerconnection
import androidx.annotation.VisibleForTesting
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicInteger
// Executor thread is started once and is used for all
// peer connection API calls to ensure new peer connection factory is
// created on the same thread as previously destroyed factory.
private const val EXECUTOR_THREADNAME_PREFIX = "LK_RTC_THREAD"
private val threadFactory = object : ThreadFactory {
private val idGenerator = AtomicInteger(0)
override fun newThread(r: Runnable): Thread {
val thread = Thread(r)
thread.name = EXECUTOR_THREADNAME_PREFIX + "_" + idGenerator.incrementAndGet()
return thread
}
}
// var only for testing purposes, do not alter!
private var executor = Executors.newSingleThreadExecutor(threadFactory)
private var rtcDispatcher: CoroutineDispatcher = executor.asCoroutineDispatcher()
@VisibleForTesting
internal fun overrideExecutorAndDispatcher(executorService: ExecutorService, dispatcher: CoroutineDispatcher) {
executor = executorService
rtcDispatcher = dispatcher
}
/**
* Execute [action] on the RTC thread. The PeerConnection API
* is generally not thread safe, so all actions relating to
* peer connection objects should go through the RTC thread.
*/
fun <T> executeOnRTCThread(action: () -> T) {
if (Thread.currentThread().name.startsWith(EXECUTOR_THREADNAME_PREFIX)) {
action()
} else {
executor.submit(action)
}
}
/**
* Execute [action] synchronously on the RTC thread. The PeerConnection API
* is generally not thread safe, so all actions relating to
* peer connection objects should go through the RTC thread.
*/
fun <T> executeBlockingOnRTCThread(action: () -> T): T {
return if (Thread.currentThread().name.startsWith(EXECUTOR_THREADNAME_PREFIX)) {
action()
} else {
executor.submit(action).get()
}
}
/**
* Launch [action] synchronously on the RTC thread. The PeerConnection API
* is generally not thread safe, so all actions relating to
* peer connection objects should go through the RTC thread.
*/
suspend fun <T> launchBlockingOnRTCThread(action: suspend () -> T): T = coroutineScope {
return@coroutineScope if (Thread.currentThread().name.startsWith(EXECUTOR_THREADNAME_PREFIX)) {
action()
} else {
async(rtcDispatcher) {
action()
}.await()
}
}
... ...
... ... @@ -16,11 +16,14 @@
package io.livekit.android
import com.google.common.util.concurrent.MoreExecutors
import io.livekit.android.coroutines.TestCoroutineRule
import io.livekit.android.util.LoggingRule
import io.livekit.android.webrtc.peerconnection.overrideExecutorAndDispatcher
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.runTest
import org.junit.Before
import org.junit.Rule
import org.mockito.junit.MockitoJUnit
... ... @@ -36,6 +39,14 @@ abstract class BaseTest {
@get:Rule
var coroutineRule = TestCoroutineRule()
@Before
fun setupRTCThread() {
overrideExecutorAndDispatcher(
executorService = MoreExecutors.newDirectExecutorService(),
dispatcher = coroutineRule.dispatcher,
)
}
@OptIn(ExperimentalCoroutinesApi::class)
fun runTest(testBody: suspend TestScope.() -> Unit) = coroutineRule.scope.runTest(testBody = testBody)
}
... ...
... ... @@ -24,6 +24,9 @@ class MockAudioStreamTrack(
var enabled: Boolean = true,
var state: State = State.LIVE,
) : AudioTrack(1L) {
var disposed = false
override fun id(): String = id
override fun kind(): String = kind
... ... @@ -40,6 +43,10 @@ class MockAudioStreamTrack(
}
override fun dispose() {
if (disposed) {
throw IllegalStateException("already disposed")
}
disposed = true
}
override fun setVolume(volume: Double) {
... ...
... ... @@ -24,6 +24,8 @@ class MockMediaStreamTrack(
var enabled: Boolean = true,
var state: State = State.LIVE,
) : MediaStreamTrack(1L) {
var disposed = false
override fun id(): String = id
override fun kind(): String = kind
... ... @@ -40,5 +42,9 @@ class MockMediaStreamTrack(
}
override fun dispose() {
if (disposed) {
throw IllegalStateException("already disposed")
}
disposed = true
}
}
... ...
... ... @@ -214,7 +214,8 @@ class MockPeerConnection(
IceConnectionState.NEW -> PeerConnectionState.NEW
IceConnectionState.CHECKING -> PeerConnectionState.CONNECTING
IceConnectionState.CONNECTED,
IceConnectionState.COMPLETED -> PeerConnectionState.CONNECTED
IceConnectionState.COMPLETED,
-> PeerConnectionState.CONNECTED
IceConnectionState.DISCONNECTED -> PeerConnectionState.DISCONNECTED
IceConnectionState.FAILED -> PeerConnectionState.FAILED
... ... @@ -242,7 +243,8 @@ class MockPeerConnection(
IceConnectionState.NEW,
IceConnectionState.CHECKING,
IceConnectionState.CONNECTED,
IceConnectionState.COMPLETED -> {
IceConnectionState.COMPLETED,
-> {
val currentOrdinal = iceConnectionState.ordinal
val newOrdinal = newState.ordinal
... ... @@ -258,7 +260,8 @@ class MockPeerConnection(
IceConnectionState.FAILED,
IceConnectionState.DISCONNECTED,
IceConnectionState.CLOSED -> {
IceConnectionState.CLOSED,
-> {
// jump to state directly.
iceConnectionState = newState
}
... ... @@ -278,6 +281,9 @@ class MockPeerConnection(
override fun dispose() {
iceConnectionState = IceConnectionState.CLOSED
closed = true
transceivers.forEach { t -> t.dispose() }
transceivers.clear()
}
override fun getNativePeerConnection(): Long = 0L
... ...
... ... @@ -28,21 +28,24 @@ import timber.log.Timber
*/
class LoggingRule : TestRule {
val logTree = object : Timber.DebugTree() {
override fun log(priority: Int, tag: String?, message: String, t: Throwable?) {
val priorityChar = when (priority) {
Log.VERBOSE -> "v"
Log.DEBUG -> "d"
Log.INFO -> "i"
Log.WARN -> "w"
Log.ERROR -> "e"
Log.ASSERT -> "a"
else -> "?"
}
companion object {
val logTree = object : Timber.DebugTree() {
override fun log(priority: Int, tag: String?, message: String, t: Throwable?) {
val priorityChar = when (priority) {
Log.VERBOSE -> "v"
Log.DEBUG -> "d"
Log.INFO -> "i"
Log.WARN -> "w"
Log.ERROR -> "e"
Log.ASSERT -> "a"
else -> "?"
}
println("$priorityChar: $tag: $message")
if (t != null) {
println(t.toString())
println("$priorityChar: $tag: $message")
if (t != null) {
println(t.toString())
}
}
}
}
... ...