davidliu
Committed by GitHub

Transcription events feature (#440)

... ... @@ -60,7 +60,7 @@ android {
buildConfig = true
}
kotlinOptions {
freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn"]
freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn", "-opt-in=io.livekit.android.annotations.Beta"]
jvmTarget = java_version
}
... ...
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.annotations
@Retention(AnnotationRetention.BINARY)
@RequiresOptIn
annotation class Experimental
@Retention(AnnotationRetention.BINARY)
@RequiresOptIn
annotation class Alpha
@Retention(AnnotationRetention.BINARY)
@RequiresOptIn
annotation class Beta
... ...
/*
* Copyright 2023 LiveKit, Inc.
* Copyright 2023-2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
... ... @@ -24,6 +24,7 @@ import io.livekit.android.room.track.LocalTrackPublication
import io.livekit.android.room.track.RemoteTrackPublication
import io.livekit.android.room.track.Track
import io.livekit.android.room.track.TrackPublication
import io.livekit.android.room.types.TranscriptionSegment
sealed class ParticipantEvent(open val participant: Participant) : Event() {
// all participants
... ... @@ -152,4 +153,16 @@ sealed class ParticipantEvent(open val participant: Participant) : Event() {
val newPermissions: ParticipantPermission?,
val oldPermissions: ParticipantPermission?,
) : ParticipantEvent(participant)
class TranscriptionReceived(
override val participant: Participant,
/**
* The transcription segments.
*/
val transcriptions: List<TranscriptionSegment>,
/**
* The applicable track publication these transcriptions apply to.
*/
val publication: TrackPublication?,
) : ParticipantEvent(participant)
}
... ...
... ... @@ -16,6 +16,7 @@
package io.livekit.android.events
import io.livekit.android.annotations.Beta
import io.livekit.android.e2ee.E2EEState
import io.livekit.android.room.Room
import io.livekit.android.room.participant.ConnectionQuality
... ... @@ -27,6 +28,7 @@ import io.livekit.android.room.track.LocalTrackPublication
import io.livekit.android.room.track.RemoteTrackPublication
import io.livekit.android.room.track.Track
import io.livekit.android.room.track.TrackPublication
import io.livekit.android.room.types.TranscriptionSegment
import livekit.LivekitModels
sealed class RoomEvent(val room: Room) : Event() {
... ... @@ -219,6 +221,23 @@ sealed class RoomEvent(val room: Room) : Event() {
val participant: Participant,
var state: E2EEState,
) : RoomEvent(room)
@Beta
class TranscriptionReceived(
room: Room,
/**
* The transcription segments.
*/
val transcriptionSegments: List<TranscriptionSegment>,
/**
* The applicable participant these transcriptions apply to.
*/
val participant: Participant?,
/**
* The applicable track publication these transcriptions apply to.
*/
val publication: TrackPublication?,
) : RoomEvent(room)
}
enum class DisconnectReason {
... ...
... ... @@ -758,6 +758,7 @@ internal constructor(
fun onFullReconnecting()
suspend fun onPostReconnect(isFullReconnect: Boolean)
fun onLocalTrackUnpublished(trackUnpublished: LivekitRtc.TrackUnpublishedResponse)
fun onTranscriptionReceived(transcription: LivekitModels.Transcription)
}
companion object {
... ... @@ -981,7 +982,7 @@ internal constructor(
}
LivekitModels.DataPacket.ValueCase.TRANSCRIPTION -> {
// TODO
listener?.onTranscriptionReceived(dp.transcription)
}
LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET,
... ...
... ... @@ -42,6 +42,7 @@ import io.livekit.android.room.network.NetworkCallbackManagerFactory
import io.livekit.android.room.participant.*
import io.livekit.android.room.provisions.LKObjects
import io.livekit.android.room.track.*
import io.livekit.android.room.types.toSDKType
import io.livekit.android.util.FlowObservable
import io.livekit.android.util.LKLog
import io.livekit.android.util.flow
... ... @@ -1007,6 +1008,26 @@ constructor(
/**
* @suppress
*/
override fun onTranscriptionReceived(transcription: LivekitModels.Transcription) {
val participant = getParticipantByIdentity(transcription.transcribedParticipantIdentity)
val publication = participant?.trackPublications?.get(transcription.trackId)
val segments = transcription.segmentsList
.map { it.toSDKType() }
val event = RoomEvent.TranscriptionReceived(
room = this,
transcriptionSegments = segments,
participant = participant,
publication = publication,
)
eventBus.tryPostEvent(event)
participant?.onTranscriptionReceived(event)
// TODO: Emit for publication
}
/**
* @suppress
*/
override fun onStreamStateUpdate(streamStates: List<LivekitRtc.StreamStateInfo>) {
for (streamState in streamStates) {
val participant = getParticipantBySid(streamState.participantSid) ?: continue
... ...
... ... @@ -20,6 +20,7 @@ import androidx.annotation.VisibleForTesting
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.events.BroadcastEventBus
import io.livekit.android.events.ParticipantEvent
import io.livekit.android.events.RoomEvent
import io.livekit.android.events.TrackEvent
import io.livekit.android.room.track.LocalTrackPublication
import io.livekit.android.room.track.RemoteTrackPublication
... ... @@ -366,6 +367,20 @@ open class Participant(
)
}
internal fun onTranscriptionReceived(transcription: RoomEvent.TranscriptionReceived) {
if (transcription.participant != this) {
return
}
eventBus.postEvent(
ParticipantEvent.TranscriptionReceived(
this,
transcriptions = transcription.transcriptionSegments,
publication = transcription.publication,
),
scope,
)
}
internal fun reinitialize() {
if (!scope.isActive) {
scope = createScope()
... ...
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.room.types
import io.livekit.android.util.LKLog
import livekit.LivekitModels
data class TranscriptionSegment(
val id: String,
val text: String,
val language: String,
val startTime: Long,
val endTime: Long,
val final: Boolean,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as TranscriptionSegment
return id == other.id
}
override fun hashCode(): Int {
return id.hashCode()
}
}
/**
* Merges new segments into the map. The key should correspond to the segment id.
*/
fun MutableMap<String, TranscriptionSegment>.mergeNewSegments(newSegments: Collection<TranscriptionSegment>) {
for (segment in newSegments) {
val existingSegment = get(segment.id)
if (existingSegment?.final == true) {
LKLog.d { "new segment for ${segment.id} overwriting final segment?" }
}
put(segment.id, segment)
}
}
/**
* @suppress
*/
fun LivekitModels.TranscriptionSegment.toSDKType() =
TranscriptionSegment(
id = id,
text = text,
language = language,
startTime = startTime,
endTime = endTime,
final = final,
)
... ...
... ... @@ -31,7 +31,7 @@ android {
targetCompatibility java_version
}
kotlinOptions {
freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn"]
freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn", "-opt-in=io.livekit.android.annotations.Beta"]
jvmTarget = java_version
}
testOptions {
... ...
... ... @@ -302,4 +302,25 @@ object TestData {
}
build()
}
// Data packets
val DATA_PACKET_TRANSCRIPTION = with(LivekitModels.DataPacket.newBuilder()) {
transcription = with(LivekitModels.Transcription.newBuilder()) {
transcribedParticipantIdentity = JOIN.join.participant.identity // Local participant's identity
addSegments(
with(LivekitModels.TranscriptionSegment.newBuilder()) {
id = "id"
language = "enUS"
text = "This is a transcription."
startTime = 1
endTime = 10
final = true
build()
},
)
build()
}
build()
}
}
... ...
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.test.util
import com.google.protobuf.MessageLite
import livekit.org.webrtc.DataChannel
import java.nio.ByteBuffer
fun MessageLite.toDataChannelBuffer() =
DataChannel.Buffer(
ByteBuffer.wrap(toByteArray()),
true,
)
... ...
/*
* Copyright 2023-2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.livekit.android.room
import io.livekit.android.events.RoomEvent
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.assert.assertIsClass
import io.livekit.android.test.events.EventCollector
import io.livekit.android.test.mock.MockDataChannel
import io.livekit.android.test.mock.MockPeerConnection
import io.livekit.android.test.mock.TestData
import io.livekit.android.test.util.toDataChannelBuffer
import kotlinx.coroutines.ExperimentalCoroutinesApi
import org.junit.Assert.assertEquals
import org.junit.Test
@OptIn(ExperimentalCoroutinesApi::class)
class RoomTranscriptionMockE2ETest : MockE2ETest() {
@Test
fun transcriptionReceived() = runTest {
connect()
val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection
val subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL)
subPeerConnection.observer?.onDataChannel(subDataChannel)
val collector = EventCollector(room.events, coroutineRule.scope)
val dataBuffer = TestData.DATA_PACKET_TRANSCRIPTION.toDataChannelBuffer()
subDataChannel.observer?.onMessage(dataBuffer)
val events = collector.stopCollecting()
assertEquals(1, events.size)
assertIsClass(RoomEvent.TranscriptionReceived::class.java, events[0])
val event = events.first() as RoomEvent.TranscriptionReceived
assertEquals(room, event.room)
assertEquals(room.localParticipant, event.participant)
val expectedSegment = TestData.DATA_PACKET_TRANSCRIPTION.transcription.getSegments(0)
val receivedSegment = event.transcriptionSegments.first()
assertEquals(expectedSegment.id, receivedSegment.id)
assertEquals(expectedSegment.text, receivedSegment.text)
}
}
... ...