davidliu
Committed by GitHub

Transcription events feature (#440)

@@ -60,7 +60,7 @@ android { @@ -60,7 +60,7 @@ android {
60 buildConfig = true 60 buildConfig = true
61 } 61 }
62 kotlinOptions { 62 kotlinOptions {
63 - freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn"] 63 + freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn", "-opt-in=io.livekit.android.annotations.Beta"]
64 jvmTarget = java_version 64 jvmTarget = java_version
65 } 65 }
66 66
  1 +/*
  2 + * Copyright 2024 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.annotations
  18 +
  19 +@Retention(AnnotationRetention.BINARY)
  20 +@RequiresOptIn
  21 +annotation class Experimental
  22 +
  23 +@Retention(AnnotationRetention.BINARY)
  24 +@RequiresOptIn
  25 +annotation class Alpha
  26 +
  27 +@Retention(AnnotationRetention.BINARY)
  28 +@RequiresOptIn
  29 +annotation class Beta
1 /* 1 /*
2 - * Copyright 2023 LiveKit, Inc. 2 + * Copyright 2023-2024 LiveKit, Inc.
3 * 3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License. 5 * you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import io.livekit.android.room.track.LocalTrackPublication @@ -24,6 +24,7 @@ import io.livekit.android.room.track.LocalTrackPublication
24 import io.livekit.android.room.track.RemoteTrackPublication 24 import io.livekit.android.room.track.RemoteTrackPublication
25 import io.livekit.android.room.track.Track 25 import io.livekit.android.room.track.Track
26 import io.livekit.android.room.track.TrackPublication 26 import io.livekit.android.room.track.TrackPublication
  27 +import io.livekit.android.room.types.TranscriptionSegment
27 28
28 sealed class ParticipantEvent(open val participant: Participant) : Event() { 29 sealed class ParticipantEvent(open val participant: Participant) : Event() {
29 // all participants 30 // all participants
@@ -152,4 +153,16 @@ sealed class ParticipantEvent(open val participant: Participant) : Event() { @@ -152,4 +153,16 @@ sealed class ParticipantEvent(open val participant: Participant) : Event() {
152 val newPermissions: ParticipantPermission?, 153 val newPermissions: ParticipantPermission?,
153 val oldPermissions: ParticipantPermission?, 154 val oldPermissions: ParticipantPermission?,
154 ) : ParticipantEvent(participant) 155 ) : ParticipantEvent(participant)
  156 +
  157 + class TranscriptionReceived(
  158 + override val participant: Participant,
  159 + /**
  160 + * The transcription segments.
  161 + */
  162 + val transcriptions: List<TranscriptionSegment>,
  163 + /**
  164 + * The applicable track publication these transcriptions apply to.
  165 + */
  166 + val publication: TrackPublication?,
  167 + ) : ParticipantEvent(participant)
155 } 168 }
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 16
17 package io.livekit.android.events 17 package io.livekit.android.events
18 18
  19 +import io.livekit.android.annotations.Beta
19 import io.livekit.android.e2ee.E2EEState 20 import io.livekit.android.e2ee.E2EEState
20 import io.livekit.android.room.Room 21 import io.livekit.android.room.Room
21 import io.livekit.android.room.participant.ConnectionQuality 22 import io.livekit.android.room.participant.ConnectionQuality
@@ -27,6 +28,7 @@ import io.livekit.android.room.track.LocalTrackPublication @@ -27,6 +28,7 @@ import io.livekit.android.room.track.LocalTrackPublication
27 import io.livekit.android.room.track.RemoteTrackPublication 28 import io.livekit.android.room.track.RemoteTrackPublication
28 import io.livekit.android.room.track.Track 29 import io.livekit.android.room.track.Track
29 import io.livekit.android.room.track.TrackPublication 30 import io.livekit.android.room.track.TrackPublication
  31 +import io.livekit.android.room.types.TranscriptionSegment
