Make PyTypeProvider.getCallType returns Ref<PyType> instead of PyType. Attempt to...
authorSemyon Proshev <Semyon.Proshev@jetbrains.com>
Mon, 1 Aug 2016 15:40:38 +0000 (18:40 +0300)
committerSemyon Proshev <Semyon.Proshev@jetbrains.com>
Wed, 10 Aug 2016 14:00:20 +0000 (17:00 +0300)
python/psi-api/src/com/jetbrains/python/psi/impl/PyTypeProvider.java
python/psi-api/src/com/jetbrains/python/psi/types/PyTypeProviderBase.java
python/src/com/jetbrains/numpy/codeInsight/NumpyDocStringTypeProvider.java
python/src/com/jetbrains/python/codeInsight/PyTypingTypeProvider.java
python/src/com/jetbrains/python/codeInsight/stdlib/PyStdlibTypeProvider.java
python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java
python/src/com/jetbrains/python/pyi/PyiTypeProvider.java

index 3a74030dcb17b7ba282fa20a97ef17826b3194b9..337fb6e23555299bb5c92b92c998e9ff83135fff 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2000-2014 JetBrains s.r.o.
+ * Copyright 2000-2016 JetBrains s.r.o.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -43,7 +43,7 @@ public interface PyTypeProvider {
   Ref<PyType> getReturnType(@NotNull PyCallable callable, @NotNull TypeEvalContext context);
 
   @Nullable
-  PyType getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context);
+  Ref<PyType> getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context);
 
   @Nullable
   PyType getContextManagerVariableType(PyClass contextManager, PyExpression withExpression, TypeEvalContext context);
index 07789a71c13db1fd70c31c022602ee34307dbf10..ce4bc351c8325d077ed6bf517e0eef17142506e8 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2000-2014 JetBrains s.r.o.
+ * Copyright 2000-2016 JetBrains s.r.o.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@ package com.jetbrains.python.psi.types;
 
 import com.intellij.openapi.util.Ref;
 import com.intellij.psi.PsiElement;
+import com.intellij.util.ObjectUtils;
 import com.intellij.util.containers.FactoryMap;
 import com.jetbrains.python.psi.*;
 import com.jetbrains.python.psi.impl.PyTypeProvider;
