davidliu
Committed by GitHub

Allow setting of preferred video codec when publishing (#223)

<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="KotlinJpsPluginSettings">
<option name="version" value="1.7.10" />
</component>
</project>
\ No newline at end of file
... ...
... ... @@ -141,7 +141,7 @@ dependencies {
lintPublish project(':livekit-lint')
testImplementation 'junit:junit:4.13.2'
testImplementation 'org.robolectric:robolectric:4.6'
testImplementation 'org.robolectric:robolectric:4.10.2'
testImplementation 'org.mockito:mockito-core:4.0.0'
testImplementation "org.mockito.kotlin:mockito-kotlin:4.0.0"
testImplementation 'androidx.test:core:1.4.0'
... ...
... ... @@ -22,6 +22,8 @@ object InjectionNames {
*/
internal const val DISPATCHER_UNCONFINED = "dispatcher_unconfined"
internal const val SENDER = "sender"
internal const val OPTIONS_VIDEO_HW_ACCEL = "options_video_hw_accel"
// Overrides
... ...
... ... @@ -17,6 +17,7 @@ import timber.log.Timber
import javax.inject.Named
import javax.inject.Singleton
typealias CapabilitiesGetter = @JvmSuppressWildcards (MediaStreamTrack.MediaType) -> RtpCapabilities
@Module
object RTCModule {
... ... @@ -193,6 +194,14 @@ object RTCModule {
}
@Provides
@Named(InjectionNames.SENDER)
fun senderCapabilitiesGetter(peerConnectionFactory: PeerConnectionFactory): CapabilitiesGetter {
return { mediaType: MediaStreamTrack.MediaType ->
peerConnectionFactory.getRtpSenderCapabilities(mediaType)
}
}
@Provides
@Named(InjectionNames.OPTIONS_VIDEO_HW_ACCEL)
fun videoHwAccel() = true
}
\ No newline at end of file
... ...
... ... @@ -7,6 +7,7 @@ import com.google.protobuf.ByteString
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
import io.livekit.android.dagger.CapabilitiesGetter
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.events.ParticipantEvent
import io.livekit.android.room.ConnectionState
... ... @@ -19,6 +20,7 @@ import kotlinx.coroutines.CoroutineDispatcher
import livekit.LivekitModels
import livekit.LivekitRtc
import org.webrtc.*
import org.webrtc.RtpCapabilities.CodecCapability
import javax.inject.Named
import kotlin.math.max
... ... @@ -36,6 +38,8 @@ internal constructor(
private val defaultsManager: DefaultsManager,
@Named(InjectionNames.DISPATCHER_DEFAULT)
coroutineDispatcher: CoroutineDispatcher,
@Named(InjectionNames.SENDER)
private val capabilitiesGetter: CapabilitiesGetter,
) : Participant("", null, coroutineDispatcher) {
var audioTrackCaptureDefaults: LocalAudioTrackOptions by defaultsManager::audioTrackCaptureDefaults
... ... @@ -303,7 +307,46 @@ internal constructor(
return false
}
// TODO: enable setting preferred codec
if (options is VideoTrackPublishOptions && options.videoCodec != null) {
val targetCodec = options.videoCodec.lowercase()
val capabilities = capabilitiesGetter(MediaStreamTrack.MediaType.MEDIA_TYPE_VIDEO)
LKLog.v { "capabilities:" }
capabilities.codecs.forEach { codec ->
LKLog.v { "codec: ${codec.name}, ${codec.kind}, ${codec.mimeType}, ${codec.parameters}, ${codec.preferredPayloadType}" }
}
val matched = mutableListOf<CodecCapability>()
val partialMatched = mutableListOf<CodecCapability>()
val unmatched = mutableListOf<CodecCapability>()
for (codec in capabilities.codecs) {
val mimeType = codec.mimeType.lowercase()
if (mimeType == "audio/opus") {
matched.add(codec)
continue
}
if (mimeType != "video/$targetCodec") {
unmatched.add(codec)
continue
}
// for h264 codecs that have sdpFmtpLine available, use only if the
// profile-level-id is 42e01f for cross-browser compatibility
if (targetCodec == "h264") {
if (codec.parameters["profile-level-id"] == "42e01f") {
matched.add(codec)
} else {
partialMatched.add(codec)
}
continue
} else {
matched.add(codec)
}
}
transceiver.setCodecPreferences(matched.plus(partialMatched).plus(unmatched))
}
val publication = LocalTrackPublication(
info = trackInfo,
... ... @@ -620,9 +663,6 @@ internal constructor(
interface Factory {
fun create(dynacast: Boolean): LocalParticipant
}
companion object {
}
}
internal fun LocalParticipant.publishTracksInfo(): List<LivekitRtc.TrackPublishedResponse> {
... ... @@ -643,18 +683,24 @@ interface TrackPublishOptions {
abstract class BaseVideoTrackPublishOptions {
abstract val videoEncoding: VideoEncoding?
abstract val simulcast: Boolean
//val videoCodec: VideoCodec? = null,
/**
* The video codec to use if available.
*/
abstract val videoCodec: String?
}
data class VideoTrackPublishDefaults(
override val videoEncoding: VideoEncoding? = null,
override val simulcast: Boolean = true
override val simulcast: Boolean = true,
override val videoCodec: String? = null,
) : BaseVideoTrackPublishOptions()
data class VideoTrackPublishOptions(
override val name: String? = null,
override val videoEncoding: VideoEncoding? = null,
override val simulcast: Boolean = true
override val simulcast: Boolean = true,
override val videoCodec: String? = null,
) : BaseVideoTrackPublishOptions(), TrackPublishOptions {
constructor(
name: String? = null,
... ... @@ -662,7 +708,8 @@ data class VideoTrackPublishOptions(
) : this(
name,
base.videoEncoding,
base.simulcast
base.simulcast,
base.videoCodec,
)
}
... ...
package io.livekit.android.mock
import org.webrtc.*
import org.webrtc.DataChannel
import org.webrtc.IceCandidate
import org.webrtc.MediaConstraints
import org.webrtc.MediaStream
import org.webrtc.MediaStreamTrack
import org.webrtc.NativePeerConnectionFactory
import org.webrtc.PeerConnection
import org.webrtc.RTCStatsCollectorCallback
import org.webrtc.RTCStatsReport
import org.webrtc.RtcCertificatePem
import org.webrtc.RtpReceiver
import org.webrtc.RtpSender
import org.webrtc.RtpTransceiver
import org.webrtc.SdpObserver
import org.webrtc.SessionDescription
import org.webrtc.StatsObserver
private class MockNativePeerConnectionFactory : NativePeerConnectionFactory {
override fun createNativePeerConnection(): Long = 0L
... ... @@ -14,6 +29,8 @@ class MockPeerConnection(
private var closed = false
var localDesc: SessionDescription? = null
var remoteDesc: SessionDescription? = null
private val transceivers = mutableListOf<RtpTransceiver>()
override fun getLocalDescription(): SessionDescription? = localDesc
override fun setLocalDescription(observer: SdpObserver?, sdp: SessionDescription?) {
localDesc = sdp
... ... @@ -85,7 +102,7 @@ class MockPeerConnection(
}
override fun getTransceivers(): List<RtpTransceiver> {
return emptyList()
return transceivers
}
override fun addTrack(track: MediaStreamTrack?): RtpSender {
... ... @@ -100,15 +117,19 @@ class MockPeerConnection(
return super.removeTrack(sender)
}
override fun addTransceiver(track: MediaStreamTrack?): RtpTransceiver {
return super.addTransceiver(track)
override fun addTransceiver(track: MediaStreamTrack): RtpTransceiver {
val transceiver = MockRtpTransceiver.create(track, RtpTransceiver.RtpTransceiverInit())
transceivers.add(transceiver)
return transceiver
}
override fun addTransceiver(
track: MediaStreamTrack,
init: RtpTransceiver.RtpTransceiverInit?
): RtpTransceiver {
return MockRtpTransceiver.create(track, init ?: RtpTransceiver.RtpTransceiverInit())
val transceiver = MockRtpTransceiver.create(track, init ?: RtpTransceiver.RtpTransceiverInit())
transceivers.add(transceiver)
return transceiver
}
override fun addTransceiver(mediaType: MediaStreamTrack.MediaType?): RtpTransceiver {
... ... @@ -177,6 +198,7 @@ class MockPeerConnection(
IceConnectionState.CHECKING -> PeerConnectionState.CONNECTING
IceConnectionState.CONNECTED,
IceConnectionState.COMPLETED -> PeerConnectionState.CONNECTED
IceConnectionState.DISCONNECTED -> PeerConnectionState.DISCONNECTED
IceConnectionState.FAILED -> PeerConnectionState.FAILED
IceConnectionState.CLOSED -> PeerConnectionState.CLOSED
... ... @@ -216,6 +238,7 @@ class MockPeerConnection(
iceConnectionState = newState
}
}
IceConnectionState.FAILED,
IceConnectionState.DISCONNECTED,
IceConnectionState.CLOSED -> {
... ...
package io.livekit.android.mock
import android.content.Context
import org.webrtc.CapturerObserver
import org.webrtc.SurfaceTextureHelper
import org.webrtc.VideoCapturer
class MockVideoCapturer : VideoCapturer {
override fun initialize(p0: SurfaceTextureHelper?, p1: Context?, p2: CapturerObserver?) {
}
override fun startCapture(p0: Int, p1: Int, p2: Int) {
}
override fun stopCapture() {
}
override fun changeCaptureFormat(p0: Int, p1: Int, p2: Int) {
}
override fun dispose() {
}
override fun isScreencast(): Boolean {
return false
}
}
\ No newline at end of file
... ...
package io.livekit.android.mock
import org.webrtc.VideoSource
class MockVideoSource(nativeSource: Long = 100) : VideoSource(nativeSource) {
}
\ No newline at end of file
... ...
... ... @@ -8,7 +8,8 @@ import okio.IOException
class MockWebSocket(
private val request: Request,
private val listener: WebSocketListener
private val listener: WebSocketListener,
private val onSend: ((ByteString) -> Unit)?
) : WebSocket {
var isClosed = false
... ... @@ -45,6 +46,7 @@ class MockWebSocket(
return false
}
mutableSentRequests.add(bytes)
onSend?.invoke(bytes)
return !isClosed
}
... ...
package io.livekit.android.mock
import io.livekit.android.util.toOkioByteString
import io.livekit.android.util.toPBByteString
import livekit.LivekitModels
import livekit.LivekitRtc
import okhttp3.Request
import okhttp3.WebSocket
import okhttp3.WebSocketListener
... ... @@ -20,7 +24,25 @@ class MockWebSocketFactory : WebSocket.Factory {
*/
lateinit var listener: WebSocketListener
override fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket {
this.ws = MockWebSocket(request, listener)
this.ws = MockWebSocket(request, listener) { byteString ->
val signalRequest = LivekitRtc.SignalRequest.parseFrom(byteString.toPBByteString())
if (signalRequest.hasAddTrack()) {
val addTrack = signalRequest.addTrack
val trackPublished = with(LivekitRtc.SignalResponse.newBuilder()) {
trackPublished = with(LivekitRtc.TrackPublishedResponse.newBuilder()) {
cid = addTrack.cid
if (addTrack.type == LivekitModels.TrackType.AUDIO) {
track = TestData.LOCAL_AUDIO_TRACK
} else {
track = TestData.LOCAL_VIDEO_TRACK
}
build()
}
build()
}
this.listener.onMessage(this.ws, trackPublished.toOkioByteString())
}
}
this.listener = listener
this.request = request
... ...
... ... @@ -9,6 +9,11 @@ object TestData {
type = LivekitModels.TrackType.AUDIO
build()
}
val LOCAL_VIDEO_TRACK = with(LivekitModels.TrackInfo.newBuilder()) {
sid = "local_video_track_sid"
type = LivekitModels.TrackType.VIDEO
build()
}
val REMOTE_AUDIO_TRACK = with(LivekitModels.TrackInfo.newBuilder()) {
sid = "remote_audio_track_sid"
... ...
... ... @@ -3,6 +3,7 @@ package io.livekit.android.mock.dagger
import android.content.Context
import dagger.Module
import dagger.Provides
import io.livekit.android.dagger.CapabilitiesGetter
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.mock.MockEglBase
import org.webrtc.*
... ... @@ -34,6 +35,14 @@ object TestRTCModule {
}
@Provides
@Named(InjectionNames.SENDER)
fun senderCapabilitiesGetter(peerConnectionFactory: PeerConnectionFactory): CapabilitiesGetter {
return { mediaType: MediaStreamTrack.MediaType ->
peerConnectionFactory.getRtpSenderCapabilities(mediaType)
}
}
@Provides
@Named(InjectionNames.OPTIONS_VIDEO_HW_ACCEL)
fun videoHwAccel() = true
}
\ No newline at end of file
... ...
... ... @@ -300,16 +300,12 @@ class RoomMockE2ETest : MockE2ETest() {
fun disconnectCleansLocalParticipant() = runTest {
connect()
val publishJob = launch {
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
)
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
)
}
wsFactory.listener.onMessage(wsFactory.ws, SignalClientTest.LOCAL_TRACK_PUBLISHED.toOkioByteString())
publishJob.join()
)
val eventCollector = EventCollector(room.events, coroutineRule.scope)
room.disconnect()
... ...
... ... @@ -6,7 +6,6 @@ import io.livekit.android.mock.MockPeerConnection
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.util.toPBByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import livekit.LivekitRtc
import org.junit.Assert
import org.junit.Assert.assertEquals
... ... @@ -91,16 +90,12 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
connect()
// publish track
val publishJob = launch {
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
)
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
)
}
simulateMessageFromServer(SignalClientTest.LOCAL_TRACK_PUBLISHED)
publishJob.join()
)
prepareForReconnect()
disconnectPeerConnection()
... ...
... ... @@ -6,17 +6,27 @@ import io.livekit.android.events.EventCollector
import io.livekit.android.events.ParticipantEvent
import io.livekit.android.events.RoomEvent
import io.livekit.android.mock.MockAudioStreamTrack
import io.livekit.android.mock.MockEglBase
import io.livekit.android.mock.MockVideoCapturer
import io.livekit.android.mock.MockVideoStreamTrack
import io.livekit.android.room.DefaultsManager
import io.livekit.android.room.SignalClientTest
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.room.track.LocalVideoTrack
import io.livekit.android.room.track.LocalVideoTrackOptions
import io.livekit.android.room.track.VideoCaptureParameter
import io.livekit.android.util.toOkioByteString
import io.livekit.android.util.toPBByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import livekit.LivekitRtc
import org.junit.Assert.*
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito
import org.mockito.Mockito.mock
import org.mockito.kotlin.argThat
import org.robolectric.RobolectricTestRunner
import org.webrtc.VideoSource
@ExperimentalCoroutinesApi
@RunWith(RobolectricTestRunner::class)
... ... @@ -26,16 +36,12 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
fun disconnectCleansLocalParticipant() = runTest {
connect()
val publishJob = launch {
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
)
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
)
}
wsFactory.listener.onMessage(wsFactory.ws, SignalClientTest.LOCAL_TRACK_PUBLISHED.toOkioByteString())
publishJob.join()
)
room.disconnect()
... ... @@ -123,4 +129,55 @@ class LocalParticipantMockE2ETest : MockE2ETest() {
participantEvents
)
}
private fun createLocalTrack() = LocalVideoTrack(
capturer = MockVideoCapturer(),
source = mock(VideoSource::class.java),
name = "",
options = LocalVideoTrackOptions(
isScreencast = false,
deviceId = null,
position = null,
captureParams = VideoCaptureParameter(width = 0, height = 0, maxFps = 0)
),
rtcTrack = MockVideoStreamTrack(),
peerConnectionFactory = component.peerConnectionFactory(),
context = context,
eglBase = MockEglBase(),
defaultsManager = DefaultsManager(),
trackFactory = mock(LocalVideoTrack.Factory::class.java)
)
@Test
fun publishSetCodecPreferencesH264() = runTest {
room.videoTrackPublishDefaults = room.videoTrackPublishDefaults.copy(videoCodec = "h264")
connect()
room.localParticipant.publishVideoTrack(track = createLocalTrack())
val peerConnection = component.rtcEngine().publisher.peerConnection
val transceiver = peerConnection.transceivers.first()
Mockito.verify(transceiver).setCodecPreferences(argThat { codecs ->
val preferredCodec = codecs.first()
return@argThat preferredCodec.name.lowercase() == "h264" &&
preferredCodec.parameters["profile-level-id"] == "42e01f"
})
}
@Test
fun publishSetCodecPreferencesVP8() = runTest {
room.videoTrackPublishDefaults = room.videoTrackPublishDefaults.copy(videoCodec = "vp8")
connect()
room.localParticipant.publishVideoTrack(track = createLocalTrack())
val peerConnection = component.rtcEngine().publisher.peerConnection
val transceiver = peerConnection.transceivers.first()
Mockito.verify(transceiver).setCodecPreferences(argThat { codecs ->
val preferredCodec = codecs.first()
return@argThat preferredCodec.name.lowercase() == "vp8"
})
}
}
\ No newline at end of file
... ...
... ... @@ -10,7 +10,6 @@ import io.livekit.android.room.SignalClientTest
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.util.toOkioByteString
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
... ... @@ -26,16 +25,12 @@ class ParticipantMockE2ETest : MockE2ETest() {
connect()
// publish track
val publishJob = launch {
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
)
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
"",
MockAudioStreamTrack(id = SignalClientTest.LOCAL_TRACK_PUBLISHED.trackPublished.cid)
)
}
simulateMessageFromServer(SignalClientTest.LOCAL_TRACK_PUBLISHED)
publishJob.join()
)
val eventCollector = EventCollector(room.events, coroutineRule.scope)
// remote unpublish
... ...
package org.webrtc
import io.livekit.android.mock.MockPeerConnection
import io.livekit.android.mock.MockVideoSource
import io.livekit.android.mock.MockVideoStreamTrack
class MockPeerConnectionFactory : PeerConnectionFactory(1L) {
override fun createPeerConnectionInternal(
... ... @@ -11,4 +13,42 @@ class MockPeerConnectionFactory : PeerConnectionFactory(1L) {
): PeerConnection {
return MockPeerConnection(rtcConfig, observer)
}
override fun createVideoSource(isScreencast: Boolean, alignTimestamps: Boolean): VideoSource {
return MockVideoSource()
}
override fun createVideoSource(isScreencast: Boolean): VideoSource {
return MockVideoSource()
}
override fun createVideoTrack(id: String, source: VideoSource?): VideoTrack {
return MockVideoStreamTrack(id = id)
}
override fun getRtpSenderCapabilities(mediaType: MediaStreamTrack.MediaType): RtpCapabilities {
return RtpCapabilities(
listOf(
RtpCapabilities.CodecCapability().apply {
name = "VP8"
mimeType = "video/VP8"
kind = MediaStreamTrack.MediaType.MEDIA_TYPE_VIDEO
parameters = emptyMap()
},
RtpCapabilities.CodecCapability().apply {
name = "H264"
mimeType = "video/H264"
kind = MediaStreamTrack.MediaType.MEDIA_TYPE_VIDEO
parameters = mapOf("profile-level-id" to "640c1f")
},
RtpCapabilities.CodecCapability().apply {
name = "H264"
mimeType = "video/H264"
kind = MediaStreamTrack.MediaType.MEDIA_TYPE_VIDEO
parameters = mapOf("profile-level-id" to "42e01f")
},
),
emptyList()
)
}
}
\ No newline at end of file
... ...