30 import livekit.LivekitModels 32 import livekit.LivekitModels
31 33
32 sealed class RoomEvent(val room: Room) : Event() { 34 sealed class RoomEvent(val room: Room) : Event() {
@@ -219,6 +221,23 @@ sealed class RoomEvent(val room: Room) : Event() { @@ -219,6 +221,23 @@ sealed class RoomEvent(val room: Room) : Event() {
219 val participant: Participant, 221 val participant: Participant,
220 var state: E2EEState, 222 var state: E2EEState,
221 ) : RoomEvent(room) 223 ) : RoomEvent(room)
  224 +
  225 + @Beta
  226 + class TranscriptionReceived(
  227 + room: Room,
  228 + /**
  229 + * The transcription segments.
  230 + */
  231 + val transcriptionSegments: List<TranscriptionSegment>,
  232 + /**
  233 + * The applicable participant these transcriptions apply to.
  234 + */
  235 + val participant: Participant?,
  236 + /**
  237 + * The applicable track publication these transcriptions apply to.
  238 + */
  239 + val publication: TrackPublication?,
  240 + ) : RoomEvent(room)
222 } 241 }
223 242
224 enum class DisconnectReason { 243 enum class DisconnectReason {
@@ -758,6 +758,7 @@ internal constructor( @@ -758,6 +758,7 @@ internal constructor(
758 fun onFullReconnecting() 758 fun onFullReconnecting()
759 suspend fun onPostReconnect(isFullReconnect: Boolean) 759 suspend fun onPostReconnect(isFullReconnect: Boolean)
760 fun onLocalTrackUnpublished(trackUnpublished: LivekitRtc.TrackUnpublishedResponse) 760 fun onLocalTrackUnpublished(trackUnpublished: LivekitRtc.TrackUnpublishedResponse)
  761 + fun onTranscriptionReceived(transcription: LivekitModels.Transcription)
761 } 762 }
762 763
763 companion object { 764 companion object {
@@ -981,7 +982,7 @@ internal constructor( @@ -981,7 +982,7 @@ internal constructor(
981 } 982 }
982 983
983 LivekitModels.DataPacket.ValueCase.TRANSCRIPTION -> { 984 LivekitModels.DataPacket.ValueCase.TRANSCRIPTION -> {
984 - // TODO 985 + listener?.onTranscriptionReceived(dp.transcription)
985 } 986 }
986 987
987 LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET, 988 LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET,
@@ -42,6 +42,7 @@ import io.livekit.android.room.network.NetworkCallbackManagerFactory @@ -42,6 +42,7 @@ import io.livekit.android.room.network.NetworkCallbackManagerFactory
42 import io.livekit.android.room.participant.* 42 import io.livekit.android.room.participant.*
43 import io.livekit.android.room.provisions.LKObjects 43 import io.livekit.android.room.provisions.LKObjects
44 import io.livekit.android.room.track.* 44 import io.livekit.android.room.track.*
  45 +import io.livekit.android.room.types.toSDKType
45 import io.livekit.android.util.FlowObservable 46 import io.livekit.android.util.FlowObservable
46 import io.livekit.android.util.LKLog 47 import io.livekit.android.util.LKLog
47 import io.livekit.android.util.flow 48 import io.livekit.android.util.flow
@@ -1007,6 +1008,26 @@ constructor( @@ -1007,6 +1008,26 @@ constructor(
1007 /** 1008 /**
1008 * @suppress 1009 * @suppress
1009 */ 1010 */
  1011 + override fun onTranscriptionReceived(transcription: LivekitModels.Transcription) {
  1012 + val participant = getParticipantByIdentity(transcription.transcribedParticipantIdentity)
  1013 + val publication = participant?.trackPublications?.get(transcription.trackId)
  1014 + val segments = transcription.segmentsList
  1015 + .map { it.toSDKType() }
  1016 +
  1017 + val event = RoomEvent.TranscriptionReceived(
  1018 + room = this,
  1019 + transcriptionSegments = segments,
  1020 + participant = participant,
  1021 + publication = publication,
  1022 + )
  1023 + eventBus.tryPostEvent(event)
  1024 + participant?.onTranscriptionReceived(event)
  1025 + // TODO: Emit for publication
  1026 + }
  1027 +
  1028 + /**
  1029 + * @suppress
  1030 + */
1010 override fun onStreamStateUpdate(streamStates: List<LivekitRtc.StreamStateInfo>) { 1031 override fun onStreamStateUpdate(streamStates: List<LivekitRtc.StreamStateInfo>) {
1011 for (streamState in streamStates) { 1032 for (streamState in streamStates) {
1012 val participant = getParticipantBySid(streamState.participantSid) ?: continue 1033 val participant = getParticipantBySid(streamState.participantSid) ?: continue
@@ -20,6 +20,7 @@ import androidx.annotation.VisibleForTesting @@ -20,6 +20,7 @@ import androidx.annotation.VisibleForTesting
20 import io.livekit.android.dagger.InjectionNames 20 import io.livekit.android.dagger.InjectionNames
21 import io.livekit.android.events.BroadcastEventBus 21 import io.livekit.android.events.BroadcastEventBus
22 import io.livekit.android.events.ParticipantEvent 22 import io.livekit.android.events.ParticipantEvent
  23 +import io.livekit.android.events.RoomEvent
23 import io.livekit.android.events.TrackEvent 24 import io.livekit.android.events.TrackEvent
24 import io.livekit.android.room.track.LocalTrackPublication 25 import io.livekit.android.room.track.LocalTrackPublication
25 import io.livekit.android.room.track.RemoteTrackPublication 26 import io.livekit.android.room.track.RemoteTrackPublication
@@ -366,6 +367,20 @@ open class Participant( @@ -366,6 +367,20 @@ open class Participant(
366 ) 367 )
367 } 368 }
368 369
  370 + internal fun onTranscriptionReceived(transcription: RoomEvent.TranscriptionReceived) {
  371 + if (transcription.participant != this) {
  372 + return
  373 + }
  374 + eventBus.postEvent(
  375 + ParticipantEvent.TranscriptionReceived(
  376 + this,
  377 + transcriptions = transcription.transcriptionSegments,
  378 + publication = transcription.publication,
  379 + ),
  380 + scope,
  381 + )
  382 + }
  383 +
369 internal fun reinitialize() { 384 internal fun reinitialize() {
370 if (!scope.isActive) { 385 if (!scope.isActive) {
371 scope = createScope() 386 scope = createScope()
  1 +/*
  2 + * Copyright 2024 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.room.types
  18 +
  19 +import io.livekit.android.util.LKLog
  20 +import livekit.LivekitModels
  21 +
  22 +data class TranscriptionSegment(
  23 + val id: String,
  24 + val text: String,
  25 + val language: String,
  26 + val startTime: Long,
  27 + val endTime: Long,
  28 + val final: Boolean,
  29 +) {
  30 + override fun equals(other: Any?): Boolean {
  31 + if (this === other) return true
  32 + if (javaClass != other?.javaClass) return false
  33 +
  34 + other as TranscriptionSegment
  35 +
  36 + return id == other.id
  37 + }
  38 +
  39 + override fun hashCode(): Int {
  40 + return id.hashCode()
  41 + }
  42 +}
  43 +
  44 +/**
  45 + * Merges new segments into the map. The key should correspond to the segment id.
  46 + */
  47 +fun MutableMap<String, TranscriptionSegment>.mergeNewSegments(newSegments: Collection<TranscriptionSegment>) {
  48 + for (segment in newSegments) {
  49 + val existingSegment = get(segment.id)
  50 + if (existingSegment?.final == true) {
  51 + LKLog.d { "new segment for ${segment.id} overwriting final segment?" }
  52 + }
  53 + put(segment.id, segment)
  54 + }
  55 +}
  56 +
  57 +/**
  58 + * @suppress
  59 + */
  60 +fun LivekitModels.TranscriptionSegment.toSDKType() =
  61 + TranscriptionSegment(
  62 + id = id,
  63 + text = text,
  64 + language = language,
  65 + startTime = startTime,
  66 + endTime = endTime,
  67 + final = final,
  68 + )
@@ -31,7 +31,7 @@ android { @@ -31,7 +31,7 @@ android {
31 targetCompatibility java_version 31 targetCompatibility java_version
32 } 32 }
33 kotlinOptions { 33 kotlinOptions {
34 - freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn"] 34 + freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn", "-opt-in=io.livekit.android.annotations.Beta"]
35 jvmTarget = java_version 35 jvmTarget = java_version
36 } 36 }
37 testOptions { 37 testOptions {
@@ -302,4 +302,25 @@ object TestData { @@ -302,4 +302,25 @@ object TestData {
302 } 302 }
303 build() 303 build()
304 } 304 }
  305 +
  306 + // Data packets
  307 +
  308 + val DATA_PACKET_TRANSCRIPTION = with(LivekitModels.DataPacket.newBuilder()) {
  309 + transcription = with(LivekitModels.Transcription.newBuilder()) {
  310 + transcribedParticipantIdentity = JOIN.join.participant.identity // Local participant's identity
  311 + addSegments(
  312 + with(LivekitModels.TranscriptionSegment.newBuilder()) {
  313 + id = "id"
  314 + language = "enUS"
  315 + text = "This is a transcription."
  316 + startTime = 1
  317 + endTime = 10
  318 + final = true
  319 + build()
  320 + },
  321 + )
  322 + build()
  323 + }
  324 + build()
  325 + }
305 } 326 }
  1 +/*
  2 + * Copyright 2024 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.test.util
  18 +
  19 +import com.google.protobuf.MessageLite
  20 +import livekit.org.webrtc.DataChannel
  21 +import java.nio.ByteBuffer
  22 +
  23 +fun MessageLite.toDataChannelBuffer() =
  24 + DataChannel.Buffer(
  25 + ByteBuffer.wrap(toByteArray()),
  26 + true,
  27 + )
  1 +/*
  2 + * Copyright 2023-2024 LiveKit, Inc.
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +
  17 +package io.livekit.android.room
  18 +
  19 +import io.livekit.android.events.RoomEvent
  20 +import io.livekit.android.test.MockE2ETest
  21 +import io.livekit.android.test.assert.assertIsClass
  22 +import io.livekit.android.test.events.EventCollector
  23 +import io.livekit.android.test.mock.MockDataChannel
  24 +import io.livekit.android.test.mock.MockPeerConnection
  25 +import io.livekit.android.test.mock.TestData
  26 +import io.livekit.android.test.util.toDataChannelBuffer
  27 +import kotlinx.coroutines.ExperimentalCoroutinesApi
  28 +import org.junit.Assert.assertEquals
  29 +import org.junit.Test
  30 +
  31 +@OptIn(ExperimentalCoroutinesApi::class)
  32 +class RoomTranscriptionMockE2ETest : MockE2ETest() {
  33 + @Test
  34 + fun transcriptionReceived() = runTest {
  35 + connect()
  36 + val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection
  37 + val subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL)
  38 + subPeerConnection.observer?.onDataChannel(subDataChannel)
  39 +
  40 + val collector = EventCollector(room.events, coroutineRule.scope)
  41 + val dataBuffer = TestData.DATA_PACKET_TRANSCRIPTION.toDataChannelBuffer()
  42 +
  43 + subDataChannel.observer?.onMessage(dataBuffer)
  44 + val events = collector.stopCollecting()
  45 +
  46 + assertEquals(1, events.size)
  47 + assertIsClass(RoomEvent.TranscriptionReceived::class.java, events[0])
  48 +
  49 + val event = events.first() as RoomEvent.TranscriptionReceived
  50 + assertEquals(room, event.room)
  51 + assertEquals(room.localParticipant, event.participant)
  52 +
  53 + val expectedSegment = TestData.DATA_PACKET_TRANSCRIPTION.transcription.getSegments(0)
  54 + val receivedSegment = event.transcriptionSegments.first()
  55 + assertEquals(expectedSegment.id, receivedSegment.id)
  56 + assertEquals(expectedSegment.text, receivedSegment.text)
  57 + }
  58 +}