@@ -26,52 +27,18 @@ import org.jetbrains.annotations.Nullable;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Optional;
 
 /**
  * @author yole
  */
 public class PyTypeProviderBase implements PyTypeProvider {
-  public PyTypeProviderBase() {
-  }
-
-  protected interface ReturnTypeCallback {
-    @Nullable
-    PyType getType(@Nullable PyCallSiteExpression callSite, @Nullable PyType qualifierType, TypeEvalContext context);
-  }
-
-  private static class ReturnTypeDescriptor {
-    private final Map<String, ReturnTypeCallback> myStringToReturnTypeMap = new HashMap<>();
-
-    void put(String className, ReturnTypeCallback callback) {
-      myStringToReturnTypeMap.put(className, callback);
-    }
 
-    @Nullable
-    public PyType get(PyFunction function, @Nullable PyCallSiteExpression callSite, TypeEvalContext context) {
-      PyClass containingClass = function.getContainingClass();
-      if (containingClass != null) {
-        final ReturnTypeCallback typeCallback = myStringToReturnTypeMap.get(containingClass.getQualifiedName());
-        if (typeCallback != null) {
-          final PyExpression callee = callSite instanceof PyCallExpression ? ((PyCallExpression)callSite).getCallee() : null;
-          final PyExpression qualifier = callee instanceof PyQualifiedExpression ? ((PyQualifiedExpression)callee).getQualifier() : null;
-          PyType qualifierType = qualifier != null ? context.getType(qualifier) : null;
-          return typeCallback.getType(callSite, qualifierType, context);
-        }
-      }
-      return null;
-    }
-  }
-
-  private final ReturnTypeCallback mySelfTypeCallback = new ReturnTypeCallback() {
-    @Override
-    public PyType getType(@Nullable PyCallSiteExpression callSite, @Nullable PyType qualifierType, TypeEvalContext context) {
-      if (qualifierType instanceof PyClassType) {
-        PyClass aClass = ((PyClassType)qualifierType).getPyClass();
-        return PyPsiFacade.getInstance(aClass.getProject()).createClassType(aClass, false);
-      }
-      return null;
-    }
-  };
+  private final ReturnTypeCallback mySelfTypeCallback = (callSite, qualifierType, context) -> Optional
+    .ofNullable(ObjectUtils.tryCast(qualifierType, PyClassType.class))
+    .map(PyClassType::getPyClass)
+    .map(pyClass -> PyPsiFacade.getInstance(pyClass.getProject()).createClassType(pyClass, false))
+    .orElse(null);
 
   @SuppressWarnings({"MismatchedQueryAndUpdateOfCollection"})
   private final Map<String, ReturnTypeDescriptor> myMethodToReturnTypeMap = new FactoryMap<String, ReturnTypeDescriptor>() {
@@ -105,8 +72,8 @@ public class PyTypeProviderBase implements PyTypeProvider {
 
   @Nullable
   @Override
-  public PyType getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
-    ReturnTypeDescriptor descriptor;
+  public Ref<PyType> getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
+    final ReturnTypeDescriptor descriptor;
     synchronized (myMethodToReturnTypeMap) {
       descriptor = myMethodToReturnTypeMap.get(function.getName());
     }
@@ -128,17 +95,50 @@ public class PyTypeProviderBase implements PyTypeProvider {
     return null;
   }
 
-  protected void registerSelfReturnType(String classQualifiedName, Collection<String> methods) {
+  protected void registerSelfReturnType(@NotNull String classQualifiedName, @NotNull Collection<String> methods) {
     registerReturnType(classQualifiedName, methods, mySelfTypeCallback);
   }
 
-  protected void registerReturnType(String classQualifiedName,
-                                    Collection<String> methods,
-                                    final ReturnTypeCallback callback) {
+  protected void registerReturnType(@NotNull String classQualifiedName,
+                                    @NotNull Collection<String> methods,
+                                    @NotNull ReturnTypeCallback callback) {
     synchronized (myMethodToReturnTypeMap) {
       for (String method : methods) {
         myMethodToReturnTypeMap.get(method).put(classQualifiedName, callback);
       }
     }
   }
+
+  protected interface ReturnTypeCallback {
+
+    @Nullable
+    PyType getType(@Nullable PyCallSiteExpression callSite, @Nullable PyType qualifierType, @NotNull TypeEvalContext context);
+  }
+
+  private static class ReturnTypeDescriptor {
+
+    private final Map<String, ReturnTypeCallback> myStringToReturnTypeMap = new HashMap<>();
+
+    public void put(@NotNull String classQualifiedName, @NotNull ReturnTypeCallback callback) {
+      myStringToReturnTypeMap.put(classQualifiedName, callback);
+    }
+
+    @Nullable
+    public Ref<PyType> get(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
+      return Optional
+        .ofNullable(function.getContainingClass())
+        .map(pyClass -> myStringToReturnTypeMap.get(pyClass.getQualifiedName()))
+        .map(typeCallback -> typeCallback.getType(callSite, getQualifierType(callSite, context), context))
+        .map(Ref::create)
+        .orElse(null);
+    }
+
+    @Nullable
+    private static PyType getQualifierType(@Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
+      final PyExpression callee = callSite instanceof PyCallExpression ? ((PyCallExpression)callSite).getCallee() : null;
+      final PyExpression qualifier = callee instanceof PyQualifiedExpression ? ((PyQualifiedExpression)callee).getQualifier() : null;
+
+      return qualifier != null ? context.getType(qualifier) : null;
+    }
+  }
 }
index 7af729d7d3778a997ee287fdd67877de5bec0cd8..25190e48cdb2102f9801b891fd3034f93f7a9282 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2000-2014 JetBrains s.r.o.
+ * Copyright 2000-2016 JetBrains s.r.o.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -205,34 +205,34 @@ public class NumpyDocStringTypeProvider extends PyTypeProviderBase {
 
   @Nullable
   @Override
-  public PyType getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
+  public Ref<PyType> getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
     if (isApplicable(function)) {
       final PyExpression callee = callSite instanceof PyCallExpression ? ((PyCallExpression)callSite).getCallee() : null;
       final NumpyDocString docString = forFunction(function, callee);
       if (docString != null) {
         final List<SectionField> returns = docString.getReturnFields();
         final PyPsiFacade facade = getPsiFacade(function);
+
         switch (returns.size()) {
           case 0:
             return null;
           case 1:
             // Function returns single value
-            final String typeName = returns.get(0).getType();
-            if (StringUtil.isNotEmpty(typeName)) {
-              final PyType genericType = getPsiFacade(function).parseTypeAnnotation("T", function);
-              if (isUfuncType(function, typeName)) return genericType;
-              return parseNumpyDocType(function, typeName);
-            }
-            return null;
+            return Optional
+              .ofNullable(returns.get(0).getType())
+              .filter(StringUtil::isNotEmpty)
+              .map(typeName -> isUfuncType(function, typeName)
+                               ? facade.parseTypeAnnotation("T", function)
+                               : parseNumpyDocType(function, typeName))
+              .map(Ref::create)
+              .orElse(null);
           default:
             // Function returns a tuple
-            final ArrayList<PyType> unionMembers = new ArrayList<>();
-
+            final List<PyType> unionMembers = new ArrayList<>();
             final List<PyType> members = new ArrayList<>();
 
             for (int i = 0; i < returns.size(); i++) {
-              SectionField ret = returns.get(i);
-              final String memberTypeName = ret.getType();
+              final String memberTypeName = returns.get(i).getType();
               final PyType returnType = StringUtil.isNotEmpty(memberTypeName) ? parseNumpyDocType(function, memberTypeName) : null;
               final boolean isOptional = StringUtil.isNotEmpty(memberTypeName) && memberTypeName.contains("optional");
 
@@ -250,13 +250,13 @@ public class NumpyDocStringTypeProvider extends PyTypeProviderBase {
                 unionMembers.add(facade.createTupleType(members, function));
               }
             }
-            if (unionMembers.isEmpty()) {
-              return facade.createTupleType(members, function);
-            }
-            return facade.createUnionType(unionMembers);
+
+            final PyType type = unionMembers.isEmpty() ? facade.createTupleType(members, function) : facade.createUnionType(unionMembers);
+            return Ref.create(type);
         }
       }
     }
+
     return null;
   }
 
@@ -414,12 +414,9 @@ public class NumpyDocStringTypeProvider extends PyTypeProviderBase {
   @Nullable
   @Override
   public Ref<PyType> getReturnType(@NotNull PyCallable callable, @NotNull TypeEvalContext context) {
-    if (callable instanceof PyFunction) {
-      final PyType type = getCallType((PyFunction)callable, null, context);
-      if (type != null) {
-        return Ref.create(type);
-      }
-    }
-    return null;
+    return Optional
+      .ofNullable(PyUtil.as(callable, PyFunction.class))
+      .map(function -> getCallType(function, null, context))
+      .orElse(null);
   }
 }
index 528203a7146b433a418a68a97d43066d14db77cd..8455e659e4613b1e1a331ef7753efe2684f47a86 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2000-2014 JetBrains s.r.o.
+ * Copyright 2000-2016 JetBrains s.r.o.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -158,15 +158,17 @@ public class PyTypingTypeProvider extends PyTypeProviderBase {
 
   @Nullable
   @Override
-  public PyType getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
-    if ("typing.cast".equals(function.getQualifiedName()) && callSite instanceof PyCallExpression) {
-      final PyCallExpression callExpr = (PyCallExpression)callSite;
-      final PyExpression[] args = callExpr.getArguments();
-      if (args.length > 0) {
-        final PyExpression typeExpr = args[0];
-        return getType(typeExpr, new Context(context));
-      }
+  public Ref<PyType> getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
+    if ("typing.cast".equals(function.getQualifiedName())) {
+      return Optional
+        .ofNullable(as(callSite, PyCallExpression.class))
+        .map(PyCallExpression::getArguments)
+        .filter(args -> args.length > 0)
+        .map(args -> getType(args[0], new Context(context)))
+        .map(Ref::create)
+        .orElse(null);
     }
+
     return null;
   }
 
index 7c64f251772354de9b00269504a62972d3c08f74..04c72473709a0c5e344c1c37d1a7105cae387f7e 100644 (file)
@@ -17,6 +17,7 @@ package com.jetbrains.python.codeInsight.stdlib;
 
 import com.google.common.collect.ImmutableSet;
 import com.intellij.openapi.extensions.Extensions;
+import com.intellij.openapi.util.Ref;
 import com.intellij.openapi.vfs.VirtualFile;
 import com.intellij.psi.PsiElement;
 import com.intellij.psi.util.QualifiedName;
@@ -144,7 +145,16 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
 
   @Nullable
   @Override
-  public PyType getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
+  public Ref<PyType> getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
+    if (callSite != null && isListGetItem(function)) {
+      final PyExpression receiver = PyTypeChecker.getReceiver(callSite, function);
+      final Map<PyExpression, PyNamedParameter> mapping = PyCallExpressionHelper.mapArguments(callSite, function, context);
+      final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyGenericCall(receiver, mapping, context);
+      if (substitutions != null) {
+        return analyzeListGetItemCallType(receiver, mapping, substitutions, context);
+      }
+    }
+
     final String qname = getQualifiedName(function, callSite);
     if (qname != null) {
       if (OPEN_FUNCTIONS.contains(qname) && callSite instanceof PyCallExpression) {
@@ -152,10 +162,7 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
         final PyResolveContext resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context);
         final PyCallExpression.PyArgumentsMapping mapping = callExpr.mapArguments(resolveContext);
         if (mapping.getMarkedCallee() != null) {
-          final PyType type = getOpenFunctionType(qname, mapping.getMappedParameters(), callSite);
-          if (type != null) {
-            return type;
-          }
+          return getOpenFunctionType(qname, mapping.getMappedParameters(), callSite);
         }
       }
       else if ("__builtin__.tuple.__add__".equals(qname) && callSite instanceof PyBinaryExpression) {
@@ -164,15 +171,8 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
       else if ("__builtin__.tuple.__mul__".equals(qname) && callSite instanceof PyBinaryExpression) {
         return getTupleMultiplicationResultType((PyBinaryExpression)callSite, context);
       }
-      else if (callSite != null && isListGetItem(function)) {
-        final PyExpression receiver = PyTypeChecker.getReceiver(callSite, function);
-        final Map<PyExpression, PyNamedParameter> mapping = PyCallExpressionHelper.mapArguments(callSite, function, context);
-        final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyGenericCall(receiver, mapping, context);
-        if (substitutions != null) {
-          return analyzeListGetItemCallType(receiver, mapping, substitutions, context);
-        }
-      }
     }
+
     return null;
   }
 
@@ -186,10 +186,10 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
   }
 
   @Nullable
-  private static PyType analyzeListGetItemCallType(@Nullable PyExpression receiver,
-                                                   @NotNull Map<PyExpression, PyNamedParameter> parameters,
-                                                   @NotNull Map<PyGenericType, PyType> substitutions,
-                                                   @NotNull TypeEvalContext context) {
+  private static Ref<PyType> analyzeListGetItemCallType(@Nullable PyExpression receiver,
+                                                        @NotNull Map<PyExpression, PyNamedParameter> parameters,
+                                                        @NotNull Map<PyGenericType, PyType> substitutions,
+                                                        @NotNull TypeEvalContext context) {
     if (parameters.size() != 1 || substitutions.size() != 1) {
       return null;
     }
@@ -204,25 +204,28 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
     }
 
     if (PyABCUtil.isSubtype(firstArgumentType, PyNames.ABC_INTEGRAL, context)) {
-      return substitutions.values().iterator().next();
+      return Ref.create(substitutions.values().iterator().next());
     }
 
     if (PyNames.SLICE.equals(firstArgumentType.getName()) && firstArgumentType.isBuiltin()) {
-      return Optional
-        .ofNullable(receiver)
-        .map(context::getType)
-        .orElseGet(() -> PyTypeChecker.substitute(PyBuiltinCache.getInstance(receiver).getListType(), substitutions, context));
+      return Ref.create(
+        Optional
+          .ofNullable(receiver)
+          .map(context::getType)
+          .orElseGet(() -> PyTypeChecker.substitute(PyBuiltinCache.getInstance(receiver).getListType(), substitutions, context))
+      );
     }
 
     return null;
   }
 
   @Nullable
-  private static PyType getTupleMultiplicationResultType(@NotNull PyBinaryExpression multiplication, @NotNull TypeEvalContext context) {
+  private static Ref<PyType> getTupleMultiplicationResultType(@NotNull PyBinaryExpression multiplication, @NotNull TypeEvalContext context) {
     final PyTupleType leftTupleType = as(context.getType(multiplication.getLeftExpression()), PyTupleType.class);
     if (leftTupleType == null) {
       return null;
     }
+
     PyExpression rightExpression = multiplication.getRightExpression();
     if (rightExpression instanceof PyReferenceExpression) {
       final PsiElement target = ((PyReferenceExpression)rightExpression).getReference().resolve();
@@ -230,10 +233,12 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
         rightExpression = ((PyTargetExpression)target).findAssignedValue();
       }
     }
+
     if (rightExpression instanceof PyNumericLiteralExpression && ((PyNumericLiteralExpression)rightExpression).isIntegerLiteral()) {
       if (leftTupleType.isHomogeneous()) {
-        return leftTupleType;
+        return Ref.create(leftTupleType);
       }
+
       final int multiplier = ((PyNumericLiteralExpression)rightExpression).getBigIntegerValue().intValue();
       final int originalSize = leftTupleType.getElementCount();
       // Heuristic
@@ -244,17 +249,19 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
             elementTypes[i * originalSize + j] = leftTupleType.getElementType(j);
           }
         }
-        return PyTupleType.create(multiplication, elementTypes);
+        return Ref.create(PyTupleType.create(multiplication, elementTypes));
       }
     }
+
     return null;
   }
 
   @Nullable
