David Liu

protocol 3: subscriber as primary

package io.livekit.android.room
import com.github.ajalt.timberkt.Timber
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
import io.livekit.android.room.util.CoroutineSdpObserver
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.util.*
import io.livekit.android.util.Either
import org.webrtc.IceCandidate
import org.webrtc.PeerConnection
import org.webrtc.PeerConnectionFactory
import org.webrtc.SessionDescription
import io.livekit.android.util.debounce
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.SupervisorJob
import org.webrtc.*
import javax.inject.Named
/**
* @suppress
... ... @@ -17,18 +21,28 @@ class PeerConnectionTransport
@AssistedInject
constructor(
@Assisted config: PeerConnection.RTCConfiguration,
@Assisted listener: PeerConnection.Observer,
@Assisted pcObserver: PeerConnection.Observer,
@Assisted private val listener: Listener?,
@Named(InjectionNames.DISPATCHER_IO)
private val ioDispatcher: CoroutineDispatcher,
connectionFactory: PeerConnectionFactory
) {
private val coroutineScope = CoroutineScope(ioDispatcher + SupervisorJob())
val peerConnection: PeerConnection = connectionFactory.createPeerConnection(
config,
listener
pcObserver
) ?: throw IllegalStateException("peer connection creation failed?")
val pendingCandidates = mutableListOf<IceCandidate>()
var iceRestart: Boolean = false
var restartingIce: Boolean = false
var renegotiate = false
interface Listener {
fun onOffer(sd: SessionDescription)
}
fun addIceCandidate(candidate: IceCandidate) {
if (peerConnection.remoteDescription != null && !iceRestart) {
if (peerConnection.remoteDescription != null && !restartingIce) {
peerConnection.addIceCandidate(candidate)
} else {
pendingCandidates.add(candidate)
... ... @@ -37,23 +51,62 @@ constructor(
suspend fun setRemoteDescription(sd: SessionDescription): Either<Unit, String?> {
val observer = object : CoroutineSdpObserver() {
override fun onSetSuccess() {
val result = peerConnection.setRemoteDescription(sd)
if (result is Either.Left) {
pendingCandidates.forEach { pending ->
peerConnection.addIceCandidate(pending)
}
pendingCandidates.clear()
iceRestart = false
super.onSetSuccess()
restartingIce = false
}
if (this.renegotiate) {
this.renegotiate = false
this.createAndSendOffer()
}
return result
}
val negotiate = debounce<Unit, Unit>(100, coroutineScope) { createAndSendOffer() }
suspend fun createAndSendOffer(constraints: MediaConstraints = MediaConstraints()) {
if (listener == null) {
return
}
val iceRestart =
constraints.findConstraint(MediaConstraintKeys.ICE_RESTART) == MediaConstraintKeys.TRUE
if (iceRestart) {
Timber.d { "restarting ice" }
restartingIce = true
}
if (this.peerConnection.signalingState() == PeerConnection.SignalingState.HAVE_LOCAL_OFFER) {
// we're waiting for the peer to accept our offer, so we'll just wait
// the only exception to this is when ICE restart is needed
val curSd = peerConnection.remoteDescription
if (iceRestart && curSd != null) {
// TODO: handle when ICE restart is needed but we don't have a remote description
// the best thing to do is to recreate the peerconnection
peerConnection.setRemoteDescription(curSd)
} else {
renegotiate = true
return
}
}
peerConnection.setRemoteDescription(observer, sd)
return observer.awaitSet()
// actually negotiate
Timber.d { "starting to negotiate" }
val offer = peerConnection.createOffer(constraints)
if (offer is Either.Left) {
peerConnection.setLocalDescription(offer.value)
listener?.onOffer(offer.value)
}
}
fun prepareForIceRestart() {
iceRestart = true
restartingIce = true
}
fun close() {
... ... @@ -64,7 +117,8 @@ constructor(
interface Factory {
fun create(
config: PeerConnection.RTCConfiguration,
listener: PeerConnection.Observer
pcObserver: PeerConnection.Observer,
listener: Listener?
): PeerConnectionTransport
}
}
\ No newline at end of file
... ...
... ... @@ -8,12 +8,18 @@ import org.webrtc.*
* @suppress
*/
class PublisherTransportObserver(
private val engine: RTCEngine
) : PeerConnection.Observer {
private val engine: RTCEngine,
private val client: SignalClient,
) : PeerConnection.Observer, PeerConnectionTransport.Listener {
var dataChannelListener: ((DataChannel?) -> Unit)? = null
var iceConnectionChangeListener: ((newState: PeerConnection.IceConnectionState?) -> Unit)? =
null
override fun onIceCandidate(iceCandidate: IceCandidate?) {
val candidate = iceCandidate ?: return
engine.client.sendCandidate(candidate, target = LivekitRtc.SignalTarget.PUBLISHER)
Timber.v { "onIceCandidate: $candidate" }
client.sendCandidate(candidate, target = LivekitRtc.SignalTarget.PUBLISHER)
}
override fun onRenegotiationNeeded() {
... ... @@ -21,15 +27,12 @@ class PublisherTransportObserver(
}
override fun onIceConnectionChange(newState: PeerConnection.IceConnectionState?) {
val state = newState ?: throw NullPointerException("unexpected null new state, what do?")
Timber.v { "onIceConnection new state: $newState" }
if (state == PeerConnection.IceConnectionState.CONNECTED) {
engine.iceState = IceState.CONNECTED
} else if (state == PeerConnection.IceConnectionState.FAILED) {
// when we publish tracks, some WebRTC versions will send out disconnected events periodically
engine.iceState = IceState.DISCONNECTED
engine.listener?.onDisconnect("Peer connection disconnected")
iceConnectionChangeListener?.invoke(newState)
}
override fun onOffer(sd: SessionDescription) {
client.sendOffer(sd)
}
override fun onStandardizedIceConnectionChange(newState: PeerConnection.IceConnectionState?) {
... ... @@ -41,7 +44,6 @@ class PublisherTransportObserver(
override fun onSelectedCandidatePairChanged(event: CandidatePairChangeEvent?) {
}
override fun onSignalingChange(p0: PeerConnection.SignalingState?) {
}
... ... @@ -60,7 +62,8 @@ class PublisherTransportObserver(
override fun onRemoveStream(p0: MediaStream?) {
}
override fun onDataChannel(p0: DataChannel?) {
override fun onDataChannel(dataChannel: DataChannel?) {
dataChannelListener?.invoke(dataChannel)
}
override fun onTrack(transceiver: RtpTransceiver?) {
... ... @@ -68,4 +71,5 @@ class PublisherTransportObserver(
override fun onAddTrack(p0: RtpReceiver?, p1: Array<out MediaStream>?) {
}
}
\ No newline at end of file
... ...
package io.livekit.android.room
import android.os.SystemClock
import com.github.ajalt.timberkt.Timber
import io.livekit.android.ConnectOptions
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.track.DataPublishReliability
import io.livekit.android.room.track.Track
import io.livekit.android.room.track.TrackException
import io.livekit.android.room.track.TrackPublication
import io.livekit.android.room.util.*
import io.livekit.android.util.CloseableCoroutineScope
import io.livekit.android.util.Either
... ... @@ -15,6 +18,9 @@ import kotlinx.coroutines.launch
import livekit.LivekitModels
import livekit.LivekitRtc
import org.webrtc.*
import java.net.ConnectException
import java.nio.ByteBuffer
import java.util.concurrent.TimeUnit
import javax.inject.Inject
import javax.inject.Named
import javax.inject.Singleton
... ... @@ -33,7 +39,7 @@ constructor(
private val pctFactory: PeerConnectionTransport.Factory,
@Named(InjectionNames.DISPATCHER_IO) ioDispatcher: CoroutineDispatcher,
) : SignalClient.Listener, DataChannel.Observer {
var listener: Listener? = null
internal var listener: Listener? = null
internal var iceState: IceState = IceState.DISCONNECTED
set(value) {
val oldVal = field
... ... @@ -55,7 +61,8 @@ constructor(
Timber.d { "publisher ICE disconnected" }
listener?.onDisconnect("Peer connection disconnected")
}
else -> {}
else -> {
}
}
}
private var wsRetries: Int = 0
... ... @@ -64,25 +71,169 @@ constructor(
private var sessionUrl: String? = null
private var sessionToken: String? = null
private val publisherObserver = PublisherTransportObserver(this)
private val subscriberObserver = SubscriberTransportObserver(this)
private val publisherObserver = PublisherTransportObserver(this, client)
private val subscriberObserver = SubscriberTransportObserver(this, client)
internal lateinit var publisher: PeerConnectionTransport
private lateinit var subscriber: PeerConnectionTransport
internal var reliableDataChannel: DataChannel? = null
internal var lossyDataChannel: DataChannel? = null
private var reliableDataChannel: DataChannel? = null
private var reliableDataChannelSub: DataChannel? = null
private var lossyDataChannel: DataChannel? = null
private var lossyDataChannelSub: DataChannel? = null
private var isSubscriberPrimary = false
private var isClosed = true
private var hasPublished = false
private val coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)
init {
client.listener = this
}
fun join(url: String, token: String, options: ConnectOptions?) {
suspend fun join(url: String, token: String, options: ConnectOptions?): LivekitRtc.JoinResponse {
sessionUrl = url
sessionToken = token
client.join(url, token, options)
val joinResponse = client.join(url, token, options)
isClosed = false
isSubscriberPrimary = joinResponse.subscriberPrimary
if (!this::publisher.isInitialized) {
configure(joinResponse)
}
// create offer
if (!this.isSubscriberPrimary) {
negotiate()
}
return joinResponse
}
private suspend fun configure(joinResponse: LivekitRtc.JoinResponse) {
if (this::publisher.isInitialized || this::subscriber.isInitialized) {
// already configured
return
}
// update ICE servers before creating PeerConnection
val iceServers = mutableListOf<PeerConnection.IceServer>()
for (serverInfo in joinResponse.iceServersList) {
val username = serverInfo.username ?: ""
val credential = serverInfo.credential ?: ""
iceServers.add(
PeerConnection.IceServer
.builder(serverInfo.urlsList)
.setUsername(username)
.setPassword(credential)
.createIceServer()
)
}
if (iceServers.isEmpty()) {
iceServers.addAll(SignalClient.DEFAULT_ICE_SERVERS)
}
joinResponse.iceServersList.forEach {
Timber.v { "username = \"${it.username}\"" }
Timber.v { "credential = \"${it.credential}\"" }
Timber.v { "urls: " }
it.urlsList.forEach {
Timber.v { " $it" }
}
}
// Setup peer connections
val rtcConfig = PeerConnection.RTCConfiguration(iceServers).apply {
sdpSemantics = PeerConnection.SdpSemantics.UNIFIED_PLAN
continualGatheringPolicy = PeerConnection.ContinualGatheringPolicy.GATHER_CONTINUALLY
enableDtlsSrtp = true
}
publisher = pctFactory.create(
rtcConfig,
publisherObserver,
publisherObserver,
)
subscriber = pctFactory.create(
rtcConfig,
subscriberObserver,
null,
)
val iceConnectionStateListener: (PeerConnection.IceConnectionState?) -> Unit = { newState ->
val state =
newState ?: throw NullPointerException("unexpected null new state, what do?")
Timber.v { "onIceConnection new state: $newState" }
if (state == PeerConnection.IceConnectionState.CONNECTED) {
iceState = IceState.CONNECTED
} else if (state == PeerConnection.IceConnectionState.FAILED) {
// when we publish tracks, some WebRTC versions will send out disconnected events periodically
iceState = IceState.DISCONNECTED
listener?.onDisconnect("Peer connection disconnected")
}
}
if (joinResponse.subscriberPrimary) {
// in subscriber primary mode, server side opens sub data channels.
publisherObserver.dataChannelListener = onDataChannel@{ dataChannel: DataChannel? ->
if (dataChannel == null) {
return@onDataChannel
}
when (dataChannel.label()) {
RELIABLE_DATA_CHANNEL_LABEL -> reliableDataChannelSub = dataChannel
LOSSY_DATA_CHANNEL_LABEL -> lossyDataChannelSub = dataChannel
else -> return@onDataChannel
}
dataChannel.registerObserver(this)
}
publisherObserver.iceConnectionChangeListener = iceConnectionStateListener
} else {
subscriberObserver.iceConnectionChangeListener = iceConnectionStateListener
}
// data channels
val reliableInit = DataChannel.Init()
reliableInit.ordered = true
reliableDataChannel = publisher.peerConnection.createDataChannel(
RELIABLE_DATA_CHANNEL_LABEL,
reliableInit
)
reliableDataChannel!!.registerObserver(this)
val lossyInit = DataChannel.Init()
lossyInit.ordered = true
lossyInit.maxRetransmits = 0
lossyDataChannel = publisher.peerConnection.createDataChannel(
LOSSY_DATA_CHANNEL_LABEL,
lossyInit
)
lossyDataChannel!!.registerObserver(this)
coroutineScope.launch {
val sdpOffer =
when (val outcome = publisher.peerConnection.createOffer(getPublisherOfferConstraints())) {
is Either.Left -> outcome.value
is Either.Right -> {
Timber.d { "error creating offer: ${outcome.value}" }
return@launch
}
}
when (val outcome = publisher.peerConnection.setLocalDescription(sdpOffer)) {
is Either.Right -> {
Timber.d { "error setting local description: ${outcome.value}" }
return@launch
}
}
client.sendOffer(sdpOffer)
}
}
suspend fun addTrack(cid: String, name: String, kind: LivekitModels.TrackType, dimensions: Track.Dimensions? = null): LivekitModels.TrackInfo {
suspend fun addTrack(
cid: String,
name: String,
kind: LivekitModels.TrackType,
dimensions: Track.Dimensions? = null
): LivekitModels.TrackInfo {
if (pendingTrackResolvers[cid] != null) {
throw TrackException.DuplicateTrackException("Track with same ID $cid has already been published!")
}
... ... @@ -108,7 +259,10 @@ constructor(
* reconnect Signal and PeerConnections
*/
internal fun reconnect() {
if (sessionUrl == null || sessionToken == null) {
val url = sessionUrl
val token = sessionToken
if (url == null || token == null) {
Timber.w { "couldn't reconnect, no url or no token" }
return
}
if (iceState == IceState.DISCONNECTED || wsRetries >= MAX_SIGNAL_RETRIES) {
... ... @@ -124,13 +278,34 @@ constructor(
}
coroutineScope.launch {
delay(startDelay)
val url = sessionUrl
val token = sessionToken
if (iceState != IceState.DISCONNECTED && url != null && token != null) {
val opts = ConnectOptions()
opts.reconnect = true
client.join(url, token, opts)
if (iceState == IceState.DISCONNECTED) {
Timber.e { "Ice is disconnected" }
return@launch
}
client.reconnect(url, token)
Timber.v { "reconnected, restarting ICE" }
wsRetries = 0
// trigger publisher reconnect
subscriber.restartingIce = true
// only restart publisher if it's needed
if (hasPublished) {
publisher.createAndSendOffer(
getPublisherOfferConstraints().apply {
with(mandatory){
add(
MediaConstraints.KeyValuePair(
MediaConstraintKeys.ICE_RESTART,
MediaConstraintKeys.TRUE
)
)
}
}
)
}
}
}
... ... @@ -140,7 +315,7 @@ constructor(
}
coroutineScope.launch {
val sdpOffer =
when (val outcome = publisher.peerConnection.createOffer(getOfferConstraints())) {
when (val outcome = publisher.peerConnection.createOffer(getPublisherOfferConstraints())) {
is Either.Left -> outcome.value
is Either.Right -> {
Timber.d { "error creating offer: ${outcome.value}" }
... ... @@ -160,20 +335,75 @@ constructor(
}
}
private fun getOfferConstraints(): MediaConstraints {
internal suspend fun sendData(dataPacket: LivekitModels.DataPacket) {
ensurePublisherConnected()
val buf = DataChannel.Buffer(
ByteBuffer.wrap(dataPacket.toByteArray()),
true,
)
val channel = when (dataPacket.kind) {
LivekitModels.DataPacket.Kind.RELIABLE -> reliableDataChannel
LivekitModels.DataPacket.Kind.LOSSY -> lossyDataChannel
else -> null
} ?: throw TrackException.PublishException("channel not established for ${dataPacket.kind.name}")
channel.send(buf)
}
private suspend fun ensurePublisherConnected(){
if (!isSubscriberPrimary) {
return
}
if (this.publisher.peerConnection.iceConnectionState() == PeerConnection.IceConnectionState.CONNECTED) {
return
}
// start negotiation
this.negotiate()
// wait until publisher ICE connected
val endTime = SystemClock.elapsedRealtime() + MAX_ICE_CONNECT_TIMEOUT_MS;
while (SystemClock.elapsedRealtime() < endTime) {
if (this.publisher.peerConnection.iceConnectionState() == PeerConnection.IceConnectionState.CONNECTED) {
return
}
delay(50)
}
throw ConnectException("could not establish publisher connection")
}
private fun getPublisherOfferConstraints(): MediaConstraints {
return MediaConstraints().apply {
with(mandatory) {
add(MediaConstraints.KeyValuePair("OfferToReceiveAudio", "false"))
add(MediaConstraints.KeyValuePair("OfferToReceiveVideo", "false"))
add(
MediaConstraints.KeyValuePair(
MediaConstraintKeys.OFFER_TO_RECV_AUDIO,
MediaConstraintKeys.FALSE
)
)
add(
MediaConstraints.KeyValuePair(
MediaConstraintKeys.OFFER_TO_RECV_VIDEO,
MediaConstraintKeys.FALSE
)
)
if (iceState == IceState.RECONNECTING) {
add(MediaConstraints.KeyValuePair("IceRestart", "true"))
add(
MediaConstraints.KeyValuePair(
MediaConstraintKeys.ICE_RESTART,
MediaConstraintKeys.TRUE
)
)
}
}
}
}
interface Listener {
fun onJoin(response: LivekitRtc.JoinResponse)
internal interface Listener {
fun onIceConnected()
fun onIceReconnected()
fun onAddTrack(track: MediaStreamTrack, streams: Array<out MediaStream>)
... ... @@ -190,6 +420,7 @@ constructor(
private const val LOSSY_DATA_CHANNEL_LABEL = "_lossy"
internal const val MAX_DATA_PACKET_SIZE = 15000
private const val MAX_SIGNAL_RETRIES = 5
private const val MAX_ICE_CONNECT_TIMEOUT_MS = 5000
internal val CONN_CONSTRAINTS = MediaConstraints().apply {
with(optional) {
... ... @@ -200,90 +431,6 @@ constructor(
//---------------------------------- SignalClient.Listener --------------------------------------//
override fun onJoin(info: LivekitRtc.JoinResponse) {
val iceServers = mutableListOf<PeerConnection.IceServer>()
for(serverInfo in info.iceServersList){
val username = serverInfo.username ?: ""
val credential = serverInfo.credential ?: ""
iceServers.add(
PeerConnection.IceServer
.builder(serverInfo.urlsList)
.setUsername(username)
.setPassword(credential)
.createIceServer()
)
}
if (iceServers.isEmpty()) {
iceServers.addAll(SignalClient.DEFAULT_ICE_SERVERS)
}
info.iceServersList.forEach {
Timber.e{ "username = \"${it.username}\""}
Timber.e{ "credential = \"${it.credential}\""}
Timber.e{ "urls: "}
it.urlsList.forEach{
Timber.e{" $it"}
}
}
val rtcConfig = PeerConnection.RTCConfiguration(iceServers).apply {
sdpSemantics = PeerConnection.SdpSemantics.UNIFIED_PLAN
continualGatheringPolicy = PeerConnection.ContinualGatheringPolicy.GATHER_CONTINUALLY
enableDtlsSrtp = true
}
publisher = pctFactory.create(rtcConfig, publisherObserver)
subscriber = pctFactory.create(rtcConfig, subscriberObserver)
val reliableInit = DataChannel.Init()
reliableInit.ordered = true
reliableDataChannel = publisher.peerConnection.createDataChannel(
RELIABLE_DATA_CHANNEL_LABEL,
reliableInit
)
reliableDataChannel!!.registerObserver(this)
val lossyInit = DataChannel.Init()
lossyInit.ordered = true
lossyInit.maxRetransmits = 1
lossyDataChannel = publisher.peerConnection.createDataChannel(
LOSSY_DATA_CHANNEL_LABEL,
lossyInit
)
lossyDataChannel!!.registerObserver(this)
coroutineScope.launch {
val sdpOffer =
when (val outcome = publisher.peerConnection.createOffer(getOfferConstraints())) {
is Either.Left -> outcome.value
is Either.Right -> {
Timber.d { "error creating offer: ${outcome.value}" }
return@launch
}
}
when (val outcome = publisher.peerConnection.setLocalDescription(sdpOffer)) {
is Either.Right -> {
Timber.d { "error setting local description: ${outcome.value}" }
return@launch
}
}
client.sendOffer(sdpOffer)
}
listener?.onJoin(info)
}
override fun onReconnected() {
Timber.v { "reconnected, restarting ICE" }
wsRetries = 0
// trigger ICE restart
iceState = IceState.RECONNECTING
publisher.prepareForIceRestart()
subscriber.prepareForIceRestart()
negotiate()
}
override fun onAnswer(sessionDescription: SessionDescription) {
Timber.v { "received server answer: ${sessionDescription.type}, ${publisher.peerConnection.signalingState()}" }
coroutineScope.launch {
... ...
... ... @@ -65,16 +65,32 @@ constructor(
get() = mutableActiveSpeakers
private var hasLostConnectivity: Boolean = false
private var connectContinuation: Continuation<Unit>? = null
suspend fun connect(url: String, token: String, options: ConnectOptions?) {
state = State.CONNECTING
engine.join(url, token, options)
val response = engine.join(url, token, options)
Timber.i { "Connected to server, server version: ${response.serverVersion}, client version: ${Version.CLIENT_VERSION}" }
sid = Sid(response.room.sid)
name = response.room.name
if (!response.hasParticipant()) {
listener?.onFailedToConnect(this, RoomException.ConnectException("server didn't return any participants"))
return
}
val lp = localParticipantFactory.create(response.participant)
lp.listener = this
localParticipant = lp
if (response.otherParticipantsList.isNotEmpty()) {
response.otherParticipantsList.forEach {
getOrCreateRemoteParticipant(it.sid, it)
}
}
val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
val networkRequest = NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
.build()
cm.registerNetworkCallback(networkRequest, this)
return suspendCoroutine { connectContinuation = it }
}
fun disconnect() {
... ... @@ -240,36 +256,8 @@ constructor(
//----------------------------------- RTCEngine.Listener ------------------------------------//
/**
* @suppress
*/
override fun onJoin(response: LivekitRtc.JoinResponse) {
Timber.i { "Connected to server, server version: ${response.serverVersion}, client version: ${Version.CLIENT_VERSION}" }
sid = Sid(response.room.sid)
name = response.room.name
if (!response.hasParticipant()) {
listener?.onFailedToConnect(this, RoomException.ConnectException("server didn't return any participants"))
connectContinuation?.resume(Unit)
connectContinuation = null
return
}
val lp = localParticipantFactory.create(response.participant)
lp.listener = this
localParticipant = lp
if (response.otherParticipantsList.isNotEmpty()) {
response.otherParticipantsList.forEach {
getOrCreateRemoteParticipant(it.sid, it)
}
}
}
override fun onIceConnected() {
state = State.CONNECTED
connectContinuation?.resume(Unit)
connectContinuation = null
}
override fun onIceReconnected() {
... ...
... ... @@ -6,7 +6,10 @@ import io.livekit.android.ConnectOptions
import io.livekit.android.Version
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.track.Track
import io.livekit.android.util.Either
import io.livekit.android.util.safe
import kotlinx.coroutines.CancellableContinuation
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
... ... @@ -20,6 +23,8 @@ import org.webrtc.PeerConnection
import org.webrtc.SessionDescription
import javax.inject.Inject
import javax.inject.Named
import kotlin.coroutines.Continuation
import kotlin.coroutines.suspendCoroutine
/**
* SignalClient to LiveKit WS servers
... ... @@ -32,6 +37,7 @@ constructor(
private val fromJsonProtobuf: JsonFormat.Parser,
private val toJsonProtobuf: JsonFormat.Printer,
private val json: Json,
private val okHttpClient: OkHttpClient,
@Named(InjectionNames.SIGNAL_JSON_ENABLED)
private val useJson: Boolean,
) : WebSocketListener() {
... ... @@ -42,11 +48,31 @@ constructor(
var listener: Listener? = null
private var lastUrl: String? = null
fun join(
private var joinContinuation: CancellableContinuation<Either<LivekitRtc.JoinResponse, Unit>>? = null
suspend fun join(
url: String,
token: String,
options: ConnectOptions?,
) {
) : LivekitRtc.JoinResponse {
val joinResponse = connect(url,token, options)
return (joinResponse as Either.Left).value
}
suspend fun reconnect(url: String, token: String){
connect(
url,
token,
ConnectOptions()
.apply { reconnect = true }
)
}
suspend fun connect(
url: String,
token: String,
options: ConnectOptions?
) : Either<LivekitRtc.JoinResponse, Unit> {
var wsUrlString = "$url/rtc" +
"?protocol=$PROTOCOL_VERSION" +
"&access_token=$token" +
... ... @@ -70,12 +96,22 @@ constructor(
isConnected = false
currentWs?.cancel()
currentWs = null
joinContinuation?.cancel()
joinContinuation = null
lastUrl = wsUrlString
val request = Request.Builder()
.url(wsUrlString)
.build()
currentWs = websocketFactory.newWebSocket(request, this)
return suspendCancellableCoroutine {
// Wait for join response through WebSocketListener
joinContinuation = it
}
}
//--------------------------------- WebSocket Listener --------------------------------------//
... ... @@ -83,7 +119,7 @@ constructor(
if (isReconnecting) {
isReconnecting = false
isConnected = true
listener?.onReconnected()
joinContinuation?.resumeWith(Result.success(Either.Right(Unit)))
}
}
... ... @@ -123,7 +159,7 @@ constructor(
substring(2).
replaceFirst("/rtc?", "/rtc/validate?")
val request = Request.Builder().url(validationUrl).build()
val resp = OkHttpClient().newCall(request).execute()
val resp = okHttpClient.newCall(request).execute()
if (!resp.isSuccessful) {
reason = resp.body?.string()
}
... ... @@ -290,7 +326,7 @@ constructor(
// Only handle joins if not connected.
if (response.hasJoin()) {
isConnected = true
listener?.onJoin(response.join)
joinContinuation?.resumeWith(Result.success(Either.Left(response.join)))
} else {
Timber.e { "Received response while not connected. ${toJsonProtobuf.print(response)}" }
}
... ... @@ -351,8 +387,6 @@ constructor(
}
interface Listener {
fun onJoin(info: LivekitRtc.JoinResponse)
fun onReconnected()
fun onAnswer(sessionDescription: SessionDescription)
fun onOffer(sessionDescription: SessionDescription)
fun onTrickle(candidate: IceCandidate, target: LivekitRtc.SignalTarget)
... ...
... ... @@ -8,13 +8,15 @@ import org.webrtc.*
* @suppress
*/
class SubscriberTransportObserver(
private val engine: RTCEngine
private val engine: RTCEngine,
private val client: SignalClient,
) : PeerConnection.Observer {
var iceConnectionChangeListener: ((PeerConnection.IceConnectionState?) -> Unit)? = null
override fun onIceCandidate(candidate: IceCandidate) {
Timber.v { "onIceCandidate: $candidate" }
engine.client.sendCandidate(candidate, LivekitRtc.SignalTarget.SUBSCRIBER)
client.sendCandidate(candidate, LivekitRtc.SignalTarget.SUBSCRIBER)
}
override fun onAddTrack(receiver: RtpReceiver, streams: Array<out MediaStream>) {
... ... @@ -48,8 +50,9 @@ class SubscriberTransportObserver(
override fun onSignalingChange(p0: PeerConnection.SignalingState?) {
}
override fun onIceConnectionChange(p0: PeerConnection.IceConnectionState?) {
Timber.v { "onIceConnection new state: $p0" }
override fun onIceConnectionChange(newState: PeerConnection.IceConnectionState?) {
Timber.v { "onIceConnection new state: $newState" }
iceConnectionChangeListener?.invoke(newState)
}
override fun onIceConnectionReceivingChange(p0: Boolean) {
... ...
... ... @@ -152,7 +152,8 @@ internal constructor(
* @param reliability for delivery guarantee, use RELIABLE. for fastest delivery without guarantee, use LOSSY
* @param destination list of participant SIDs to deliver the payload, null to deliver to everyone
*/
fun publishData(data: ByteArray, reliability: DataPublishReliability, destination: List<String>?) {
@Suppress("unused")
suspend fun publishData(data: ByteArray, reliability: DataPublishReliability, destination: List<String>?) {
if (data.size > RTCEngine.MAX_DATA_PACKET_SIZE) {
throw IllegalArgumentException("cannot publish data larger than " + RTCEngine.MAX_DATA_PACKET_SIZE)
}
... ... @@ -161,11 +162,6 @@ internal constructor(
DataPublishReliability.RELIABLE -> LivekitModels.DataPacket.Kind.RELIABLE
DataPublishReliability.LOSSY -> LivekitModels.DataPacket.Kind.LOSSY
}
val channel = when (reliability) {
DataPublishReliability.RELIABLE -> engine.reliableDataChannel
DataPublishReliability.LOSSY -> engine.lossyDataChannel
} ?: throw TrackException.PublishException("data channel not established")
val packetBuilder = LivekitModels.UserPacket.newBuilder().
setPayload(ByteString.copyFrom(data)).
setParticipantSid(sid)
... ... @@ -176,12 +172,8 @@ internal constructor(
setUser(packetBuilder).
setKind(kind).
build()
val buf = DataChannel.Buffer(
ByteBuffer.wrap(dataPacket.toByteArray()),
true,
)
channel.send(buf)
engine.sendData(dataPacket)
}
override fun updateFromInfo(info: LivekitModels.ParticipantInfo) {
... ...
package io.livekit.android.room.util
import org.webrtc.MediaConstraints
object MediaConstraintKeys {
const val OFFER_TO_RECV_AUDIO = "OfferToReceiveAudio"
const val OFFER_TO_RECV_VIDEO = "OfferToReceiveVideo"
const val ICE_RESTART = "IceRestart"
const val FALSE = "false"
const val TRUE = "true"
}
fun MediaConstraints.findConstraint(key: String): String? {
return mandatory.firstOrNull { it.key == key }?.value
?: optional.firstOrNull { it.key == key }?.value
}
\ No newline at end of file
... ...
package io.livekit.android.util
import kotlinx.coroutines.*
fun <T, R> debounce(
waitMs: Long = 300L,
coroutineScope: CoroutineScope,
destinationFunction: suspend (T) -> R
): (T) -> Unit {
var debounceJob: Deferred<R>? = null
return { param: T ->
debounceJob?.cancel()
debounceJob = coroutineScope.async {
delay(waitMs)
return@async destinationFunction(param)
}
}
}
\ No newline at end of file
... ...
package io.livekit.android.util
import com.google.protobuf.MessageLite
import okio.ByteString
import okio.ByteString.Companion.toByteString
fun MessageLite.toOkioByteString(): ByteString {
val byteArray = toByteArray()
return byteArray.toByteString(0, byteArray.size)
}
\ No newline at end of file
... ...
... ... @@ -6,11 +6,9 @@ import io.livekit.android.room.mock.MockEglBase
import io.livekit.android.room.participant.LocalParticipant
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.TestCoroutineScope
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.coroutines.test.runBlockingTest
import livekit.LivekitModels
import org.junit.Assert
import org.junit.Before
import org.junit.Rule
import org.junit.Test
... ... @@ -72,12 +70,8 @@ class RoomTest {
)
}
room.onIceConnected()
runBlocking {
Assert.assertNotNull(
withTimeoutOrNull(1000) {
runBlockingTest {
job.join()
}
)
}
}
}
\ No newline at end of file
... ...
package io.livekit.android.room
import com.google.protobuf.util.JsonFormat
import io.livekit.android.util.toOkioByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.test.TestCoroutineScope
import kotlinx.coroutines.test.runBlockingTest
import kotlinx.serialization.json.Json
import livekit.LivekitRtc
import okhttp3.*
import okio.ByteString.Companion.toByteString
import org.junit.Assert
import org.junit.Before
import org.junit.Test
import org.mockito.Mockito
import org.mockito.kotlin.verify
@ExperimentalCoroutinesApi
class SignalClientTest {
lateinit var wsFactory: MockWebsocketFactory
lateinit var client: SignalClient
lateinit var listener: SignalClient.Listener
lateinit var okHttpClient: OkHttpClient
class MockWebsocketFactory : WebSocket.Factory {
lateinit var ws: WebSocket
... ... @@ -29,35 +35,64 @@ class SignalClientTest {
@Before
fun setup() {
wsFactory = MockWebsocketFactory()
okHttpClient = Mockito.mock(OkHttpClient::class.java)
client = SignalClient(
wsFactory,
JsonFormat.parser(),
JsonFormat.printer(),
Json,
useJson = false
useJson = false,
okHttpClient = okHttpClient,
)
listener = Mockito.mock(SignalClient.Listener::class.java)
client.listener = listener
}
fun join() {
client.join("http://www.example.com", "", null)
private fun createOpenResponse(request: Request): Response {
return Response.Builder()
.request(request)
.code(200)
.protocol(Protocol.HTTP_2)
.message("")
.build()
}
@Test
fun joinAndResponse() {
join()
val job = TestCoroutineScope().async {
client.join("http://www.example.com", "", null)
}
client.onOpen(
wsFactory.ws,
Response.Builder()
.request(wsFactory.request)
.code(200)
.protocol(Protocol.HTTP_2)
.message("")
.build()
createOpenResponse(wsFactory.request)
)
client.onMessage(wsFactory.ws, JOIN.toOkioByteString())
val response = with(LivekitRtc.SignalResponse.newBuilder()) {
runBlockingTest {
val response = job.await()
Assert.assertEquals(response, JOIN.join)
}
}
@Test
fun reconnect() {
val job = TestCoroutineScope().async {
client.reconnect("http://www.example.com", "")
}
client.onOpen(
wsFactory.ws,
createOpenResponse(wsFactory.request)
)
runBlockingTest {
job.await()
}
}
// mock data
companion object {
private val EXAMPLE_URL = "http://www.example.com"
private val JOIN = with(LivekitRtc.SignalResponse.newBuilder()) {
join = with(joinBuilder) {
room = with(roomBuilder) {
name = "roomname"
... ... @@ -68,11 +103,5 @@ class SignalClientTest {
}
build()
}
val byteArray = response.toByteArray()
val byteString = byteArray.toByteString(0, byteArray.size)
client.onMessage(wsFactory.ws, byteString)
verify(listener).onJoin(response.join)
}
}
\ No newline at end of file
... ...