Skip to content

Commit

Permalink
[inspection] re-enable comparing unrelated types inspection for scala 3
Browse files Browse the repository at this point in the history
Adds support for '--language:strictEquality' compiler flag and 'CanEqual' when running inspection on scala 3 code
  • Loading branch information
disordered committed Mar 25, 2024
1 parent fb6b8f2 commit 2499114
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package org.jetbrains.plugins.scala.codeInspection.typeChecking

import com.intellij.codeInspection.{LocalInspectionTool, ProblemHighlightType, ProblemsHolder}
import com.intellij.codeInspection.{LocalInspectionTool, ProblemsHolder}
import com.intellij.psi.PsiMethod
import com.siyeh.ig.psiutils.MethodUtils
import org.jetbrains.annotations.Nls
import org.jetbrains.plugins.scala.codeInspection.collections.MethodRepr
import org.jetbrains.plugins.scala.codeInspection.typeChecking.ComparingUnrelatedTypesInspection._
import org.jetbrains.plugins.scala.codeInspection.{PsiElementVisitorSimple, ScalaInspectionBundle}
import org.jetbrains.plugins.scala.extensions._
import org.jetbrains.plugins.scala.lang.psi.api.base.types.ScParameterizedTypeElement
import org.jetbrains.plugins.scala.lang.psi.api.expr.{ScExpression, ScReferenceExpression}
import org.jetbrains.plugins.scala.lang.psi.api.statements.ScFunction
import org.jetbrains.plugins.scala.lang.psi.api.toplevel.typedef.ScClass
import org.jetbrains.plugins.scala.lang.psi.api.toplevel.typedef.{ScClass, ScGiven}
import org.jetbrains.plugins.scala.lang.psi.impl.toplevel.synthetic.ScSyntheticFunction
import org.jetbrains.plugins.scala.lang.psi.types._
import org.jetbrains.plugins.scala.lang.psi.types.api._
Expand Down Expand Up @@ -127,12 +128,36 @@ object ComparingUnrelatedTypesInspection {
}
}
}

private def hasCanEqual(expr: ScExpression, source: ScType, target: ScType): Boolean = {
lazy val expressionTypes: Seq[ScType] = List(source, target)
lazy val canEqualExists: Boolean = expr
.contexts
.flatMap(_.children)
.filterByType[ScGiven]
.filter(_.`type`().map(_.canonicalText.matches("_root_\\.scala\\.CanEqual\\[.+?, .+?]")).getOrElse(false))
.flatMap(_.children.filterByType[ScParameterizedTypeElement])
.map(_.typeArgList.typeArgs.flatMap(_.`type`().map(_.tryExtractDesignatorSingleton).toSeq))
.exists(_
.zip(expressionTypes)
.forall {
case (givenType, compType) =>
!checkComparability(givenType, compType, isBuiltinOperation = true).shouldNotBeCompared
}
)

val wideSource: ScType = source.widenIfLiteral
// Even though CanEqual[Primitive | String, _] can be defined and will satisfy compiler in strictEquals mode,
// it is not possible to override equals method on the primitives or Strings
!wideSource.isPrimitive &&
!wideSource.canonicalText.matches("_root_\\.java\\.lang\\.String") &&
(expr.isCompilerStrictEqualityMode || canEqualExists)
}
}

