David Liu

Fix up media track equals detector

@@ -4,6 +4,8 @@ @@ -4,6 +4,8 @@
4 <bytecodeTargetLevel target="1.8"> 4 <bytecodeTargetLevel target="1.8">
5 <module name="livekit-android.livekit-android-sdk" target="11" /> 5 <module name="livekit-android.livekit-android-sdk" target="11" />
6 <module name="livekit-android.sample-app" target="11" /> 6 <module name="livekit-android.sample-app" target="11" />
  7 + <module name="livekit-android.sample-app-common" target="11" />
  8 + <module name="livekit-android.sample-app-compose" target="11" />
7 </bytecodeTargetLevel> 9 </bytecodeTargetLevel>
8 </component> 10 </component>
9 </project> 11 </project>
@@ -53,6 +53,7 @@ ext { @@ -53,6 +53,7 @@ ext {
53 versions = [ 53 versions = [
54 androidx_core : "1.6.0", 54 androidx_core : "1.6.0",
55 androidx_lifecycle: "2.4.0", 55 androidx_lifecycle: "2.4.0",
  56 + autoService : '1.0.1',
56 dagger : "2.27", 57 dagger : "2.27",
57 groupie : "2.9.0", 58 groupie : "2.9.0",
58 junit : "4.13.2", 59 junit : "4.13.2",
@@ -64,9 +65,12 @@ ext { @@ -64,9 +65,12 @@ ext {
64 protoSrc: "$projectDir/protocol", 65 protoSrc: "$projectDir/protocol",
65 ] 66 ]
66 deps = [ 67 deps = [
  68 + auto : [
  69 + 'service' : "com.google.auto.service:auto-service:${versions.autoService}",
  70 + 'serviceAnnotations': "com.google.auto.service:auto-service-annotations:${versions.autoService}",
  71 + ],
67 kotlinx_coroutines: "org.jetbrains.kotlinx:kotlinx-coroutines-android:1.5.2", 72 kotlinx_coroutines: "org.jetbrains.kotlinx:kotlinx-coroutines-android:1.5.2",
68 timber : "com.github.ajalt:timberkt:1.5.1", 73 timber : "com.github.ajalt:timberkt:1.5.1",
69 -  
70 // lint 74 // lint
71 lint : "com.android.tools.lint:lint:${versions.lint}", 75 lint : "com.android.tools.lint:lint:${versions.lint}",
72 lintApi : "com.android.tools.lint:lint-api:${versions.lint}", 76 lintApi : "com.android.tools.lint:lint-api:${versions.lint}",
@@ -118,6 +118,9 @@ dependencies { @@ -118,6 +118,9 @@ dependencies {
118 implementation deps.timber 118 implementation deps.timber
119 implementation 'com.vdurmont:semver4j:3.1.0' 119 implementation 'com.vdurmont:semver4j:3.1.0'
120 120
  121 + lintChecks project(':livekit-lint')
  122 + lintPublish project(':livekit-lint')
  123 +
121 testImplementation 'junit:junit:4.13.2' 124 testImplementation 'junit:junit:4.13.2'
122 testImplementation 'org.robolectric:robolectric:4.6' 125 testImplementation 'org.robolectric:robolectric:4.6'
123 testImplementation 'org.mockito:mockito-core:4.0.0' 126 testImplementation 'org.mockito:mockito-core:4.0.0'
1 plugins { 1 plugins {
2 id 'java-library' 2 id 'java-library'
3 id 'kotlin' 3 id 'kotlin'
  4 + id 'kotlin-kapt'
4 } 5 }
5 6
6 java { 7 java {
@@ -9,15 +10,22 @@ java { @@ -9,15 +10,22 @@ java {
9 } 10 }
10 11
11 dependencies { 12 dependencies {
  13 +
  14 + compileOnly "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version"
12 // used for lint rules 15 // used for lint rules
13 compileOnly deps.lintApi 16 compileOnly deps.lintApi
14 compileOnly deps.lintChecks 17 compileOnly deps.lintChecks
15 compileOnly deps.lintTests 18 compileOnly deps.lintTests
16 19
  20 + // Handle creating manifests for lint checker
  21 + compileOnly deps.auto.serviceAnnotations
  22 + kapt deps.auto.service
  23 +
17 // test lint 24 // test lint
18 testImplementation deps.lint 25 testImplementation deps.lint
19 testImplementation deps.lintTests 26 testImplementation deps.lintTests
20 27
  28 + compileOnly "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version"
21 // test runners 29 // test runners
22 testImplementation deps.junit 30 testImplementation deps.junit
23 testImplementation deps.junitJupiterApi 31 testImplementation deps.junitJupiterApi
@@ -3,7 +3,10 @@ package io.livekit.lint @@ -3,7 +3,10 @@ package io.livekit.lint
3 import com.android.tools.lint.client.api.IssueRegistry 3 import com.android.tools.lint.client.api.IssueRegistry
4 import com.android.tools.lint.detector.api.CURRENT_API 4 import com.android.tools.lint.detector.api.CURRENT_API
5 import com.android.tools.lint.detector.api.Issue 5 import com.android.tools.lint.detector.api.Issue
  6 +import com.google.auto.service.AutoService
6 7
  8 +@Suppress("UnstableApiUsage", "unused")
  9 +@AutoService(value = [IssueRegistry::class])
7 class IssueRegistry : IssueRegistry() { 10 class IssueRegistry : IssueRegistry() {
8 11
9 override val api: Int = CURRENT_API 12 override val api: Int = CURRENT_API
@@ -20,14 +20,14 @@ package io.livekit.lint @@ -20,14 +20,14 @@ package io.livekit.lint
20 20
21 import com.android.tools.lint.client.api.UElementHandler 21 import com.android.tools.lint.client.api.UElementHandler
22 import com.android.tools.lint.detector.api.* 22 import com.android.tools.lint.detector.api.*
23 -import com.intellij.psi.CommonClassNames.JAVA_LANG_OBJECT  
24 import com.intellij.psi.PsiClassType 23 import com.intellij.psi.PsiClassType
25 -import com.intellij.psi.PsiElement  
26 -import com.intellij.psi.PsiMethod  
27 -import org.jetbrains.uast.*  
28 -import org.jetbrains.uast.util.isMethodCall 24 +import org.jetbrains.uast.UBinaryExpression
  25 +import org.jetbrains.uast.UCallExpression
  26 +import org.jetbrains.uast.UastBinaryOperator
29 27
30 -/** Checks related to DiffUtil computation. */ 28 +/**
  29 + * Detects MediaStreamTrack.equals() usage. This is generally a mistake and should not be used.
  30 + */
31 class MediaTrackEqualsDetector : Detector(), SourceCodeScanner { 31 class MediaTrackEqualsDetector : Detector(), SourceCodeScanner {
32 32
33 override fun getApplicableUastTypes() = 33 override fun getApplicableUastTypes() =
@@ -46,138 +46,21 @@ class MediaTrackEqualsDetector : Detector(), SourceCodeScanner { @@ -46,138 +46,21 @@ class MediaTrackEqualsDetector : Detector(), SourceCodeScanner {
46 } 46 }
47 } 47 }
48 48
49 - private fun defaultEquals(context: JavaContext, node: UElement): Boolean {  
50 - val resolved: PsiMethod?  
51 -  
52 - when (node) {  
53 - is UBinaryExpression -> {  
54 - resolved = node.resolveOperator()  
55 - if (resolved == null) {  
56 - val left = node.leftOperand.getExpressionType() as? PsiClassType  
57 - return defaultEquals(context, left)  
58 - }  
59 - }  
60 - is UCallExpression -> {  
61 - resolved = node.takeIf { it.isMethodCall() }  
62 - ?.resolve()  
63 - }  
64 - is UParenthesizedExpression -> {  
65 - return defaultEquals(context, node.expression)  
66 - }  
67 - else -> {  
68 - // We don't know any better  
69 - return false  
70 - }  
71 - }  
72 -  
73 - return resolved?.containingClass?.qualifiedName == MEDIA_STREAM_TRACK  
74 - }  
75 -  
76 - private fun defaultEquals(  
77 - context: JavaContext,  
78 - type: PsiClassType?  
79 - ): Boolean {  
80 - val cls = type?.resolve() ?: return false  
81 -  
82 - if (isKotlin(cls) && (context.evaluator.isSealed(cls) || context.evaluator.isData(cls))) {  
83 - // Sealed class doesn't guarantee that it defines equals/hashCode  
84 - // but it's likely (we'd need to go look at each inner class)  
85 - return false  
86 - }  
87 -  
88 - for (m in cls.findMethodsByName("equals", true)) {  
89 - if (m is PsiMethod) {  
90 - val parameters = m.parameterList.parameters  
91 - if (parameters.size == 1 &&  
92 - parameters[0].type.canonicalText == MEDIA_STREAM_TRACK  
93 - ) {  
94 - return m.containingClass?.qualifiedName == MEDIA_STREAM_TRACK  
95 - }  
96 - }  
97 - }  
98 -  
99 - return false  
100 - }  
101 -  
102 private fun checkCall(context: JavaContext, node: UCallExpression) { 49 private fun checkCall(context: JavaContext, node: UCallExpression) {
103 - if (defaultEquals(context, node)) {  
104 - // Within cast or instanceof check which implies a more specific type  
105 - // which provides an equals implementation?  
106 - if (withinCastWithEquals(context, node)) {  
107 - return  
108 - }  
109 - 50 + if (node.methodName == "equals") {
  51 + val left = node.receiverType ?: return
  52 + val right = node.valueArguments.takeIf { it.isNotEmpty() }
  53 + ?.get(0)
  54 + ?.getExpressionType()
  55 + ?: return
  56 + if (left is PsiClassType && right is PsiClassType
  57 + && (left.canonicalText == MEDIA_STREAM_TRACK || right.canonicalText == MEDIA_STREAM_TRACK)
  58 + ) {
110 val message = DEFAULT_MSG 59 val message = DEFAULT_MSG
111 - val location = context.getCallLocation(  
112 - node,  
113 - includeReceiver = false,  
114 - includeArguments = true  
115 - ) 60 + val location = context.getLocation(node)
116 context.report(ISSUE, node, location, message) 61 context.report(ISSUE, node, location, message)
117 } 62 }
118 } 63 }
119 -  
120 - /**  
121 - * Is this .equals() call within another if check which checks  
122 - * instanceof on a more specific type than we're calling equals on?  
123 - * If so, does that more specific type define its own equals?  
124 - *  
125 - * Also handle an implicit check via short circuit evaluation; e.g.  
126 - * something like "return a is A && b is B && a.equals(b)".  
127 - */  
128 - private fun withinCastWithEquals(context: JavaContext, node: UExpression): Boolean {  
129 - var parent = skipParenthesizedExprUp(node.uastParent)  
130 - if (parent is UQualifiedReferenceExpression) {  
131 - parent = skipParenthesizedExprUp(parent.uastParent)  
132 - }  
133 - val target: PsiElement? = when (node) {  
134 - is UCallExpression -> node.receiver?.tryResolve()  
135 - is UBinaryExpression -> node.leftOperand.tryResolve()  
136 - else -> null  
137 - }  
138 -  
139 - if (parent is UPolyadicExpression && parent.operator == UastBinaryOperator.LOGICAL_AND) {  
140 - val operands = parent.operands  
141 - for (operand in operands) {  
142 - if (operand === node) {  
143 - break  
144 - }  
145 - if (isCastWithEquals(context, operand, target)) {  
146 - return true  
147 - }  
148 - }  
149 - }  
150 - val ifStatement = node.getParentOfType<UElement>(UIfExpression::class.java, false, UMethod::class.java)  
151 - as? UIfExpression ?: return false  
152 - val condition = ifStatement.condition  
153 - return isCastWithEquals(context, condition, target)  
154 - }  
155 -  
156 - private fun isCastWithEquals(context: JavaContext, node: UExpression, target: PsiElement?): Boolean {  
157 - when {  
158 - node is UBinaryExpressionWithType -> {  
159 - if (target != null) {  
160 - val resolved = node.operand.tryResolve()  
161 - // Unfortunately in some scenarios isEquivalentTo returns false for equal instances  
162 - //noinspection LintImplPsiEquals  
163 - if (resolved != null && !(target == resolved || target.isEquivalentTo(resolved))) {  
164 - return false  
165 - }  
166 - }  
167 - return !defaultEquals(context, node.type as? PsiClassType)  
168 - }  
169 - node is UPolyadicExpression && node.operator == UastBinaryOperator.LOGICAL_AND -> {  
170 - for (operand in node.operands) {  
171 - if (isCastWithEquals(context, operand, target)) {  
172 - return true  
173 - }  
174 - }  
175 - }  
176 - node is UParenthesizedExpression -> {  
177 - return isCastWithEquals(context, node.expression, target)  
178 - }  
179 - }  
180 - return false  
181 } 64 }
182 65
183 private fun checkExpression(context: JavaContext, node: UBinaryExpression) { 66 private fun checkExpression(context: JavaContext, node: UBinaryExpression) {
@@ -187,28 +70,14 @@ class MediaTrackEqualsDetector : Detector(), SourceCodeScanner { @@ -187,28 +70,14 @@ class MediaTrackEqualsDetector : Detector(), SourceCodeScanner {
187 val left = node.leftOperand.getExpressionType() ?: return 70 val left = node.leftOperand.getExpressionType() ?: return
188 val right = node.rightOperand.getExpressionType() ?: return 71 val right = node.rightOperand.getExpressionType() ?: return
189 if (left is PsiClassType && right is PsiClassType 72 if (left is PsiClassType && right is PsiClassType
190 - && (left.className == "MediaStreamTrack" || right.className == "MediaStreamTrack") 73 + && (left.canonicalText == MEDIA_STREAM_TRACK || right.canonicalText == MEDIA_STREAM_TRACK)
191 ) { 74 ) {
192 - if (node.operator == UastBinaryOperator.EQUALS) {  
193 - if (defaultEquals(context, node)) {  
194 - if (withinCastWithEquals(context, node)) {  
195 - return  
196 - }  
197 -  
198 val message = DEFAULT_MSG 75 val message = DEFAULT_MSG
199 val location = node.operatorIdentifier?.let { 76 val location = node.operatorIdentifier?.let {
200 context.getLocation(it) 77 context.getLocation(it)
201 } ?: context.getLocation(node) 78 } ?: context.getLocation(node)
202 context.report(ISSUE, node, location, message) 79 context.report(ISSUE, node, location, message)
203 } 80 }
204 - } else {  
205 - val message = DEFAULT_MSG  
206 - val location = node.operatorIdentifier?.let {  
207 - context.getLocation(it)  
208 - } ?: context.getLocation(node)  
209 - context.report(ISSUE, node, location, message)  
210 - }  
211 - }  
212 } 81 }
213 } 82 }
214 83
@@ -224,17 +93,16 @@ class MediaTrackEqualsDetector : Detector(), SourceCodeScanner { @@ -224,17 +93,16 @@ class MediaTrackEqualsDetector : Detector(), SourceCodeScanner {
224 @JvmField 93 @JvmField
225 val ISSUE = Issue.create( 94 val ISSUE = Issue.create(
226 id = "MediaTrackEqualsDetector", 95 id = "MediaTrackEqualsDetector",
227 - briefDescription = "Suspicious DiffUtil Equality", 96 + briefDescription = "Suspicious MediaStreamTrack Equality",
228 explanation = """ 97 explanation = """
229 - `areContentsTheSame` is used by `DiffUtil` to produce diffs. If the \  
230 - method is implemented incorrectly, such as using identity equals \  
231 - instead of equals, or calling equals on a class that has not implemented \  
232 - it, weird visual artifacts can occur. 98 + MediaStreamTrack does not implement `equals`, and therefore cannot be relied upon.
  99 + Additionally, many MediaStreamTrack objects may exist for the same underlying stream,
  100 + and therefore the identity operator `===` is unreliable.
233 """, 101 """,
234 category = Category.CORRECTNESS, 102 category = Category.CORRECTNESS,
235 priority = 4, 103 priority = 4,
236 androidSpecific = true, 104 androidSpecific = true,
237 - moreInfo = "https://issuetracker.google.com/116789824", 105 + moreInfo = "https://github.com/livekit/client-sdk-android/commit/01152f2ac01dae59759383d587cdc21035718b8e",
238 severity = Severity.ERROR, 106 severity = Severity.ERROR,
239 implementation = IMPLEMENTATION 107 implementation = IMPLEMENTATION
240 ) 108 )
@@ -4,7 +4,7 @@ package io.livekit.lint @@ -4,7 +4,7 @@ package io.livekit.lint
4 4
5 import com.android.tools.lint.checks.infrastructure.TestFile 5 import com.android.tools.lint.checks.infrastructure.TestFile
6 import com.android.tools.lint.checks.infrastructure.TestFiles.java 6 import com.android.tools.lint.checks.infrastructure.TestFiles.java
7 -import com.android.tools.lint.checks.infrastructure.TestLintTask 7 +import com.android.tools.lint.checks.infrastructure.TestFiles.kotlin
8 import com.android.tools.lint.checks.infrastructure.TestLintTask.lint 8 import com.android.tools.lint.checks.infrastructure.TestLintTask.lint
9 import org.junit.Test 9 import org.junit.Test
10 10
@@ -131,25 +131,85 @@ class MediaTrackEqualsDetectorTest { @@ -131,25 +131,85 @@ class MediaTrackEqualsDetectorTest {
131 .run() 131 .run()
132 .expectClean() 132 .expectClean()
133 } 133 }
134 -}  
135 134
136 -fun TestLintTask.mediaStreamTrack(): TestLintTask {  
137 - return this.files(  
138 - java( 135 + @Test
  136 + fun kotlinMediaTrackEqualityOperator() {
  137 + lint()
  138 + .allowMissingSdk()
  139 + .files(
  140 + mediaStreamTrack(),
  141 + kotlin(
139 """ 142 """
140 - package org.webrtc; 143 + package foo
  144 + import org.webrtc.MediaStreamTrack
141 145
142 - class MediaStreamTrack {  
143 - int getId(){  
144 - return 0; 146 + class Example {
  147 + fun foo() : Boolean {
  148 + val a = MediaStreamTrack()
  149 + val b = MediaStreamTrack()
  150 + return a === b;
  151 + }
  152 + }"""
  153 + ).indented()
  154 + )
  155 + .issues(MediaTrackEqualsDetector.ISSUE)
  156 + .run()
  157 + .expectErrorCount(1)
  158 + }
  159 +
  160 + @Test
  161 + fun kotlinMediaTrackIdentityEqualityOperator() {
  162 + lint()
  163 + .allowMissingSdk()
  164 + .files(
  165 + mediaStreamTrack(),
  166 + kotlin(
  167 + """
  168 + package foo
  169 + import org.webrtc.MediaStreamTrack
  170 +
  171 + class Example {
  172 + fun foo() : Boolean {
  173 + val a = MediaStreamTrack()
  174 + val b = MediaStreamTrack()
  175 + return a == b
145 } 176 }
  177 + }"""
  178 + ).indented()
  179 + )
  180 + .issues(MediaTrackEqualsDetector.ISSUE)
  181 + .run()
  182 + .expectErrorCount(1)
146 } 183 }
  184 +
  185 +
  186 + @Test
  187 + fun kotlinProperMediaTrackEquality() {
  188 + lint()
  189 + .allowMissingSdk()
  190 + .files(
  191 + mediaStreamTrack(),
  192 + kotlin(
147 """ 193 """
  194 + package foo
  195 + import org.webrtc.MediaStreamTrack
  196 +
  197 + class Example {
  198 + fun foo() : Boolean {
  199 + val a = MediaStreamTrack()
  200 + val b = MediaStreamTrack()
  201 + return a.id() == b.id()
  202 + }
  203 + }"""
148 ).indented() 204 ).indented()
149 ) 205 )
  206 + .issues(MediaTrackEqualsDetector.ISSUE)
  207 + .run()
  208 + .expectClean()
  209 + }
150 } 210 }
151 211
152 -fun Any.mediaStreamTrack(): TestFile { 212 +fun mediaStreamTrack(): TestFile {
153 return java( 213 return java(
154 """ 214 """
155 package org.webrtc; 215 package org.webrtc;
1 -package org.webrtc;  
2 -  
3 -class MediaStreamTrack {  
4 - int getId() {  
5 - return 0;  
6 - }  
7 -}