davidliu
Committed by GitHub

Race condition fixes (#315)

* Add more mutexes surrounding peerconnection usage

* proper locking for response flow job

* Fix race condition in CoroutineSdpObserver

* Lock down peer connection usage

* spotless apply

* Exclude ReentrantMutex from spotless

* Fix tests

* spotless apply
... ... @@ -53,14 +53,17 @@ subprojects {
// through git history (see "license" section below)
licenseHeaderFile rootProject.file("LicenseHeaderFile.txt")
removeUnusedImports()
toggleOffOn()
}
kotlin {
target("src/*/java/**/*.kt")
targetExclude("src/*/java/**/ReentrantMutex.kt") // Different license
ktlint("0.50.0")
.setEditorConfigPath("$rootDir/.editorconfig")
licenseHeaderFile(rootProject.file("LicenseHeaderFile.txt"))
.named('license')
endWithNewline()
toggleOffOn()
}
}
}
... ...
/*
This is free and unencumbered software released into the public domain.
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to <https://unlicense.org>
Original is at https://gist.github.com/elizarov/9a48b9709ffd508909d34fab6786acfe
*/
package io.livekit.android.coroutines
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext
internal suspend fun <T> Mutex.withReentrantLock(block: suspend () -> T): T {
val key = ReentrantMutexContextKey(this)
// call block directly when this mutex is already locked in the context
if (coroutineContext[key] != null) return block()
// otherwise add it to the context and lock the mutex
return withContext(ReentrantMutexContextElement(key)) {
withLock { block() }
}
}
internal class ReentrantMutexContextElement(
override val key: ReentrantMutexContextKey,
) : CoroutineContext.Element
internal data class ReentrantMutexContextKey(
val mutex: Mutex,
) : CoroutineContext.Key<ReentrantMutexContextElement>
... ...
... ... @@ -21,6 +21,7 @@ import android.javax.sdp.SdpFactory
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
import io.livekit.android.coroutines.withReentrantLock
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.util.*
import io.livekit.android.util.Either
... ... @@ -32,15 +33,20 @@ import io.livekit.android.webrtc.getExts
import io.livekit.android.webrtc.getFmtps
import io.livekit.android.webrtc.getMsid
import io.livekit.android.webrtc.getRtps
import io.livekit.android.webrtc.isConnected
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.webrtc.*
import org.webrtc.PeerConnection.RTCConfiguration
import org.webrtc.PeerConnection.SignalingState
import java.util.concurrent.atomic.AtomicBoolean
import javax.inject.Named
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.math.roundToLong
/**
... ... @@ -58,7 +64,7 @@ constructor(
private val sdpFactory: SdpFactory,
) {
private val coroutineScope = CoroutineScope(ioDispatcher + SupervisorJob())
internal val peerConnection: PeerConnection = connectionFactory.createPeerConnection(
private val peerConnection: PeerConnection = connectionFactory.createPeerConnection(
config,
pcObserver,
) ?: throw IllegalStateException("peer connection creation failed?")
... ... @@ -70,6 +76,7 @@ constructor(
private val mutex = Mutex()
private var trackBitrates = mutableMapOf<TrackBitrateInfoKey, TrackBitrateInfo>()
private var isClosed = AtomicBoolean(false)
interface Listener {
fun onOffer(sd: SessionDescription)
... ... @@ -77,7 +84,7 @@ constructor(
fun addIceCandidate(candidate: IceCandidate) {
runBlocking {
mutex.withLock {
withNotClosedLock {
if (peerConnection.remoteDescription != null && !restartingIce) {
peerConnection.addIceCandidate(candidate)
} else {
... ... @@ -87,17 +94,24 @@ constructor(
}
}
suspend fun <T> withPeerConnection(action: suspend PeerConnection.() -> T): T? {
return withNotClosedLock {
action(peerConnection)
}
}
suspend fun setRemoteDescription(sd: SessionDescription): Either<Unit, String?> {
val result = peerConnection.setRemoteDescription(sd)
if (result is Either.Left) {
mutex.withLock {
val result = withNotClosedLock {
val result = peerConnection.setRemoteDescription(sd)
if (result is Either.Left) {
pendingCandidates.forEach { pending ->
peerConnection.addIceCandidate(pending)
}
pendingCandidates.clear()
restartingIce = false
}
}
result
} ?: Either.Right("PCT is closed.")
if (this.renegotiate) {
this.renegotiate = false
... ... @@ -115,60 +129,65 @@ constructor(
}
}
suspend fun createAndSendOffer(constraints: MediaConstraints = MediaConstraints()) {
private suspend fun createAndSendOffer(constraints: MediaConstraints = MediaConstraints()) {
if (listener == null) {
return
}
val iceRestart =
constraints.findConstraint(MediaConstraintKeys.ICE_RESTART) == MediaConstraintKeys.TRUE
if (iceRestart) {
LKLog.d { "restarting ice" }
restartingIce = true
}
var finalSdp: SessionDescription? = null
if (this.peerConnection.signalingState() == PeerConnection.SignalingState.HAVE_LOCAL_OFFER) {
// we're waiting for the peer to accept our offer, so we'll just wait
// the only exception to this is when ICE restart is needed
val curSd = peerConnection.remoteDescription
if (iceRestart && curSd != null) {
// TODO: handle when ICE restart is needed but we don't have a remote description
// the best thing to do is to recreate the peerconnection
peerConnection.setRemoteDescription(curSd)
} else {
renegotiate = true
return
// TODO: This is a potentially long lock hold. May need to break up.
withNotClosedLock {
val iceRestart =
constraints.findConstraint(MediaConstraintKeys.ICE_RESTART) == MediaConstraintKeys.TRUE
if (iceRestart) {
LKLog.d { "restarting ice" }
restartingIce = true
}
}
// actually negotiate
LKLog.d { "starting to negotiate" }
val sdpOffer = when (val outcome = peerConnection.createOffer(constraints)) {
is Either.Left -> outcome.value
is Either.Right -> {
LKLog.d { "error creating offer: ${outcome.value}" }
return
if (this.peerConnection.signalingState() == SignalingState.HAVE_LOCAL_OFFER) {
// we're waiting for the peer to accept our offer, so we'll just wait
// the only exception to this is when ICE restart is needed
val curSd = peerConnection.remoteDescription
if (iceRestart && curSd != null) {
// TODO: handle when ICE restart is needed but we don't have a remote description
// the best thing to do is to recreate the peerconnection
peerConnection.setRemoteDescription(curSd)
} else {
renegotiate = true
return@withNotClosedLock
}
}
}
// munge sdp
val sdpDescription = sdpFactory.createSessionDescription(sdpOffer.description)
val mediaDescs = sdpDescription.getMediaDescriptions(true)
for (mediaDesc in mediaDescs) {
if (mediaDesc !is MediaDescription) {
continue
// actually negotiate
val sdpOffer = when (val outcome = peerConnection.createOffer(constraints)) {
is Either.Left -> outcome.value
is Either.Right -> {
LKLog.d { "error creating offer: ${outcome.value}" }
return@withNotClosedLock
}
}
if (mediaDesc.media.mediaType == "audio") {
// TODO
} else if (mediaDesc.media.mediaType == "video") {
ensureVideoDDExtensionForSVC(mediaDesc)
ensureCodecBitrates(mediaDesc, trackBitrates = trackBitrates)
// munge sdp
val sdpDescription = sdpFactory.createSessionDescription(sdpOffer.description)
val mediaDescs = sdpDescription.getMediaDescriptions(true)
for (mediaDesc in mediaDescs) {
if (mediaDesc !is MediaDescription) {
continue
}
if (mediaDesc.media.mediaType == "audio") {
// TODO
} else if (mediaDesc.media.mediaType == "video") {
ensureVideoDDExtensionForSVC(mediaDesc)
ensureCodecBitrates(mediaDesc, trackBitrates = trackBitrates)
}
}
finalSdp = setMungedSdp(sdpOffer, sdpDescription.toString())
}
if (finalSdp != null) {
listener.onOffer(finalSdp!!)
}
val finalSdp = setMungedSdp(sdpOffer, sdpDescription.toString())
listener.onOffer(finalSdp)
}
private suspend fun setMungedSdp(sdp: SessionDescription, mungedDescription: String, remote: Boolean = false): SessionDescription {
... ... @@ -233,12 +252,27 @@ constructor(
restartingIce = true
}
fun close() {
peerConnection.dispose()
fun isClosed() = isClosed.get()
fun closeBlocking() {
runBlocking {
close()
}
}
suspend fun close() {
withNotClosedLock {
isClosed.set(true)
peerConnection.dispose()
}
}
fun updateRTCConfig(config: RTCConfiguration) {
peerConnection.setConfiguration(config)
runBlocking {
withNotClosedLock {
peerConnection.setConfiguration(config)
}
}
}
fun registerTrackBitrateInfo(cid: String, trackBitrateInfo: TrackBitrateInfo) {
... ... @@ -249,6 +283,44 @@ constructor(
trackBitrates[TrackBitrateInfoKey.Transceiver(transceiver)] = trackBitrateInfo
}
suspend fun isConnected(): Boolean {
return withNotClosedLock {
peerConnection.isConnected()
} ?: false
}
suspend fun iceConnectionState(): PeerConnection.IceConnectionState {
return withNotClosedLock {
peerConnection.iceConnectionState()
} ?: PeerConnection.IceConnectionState.CLOSED
}
suspend fun connectionState(): PeerConnection.PeerConnectionState {
return withNotClosedLock {
peerConnection.connectionState()
} ?: PeerConnection.PeerConnectionState.CLOSED
}
suspend fun signalingState(): SignalingState {
return withNotClosedLock {
peerConnection.signalingState()
} ?: SignalingState.CLOSED
}
@OptIn(ExperimentalContracts::class)
private suspend inline fun <T> withNotClosedLock(crossinline action: suspend () -> T): T? {
contract { callsInPlace(action, InvocationKind.AT_MOST_ONCE) }
if (isClosed()) {
return null
}
return mutex.withReentrantLock {
if (isClosed()) {
return@withReentrantLock null
}
return@withReentrantLock action()
}
}
@AssistedFactory
interface Factory {
fun create(
... ... @@ -296,7 +368,7 @@ internal fun ensureVideoDDExtensionForSVC(mediaDesc: MediaDescription) {
}
}
/* The svc codec (av1/vp9) would use a very low bitrate at the begining and
/* The svc codec (av1/vp9) would use a very low bitrate at the beginning and
increase slowly by the bandwidth estimator until it reach the target bitrate. The
process commonly cost more than 10 seconds cause subscriber will get blur video at
the first few seconds. So we use a 70% of target bitrate here as the start bitrate to
... ...
... ... @@ -17,6 +17,7 @@
package io.livekit.android.room
import android.os.SystemClock
import androidx.annotation.VisibleForTesting
import com.google.protobuf.ByteString
import io.livekit.android.ConnectOptions
import io.livekit.android.RoomOptions
... ... @@ -31,6 +32,7 @@ import io.livekit.android.room.util.setLocalDescription
import io.livekit.android.util.CloseableCoroutineScope
import io.livekit.android.util.Either
import io.livekit.android.util.LKLog
import io.livekit.android.webrtc.RTCStatsGetter
import io.livekit.android.webrtc.copy
import io.livekit.android.webrtc.isConnected
import io.livekit.android.webrtc.isDisconnected
... ... @@ -114,18 +116,8 @@ internal constructor(
private val publisherObserver = PublisherTransportObserver(this, client)
private val subscriberObserver = SubscriberTransportObserver(this, client)
private var _publisher: PeerConnectionTransport? = null
internal val publisher: PeerConnectionTransport
get() {
return _publisher
?: throw UninitializedPropertyAccessException("publisher has not been initialized yet.")
}
private var _subscriber: PeerConnectionTransport? = null
internal val subscriber: PeerConnectionTransport
get() {
return _subscriber
?: throw UninitializedPropertyAccessException("subscriber has not been initialized yet.")
}
private var publisher: PeerConnectionTransport? = null
private var subscriber: PeerConnectionTransport? = null
private var reliableDataChannel: DataChannel? = null
private var reliableDataChannelSub: DataChannel? = null
... ... @@ -181,8 +173,8 @@ internal constructor(
return joinResponse
}
private fun configure(joinResponse: JoinResponse, connectOptions: ConnectOptions) {
if (_publisher != null && _subscriber != null) {
private suspend fun configure(joinResponse: JoinResponse, connectOptions: ConnectOptions) {
if (publisher != null && subscriber != null) {
// already configured
return
}
... ... @@ -196,14 +188,14 @@ internal constructor(
// Setup peer connections
val rtcConfig = makeRTCConfig(Either.Left(joinResponse), connectOptions)
_publisher?.close()
_publisher = pctFactory.create(
publisher?.close()
publisher = pctFactory.create(
rtcConfig,
publisherObserver,
publisherObserver,
)
_subscriber?.close()
_subscriber = pctFactory.create(
subscriber?.close()
subscriber = pctFactory.create(
rtcConfig,
subscriberObserver,
null,
... ... @@ -243,19 +235,22 @@ internal constructor(
// data channels
val reliableInit = DataChannel.Init()
reliableInit.ordered = true
reliableDataChannel = publisher.peerConnection.createDataChannel(
RELIABLE_DATA_CHANNEL_LABEL,
reliableInit,
)
reliableDataChannel!!.registerObserver(this)
reliableDataChannel = publisher?.withPeerConnection {
createDataChannel(
RELIABLE_DATA_CHANNEL_LABEL,
reliableInit,
).apply { registerObserver(this@RTCEngine) }
}
val lossyInit = DataChannel.Init()
lossyInit.ordered = true
lossyInit.maxRetransmits = 0
lossyDataChannel = publisher.peerConnection.createDataChannel(
LOSSY_DATA_CHANNEL_LABEL,
lossyInit,
)
lossyDataChannel!!.registerObserver(this)
lossyDataChannel = publisher?.withPeerConnection {
createDataChannel(
LOSSY_DATA_CHANNEL_LABEL,
lossyInit,
).apply { registerObserver(this@RTCEngine) }
}
}
/**
... ... @@ -277,11 +272,13 @@ internal constructor(
}
}
internal fun createSenderTransceiver(
internal suspend fun createSenderTransceiver(
rtcTrack: MediaStreamTrack,
transInit: RtpTransceiverInit,
): RtpTransceiver? {
return publisher.peerConnection.addTransceiver(rtcTrack, transInit)
return publisher?.withPeerConnection {
addTransceiver(rtcTrack, transInit)
}
}
fun updateSubscriptionPermissions(
... ... @@ -301,15 +298,15 @@ internal constructor(
}
LKLog.v { "Close - $reason" }
isClosed = true
reconnectingJob?.cancel()
reconnectingJob = null
coroutineScope.close()
hasPublished = false
sessionUrl = null
sessionToken = null
connectOptions = null
lastRoomOptions = null
participantSid = null
reconnectingJob?.cancel()
reconnectingJob = null
coroutineScope.close()
closeResources(reason)
connectionState = ConnectionState.DISCONNECTED
}
... ... @@ -317,10 +314,10 @@ internal constructor(
private fun closeResources(reason: String) {
publisherObserver.connectionChangeListener = null
subscriberObserver.connectionChangeListener = null
_publisher?.close()
_publisher = null
_subscriber?.close()
_subscriber = null
publisher?.closeBlocking()
publisher = null
subscriber?.closeBlocking()
subscriber = null
fun DataChannel?.completeDispose() {
this?.unregisterObserver()
... ... @@ -366,6 +363,15 @@ internal constructor(
val reconnectStartTime = SystemClock.elapsedRealtime()
for (retries in 0 until MAX_RECONNECT_RETRIES) {
if (retries != 0) {
yield()
}
if (isClosed) {
LKLog.v { "RTCEngine closed, aborting reconnection" }
break
}
var startDelay = 100 + retries.toLong() * retries * 500
if (startDelay > 5000) {
startDelay = 5000
... ... @@ -395,14 +401,14 @@ internal constructor(
}
} else {
LKLog.v { "Attempting soft reconnect." }
subscriber.prepareForIceRestart()
subscriber?.prepareForIceRestart()
try {
val response = client.reconnect(url, token, participantSid)
if (response is Either.Left) {
val reconnectResponse = response.value
val rtcConfig = makeRTCConfig(Either.Right(reconnectResponse), connectOptions)
_subscriber?.updateRTCConfig(rtcConfig)
_publisher?.updateRTCConfig(rtcConfig)
subscriber?.updateRTCConfig(rtcConfig)
publisher?.updateRTCConfig(rtcConfig)
}
client.onReadyForResponses()
} catch (e: Exception) {
... ... @@ -420,11 +426,17 @@ internal constructor(
negotiatePublisher()
}
}
if (isClosed) {
LKLog.v { "RTCEngine closed, aborting reconnection" }
break
}
// wait until ICE connected
val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS
if (hasPublished) {
while (SystemClock.elapsedRealtime() < endTime) {
if (publisher.peerConnection.connectionState().isConnected()) {
if (publisher?.isConnected() == true) {
LKLog.v { "publisher reconnected to ICE" }
break
}
... ... @@ -432,8 +444,13 @@ internal constructor(
}
}
if (isClosed) {
LKLog.v { "RTCEngine closed, aborting reconnection" }
break
}
while (SystemClock.elapsedRealtime() < endTime) {
if (subscriber.peerConnection.connectionState().isConnected()) {
if (subscriber?.isConnected() == true) {
LKLog.v { "reconnected to ICE" }
connectionState = ConnectionState.CONNECTED
break
... ... @@ -441,8 +458,12 @@ internal constructor(
delay(100)
}
if (isClosed) {
LKLog.v { "RTCEngine closed, aborting reconnection" }
break
}
if (connectionState == ConnectionState.CONNECTED &&
(!hasPublished || publisher.peerConnection.connectionState().isConnected())
(!hasPublished || publisher?.isConnected() == true)
) {
client.onPCConnected()
listener?.onPostReconnect(isFullReconnect)
... ... @@ -475,7 +496,7 @@ internal constructor(
hasPublished = true
coroutineScope.launch {
publisher.negotiate(getPublisherOfferConstraints())
publisher?.negotiate?.invoke(getPublisherOfferConstraints())
}
}
... ... @@ -498,12 +519,12 @@ internal constructor(
return
}
if (_publisher == null) {
if (publisher == null) {
throw RoomException.ConnectException("Publisher isn't setup yet! Is room not connected?!")
}
if (!publisher.peerConnection.isConnected() &&
publisher.peerConnection.iceConnectionState() != PeerConnection.IceConnectionState.CHECKING
if (publisher?.isConnected() != true &&
publisher?.iceConnectionState() != PeerConnection.IceConnectionState.CHECKING
) {
// start negotiation
this.negotiatePublisher()
... ... @@ -517,7 +538,7 @@ internal constructor(
// wait until publisher ICE connected
val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS
while (SystemClock.elapsedRealtime() < endTime) {
if (this.publisher.peerConnection.isConnected() && targetChannel.state() == DataChannel.State.OPEN) {
if (publisher?.isConnected() == true && targetChannel.state() == DataChannel.State.OPEN) {
return
}
delay(50)
... ... @@ -676,10 +697,11 @@ internal constructor(
// ---------------------------------- SignalClient.Listener --------------------------------------//
override fun onAnswer(sessionDescription: SessionDescription) {
LKLog.v { "received server answer: ${sessionDescription.type}, ${publisher.peerConnection.signalingState()}" }
val signalingState = runBlocking { publisher?.signalingState() }
LKLog.v { "received server answer: ${sessionDescription.type}, $signalingState" }
coroutineScope.launch {
LKLog.i { sessionDescription.toString() }
when (val outcome = publisher.setRemoteDescription(sessionDescription)) {
when (val outcome = publisher?.setRemoteDescription(sessionDescription)) {
is Either.Left -> {
// do nothing.
}
... ... @@ -687,49 +709,71 @@ internal constructor(
is Either.Right -> {
LKLog.e { "error setting remote description for answer: ${outcome.value} " }
}
else -> {
LKLog.w { "publisher is null, can't set remote description." }
}
}
}
}
override fun onOffer(sessionDescription: SessionDescription) {
LKLog.v { "received server offer: ${sessionDescription.type}, ${subscriber.peerConnection.signalingState()}" }
val signalingState = runBlocking { publisher?.signalingState() }
LKLog.v { "received server offer: ${sessionDescription.type}, $signalingState" }
coroutineScope.launch {
run<Unit> {
when (
val outcome =
subscriber.setRemoteDescription(sessionDescription)
) {
is Either.Right -> {
LKLog.e { "error setting remote description for answer: ${outcome.value} " }
return@launch
// TODO: This is a potentially very long lock hold. May need to break up.
val answer = subscriber?.withPeerConnection {
run {
when (
val outcome =
subscriber?.setRemoteDescription(sessionDescription)
) {
is Either.Right -> {
LKLog.e { "error setting remote description for answer: ${outcome.value} " }
return@withPeerConnection null
}
else -> {}
}
}
else -> {}
if (isClosed) {
return@withPeerConnection null
}
}
val answer = run {
when (val outcome = subscriber.peerConnection.createAnswer(MediaConstraints())) {
is Either.Left -> outcome.value
is Either.Right -> {
LKLog.e { "error creating answer: ${outcome.value}" }
return@launch
val answer = run {
when (val outcome = createAnswer(MediaConstraints())) {
is Either.Left -> outcome.value
is Either.Right -> {
LKLog.e { "error creating answer: ${outcome.value}" }
return@withPeerConnection null
}
}
}
}
run<Unit> {
when (val outcome = subscriber.peerConnection.setLocalDescription(answer)) {
is Either.Right -> {
LKLog.e { "error setting local description for answer: ${outcome.value}" }
return@launch
if (isClosed) {
return@withPeerConnection null
}
run<Unit> {
when (val outcome = setLocalDescription(answer)) {
is Either.Right -> {
LKLog.e { "error setting local description for answer: ${outcome.value}" }
return@withPeerConnection null
}
else -> {}
}
}
else -> {}
if (isClosed) {
return@withPeerConnection null
}
return@withPeerConnection answer
}
answer?.let {
client.sendAnswer(it)
}
client.sendAnswer(answer)
}
}
... ... @@ -737,14 +781,15 @@ internal constructor(
LKLog.v { "received ice candidate from peer: $candidate, $target" }
when (target) {
LivekitRtc.SignalTarget.PUBLISHER -> {
if (_publisher != null) {
publisher.addIceCandidate(candidate)
} else {
LKLog.w { "received candidate for publisher when we don't have one. ignoring." }
}
publisher?.addIceCandidate(candidate)
?: LKLog.w { "received candidate for publisher when we don't have one. ignoring." }
}
LivekitRtc.SignalTarget.SUBSCRIBER -> {
subscriber?.addIceCandidate(candidate)
?: LKLog.w { "received candidate for subscriber when we don't have one. ignoring." }
}
LivekitRtc.SignalTarget.SUBSCRIBER -> subscriber.addIceCandidate(candidate)
else -> LKLog.i { "unknown ice candidate target?" }
}
}
... ... @@ -866,7 +911,9 @@ internal constructor(
subscription: LivekitRtc.UpdateSubscription,
publishedTracks: List<LivekitRtc.TrackPublishedResponse>,
) {
val answer = subscriber.peerConnection.localDescription?.toProtoSessionDescription()
val answer = runBlocking {
subscriber?.withPeerConnection { localDescription?.toProtoSessionDescription() }
}
val dataChannelInfos = LivekitModels.DataPacket.Kind.values()
.toList()
... ... @@ -892,12 +939,66 @@ internal constructor(
}
fun getPublisherRTCStats(callback: RTCStatsCollectorCallback) {
_publisher?.peerConnection?.getStats(callback) ?: callback.onStatsDelivered(RTCStatsReport(0, emptyMap()))
runBlocking {
publisher?.withPeerConnection { getStats(callback) }
?: callback.onStatsDelivered(RTCStatsReport(0, emptyMap()))
}
}
fun getSubscriberRTCStats(callback: RTCStatsCollectorCallback) {
_subscriber?.peerConnection?.getStats(callback) ?: callback.onStatsDelivered(RTCStatsReport(0, emptyMap()))
runBlocking {
subscriber?.withPeerConnection { getStats(callback) }
?: callback.onStatsDelivered(RTCStatsReport(0, emptyMap()))
}
}
fun createStatsGetter(sender: RtpSender): RTCStatsGetter {
val p = publisher
return { statsCallback: RTCStatsCollectorCallback ->
runBlocking {
p?.withPeerConnection {
getStats(sender, statsCallback)
} ?: statsCallback.onStatsDelivered(RTCStatsReport(0, emptyMap()))
}
}
}
fun createStatsGetter(receiver: RtpReceiver): RTCStatsGetter {
val p = subscriber
return { statsCallback: RTCStatsCollectorCallback ->
runBlocking {
p?.withPeerConnection {
getStats(receiver, statsCallback)
} ?: statsCallback.onStatsDelivered(RTCStatsReport(0, emptyMap()))
}
}
}
internal fun registerTrackBitrateInfo(cid: String, trackBitrateInfo: TrackBitrateInfo) {
publisher?.registerTrackBitrateInfo(cid, trackBitrateInfo)
}
internal fun removeTrack(rtcTrack: MediaStreamTrack) {
runBlocking {
publisher?.withPeerConnection {
val senders = this.senders
for (sender in senders) {
val t = sender.track() ?: continue
if (t.id() == rtcTrack.id()) {
this@withPeerConnection.removeTrack(sender)
}
}
}
}
}
@VisibleForTesting
internal suspend fun getPublisherPeerConnection() =
publisher?.withPeerConnection { this }!!
@VisibleForTesting
internal suspend fun getSubscriberPeerConnection() =
subscriber?.withPeerConnection { this }!!
}
/**
... ...
... ... @@ -42,7 +42,6 @@ import io.livekit.android.util.FlowObservable
import io.livekit.android.util.LKLog
import io.livekit.android.util.flowDelegate
import io.livekit.android.util.invoke
import io.livekit.android.webrtc.createStatsGetter
import io.livekit.android.webrtc.getFilteredStats
import kotlinx.coroutines.*
import livekit.LivekitModels
... ... @@ -708,7 +707,7 @@ constructor(
trackSid = track.id()
}
val participant = getOrCreateRemoteParticipant(participantSid)
val statsGetter = createStatsGetter(engine.subscriber.peerConnection, receiver)
val statsGetter = engine.createStatsGetter(receiver)
participant.addSubscribedMediaTrack(
track,
trackSid!!,
... ...
... ... @@ -88,6 +88,8 @@ constructor(
private var requestFlowJob: Job? = null
private val requestFlow = MutableSharedFlow<LivekitRtc.SignalRequest>(Int.MAX_VALUE)
private val responseFlowJobLock = Object()
private var responseFlowJob: Job? = null
private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE)
private var pingJob: Job? = null
... ... @@ -202,10 +204,17 @@ constructor(
* Should be called after resolving the join message.
*/
fun onReadyForResponses() {
coroutineScope.launch {
responseFlow.collect {
responseFlow.resetReplayCache()
handleSignalResponseImpl(it)
if (responseFlowJob != null) {
return
}
synchronized(responseFlowJobLock) {
if (responseFlowJob == null) {
responseFlowJob = coroutineScope.launch {
responseFlow.collect {
responseFlow.resetReplayCache()
handleSignalResponseImpl(it)
}
}
}
}
}
... ... @@ -378,7 +387,7 @@ constructor(
type: LivekitModels.TrackType,
builder: LivekitRtc.AddTrackRequest.Builder = LivekitRtc.AddTrackRequest.newBuilder(),
) {
var encryptionType = lastRoomOptions?.e2eeOptions?.encryptionType ?: LivekitModels.Encryption.Type.NONE
val encryptionType = lastRoomOptions?.e2eeOptions?.encryptionType ?: LivekitModels.Encryption.Type.NONE
val addTrackRequest = builder
.setCid(cid)
.setName(name)
... ... @@ -731,6 +740,8 @@ constructor(
}
requestFlowJob?.cancel()
requestFlowJob = null
responseFlowJob?.cancel()
responseFlowJob = null
pingJob?.cancel()
pingJob = null
pongJob?.cancel()
... ...
... ... @@ -34,7 +34,6 @@ import io.livekit.android.room.isSVCCodec
import io.livekit.android.room.track.*
import io.livekit.android.room.util.EncodingUtils
import io.livekit.android.util.LKLog
import io.livekit.android.webrtc.createStatsGetter
import io.livekit.android.webrtc.sortVideoCodecPreferences
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.launch
... ... @@ -393,12 +392,12 @@ internal constructor(
return false
}
track.statsGetter = createStatsGetter(engine.publisher.peerConnection, transceiver.sender)
track.statsGetter = engine.createStatsGetter(transceiver.sender)
// Handle trackBitrates
if (encodings.isNotEmpty()) {
if (options is VideoTrackPublishOptions && isSVCCodec(options.videoCodec) && encodings.firstOrNull()?.maxBitrateBps != null) {
engine.publisher.registerTrackBitrateInfo(
engine.registerTrackBitrateInfo(
cid = cid,
TrackBitrateInfo(
codec = options.videoCodec,
... ... @@ -556,13 +555,7 @@ internal constructor(
tracks = tracks.toMutableMap().apply { remove(sid) }
if (engine.connectionState == ConnectionState.CONNECTED) {
val senders = engine.publisher.peerConnection.senders
for (sender in senders) {
val t = sender.track() ?: continue
if (t.id() == track.rtcTrack.id()) {
engine.publisher.peerConnection.removeTrack(sender)
}
}
engine.removeTrack(track.rtcTrack)
}
if (stopOnUnpublish) {
track.stop()
... ...
... ... @@ -17,6 +17,10 @@
package io.livekit.android.room.util
import io.livekit.android.util.Either
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.webrtc.MediaConstraints
import org.webrtc.PeerConnection
import org.webrtc.SdpObserver
... ... @@ -26,26 +30,47 @@ import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine
open class CoroutineSdpObserver : SdpObserver {
private val stateLock = Mutex()
private var createOutcome: Either<SessionDescription, String?>? = null
set(value) {
field = value
val conts = runBlocking {
stateLock.withLock {
field = value
if (value != null) {
val conts = pendingCreate.toList()
pendingCreate.clear()
conts
} else {
null
}
}
}
if (value != null) {
val conts = pendingCreate.toList()
pendingCreate.clear()
conts.forEach {
conts?.forEach {
it.resume(value)
}
}
}
private var pendingCreate = mutableListOf<Continuation<Either<SessionDescription, String?>>>()
private var setOutcome: Either<Unit, String?>? = null
set(value) {
field = value
val conts = runBlocking {
stateLock.withLock {
field = value
if (value != null) {
val conts = pendingSets.toList()
pendingSets.clear()
conts
} else {
null
}
}
}
if (value != null) {
val conts = pendingSets.toList()
pendingSets.clear()
conts.forEach {
conts?.forEach {
it.resume(value)
}
}
... ... @@ -72,21 +97,41 @@ open class CoroutineSdpObserver : SdpObserver {
setOutcome = Either.Right(message)
}
suspend fun awaitCreate() = suspendCoroutine { cont ->
val curOutcome = createOutcome
if (curOutcome != null) {
cont.resume(curOutcome)
suspend fun awaitCreate() = suspendCancellableCoroutine { cont ->
val unlockedOutcome = createOutcome
if (unlockedOutcome != null) {
cont.resume(unlockedOutcome)
} else {
pendingCreate.add(cont)
runBlocking {
stateLock.lock()
val lockedOutcome = createOutcome
if (lockedOutcome != null) {
stateLock.unlock()
cont.resume(lockedOutcome)
} else {
pendingCreate.add(cont)
stateLock.unlock()
}
}
}
}
suspend fun awaitSet() = suspendCoroutine { cont ->
val curOutcome = setOutcome
if (curOutcome != null) {
cont.resume(curOutcome)
val unlockedOutcome = setOutcome
if (unlockedOutcome != null) {
cont.resume(unlockedOutcome)
} else {
pendingSets.add(cont)
runBlocking {
stateLock.lock()
val lockedOutcome = setOutcome
if (lockedOutcome != null) {
stateLock.unlock()
cont.resume(lockedOutcome)
} else {
pendingSets.add(cont)
stateLock.unlock()
}
}
}
}
}
... ...
... ... @@ -19,12 +19,9 @@ package io.livekit.android.webrtc
import io.livekit.android.util.LKLog
import kotlinx.coroutines.suspendCancellableCoroutine
import org.webrtc.MediaStreamTrack
import org.webrtc.PeerConnection
import org.webrtc.RTCStats
import org.webrtc.RTCStatsCollectorCallback
import org.webrtc.RTCStatsReport
import org.webrtc.RtpReceiver
import org.webrtc.RtpSender
import kotlin.coroutines.resume
/**
... ... @@ -174,13 +171,3 @@ suspend fun RTCStatsGetter.getStats(): RTCStatsReport = suspendCancellableCorout
}
this.invoke(listener)
}
fun createStatsGetter(peerConnection: PeerConnection, sender: RtpSender): RTCStatsGetter =
{ statsCallback: RTCStatsCollectorCallback ->
peerConnection.getStats(sender, statsCallback)
}
fun createStatsGetter(peerConnection: PeerConnection, receiver: RtpReceiver): RTCStatsGetter =
{ statsCallback: RTCStatsCollectorCallback ->
peerConnection.getStats(receiver, statsCallback)
}
... ...
... ... @@ -24,7 +24,6 @@ import io.livekit.android.mock.MockWebSocketFactory
import io.livekit.android.mock.dagger.DaggerTestLiveKitComponent
import io.livekit.android.mock.dagger.TestCoroutinesModule
import io.livekit.android.mock.dagger.TestLiveKitComponent
import io.livekit.android.room.PeerConnectionTransport
import io.livekit.android.room.Room
import io.livekit.android.room.SignalClientTest
import io.livekit.android.util.toOkioByteString
... ... @@ -45,7 +44,6 @@ abstract class MockE2ETest : BaseTest() {
internal lateinit var context: Context
internal lateinit var room: Room
internal lateinit var wsFactory: MockWebSocketFactory
internal lateinit var subscriber: PeerConnectionTransport
@Before
fun mocksSetup() {
... ... @@ -77,16 +75,26 @@ abstract class MockE2ETest : BaseTest() {
job.join()
}
fun connectPeerConnection() {
subscriber = component.rtcEngine().subscriber
suspend fun getSubscriberPeerConnection() =
component
.rtcEngine()
.getSubscriberPeerConnection() as MockPeerConnection
suspend fun getPublisherPeerConnection() =
component
.rtcEngine()
.getPublisherPeerConnection() as MockPeerConnection
suspend fun connectPeerConnection() {
simulateMessageFromServer(SignalClientTest.OFFER)
val subPeerConnection = subscriber.peerConnection as MockPeerConnection
val subPeerConnection = getSubscriberPeerConnection()
subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.CONNECTED)
}
fun disconnectPeerConnection() {
subscriber = component.rtcEngine().subscriber
val subPeerConnection = subscriber.peerConnection as MockPeerConnection
suspend fun disconnectPeerConnection() {
val subPeerConnection = component
.rtcEngine()
.getSubscriberPeerConnection() as MockPeerConnection
subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
}
... ...
... ... @@ -16,10 +16,14 @@
package io.livekit.android.mock
import com.google.protobuf.MessageLite
import io.livekit.android.util.toOkioByteString
import io.livekit.android.util.toPBByteString
import livekit.LivekitModels
import livekit.LivekitRtc
import livekit.LivekitRtc.LeaveRequest
import livekit.LivekitRtc.SignalRequest
import livekit.LivekitRtc.SignalResponse
import livekit.LivekitRtc.TrackPublishedResponse
import okhttp3.Request
import okhttp3.WebSocket
import okhttp3.WebSocketListener
... ... @@ -42,34 +46,80 @@ class MockWebSocketFactory : WebSocket.Factory {
lateinit var listener: WebSocketListener
override fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket {
this.ws = MockWebSocket(request, listener) { byteString ->
val signalRequest = LivekitRtc.SignalRequest.parseFrom(byteString.toPBByteString())
if (signalRequest.hasAddTrack()) {
val signalRequest = SignalRequest.parseFrom(byteString.toPBByteString())
handleSignalRequest(signalRequest)
}
this.listener = listener
this.request = request
onOpen?.invoke(this)
return ws
}
private val signalRequestHandlers = mutableListOf<SignalRequestHandler>(
{ signalRequest -> defaultHandleSignalRequest(signalRequest) },
)
fun registerSignalRequestHandler(handler: SignalRequestHandler) {
signalRequestHandlers.add(0, handler)
}
private fun handleSignalRequest(signalRequest: SignalRequest) {
for (handler in signalRequestHandlers) {
if (handler.invoke(signalRequest)) {
break
}
}
}
private fun defaultHandleSignalRequest(signalRequest: SignalRequest): Boolean {
when (signalRequest.messageCase) {
SignalRequest.MessageCase.ADD_TRACK -> {
val addTrack = signalRequest.addTrack
val trackPublished = with(LivekitRtc.SignalResponse.newBuilder()) {
trackPublished = with(LivekitRtc.TrackPublishedResponse.newBuilder()) {
val trackPublished = with(SignalResponse.newBuilder()) {
trackPublished = with(TrackPublishedResponse.newBuilder()) {
cid = addTrack.cid
if (addTrack.type == LivekitModels.TrackType.AUDIO) {
track = TestData.LOCAL_AUDIO_TRACK
track = if (addTrack.type == LivekitModels.TrackType.AUDIO) {
TestData.LOCAL_AUDIO_TRACK
} else {
track = TestData.LOCAL_VIDEO_TRACK
TestData.LOCAL_VIDEO_TRACK
}
build()
}
build()
}
this.listener.onMessage(this.ws, trackPublished.toOkioByteString())
receiveMessage(trackPublished)
return true
}
}
this.listener = listener
this.request = request
onOpen?.invoke(this)
return ws
SignalRequest.MessageCase.LEAVE -> {
val leaveResponse = with(SignalResponse.newBuilder()) {
leave = with(LeaveRequest.newBuilder()) {
canReconnect = false
reason = LivekitModels.DisconnectReason.CLIENT_INITIATED
build()
}
build()
}
receiveMessage(leaveResponse)
return true
}
else -> {
return false
}
}
}
var onOpen: ((MockWebSocketFactory) -> Unit)? = null
fun receiveMessage(message: MessageLite) {
receiveMessage(message.toOkioByteString())
}
fun receiveMessage(byteString: ByteString) {
listener.onMessage(ws, byteString)
}
}
typealias SignalRequestHandler = (SignalRequest) -> Boolean
... ...
... ... @@ -17,7 +17,6 @@
package io.livekit.android.room
import io.livekit.android.MockE2ETest
import io.livekit.android.mock.MockPeerConnection
import io.livekit.android.util.toOkioByteString
import io.livekit.android.util.toPBByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
... ... @@ -47,7 +46,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
connect()
val sentIceServers = SignalClientTest.JOIN.join.iceServersList
.map { it.toWebrtc() }
val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection
val subPeerConnection = getSubscriberPeerConnection()
assertEquals(sentIceServers, subPeerConnection.rtcConfig.iceServers)
}
... ... @@ -57,7 +56,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
connect()
assertEquals(
SignalClientTest.OFFER.offer.sdp,
rtcEngine.subscriber.peerConnection.remoteDescription.description,
getSubscriberPeerConnection().remoteDescription?.description,
)
val ws = wsFactory.ws
... ... @@ -65,7 +64,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
.mergeFrom(ws.sentRequests[0].toPBByteString())
.build()
val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection
val subPeerConnection = getSubscriberPeerConnection()
val localAnswer = subPeerConnection.localDescription ?: throw IllegalStateException("no answer was created.")
Assert.assertTrue(sentRequest.hasAnswer())
assertEquals(localAnswer.description, sentRequest.answer.sdp)
... ... @@ -88,7 +87,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
connect()
val oldWs = wsFactory.ws
val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection
val subPeerConnection = getSubscriberPeerConnection()
subPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
testScheduler.advanceTimeBy(1000)
... ... @@ -101,7 +100,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
connect()
val oldWs = wsFactory.ws
val pubPeerConnection = rtcEngine.publisher.peerConnection as MockPeerConnection
val pubPeerConnection = getPublisherPeerConnection()
pubPeerConnection.moveToIceConnectionState(PeerConnection.IceConnectionState.FAILED)
testScheduler.advanceTimeBy(1000)
... ... @@ -138,7 +137,7 @@ class RTCEngineMockE2ETest : MockE2ETest() {
},
)
val subPeerConnection = rtcEngine.subscriber.peerConnection as MockPeerConnection
val subPeerConnection = getSubscriberPeerConnection()
assertEquals(PeerConnection.IceTransportsType.RELAY, subPeerConnection.rtcConfig.iceTransportsType)
}
... ...
... ... @@ -34,8 +34,12 @@ import io.livekit.android.room.track.Track
import io.livekit.android.util.flow
import io.livekit.android.util.toOkioByteString
import junit.framework.Assert.assertEquals
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.launch
import livekit.LivekitRtc
import org.junit.Assert
import org.junit.Test
import org.junit.runner.RunWith
... ... @@ -336,6 +340,52 @@ class RoomMockE2ETest : MockE2ETest() {
)
val eventCollector = EventCollector(room.events, coroutineRule.scope)
wsFactory.listener.onMessage(
wsFactory.ws,
SignalClientTest.LEAVE.toOkioByteString(),
)
room.disconnect()
val events = eventCollector.stopCollecting()
assertEquals(2, events.size)
assertEquals(true, events[0] is RoomEvent.TrackUnpublished)
assertEquals(true, events[1] is RoomEvent.Disconnected)
}
/**
*
*/
@Test
fun disconnectWithTracks() = runTest {
connect()
val differentThread = CoroutineScope(Dispatchers.IO + SupervisorJob())
wsFactory.registerSignalRequestHandler {
if (it.hasLeave()) {
differentThread.launch {
val leaveResponse = with(LivekitRtc.SignalResponse.newBuilder()) {
leave = with(LivekitRtc.LeaveRequest.newBuilder()) {
canReconnect = false
reason = livekit.LivekitModels.DisconnectReason.CLIENT_INITIATED
build()
}
build()
}
wsFactory.receiveMessage(leaveResponse)
}
return@registerSignalRequestHandler true
}
return@registerSignalRequestHandler false
}
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
),
)
val eventCollector = EventCollector(room.events, coroutineRule.scope)
room.disconnect()
val events = eventCollector.stopCollecting()
... ...
... ... @@ -18,7 +18,6 @@ package io.livekit.android.room
import io.livekit.android.MockE2ETest
import io.livekit.android.mock.MockAudioStreamTrack
import io.livekit.android.mock.MockPeerConnection
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.util.toPBByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
... ... @@ -89,8 +88,7 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
testScheduler.advanceTimeBy(1000)
connectPeerConnection()
val rtcEngine = component.rtcEngine()
val rtcConfig = (rtcEngine.subscriber.peerConnection as MockPeerConnection).rtcConfig
val rtcConfig = getSubscriberPeerConnection().rtcConfig
assertEquals(PeerConnection.IceTransportsType.RELAY, rtcConfig.iceTransportsType)
val sentIceServers = SignalClientTest.RECONNECT.reconnect.iceServersList
... ...
... ... @@ -23,7 +23,6 @@ import io.livekit.android.events.ParticipantEvent
import io.livekit.android.events.RoomEvent
import io.livekit.android.mock.MockAudioStreamTrack
import io.livekit.android.mock.MockEglBase
import io.livekit.android.mock.MockPeerConnection
import io.livekit.android.mock.MockVideoCapturer
import io.livekit.android.mock.MockVideoStreamTrack
import io.livekit.android.room.DefaultsManager
... ... @@ -176,7 +175,7 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
room.localParticipant.publishVideoTrack(track = createLocalTrack())
val peerConnection = component.rtcEngine().publisher.peerConnection
val peerConnection = getPublisherPeerConnection()
val transceiver = peerConnection.transceivers.first()
Mockito.verify(transceiver).setCodecPreferences(
... ... @@ -195,7 +194,7 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
room.localParticipant.publishVideoTrack(track = createLocalTrack())
val peerConnection = component.rtcEngine().publisher.peerConnection
val peerConnection = getPublisherPeerConnection()
val transceiver = peerConnection.transceivers.first()
Mockito.verify(transceiver).setCodecPreferences(
... ... @@ -236,7 +235,7 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
val vp8Codec = addTrackRequest.simulcastCodecsList[1]
assertEquals("vp8", vp8Codec.codec)
val publisherConn = component.rtcEngine().publisher.peerConnection as MockPeerConnection
val publisherConn = getPublisherPeerConnection()
assertEquals(1, publisherConn.transceivers.size)
Mockito.verify(publisherConn.transceivers.first()).setCodecPreferences(
... ...