class ComparingUnrelatedTypesInspection extends LocalInspectionTool {

override def buildVisitor(holder: ProblemsHolder, isOnTheFly: Boolean): PsiElementVisitorSimple = {
case e if e.isInScala3File => () // TODO Handle Scala 3 code (`CanEqual` instances, etc.), SCL-19722
case MethodRepr(expr, Some(left), Some(oper), Seq(right)) if isComparingFunctions(oper.refName) =>
// "blub" == 3
val needHighlighting = oper.resolve() match {
Expand All @@ -145,7 +170,8 @@ class ComparingUnrelatedTypesInspection extends LocalInspectionTool {
case Seq(Right(leftType), Right(rightType)) =>
val isBuiltinOperation = isIdentityFunction(oper.refName) || !hasNonDefaultEquals(leftType)
val comparability = checkComparability(leftType, rightType, isBuiltinOperation)
if (comparability.shouldNotBeCompared) {
if ((!expr.isInScala3File && comparability.shouldNotBeCompared) ||
(expr.isInScala3File && comparability.shouldNotBeCompared && !hasCanEqual(expr, leftType, rightType))) {
val message = generateComparingUnrelatedTypesMsg(leftType, rightType)(expr)
holder.registerProblem(expr, message)
}
Expand All @@ -158,7 +184,8 @@ class ComparingUnrelatedTypesInspection extends LocalInspectionTool {
ParameterizedType(_, Seq(elemType)) <- receiverType(baseExpr, ref).map(_.tryExtractDesignatorSingleton)
argType <- arg.`type`().toOption
comparability = checkComparability(elemType, argType, isBuiltinOperation = !hasNonDefaultEquals(elemType))
if comparability.shouldNotBeCompared
if (!baseExpr.isInScala3File && comparability.shouldNotBeCompared) ||
(baseExpr.isInScala3File && comparability.shouldNotBeCompared && !hasCanEqual(baseExpr, elemType, argType))
} {
val message = generateComparingUnrelatedTypesMsg(elemType, argType)(arg)
holder.registerProblem(arg, message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.intellij.openapi.roots.{OrderEnumerator, OrderRootType, libraries}
import com.intellij.openapi.util.ModificationTracker
import com.intellij.openapi.util.io.JarUtil.{containsEntry, getJarAttribute}
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.util.CommonProcessors.{CollectProcessor, FindProcessor}
import com.intellij.util.CommonProcessors.FindProcessor
import org.jetbrains.plugins.scala.ScalaVersion
import org.jetbrains.plugins.scala.caches.cached
import org.jetbrains.plugins.scala.project.ScalaFeatures.SerializableScalaFeatures
Expand All @@ -18,7 +18,6 @@ import org.jetbrains.sbt.project.SbtVersionProvider

import java.io.File
import java.util.jar.Attributes
import scala.jdk.CollectionConverters.IteratorHasAsScala

private class ScalaModuleSettings private(
module: Module,
Expand Down Expand Up @@ -142,6 +141,9 @@ private class ScalaModuleSettings private(
val isCompilerStrictMode: Boolean =
settingsForHighlighting.exists(_.strict)

val isCompilerStrictEqualityMode: Boolean =
settingsForHighlighting.exists(_.strictEquality)

val customDefaultImports: Option[Seq[String]] =
additionalCompilerOptions.collectFirst {
case Yimports(imports) if scalaLanguageLevel >= Scala_2_13 => imports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ package object project {
def isCompilerStrictMode: Boolean =
scalaModuleSettings.exists(_.isCompilerStrictMode)

def isCompilerStrictEqualityMode: Boolean =
scalaModuleSettings.exists(_.isCompilerStrictEqualityMode)

def scalaCompilerClasspath: Seq[File] = module.scalaSdk
.fold(throw new ScalaSdkNotConfiguredException(module)) {
_.properties.compilerClasspath
Expand Down Expand Up @@ -537,6 +540,8 @@ package object project {

def isCompilerStrictMode: Boolean = module.exists(_.isCompilerStrictMode)

def isCompilerStrictEqualityMode: Boolean = isInScala3Module && module.exists(_.isCompilerStrictEqualityMode)

def scalaLanguageLevel: Option[ScalaLanguageLevel] =
fromFeaturesOrModule(_.languageLevel, _.scalaLanguageLevel)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ case class ScalaCompilerSettings(compileOrder: CompileOrder,
val languageWildcard: Boolean = additionalCompilerOptions.contains("-language:_") ||
additionalCompilerOptions.contains("--language:_")
val strict: Boolean = additionalCompilerOptions.contains("-strict")
val strictEquality: Boolean = additionalCompilerOptions.contains("-language:strictEquality") ||
additionalCompilerOptions.contains("--language:strictEquality")

def getOptionsAsStrings(forScala3Compiler: Boolean): Seq[String] = {
val state = toState
Expand Down
Loading

0 comments on commit 2499114

Please sign in to comment.