PY-43388 Add basic type checker support in Cython files
[idea/community.git] / python / python-psi-impl / src / com / jetbrains / python / psi / types / PyTypeChecker.java
index 1315c292a1a31ad5f565c0556c92e2abdb54079e..361de975e3bfd10a7f6b8dd718dc0c4c49032f7b 100644 (file)
@@ -2,6 +2,7 @@
 package com.jetbrains.python.psi.types;
 
 import com.intellij.openapi.util.Pair;
+import com.intellij.openapi.util.RecursionManager;
 import com.intellij.psi.PsiElement;
 import com.intellij.psi.PsiFile;
 import com.intellij.psi.ResolveResult;
@@ -70,14 +71,13 @@ public final class PyTypeChecker {
 
   @NotNull
   private static Optional<Boolean> match(@Nullable PyType expected, @Nullable PyType actual, @NotNull MatchContext context) {
-    final Pair<PyType, PyType> types = Pair.create(expected, actual);
-    if (context.matching.contains(types)) return Optional.of(true);
+    final Optional<Boolean> result = RecursionManager.doPreventingRecursion(
+      Pair.create(expected, actual),
+      false,
+      () -> matchImpl(expected, actual, context)
+    );
 
-    context.matching.add(types);
-    final Optional<Boolean> result = matchImpl(expected, actual, context);
-    context.matching.remove(types);
-
-    return result;
+    return result == null ? Optional.of(true) : result;
   }
 
   /**
@@ -92,6 +92,13 @@ public final class PyTypeChecker {
    */
   @NotNull
   private static Optional<Boolean> matchImpl(@Nullable PyType expected, @Nullable PyType actual, @NotNull MatchContext context) {
+    for (PyTypeCheckerExtension extension : PyTypeCheckerExtension.EP_NAME.getExtensionList()) {
+      final Optional<Boolean> result = extension.match(expected, actual, context.context, context.substitutions);
+      if (result.isPresent()) {
+        return result;
+      }
+    }
+
     if (expected instanceof PyClassType) {
       Optional<Boolean> match = matchObject((PyClassType)expected, actual);
       if (match.isPresent()) {
@@ -211,16 +218,13 @@ public final class PyTypeChecker {
       if (expected.equals(actual) || substitution.equals(expected)) {
         return true;
       }
-      if (context.typeVarsInMatching.add(expected)) {
-        Optional<Boolean> recursiveMatch = context.reversedSubstitutions
-                                           ? match(actual, substitution, context)
-                                           : match(substitution, actual, context);
-        context.typeVarsInMatching.remove(expected);
-        if (recursiveMatch.isPresent()) {
-          return recursiveMatch.get();
-        }
-      }
-      return false;
+
+      Optional<Boolean> recursiveMatch = RecursionManager.doPreventingRecursion(
+        expected, false, context.reversedSubstitutions
+                         ? () -> match(actual, substitution, context)
+                         : () -> match(substitution, actual, context)
+      );
+      return recursiveMatch != null ? recursiveMatch.orElse(false) : false;
     }
 
     if (actual != null) {
@@ -1025,34 +1029,24 @@ public final class PyTypeChecker {
     @NotNull
     private final Map<PyGenericType, PyType> substitutions; // mutable
 
-    @NotNull
-    private final Set<PyGenericType> typeVarsInMatching; // mutable
-
-    @NotNull
-    private final Set<Pair<PyType, PyType>> matching; // mutable
-
     private final boolean reversedSubstitutions;
 
     MatchContext(@NotNull TypeEvalContext context,
                  @NotNull Map<PyGenericType, PyType> substitutions) {
-      this(context, substitutions, new HashSet<>(), new HashSet<>(), false);
+      this(context, substitutions, false);
     }
 
     private MatchContext(@NotNull TypeEvalContext context,
                          @NotNull Map<PyGenericType, PyType> substitutions,
-                         @NotNull Set<PyGenericType> typeVarsInMatching,
-                         @NotNull Set<Pair<PyType, PyType>> matching,
                          boolean reversedSubstitutions) {
       this.context = context;
       this.substitutions = substitutions;
-      this.typeVarsInMatching = typeVarsInMatching;
-      this.matching = matching;
       this.reversedSubstitutions = reversedSubstitutions;
     }
 
     @NotNull
     public MatchContext reverseSubstitutions() {
-      return new MatchContext(context, substitutions, typeVarsInMatching, matching, !reversedSubstitutions);
+      return new MatchContext(context, substitutions, !reversedSubstitutions);
     }
   }
 }