-  private static PyType getTupleConcatenationResultType(@NotNull PyBinaryExpression addition, @NotNull TypeEvalContext context) {
-    final PyTupleType leftTupleType = as(context.getType(addition.getLeftExpression()), PyTupleType.class);
+  private static Ref<PyType> getTupleConcatenationResultType(@NotNull PyBinaryExpression addition, @NotNull TypeEvalContext context) {
     if (addition.getRightExpression() != null) {
+      final PyTupleType leftTupleType = as(context.getType(addition.getLeftExpression()), PyTupleType.class);
       final PyTupleType rightTupleType = as(context.getType(addition.getRightExpression()), PyTupleType.class);
+
       if (leftTupleType != null && rightTupleType != null) {
         if (leftTupleType.isHomogeneous() || rightTupleType.isHomogeneous()) {
           // We may try to find the common type of elements of two homogeneous tuple as an alternative
@@ -268,9 +275,11 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
         for (int i = 0; i < rightTupleType.getElementCount(); i++) {
           elementTypes[i + leftTupleType.getElementCount()] = rightTupleType.getElementType(i);
         }
-        return PyTupleType.create(addition, elementTypes);
+
+        return Ref.create(PyTupleType.create(addition, elementTypes));
       }
     }
+
     return null;
   }
 
@@ -310,10 +319,10 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
     return null;
   }
 
-  @Nullable
-  private static PyType getOpenFunctionType(@NotNull String callQName,
-                                            @NotNull Map<PyExpression, PyNamedParameter> arguments,
-                                            @NotNull PsiElement anchor) {
+  @NotNull
+  private static Ref<PyType> getOpenFunctionType(@NotNull String callQName,
+                                                 @NotNull Map<PyExpression, PyNamedParameter> arguments,
+                                                 @NotNull PsiElement anchor) {
     String mode = "r";
     for (Map.Entry<PyExpression, PyNamedParameter> entry : arguments.entrySet()) {
       final PyNamedParameter parameter = entry.getValue();
@@ -328,17 +337,17 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
         }
       }
     }
-    final LanguageLevel level = LanguageLevel.forElement(anchor);
 
-    if (level.isPy3K() || "io.open".equals(callQName)) {
+    if (LanguageLevel.forElement(anchor).isAtLeast(LanguageLevel.PYTHON30) || "io.open".equals(callQName)) {
       if (mode.contains("b")) {
-        return PyTypeParser.getTypeByName(anchor, PY3K_BINARY_FILE_TYPE);
-      } else {
-        return PyTypeParser.getTypeByName(anchor, PY3K_TEXT_FILE_TYPE);
+        return Ref.create(PyTypeParser.getTypeByName(anchor, PY3K_BINARY_FILE_TYPE));
+      }
+      else {
+        return Ref.create(PyTypeParser.getTypeByName(anchor, PY3K_TEXT_FILE_TYPE));
       }
     }
 
-    return PyTypeParser.getTypeByName(anchor, PY2K_FILE_TYPE);
+    return Ref.create(PyTypeParser.getTypeByName(anchor, PY2K_FILE_TYPE));
   }
 
   @Nullable
index dca8496cb54f338c9f291df6b6ea017740cf36bd..e88304ef287ad8e4b77462edce2472d0ec2a0fd2 100644 (file)
@@ -194,13 +194,10 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
     for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) {
       final Ref<PyType> returnTypeRef = typeProvider.getReturnType(this, context);
       if (returnTypeRef != null) {
-        final PyType returnType = returnTypeRef.get();
-        if (returnType != null) {
-          returnType.assertValid(typeProvider.toString());
-        }
-        return returnType;
+        return derefType(returnTypeRef, typeProvider);
       }
     }
