davidliu
Committed by GitHub

Signal local audio track feature updates (#456)

* Signal local audio track feature updates

* Fix npe

* Fix custom audio processing factory
... ... @@ -16,28 +16,62 @@
package io.livekit.android.audio
import io.livekit.android.util.FlowObservable
/**
* Interface for controlling external audio processing.
*/
interface AudioProcessingController {
/**
* Set the audio processing to be used for capture post.
* the audio processor to be used for capture post processing.
*/
@FlowObservable
@get:FlowObservable
var capturePostProcessor: AudioProcessorInterface?
/**
* the audio processor to be used for render pre processing.
*/
@FlowObservable
@get:FlowObservable
var renderPreProcessor: AudioProcessorInterface?
/**
* whether to bypass mode the render pre processing.
*/
@FlowObservable
@get:FlowObservable
var bypassRenderPreProcessing: Boolean
/**
* whether to bypass the capture post processing.
*/
@FlowObservable
@get:FlowObservable
var bypassCapturePostProcessing: Boolean
/**
* Set the audio processor to be used for capture post processing.
*/
@Deprecated("Use the capturePostProcessing variable directly instead")
fun setCapturePostProcessing(processing: AudioProcessorInterface?)
/**
* Set whether to bypass mode the capture post processing.
* Set whether to bypass the capture post processing.
*/
@Deprecated("Use the bypassCapturePostProcessing variable directly instead")
fun setBypassForCapturePostProcessing(bypass: Boolean)
/**
* Set the audio processing to be used for render pre.
* Set the audio processor to be used for render pre processing.
*/
@Deprecated("Use the renderPreProcessing variable directly instead")
fun setRenderPreProcessing(processing: AudioProcessorInterface?)
/**
* Set whether to bypass mode the render pre processing.
* Set whether to bypass the render pre processing.
*/
@Deprecated("Use the bypassRendererPreProcessing variable directly instead")
fun setBypassForRenderPreProcessing(bypass: Boolean)
}
... ...
... ... @@ -41,12 +41,3 @@ data class AudioProcessorOptions(
*/
val renderPreBypass: Boolean = false,
)
internal fun AudioProcessorOptions.authenticateProcessors(url: String, token: String) {
if (capturePostProcessor is AuthedAudioProcessorInterface) {
capturePostProcessor.authenticate(url, token)
}
if (renderPreProcessor is AuthedAudioProcessorInterface) {
renderPreProcessor.authenticate(url, token)
}
}
... ...
... ... @@ -55,6 +55,7 @@ import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.yield
import livekit.LivekitModels
import livekit.LivekitModels.AudioTrackFeature
import livekit.LivekitRtc
import livekit.LivekitRtc.JoinResponse
import livekit.LivekitRtc.ReconnectResponse
... ... @@ -348,6 +349,10 @@ internal constructor(
client.sendMuteTrack(sid, muted)
}
fun updateLocalAudioTrack(sid: String, features: Collection<AudioTrackFeature>) {
client.sendUpdateLocalAudioTrack(sid, features)
}
fun close(reason: String = "Normal Closure") {
if (isClosed) {
return
... ...
... ... @@ -29,22 +29,34 @@ import io.livekit.android.util.CloseableCoroutineScope
import io.livekit.android.util.Either
import io.livekit.android.util.LKLog
import io.livekit.android.webrtc.toProtoSessionDescription
import kotlinx.coroutines.*
import kotlinx.coroutines.CancellableContinuation
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import livekit.LivekitModels
import livekit.LivekitModels.AudioTrackFeature
import livekit.LivekitRtc
import livekit.LivekitRtc.JoinResponse
import livekit.LivekitRtc.ReconnectResponse
import livekit.org.webrtc.IceCandidate
import livekit.org.webrtc.PeerConnection
import livekit.org.webrtc.SessionDescription
import okhttp3.*
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener
import okio.ByteString
import okio.ByteString.Companion.toByteString
import java.util.*
import java.util.Date
import javax.inject.Inject
import javax.inject.Named
import javax.inject.Singleton
... ... @@ -552,6 +564,19 @@ constructor(
return time
}
fun sendUpdateLocalAudioTrack(trackSid: String, features: Collection<AudioTrackFeature>) {
val request = with(LivekitRtc.SignalRequest.newBuilder()) {
updateAudioTrack = with(LivekitRtc.UpdateLocalAudioTrack.newBuilder()) {
setTrackSid(trackSid)
addAllFeatures(features)
build()
}
build()
}
sendRequest(request)
}
private fun sendRequest(request: LivekitRtc.SignalRequest) {
val skipQueue = skipQueueTypes.contains(request.messageCase)
... ...
... ... @@ -47,8 +47,10 @@ import io.livekit.android.room.track.VideoCodec
import io.livekit.android.room.track.VideoEncoding
import io.livekit.android.room.util.EncodingUtils
import io.livekit.android.util.LKLog
import io.livekit.android.util.flow
import io.livekit.android.webrtc.sortVideoCodecPreferences
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import livekit.LivekitModels
import livekit.LivekitRtc
... ... @@ -76,6 +78,7 @@ internal constructor(
private val eglBase: EglBase,
private val screencastVideoTrackFactory: LocalScreencastVideoTrack.Factory,
private val videoTrackFactory: LocalVideoTrack.Factory,
private val audioTrackFactory: LocalAudioTrack.Factory,
private val defaultsManager: DefaultsManager,
@Named(InjectionNames.DISPATCHER_DEFAULT)
coroutineDispatcher: CoroutineDispatcher,
... ... @@ -94,6 +97,8 @@ internal constructor(
.mapNotNull { it as? LocalTrackPublication }
.toList()
private val jobs = mutableMapOf<Any, Job>()
/**
* Creates an audio track, recording audio through the microphone with the given [options].
*
... ... @@ -103,7 +108,7 @@ internal constructor(
name: String = "",
options: LocalAudioTrackOptions = audioTrackCaptureDefaults,
): LocalAudioTrack {
return LocalAudioTrack.createTrack(context, peerConnectionFactory, options, name)
return LocalAudioTrack.createTrack(context, peerConnectionFactory, options, audioTrackFactory, name)
}
/**
... ... @@ -295,7 +300,7 @@ internal constructor(
}
},
)
publishTrackImpl(
val publication = publishTrackImpl(
track = track,
options = options,
requestConfig = {
... ... @@ -306,6 +311,15 @@ internal constructor(
encodings = encodings,
publishListener = publishListener,
)
if (publication != null) {
val job = scope.launch {
track::features.flow.collect {
engine.updateLocalAudioTrack(publication.sid, it)
}
}
jobs[publication] = job
}
}
/**
... ... @@ -379,14 +393,14 @@ internal constructor(
requestConfig: AddTrackRequest.Builder.() -> Unit,
encodings: List<RtpParameters.Encoding> = emptyList(),
publishListener: PublishListener? = null,
): Boolean {
): LocalTrackPublication? {
@Suppress("NAME_SHADOWING") var options = options
@Suppress("NAME_SHADOWING") var encodings = encodings
if (localTrackPublications.any { it.track == track }) {
publishListener?.onPublishFailure(TrackException.PublishException("Track has already been published"))
return false
return null
}
val cid = track.rtcTrack.id()
... ... @@ -435,7 +449,7 @@ internal constructor(
if (transceiver == null) {
publishListener?.onPublishFailure(TrackException.PublishException("null sender returned from peer connection"))
return false
return null
}
track.statsGetter = engine.createStatsGetter(transceiver.sender)
... ... @@ -475,7 +489,7 @@ internal constructor(
internalListener?.onTrackPublished(publication, this)
eventBus.postEvent(ParticipantEvent.LocalTrackPublished(this, publication), scope)
return true
return publication
}
private fun computeVideoEncodings(
... ... @@ -606,6 +620,12 @@ internal constructor(
return
}
val publicationJob = jobs[publication]
if (publicationJob != null) {
publicationJob.cancel()
jobs.remove(publicationJob)
}
val sid = publication.sid
trackPublications = trackPublications.toMutableMap().apply { remove(sid) }
... ...
... ... @@ -20,23 +20,50 @@ import android.Manifest
import android.content.Context
import android.content.pm.PackageManager
import androidx.core.content.ContextCompat
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
import io.livekit.android.audio.AudioProcessingController
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.participant.LocalParticipant
import io.livekit.android.util.FlowObservable
import io.livekit.android.util.flow
import io.livekit.android.util.flowDelegate
import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.stateIn
import livekit.LivekitModels.AudioTrackFeature
import livekit.org.webrtc.MediaConstraints
import livekit.org.webrtc.PeerConnectionFactory
import livekit.org.webrtc.RtpSender
import livekit.org.webrtc.RtpTransceiver
import java.util.UUID
import javax.inject.Named
/**
* Represents a local audio track (generally using the microphone as input).
*
* This class should not be constructed directly, but rather through [LocalParticipant.createAudioTrack].
*/
class LocalAudioTrack(
name: String,
mediaTrack: livekit.org.webrtc.AudioTrack
class LocalAudioTrack
@AssistedInject
constructor(
@Assisted name: String,
@Assisted mediaTrack: livekit.org.webrtc.AudioTrack,
@Assisted private val options: LocalAudioTrackOptions,
private val audioProcessingController: AudioProcessingController,
@Named(InjectionNames.DISPATCHER_DEFAULT)
private val dispatcher: CoroutineDispatcher,
) : AudioTrack(name, mediaTrack) {
/**
* To only be used for flow delegate scoping, and should not be cancelled.
**/
private val delegateScope = CoroutineScope(dispatcher + SupervisorJob())
var enabled: Boolean
get() = executeBlockingOnRTCThread { rtcTrack.enabled() }
set(value) {
... ... @@ -47,12 +74,52 @@ class LocalAudioTrack(
internal val sender: RtpSender?
get() = transceiver?.sender
/**
* Changes can be observed by using [io.livekit.android.util.flow]
*/
@FlowObservable
@get:FlowObservable
val features by flowDelegate(
stateFlow = combine(
audioProcessingController::capturePostProcessor.flow,
audioProcessingController::bypassCapturePostProcessing.flow,
) { processor, bypass ->
processor to bypass
}
.map {
val features = getConstantFeatures()
val (processor, bypass) = it
if (!bypass && processor?.getName() == "krisp_noise_cancellation") {
features.add(AudioTrackFeature.TF_ENHANCED_NOISE_CANCELLATION)
}
return@map features
}
.stateIn(delegateScope, SharingStarted.Eagerly, emptySet()),
)
private fun getConstantFeatures(): MutableSet<AudioTrackFeature> {
val features = mutableSetOf<AudioTrackFeature>()
if (options.echoCancellation) {
features.add(AudioTrackFeature.TF_ECHO_CANCELLATION)
}
if (options.noiseSuppression) {
features.add(AudioTrackFeature.TF_NOISE_SUPPRESSION)
}
if (options.autoGainControl) {
features.add(AudioTrackFeature.TF_AUTO_GAIN_CONTROL)
}
// TODO: Handle getting other info from JavaAudioDeviceModule
return features
}
companion object {
internal fun createTrack(
context: Context,
factory: PeerConnectionFactory,
options: LocalAudioTrackOptions = LocalAudioTrackOptions(),
name: String = ""
audioTrackFactory: Factory,
name: String = "",
): LocalAudioTrack {
if (ContextCompat.checkSelfPermission(context, Manifest.permission.RECORD_AUDIO) !=
PackageManager.PERMISSION_GRANTED
... ... @@ -74,7 +141,16 @@ class LocalAudioTrack(
val rtcAudioTrack =
factory.createAudioTrack(UUID.randomUUID().toString(), audioSource)
return LocalAudioTrack(name = name, mediaTrack = rtcAudioTrack)
return audioTrackFactory.create(name = name, mediaTrack = rtcAudioTrack, options = options)
}
}
@AssistedFactory
interface Factory {
fun create(
name: String,
mediaTrack: livekit.org.webrtc.AudioTrack,
options: LocalAudioTrackOptions,
): LocalAudioTrack
}
}
... ...
... ... @@ -19,29 +19,43 @@ package io.livekit.android.webrtc
import io.livekit.android.audio.AudioProcessorInterface
import io.livekit.android.audio.AudioProcessorOptions
import io.livekit.android.audio.AuthedAudioProcessingController
import io.livekit.android.audio.authenticateProcessors
import io.livekit.android.audio.AuthedAudioProcessorInterface
import io.livekit.android.util.flowDelegate
import livekit.org.webrtc.AudioProcessingFactory
import livekit.org.webrtc.ExternalAudioProcessingFactory
import java.nio.ByteBuffer
class CustomAudioProcessingFactory(private var audioProcessorOptions: AudioProcessorOptions) : AuthedAudioProcessingController {
/**
* @suppress
*/
internal class CustomAudioProcessingFactory() : AuthedAudioProcessingController {
constructor(audioProcessorOptions: AudioProcessorOptions) : this() {
capturePostProcessor = audioProcessorOptions.capturePostProcessor
renderPreProcessor = audioProcessorOptions.renderPreProcessor
bypassCapturePostProcessing = audioProcessorOptions.capturePostBypass
bypassRenderPreProcessing = audioProcessorOptions.renderPreBypass
}
private val externalAudioProcessor = ExternalAudioProcessingFactory()
init {
if (audioProcessorOptions.capturePostProcessor != null) {
setCapturePostProcessing(audioProcessorOptions.capturePostProcessor)
} else {
setCapturePostProcessing(null)
setBypassForCapturePostProcessing(false)
}
if (audioProcessorOptions.renderPreProcessor != null) {
setRenderPreProcessing(audioProcessorOptions.renderPreProcessor)
setBypassForRenderPreProcessing(audioProcessorOptions.renderPreBypass)
} else {
setRenderPreProcessing(null)
setBypassForRenderPreProcessing(false)
}
override var capturePostProcessor: AudioProcessorInterface? by flowDelegate(null) { value, _ ->
externalAudioProcessor.setCapturePostProcessing(
value.toAudioProcessing(),
)
}
override var renderPreProcessor: AudioProcessorInterface? by flowDelegate(null) { value, _ ->
externalAudioProcessor.setRenderPreProcessing(
value.toAudioProcessing(),
)
}
override var bypassCapturePostProcessing: Boolean by flowDelegate(false) { value, _ ->
externalAudioProcessor.setBypassFlagForCapturePost(value)
}
override var bypassRenderPreProcessing: Boolean by flowDelegate(false) { value, _ ->
externalAudioProcessor.setBypassFlagForRenderPre(value)
}
fun getAudioProcessingFactory(): AudioProcessingFactory {
... ... @@ -49,31 +63,28 @@ class CustomAudioProcessingFactory(private var audioProcessorOptions: AudioProce
}
override fun authenticate(url: String, token: String) {
audioProcessorOptions.authenticateProcessors(url, token)
(capturePostProcessor as? AuthedAudioProcessorInterface)?.authenticate(url, token)
(renderPreProcessor as? AuthedAudioProcessorInterface)?.authenticate(url, token)
}
@Deprecated("Use the capturePostProcessing variable directly instead", ReplaceWith("capturePostProcessor = processing"))
override fun setCapturePostProcessing(processing: AudioProcessorInterface?) {
audioProcessorOptions = audioProcessorOptions.copy(capturePostProcessor = processing)
externalAudioProcessor.setCapturePostProcessing(
processing.toAudioProcessing(),
)
capturePostProcessor = processing
}
override fun setBypassForCapturePostProcessing(bypass: Boolean) {
audioProcessorOptions = audioProcessorOptions.copy(capturePostBypass = bypass)
externalAudioProcessor.setBypassFlagForCapturePost(bypass)
@Deprecated("Use the renderPreProcessing variable directly instead", ReplaceWith("renderPreProcessor = processing"))
override fun setRenderPreProcessing(processing: AudioProcessorInterface?) {
renderPreProcessor = processing
}
override fun setRenderPreProcessing(processing: AudioProcessorInterface?) {
audioProcessorOptions = audioProcessorOptions.copy(renderPreProcessor = processing)
externalAudioProcessor.setRenderPreProcessing(
processing.toAudioProcessing(),
)
@Deprecated("Use the bypassCapturePostProcessing variable directly instead", ReplaceWith("bypassCapturePostProcessing = bypass"))
override fun setBypassForCapturePostProcessing(bypass: Boolean) {
bypassCapturePostProcessing = bypass
}
@Deprecated("Use the bypassRendererPreProcessing variable directly instead", ReplaceWith("bypassRenderPreProcessing = bypass"))
override fun setBypassForRenderPreProcessing(bypass: Boolean) {
audioProcessorOptions = audioProcessorOptions.copy(renderPreBypass = bypass)
externalAudioProcessor.setBypassFlagForRenderPre(bypass)
bypassRenderPreProcessing = bypass
}
private class AudioProcessingBridge(
... ...
... ... @@ -18,8 +18,27 @@ package io.livekit.android.test.mock
import io.livekit.android.audio.AudioProcessingController
import io.livekit.android.audio.AudioProcessorInterface
import io.livekit.android.util.FlowObservable
import io.livekit.android.util.flowDelegate
class MockAudioProcessingController : AudioProcessingController {
@FlowObservable
@get:FlowObservable
override var capturePostProcessor: AudioProcessorInterface? by flowDelegate(null)
@FlowObservable
@get:FlowObservable
override var renderPreProcessor: AudioProcessorInterface? by flowDelegate(null)
@FlowObservable
@get:FlowObservable
override var bypassRenderPreProcessing: Boolean by flowDelegate(false)
@FlowObservable
@get:FlowObservable
override var bypassCapturePostProcessing: Boolean by flowDelegate(false)
override fun setCapturePostProcessing(processing: AudioProcessorInterface?) {
}
... ...
... ... @@ -17,14 +17,18 @@
package io.livekit.android.room
import android.net.Network
import io.livekit.android.events.*
import io.livekit.android.events.DisconnectReason
import io.livekit.android.events.RoomEvent
import io.livekit.android.events.convert
import io.livekit.android.room.participant.ConnectionQuality
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.room.track.LocalAudioTrackOptions
import io.livekit.android.room.track.Track
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.assert.assertIsClassList
import io.livekit.android.test.events.EventCollector
import io.livekit.android.test.events.FlowCollector
import io.livekit.android.test.mock.MockAudioProcessingController
import io.livekit.android.test.mock.MockAudioStreamTrack
import io.livekit.android.test.mock.MockMediaStream
import io.livekit.android.test.mock.MockRtpReceiver
... ... @@ -335,8 +339,11 @@ class RoomMockE2ETest : MockE2ETest() {
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = MockAudioProcessingController(),
dispatcher = coroutineRule.dispatcher,
),
)
... ... @@ -381,8 +388,11 @@ class RoomMockE2ETest : MockE2ETest() {
}
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = MockAudioProcessingController(),
dispatcher = coroutineRule.dispatcher,
),
)
... ...
... ... @@ -17,7 +17,9 @@
package io.livekit.android.room
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.room.track.LocalAudioTrackOptions
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.mock.MockAudioProcessingController
import io.livekit.android.test.mock.MockAudioStreamTrack
import io.livekit.android.test.mock.TestData
import io.livekit.android.test.util.toPBByteString
... ... @@ -103,8 +105,11 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
// publish track
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = MockAudioProcessingController(),
dispatcher = coroutineRule.dispatcher,
),
)
... ...
... ... @@ -21,10 +21,12 @@ import io.livekit.android.events.RoomEvent
import io.livekit.android.events.TrackPublicationEvent
import io.livekit.android.room.participant.AudioTrackPublishOptions
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.room.track.LocalAudioTrackOptions
import io.livekit.android.room.track.Track
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.assert.assertIsClass
import io.livekit.android.test.events.EventCollector
import io.livekit.android.test.mock.MockAudioProcessingController
import io.livekit.android.test.mock.MockAudioStreamTrack
import io.livekit.android.test.mock.MockDataChannel
import io.livekit.android.test.mock.MockPeerConnection
... ... @@ -41,8 +43,11 @@ class RoomTranscriptionMockE2ETest : MockE2ETest() {
connect()
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = MockAudioProcessingController(),
dispatcher = coroutineRule.dispatcher,
),
options = AudioTrackPublishOptions(
source = Track.Source.MICROPHONE,
... ...
... ... @@ -16,10 +16,12 @@
package io.livekit.android.room.participant
import io.livekit.android.audio.AudioProcessorInterface
import io.livekit.android.events.ParticipantEvent
import io.livekit.android.events.RoomEvent
import io.livekit.android.room.DefaultsManager
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.room.track.LocalAudioTrackOptions
import io.livekit.android.room.track.LocalVideoTrack
import io.livekit.android.room.track.LocalVideoTrackOptions
import io.livekit.android.room.track.Track
... ... @@ -28,6 +30,7 @@ import io.livekit.android.room.track.VideoCodec
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.assert.assertIsClassList
import io.livekit.android.test.events.EventCollector
import io.livekit.android.test.mock.MockAudioProcessingController
import io.livekit.android.test.mock.MockAudioStreamTrack
import io.livekit.android.test.mock.MockEglBase
import io.livekit.android.test.mock.MockVideoCapturer
... ... @@ -36,19 +39,25 @@ import io.livekit.android.test.mock.TestData
import io.livekit.android.test.util.toPBByteString
import io.livekit.android.util.toOkioByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.advanceUntilIdle
import livekit.LivekitModels
import livekit.LivekitModels.AudioTrackFeature
import livekit.LivekitRtc
import livekit.LivekitRtc.SubscribedCodec
import livekit.LivekitRtc.SubscribedQuality
import livekit.org.webrtc.RtpParameters
import livekit.org.webrtc.VideoSource
import org.junit.Assert.*
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertNull
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito
import org.mockito.Mockito.mock
import org.mockito.kotlin.argThat
import org.robolectric.RobolectricTestRunner
import java.nio.ByteBuffer
@ExperimentalCoroutinesApi
@RunWith(RobolectricTestRunner::class)
... ... @@ -60,8 +69,11 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = MockAudioProcessingController(),
dispatcher = coroutineRule.dispatcher,
),
)
... ... @@ -348,4 +360,80 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
assertEquals(preference, transceiver.sender.parameters.degradationPreference)
}
@Test
fun sendsInitialAudioTrackFeatures() = runTest {
connect()
wsFactory.ws.clearRequests()
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = MockAudioProcessingController(),
dispatcher = coroutineRule.dispatcher,
),
)
advanceUntilIdle()
assertEquals(2, wsFactory.ws.sentRequests.size)
// Verify the update audio track request gets the proper publish options set.
val requestString = wsFactory.ws.sentRequests[1].toPBByteString()
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(requestString)
.build()
assertTrue(sentRequest.hasUpdateAudioTrack())
val features = sentRequest.updateAudioTrack.featuresList
assertTrue(features.contains(AudioTrackFeature.TF_ECHO_CANCELLATION))
assertTrue(features.contains(AudioTrackFeature.TF_NOISE_SUPPRESSION))
assertTrue(features.contains(AudioTrackFeature.TF_AUTO_GAIN_CONTROL))
}
@Test
fun sendsUpdatedAudioTrackFeatures() = runTest {
connect()
val audioProcessingController = MockAudioProcessingController()
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = audioProcessingController,
dispatcher = coroutineRule.dispatcher,
),
)
advanceUntilIdle()
wsFactory.ws.clearRequests()
audioProcessingController.capturePostProcessor = object : AudioProcessorInterface {
override fun isEnabled(): Boolean = true
override fun getName(): String = "krisp_noise_cancellation"
override fun initializeAudioProcessing(sampleRateHz: Int, numChannels: Int) {}
override fun resetAudioProcessing(newRate: Int) {}
override fun processAudio(numBands: Int, numFrames: Int, buffer: ByteBuffer) {}
}
assertEquals(1, wsFactory.ws.sentRequests.size)
// Verify the update audio track request gets the proper publish options set.
val requestString = wsFactory.ws.sentRequests[0].toPBByteString()
val sentRequest = LivekitRtc.SignalRequest.newBuilder()
.mergeFrom(requestString)
.build()
assertTrue(sentRequest.hasUpdateAudioTrack())
val features = sentRequest.updateAudioTrack.featuresList
assertTrue(features.contains(AudioTrackFeature.TF_ECHO_CANCELLATION))
assertTrue(features.contains(AudioTrackFeature.TF_NOISE_SUPPRESSION))
assertTrue(features.contains(AudioTrackFeature.TF_AUTO_GAIN_CONTROL))
assertTrue(features.contains(AudioTrackFeature.TF_ENHANCED_NOISE_CANCELLATION))
}
}
... ...
... ... @@ -19,12 +19,13 @@ package io.livekit.android.room.participant
import io.livekit.android.events.ParticipantEvent
import io.livekit.android.events.RoomEvent
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.room.track.LocalAudioTrackOptions
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.assert.assertIsClassList
import io.livekit.android.test.events.EventCollector
import io.livekit.android.test.mock.MockAudioProcessingController
import io.livekit.android.test.mock.MockAudioStreamTrack
import io.livekit.android.test.mock.TestData
import io.livekit.android.test.util.toOkioByteString
import io.livekit.android.util.toOkioByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import org.junit.Assert.assertEquals
... ... @@ -43,8 +44,11 @@ class ParticipantMockE2ETest : MockE2ETest() {
// publish track
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = MockAudioProcessingController(),
dispatcher = coroutineRule.dispatcher,
),
)
... ...