David Zhao

Handle network changes with reconnection, v0.6.0

... ... @@ -23,7 +23,7 @@ kotlin.code.style=official
###############################################################
GROUP=io.livekit
VERSION_NAME=0.5.1
VERSION_NAME=0.6.0
POM_DESCRIPTION=Android SDK for WebRTC communication
... ...
<manifest package="io.livekit.android" />
<manifest package="io.livekit.android"
xmlns:android="http://schemas.android.com/apk/res/android">
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
<uses-permission android:name="android.permission.INTERNET" />
</manifest>
... ...
... ... @@ -19,13 +19,13 @@ class LiveKit {
options: ConnectOptions,
listener: RoomListener?
): Room {
val ctx = appContext.applicationContext
val component = DaggerLiveKitComponent
.factory()
.create(appContext.applicationContext)
.create(ctx)
val room = component.roomFactory()
.create(options)
.create(options, ctx)
room.listener = listener
room.connect(url, token)
... ...
... ... @@ -23,12 +23,11 @@ class PublisherTransportObserver(
override fun onIceConnectionChange(newState: PeerConnection.IceConnectionState?) {
val state = newState ?: throw NullPointerException("unexpected null new state, what do?")
Timber.v { "onIceConnection new state: $newState" }
if (state == PeerConnection.IceConnectionState.CONNECTED && !engine.iceConnected) {
engine.iceConnected = true
engine.listener?.onICEConnected()
if (state == PeerConnection.IceConnectionState.CONNECTED) {
engine.iceState = IceState.CONNECTED
} else if (state == PeerConnection.IceConnectionState.FAILED) {
// when we publish tracks, some WebRTC versions will send out disconnected events periodically
engine.iceConnected = false
engine.iceState = IceState.DISCONNECTED
engine.listener?.onDisconnect("Peer connection disconnected")
}
}
... ...
... ... @@ -22,6 +22,7 @@ import javax.inject.Inject
import javax.inject.Named
/**
* SignalClient to LiveKit WS servers
* @suppress
*/
class RTCClient
... ... @@ -37,24 +38,39 @@ constructor(
var isConnected = false
private set
private var currentWs: WebSocket? = null
private var isReconnecting: Boolean = false
var listener: Listener? = null
fun join(
url: String,
token: String,
reconnect: Boolean = false
) {
val wsUrlString = "$url/rtc?protocol=$PROTOCOL_VERSION&access_token=$token"
var wsUrlString = "$url/rtc?protocol=$PROTOCOL_VERSION&access_token=$token"
if (reconnect) {
wsUrlString += "&reconnect=1"
}
Timber.i { "connecting to $wsUrlString" }
isReconnecting = reconnect
isConnected = false
currentWs?.cancel()
val request = Request.Builder()
.url(wsUrlString)
.build()
currentWs = websocketFactory.newWebSocket(request, this)
}
//--------------------------------- WebSocket Listener --------------------------------------//
override fun onOpen(webSocket: WebSocket, response: Response) {
Timber.v { response.message }
super.onOpen(webSocket, response)
if (isReconnecting) {
isReconnecting = false
isConnected = true
listener?.onReconnected()
}
}
override fun onMessage(webSocket: WebSocket, text: String) {
... ... @@ -91,6 +107,7 @@ constructor(
super.onFailure(webSocket, t, response)
}
//------------------------------- End WebSocket Listener ------------------------------------//
fun fromProtoSessionDescription(sd: LivekitRtc.SessionDescription): SessionDescription {
val rtcSdpType = when (sd.type) {
... ... @@ -292,6 +309,7 @@ constructor(
interface Listener {
fun onJoin(info: LivekitRtc.JoinResponse)
fun onReconnected()
fun onAnswer(sessionDescription: SessionDescription)
fun onOffer(sessionDescription: SessionDescription)
fun onTrickle(candidate: IceCandidate, target: LivekitRtc.SignalTarget)
... ...
... ... @@ -8,6 +8,7 @@ import io.livekit.android.util.CloseableCoroutineScope
import io.livekit.android.util.Either
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import livekit.LivekitModels
import livekit.LivekitRtc
... ... @@ -30,12 +31,36 @@ constructor(
private val pctFactory: PeerConnectionTransport.Factory,
@Named(InjectionNames.DISPATCHER_IO) ioDispatcher: CoroutineDispatcher,
) : RTCClient.Listener, DataChannel.Observer {
var listener: Listener? = null
var rtcConnected: Boolean = false
var iceConnected: Boolean = false
internal var iceState: IceState = IceState.DISCONNECTED
set(value) {
val oldVal = field
field = value
if (value == oldVal) {
return
}
when (value) {
IceState.CONNECTED -> {
if (oldVal == IceState.DISCONNECTED) {
Timber.d { "publisher ICE connected" }
listener?.onIceConnected()
} else if (oldVal == IceState.RECONNECTING) {
Timber.d { "publisher ICE reconnected" }
listener?.onIceReconnected()
}
}
IceState.DISCONNECTED -> {
Timber.d { "publisher ICE disconnected" }
listener?.onDisconnect("Peer connection disconnected")
}
else -> {}
}
}
private var wsRetries: Int = 0
private val pendingTrackResolvers: MutableMap<String, Continuation<LivekitModels.TrackInfo>> =
mutableMapOf()
private var sessionUrl: String? = null
private var sessionToken: String? = null
private val publisherObserver = PublisherTransportObserver(this)
private val subscriberObserver = SubscriberTransportObserver(this)
... ... @@ -50,6 +75,8 @@ constructor(
}
fun join(url: String, token: String) {
sessionUrl = url
sessionToken = token
client.join(url, token)
}
... ... @@ -75,13 +102,39 @@ constructor(
client.close()
}
fun negotiate() {
/**
* reconnect Signal and PeerConnections
*/
internal fun reconnect() {
if (sessionUrl == null || sessionToken == null) {
return
}
if (iceState == IceState.DISCONNECTED || wsRetries >= MAX_SIGNAL_RETRIES) {
Timber.w { "could not connect to signal after max attempts, giving up" }
close()
listener?.onDisconnect("could not reconnect after limit")
return
}
var startDelay = wsRetries.toLong() * wsRetries * 500
if (startDelay > 5000) {
startDelay = 5000
}
coroutineScope.launch {
delay(startDelay)
if (iceState != IceState.DISCONNECTED && sessionUrl != null && sessionToken != null) {
client.join(sessionUrl!!, sessionToken!!, true)
}
}
}
internal fun negotiate() {
if (!client.isConnected) {
return
}
coroutineScope.launch {
val sdpOffer =
when (val outcome = publisher.peerConnection.createOffer(OFFER_CONSTRAINTS)) {
when (val outcome = publisher.peerConnection.createOffer(getOfferConstraints())) {
is Either.Left -> outcome.value
is Either.Right -> {
Timber.d { "error creating offer: ${outcome.value}" }
... ... @@ -101,16 +154,23 @@ constructor(
}
}
private fun onRTCConnected() {
Timber.v { "RTC Connected" }
rtcConnected = true
private fun getOfferConstraints(): MediaConstraints {
return MediaConstraints().apply {
with(mandatory) {
add(MediaConstraints.KeyValuePair("OfferToReceiveAudio", "false"))
add(MediaConstraints.KeyValuePair("OfferToReceiveVideo", "false"))
if (iceState == IceState.RECONNECTING) {
add(MediaConstraints.KeyValuePair("IceRestart", "true"))
}
}
}
}
interface Listener {
fun onJoin(response: LivekitRtc.JoinResponse)
fun onICEConnected()
fun onIceConnected()
fun onIceReconnected()
fun onAddTrack(track: MediaStreamTrack, streams: Array<out MediaStream>)
// fun onPublishLocalTrack(cid: String, track: LivekitModels.TrackInfo)
fun onUpdateParticipants(updates: List<LivekitModels.ParticipantInfo>)
fun onUpdateSpeakers(speakers: List<LivekitRtc.SpeakerInfo>)
fun onDisconnect(reason: String)
... ... @@ -122,15 +182,7 @@ constructor(
private const val RELIABLE_DATA_CHANNEL_LABEL = "_reliable"
private const val LOSSY_DATA_CHANNEL_LABEL = "_lossy"
internal const val MAX_DATA_PACKET_SIZE = 15000
private val OFFER_CONSTRAINTS = MediaConstraints().apply {
with(mandatory) {
add(MediaConstraints.KeyValuePair("OfferToReceiveAudio", "false"))
add(MediaConstraints.KeyValuePair("OfferToReceiveVideo", "false"))
}
}
private val MEDIA_CONSTRAINTS = MediaConstraints()
private const val MAX_SIGNAL_RETRIES = 5
internal val CONN_CONSTRAINTS = MediaConstraints().apply {
with(optional) {
... ... @@ -155,7 +207,7 @@ constructor(
)
}
if(iceServers.isEmpty()){
if (iceServers.isEmpty()) {
iceServers.addAll(RTCClient.DEFAULT_ICE_SERVERS)
}
info.iceServersList.forEach {
... ... @@ -192,7 +244,7 @@ constructor(
coroutineScope.launch {
val sdpOffer =
when (val outcome = publisher.peerConnection.createOffer(OFFER_CONSTRAINTS)) {
when (val outcome = publisher.peerConnection.createOffer(getOfferConstraints())) {
is Either.Left -> outcome.value
is Either.Right -> {
Timber.d { "error creating offer: ${outcome.value}" }
... ... @@ -212,14 +264,25 @@ constructor(
listener?.onJoin(info)
}
override fun onReconnected() {
Timber.v { "reconnected, restarting ICE" }
wsRetries = 0
// trigger ICE restart
iceState = IceState.RECONNECTING
negotiate()
}
override fun onAnswer(sessionDescription: SessionDescription) {
Timber.v { "received server answer: ${sessionDescription.type}, ${publisher.peerConnection.signalingState()}" }
coroutineScope.launch {
Timber.i { sessionDescription.toString() }
when (val outcome = publisher.setRemoteDescription(sessionDescription)) {
is Either.Left -> {
if (!rtcConnected) {
onRTCConnected()
// when reconnecting, ICE might not have disconnected and won't trigger
// our connected callback, so we'll take a shortcut and set it to active
if (iceState == IceState.RECONNECTING) {
iceState = IceState.CONNECTED
}
}
is Either.Right -> {
... ... @@ -243,7 +306,7 @@ constructor(
}
val answer = run {
when (val outcome = subscriber.peerConnection.createAnswer(OFFER_CONSTRAINTS)) {
when (val outcome = subscriber.peerConnection.createAnswer(MediaConstraints())) {
is Either.Left -> outcome.value
is Either.Right -> {
Timber.e { "error creating answer: ${outcome.value}" }
... ... @@ -275,11 +338,10 @@ constructor(
}
override fun onLocalTrackPublished(response: LivekitRtc.TrackPublishedResponse) {
val signalCid = response.cid ?: run {
val cid = response.cid ?: run {
Timber.e { "local track published with null cid?" }
return
}
val cid = signalCid
val track = response.track
if (track == null) {
... ... @@ -293,7 +355,6 @@ constructor(
return
}
cont.resume(response.track)
// listener?.onPublishLocalTrack(cid, track)
}
override fun onParticipantUpdate(updates: List<LivekitModels.ParticipantInfo>) {
... ... @@ -346,3 +407,9 @@ constructor(
}
}
}
internal enum class IceState {
DISCONNECTED,
RECONNECTING,
CONNECTED,
}
\ No newline at end of file
... ...
package io.livekit.android.room
import android.content.Context
import android.net.ConnectivityManager
import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkRequest
import com.github.ajalt.timberkt.Timber
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
... ... @@ -22,10 +27,11 @@ class Room
@AssistedInject
constructor(
@Assisted private val connectOptions: ConnectOptions,
@Assisted private val context: Context,
private val engine: RTCEngine,
private val eglBase: EglBase,
private val localParticipantFactory: LocalParticipant.Factory
) : RTCEngine.Listener, ParticipantListener {
) : RTCEngine.Listener, ParticipantListener, ConnectivityManager.NetworkCallback() {
init {
engine.listener = this
}
... ... @@ -57,10 +63,16 @@ constructor(
val activeSpeakers: List<Participant>
get() = mutableActiveSpeakers
private var hasLostConnectivity: Boolean = false
private var connectContinuation: Continuation<Unit>? = null
suspend fun connect(url: String, token: String) {
state = State.CONNECTING
engine.join(url, token)
val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
val networkRequest = NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
.build()
cm.registerNetworkCallback(networkRequest, this)
return suspendCoroutine { connectContinuation = it }
}
... ... @@ -136,7 +148,18 @@ constructor(
listener?.onActiveSpeakersChanged(speakers, this)
}
private fun reconnect() {
if (state == State.RECONNECTING) {
return
}
state = State.RECONNECTING
engine.reconnect()
listener?.onReconnecting(this)
}
private fun handleDisconnect() {
val cm = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
cm.unregisterNetworkCallback(this)
for (pub in localParticipant.tracks.values) {
pub.track?.stop()
}
... ... @@ -149,6 +172,7 @@ constructor(
engine.close()
state = State.DISCONNECTED
listener?.onDisconnect(this, null)
listener = null
}
/**
... ... @@ -156,9 +180,33 @@ constructor(
*/
@AssistedFactory
interface Factory {
fun create(connectOptions: ConnectOptions): Room
fun create(connectOptions: ConnectOptions, context: Context): Room
}
//------------------------------------- NetworkCallback -------------------------------------//
/**
* @suppress
*/
override fun onLost(network: Network) {
// lost connection, flip to reconnecting
hasLostConnectivity = true
}
/**
* @suppress
*/
override fun onAvailable(network: Network) {
// only actually reconnect after connection is re-established
if (!hasLostConnectivity) {
return
}
Timber.i { "network connection available, reconnecting" }
reconnect()
hasLostConnectivity = false
}
//----------------------------------- RTCEngine.Listener ------------------------------------//
/**
* @suppress
... ... @@ -186,12 +234,17 @@ constructor(
}
}
override fun onICEConnected() {
override fun onIceConnected() {
state = State.CONNECTED
connectContinuation?.resume(Unit)
connectContinuation = null
}
override fun onIceReconnected() {
state = State.CONNECTED
listener?.onReconnected(this)
}
/**
* @suppress
*/
... ... @@ -346,6 +399,17 @@ constructor(
*/
interface RoomListener {
/**
* A network change has been detected and LiveKit attempts to reconnect to the room
* When reconnect attempts succeed, the room state will be kept, including tracks that are subscribed/published
*/
fun onReconnecting(room: Room) {}
/**
* The reconnect attempt had been successful
*/
fun onReconnected(room: Room) {}
/**
* Disconnected from room
*/
fun onDisconnect(room: Room, error: Exception?) {}
... ...