+
     if (context.allowReturnTypes(this)) {
       final Ref<? extends PyType> yieldTypeRef = getYieldStatementType(context);
       if (yieldTypeRef != null) {
@@ -208,6 +205,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
       }
       return getReturnStatementType(context);
     }
+
     return null;
   }
 
@@ -215,17 +213,26 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
   @Override
   public PyType getCallType(@NotNull TypeEvalContext context, @NotNull PyCallSiteExpression callSite) {
     for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) {
-      final PyType type = typeProvider.getCallType(this, callSite, context);
-      if (type != null) {
-        type.assertValid(typeProvider.toString());
-        return type;
+      final Ref<PyType> typeRef = typeProvider.getCallType(this, callSite, context);
+      if (typeRef != null) {
+        return derefType(typeRef, typeProvider);
       }
     }
+
     final PyExpression receiver = PyTypeChecker.getReceiver(callSite, this);
     final Map<PyExpression, PyNamedParameter> mapping = PyCallExpressionHelper.mapArguments(callSite, this, context);
     return getCallType(receiver, mapping, context);
   }
 
+  @Nullable
+  private static PyType derefType(@NotNull Ref<PyType> typeRef, @NotNull PyTypeProvider typeProvider) {
+    final PyType type = typeRef.get();
+    if (type != null) {
+      type.assertValid(typeProvider.toString());
+    }
+    return type;
+  }
+
   @Nullable
   @Override
   public PyType getCallType(@Nullable PyExpression receiver,
index 923b1f6521db926b48a7e4e140567e9e34fdeea5..dc7cf7eb5745e90f19aaf7cb6b2c4e7609da00c5 100644 (file)
@@ -91,42 +91,49 @@ public class PyiTypeProvider extends PyTypeProviderBase {
 
   @Nullable
   @Override
-  public PyType getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
+  public Ref<PyType> getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
     if (callSite != null) {
       final PsiElement pythonStub = PyiUtil.getPythonStub(function);
+
       if (pythonStub instanceof PyFunction) {
-        final PyFunction functionStub = (PyFunction)pythonStub;
-        return getOverloadedCallType(functionStub, callSite, context);
+        return getOverloadedCallType((PyFunction)pythonStub, callSite, context);
       }
       else if (function.getContainingFile() instanceof PyiFile) {
         return getOverloadedCallType(function, callSite, context);
       }
     }
+
     return null;
   }
 
   @Nullable
-  private static PyType getOverloadedCallType(@NotNull PyFunction function, @NotNull PyCallSiteExpression callSite,
-                                              @NotNull TypeEvalContext context) {
+  private static Ref<PyType> getOverloadedCallType(@NotNull PyFunction function,
+                                                   @NotNull PyCallSiteExpression callSite,
+                                                   @NotNull TypeEvalContext context) {
     if (isOverload(function, context)) {
-      final List<PyType> matchedReturnTypes = new ArrayList<>();
-      final List<PyType> allReturnTypes = new ArrayList<>();
       final List<PyFunction> overloads = getOverloads(function, context);
+      final List<PyType> allReturnTypes = new ArrayList<>();
+      final List<PyType> matchedReturnTypes = new ArrayList<>();
+
       for (PyFunction overload : overloads) {
-        final Map<PyExpression, PyNamedParameter> mapping = mapArguments(callSite, overload, context);
-        final PyExpression receiver = PyTypeChecker.getReceiver(callSite, overload);
-        final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyGenericCall(receiver, mapping, context);
         final PyType returnType = context.getReturnType(overload);
         if (!PyTypeChecker.hasGenerics(returnType, context)) {
           allReturnTypes.add(returnType);
         }
+
+        final PyExpression receiver = PyTypeChecker.getReceiver(callSite, overload);
+        final Map<PyExpression, PyNamedParameter> mapping = mapArguments(callSite, overload, context);
+        final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyGenericCall(receiver, mapping, context);
+
         final PyType unifiedType = substitutions != null ? PyTypeChecker.substitute(returnType, substitutions, context) : null;
         if (unifiedType != null) {
           matchedReturnTypes.add(unifiedType);
         }
       }
-      return PyUnionType.union(matchedReturnTypes.isEmpty() ? allReturnTypes : matchedReturnTypes);
+
+      return Ref.create(PyUnionType.union(matchedReturnTypes.isEmpty() ? allReturnTypes : matchedReturnTypes));
     }
+
     return null;
   }