1 // Copyright 2000-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
2 package com.jetbrains.python.psi.types;
4 import com.intellij.openapi.util.Pair;
5 import com.intellij.openapi.util.RecursionManager;
6 import com.intellij.psi.PsiElement;
7 import com.intellij.psi.PsiFile;
8 import com.intellij.psi.ResolveResult;
9 import com.intellij.util.ArrayUtil;
10 import com.intellij.util.containers.ContainerUtil;
11 import com.jetbrains.python.PyNames;
12 import com.jetbrains.python.PythonRuntimeService;
13 import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
14 import com.jetbrains.python.codeInsight.typing.PyProtocolsKt;
15 import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
16 import com.jetbrains.python.psi.*;
17 import com.jetbrains.python.psi.impl.PyBuiltinCache;
18 import com.jetbrains.python.psi.impl.PyTypeProvider;
19 import com.jetbrains.python.psi.resolve.PyResolveContext;
20 import com.jetbrains.python.psi.resolve.RatedResolveResult;
21 import com.jetbrains.python.pyi.PyiFile;
22 import com.jetbrains.python.sdk.PythonSdkUtil;
23 import one.util.streamex.StreamEx;
24 import org.jetbrains.annotations.NotNull;
25 import org.jetbrains.annotations.Nullable;
28 import java.util.function.Function;
30 import static com.jetbrains.python.PyNames.FUNCTION;
31 import static com.jetbrains.python.psi.PyUtil.as;
32 import static com.jetbrains.python.psi.PyUtil.getReturnTypeToAnalyzeAsCallType;
33 import static com.jetbrains.python.psi.impl.PyCallExpressionHelper.*;
38 public final class PyTypeChecker {
39 private PyTypeChecker() {
43 * See {@link PyTypeChecker#match(PyType, PyType, TypeEvalContext, Map)} for description.
45 public static boolean match(@Nullable PyType expected, @Nullable PyType actual, @NotNull TypeEvalContext context) {
46 return match(expected, actual, new MatchContext(context, new HashMap<>())).orElse(true);
50 * Checks whether a type {@code actual} can be placed where {@code expected} is expected.
52 * For example {@code int} matches {@code object}, while {@code str} doesn't match {@code int}.
53 * Work for builtin types, classes, tuples etc.
55 * Whether it's unknown if {@code actual} match {@code expected} the method returns {@code true}.
57 * @implNote This behavior may be changed in future by replacing {@code boolean} with {@code Optional<Boolean>} and updating the clients.
59 * @param expected expected type
60 * @param actual type to be matched against expected
61 * @param context type evaluation context
62 * @param substitutions map of substitutions for {@code expected} type
63 * @return {@code false} if {@code expected} and {@code actual} don't match, true otherwise
65 public static boolean match(@Nullable PyType expected,
66 @Nullable PyType actual,
67 @NotNull TypeEvalContext context,
68 @NotNull Map<PyGenericType, PyType> substitutions) {
69 return match(expected, actual, new MatchContext(context, substitutions)).orElse(true);
73 private static Optional<Boolean> match(@Nullable PyType expected, @Nullable PyType actual, @NotNull MatchContext context) {
74 final Optional<Boolean> result = RecursionManager.doPreventingRecursion(
75 Pair.create(expected, actual),
77 () -> matchImpl(expected, actual, context)
80 return result == null ? Optional.of(true) : result;
84 * Perform type matching.
86 * Implementation details:
88 * <li>The method mutates {@code context.substitutions} map adding new entries into it
89 * <li>The order of match subroutine calls is important
90 * <li>The method may recursively call itself
94 private static Optional<Boolean> matchImpl(@Nullable PyType expected, @Nullable PyType actual, @NotNull MatchContext context) {
95 for (PyTypeCheckerExtension extension : PyTypeCheckerExtension.EP_NAME.getExtensionList()) {
96 final Optional<Boolean> result = extension.match(expected, actual, context.context, context.substitutions);
97 if (result.isPresent()) {
102 if (expected instanceof PyClassType) {
103 Optional<Boolean> match = matchObject((PyClassType)expected, actual);
104 if (match.isPresent()) {
109 if (actual instanceof PyGenericType && context.reversedSubstitutions) {
110 return Optional.of(match((PyGenericType)actual, expected, context));
113 if (expected instanceof PyGenericType) {
114 return Optional.of(match((PyGenericType)expected, actual, context));
117 if (expected == null || actual == null || isUnknown(actual, context.context)) {
118 return Optional.of(true);
121 if (actual instanceof PyUnionType) {
122 return Optional.of(match(expected, (PyUnionType)actual, context));
125 if (expected instanceof PyUnionType) {
126 return Optional.of(match((PyUnionType)expected, actual, context));
129 if (expected instanceof PyClassType && actual instanceof PyClassType) {
130 Optional<Boolean> match = match((PyClassType)expected, (PyClassType)actual, context);
131 if (match.isPresent()) {
136 if (actual instanceof PyStructuralType && ((PyStructuralType)actual).isInferredFromUsages()) {
137 return Optional.of(true);
140 if (expected instanceof PyStructuralType) {
141 return Optional.of(match((PyStructuralType)expected, actual, context.context));
144 if (actual instanceof PyStructuralType && expected instanceof PyClassType) {
145 final Set<String> expectedAttributes = ((PyClassType)expected).getMemberNames(true, context.context);
146 return Optional.of(expectedAttributes.containsAll(((PyStructuralType)actual).getAttributeNames()));
149 if (actual instanceof PyCallableType && expected instanceof PyCallableType) {
150 final PyCallableType expectedCallable = (PyCallableType)expected;
151 final PyCallableType actualCallable = (PyCallableType)actual;
152 final Optional<Boolean> match = match(expectedCallable, actualCallable, context);
153 if (match.isPresent()) {
158 // remove after making PyNoneType inheriting PyClassType
159 if (expected instanceof PyNoneType) {
160 return Optional.of(actual instanceof PyNoneType);
163 if (expected instanceof PyModuleType) {
164 return Optional.of(actual instanceof PyModuleType && ((PyModuleType)expected).getModule() == ((PyModuleType)actual).getModule());
167 if (expected instanceof PyClassType && actual instanceof PyModuleType) {
168 return match(expected, ((PyModuleType)actual).getModuleClassType(), context);
171 return Optional.of(matchNumericTypes(expected, actual));
175 * Check whether {@code expected} is Python *object* or *type*.
177 * {@see PyTypeChecker#match(PyType, PyType, TypeEvalContext, Map)}
180 private static Optional<Boolean> matchObject(@NotNull PyClassType expected, @Nullable PyType actual) {
181 if (ArrayUtil.contains(expected.getName(), PyNames.OBJECT, PyNames.TYPE)) {
182 final PyBuiltinCache builtinCache = PyBuiltinCache.getInstance(expected.getPyClass());
183 if (expected.equals(builtinCache.getObjectType())) {
184 return Optional.of(true);
186 if (expected.equals(builtinCache.getTypeType()) &&
187 actual instanceof PyInstantiableType && ((PyInstantiableType)actual).isDefinition()) {
188 return Optional.of(true);
191 return Optional.empty();
195 * Match {@code actual} versus {@code PyGenericType expected}.
197 * The method mutates {@code context.substitutions} map adding new entries into it
199 private static boolean match(@NotNull PyGenericType expected, @Nullable PyType actual, @NotNull MatchContext context) {
200 if (expected.isDefinition() && actual instanceof PyInstantiableType && !((PyInstantiableType<?>)actual).isDefinition()) {
204 final PyType substitution = context.substitutions.get(expected);
205 PyType bound = expected.getBound();
206 // Promote int in Type[TypeVar('T', int)] to Type[int] before checking that bounds match
207 if (expected.isDefinition()) {
208 final Function<PyType, PyType> toDefinition = t -> t instanceof PyInstantiableType ? ((PyInstantiableType<?>)t).toClass() : t;
209 bound = PyUnionType.union(PyTypeUtil.toStream(bound).map(toDefinition).toList());
212 Optional<Boolean> match = match(bound, actual, context);
213 if (match.isPresent() && !match.get()) {
217 if (substitution != null) {
218 if (expected.equals(actual) || substitution.equals(expected)) {
222 Optional<Boolean> recursiveMatch = RecursionManager.doPreventingRecursion(
223 expected, false, context.reversedSubstitutions
224 ? () -> match(actual, substitution, context)
225 : () -> match(substitution, actual, context)
227 return recursiveMatch != null ? recursiveMatch.orElse(false) : false;
230 if (actual != null) {
231 context.substitutions.put(expected, actual);
233 else if (bound != null) {
234 context.substitutions.put(expected, PyUnionType.createWeakType(bound));
240 private static boolean match(@NotNull PyType expected, @NotNull PyUnionType actual, @NotNull MatchContext context) {
241 if (expected instanceof PyTupleType) {
242 Optional<Boolean> match = match((PyTupleType)expected, actual, context);
243 if (match.isPresent()) {
248 return ContainerUtil.or(actual.getMembers(), type -> match(expected, type, context).orElse(false));
252 private static Optional<Boolean> match(@NotNull PyTupleType expected, @NotNull PyUnionType actual, @NotNull MatchContext context) {
253 final int elementCount = expected.getElementCount();
255 if (!expected.isHomogeneous() && consistsOfSameElementNumberTuples(actual, elementCount)) {
256 return Optional.of(substituteExpectedElementsWithUnions(expected, elementCount, actual, context));
259 return Optional.empty();
262 private static boolean match(@NotNull PyUnionType expected, @NotNull PyType actual, @NotNull MatchContext context) {
263 return ContainerUtil.or(expected.getMembers(), type -> match(type, actual, context).orElse(true));
267 private static Optional<Boolean> match(@NotNull PyClassType expected, @NotNull PyClassType actual, @NotNull MatchContext matchContext) {
268 if (expected.equals(actual)) {
269 return Optional.of(true);
272 final TypeEvalContext context = matchContext.context;
274 if (expected.isDefinition() ^ actual.isDefinition() && !PyProtocolsKt.isProtocol(expected, context)) {
275 if (!expected.isDefinition() && actual.isDefinition()) {
276 final PyClassLikeType metaClass = actual.getMetaClassType(context, true);
277 return Optional.of(metaClass != null && match((PyType)expected, metaClass.toInstance(), matchContext).orElse(true));
279 return Optional.of(false);
282 if (expected instanceof PyTupleType && actual instanceof PyTupleType) {
283 return match((PyTupleType)expected, (PyTupleType)actual, matchContext);
286 if (expected instanceof PyLiteralType) {
287 return Optional.of(actual instanceof PyLiteralType && PyLiteralType.Companion.match((PyLiteralType)expected, (PyLiteralType)actual));
290 if (actual instanceof PyTypedDictType) {
291 if (!((PyTypedDictType)actual).isInferred()) {
292 final Optional<Boolean> match = PyTypedDictType.Companion.checkStructuralCompatibility(expected, (PyTypedDictType)actual, context);
293 if (match.isPresent()) {
297 if (expected instanceof PyTypedDictType) {
298 return Optional.of(PyTypedDictType.Companion.match((PyTypedDictType)expected, (PyTypedDictType)actual, context));
302 final PyClass superClass = expected.getPyClass();
303 final PyClass subClass = actual.getPyClass();
304 final boolean matchClasses = matchClasses(superClass, subClass, context);
306 if (PyProtocolsKt.isProtocol(expected, context) && !matchClasses) {
307 if (expected instanceof PyCollectionType && !matchGenerics((PyCollectionType)expected, actual, matchContext)) {
308 return Optional.of(false);
311 // methods from the actual will be matched against method definitions from the expected below
312 // here we make substitutions from expected definition to its usage
314 .of(PyTypeProvider.EP_NAME.getExtensionList())
315 .map(provider -> provider.getGenericType(superClass, context))
316 .select(PyCollectionType.class)
318 .ifPresent(it -> matchGenerics(it, expected, matchContext));
320 for (kotlin.Pair<PyTypedElement, List<RatedResolveResult>> pair : PyProtocolsKt.inspectProtocolSubclass(expected, actual, context)) {
321 final List<RatedResolveResult> subclassElements = pair.getSecond();
322 if (ContainerUtil.isEmpty(subclassElements)) {
323 return Optional.of(false);
326 final PyType protocolElementType = dropSelfIfNeeded(expected, context.getType(pair.getFirst()), context);
328 final boolean elementResult = StreamEx
329 .of(subclassElements)
330 .map(ResolveResult::getElement)
331 .select(PyTypedElement.class)
332 .map(context::getType)
334 subclassElementType -> {
335 return match(protocolElementType,
336 dropSelfIfNeeded(actual, subclassElementType, context), matchContext).orElse(true);
340 if (!elementResult) {
341 return Optional.of(false);
345 return Optional.of(true);
348 if (expected instanceof PyCollectionType) {
349 return Optional.of(match((PyCollectionType)expected, actual, matchContext));
353 if (expected instanceof PyTypingNewType && !expected.equals(actual) && superClass.equals(subClass)) {
354 return Optional.of(actual.getAncestorTypes(context).contains(expected));
356 return Optional.of(true);
359 if (expected.equals(actual)) {
360 return Optional.of(true);
362 return Optional.empty();
365 private static @Nullable PyType dropSelfIfNeeded(@NotNull PyClassType classType,
366 @Nullable PyType elementType,
367 @NotNull TypeEvalContext context) {
368 if (elementType instanceof PyFunctionType) {
369 PyFunctionType functionType = (PyFunctionType)elementType;
370 if (PyUtil.isInitOrNewMethod(functionType.getCallable()) || !classType.isDefinition()) {
371 return functionType.dropSelf(context);
379 private static Optional<Boolean> match(@NotNull PyTupleType expected, @NotNull PyTupleType actual, @NotNull MatchContext context) {
380 if (!expected.isHomogeneous() && !actual.isHomogeneous()) {
381 if (expected.getElementCount() != actual.getElementCount()) {
382 return Optional.of(false);
385 for (int i = 0; i < expected.getElementCount(); i++) {
386 if (!match(expected.getElementType(i), actual.getElementType(i), context).orElse(true)) {
387 return Optional.of(false);
390 return Optional.of(true);
393 if (expected.isHomogeneous() && !actual.isHomogeneous()) {
394 final PyType expectedElementType = expected.getIteratedItemType();
395 for (int i = 0; i < actual.getElementCount(); i++) {
396 if (!match(expectedElementType, actual.getElementType(i), context).orElse(true)) {
397 return Optional.of(false);
400 return Optional.of(true);
403 if (!expected.isHomogeneous() && actual.isHomogeneous()) {
404 return Optional.of(false);
407 return match(expected.getIteratedItemType(), actual.getIteratedItemType(), context);
410 private static boolean match(@NotNull PyCollectionType expected, @NotNull PyClassType actual, @NotNull MatchContext context) {
411 if (actual instanceof PyTupleType) {
412 return match(expected, (PyTupleType)actual, context);
415 final PyClass superClass = expected.getPyClass();
416 final PyClass subClass = actual.getPyClass();
418 return matchClasses(superClass, subClass, context.context) && matchGenerics(expected, actual, context);
421 private static boolean match(@NotNull PyCollectionType expected, @NotNull PyTupleType actual, @NotNull MatchContext context) {
422 if (!matchClasses(expected.getPyClass(), actual.getPyClass(), context.context)) {
426 final PyType superElementType = expected.getIteratedItemType();
427 final PyType subElementType = actual.getIteratedItemType();
429 return match(superElementType, subElementType, context).orElse(true);
432 private static boolean match(@NotNull PyStructuralType expected, @NotNull PyType actual, @NotNull TypeEvalContext context) {
433 if (actual instanceof PyStructuralType) {
434 return match(expected, (PyStructuralType)actual);
436 if (actual instanceof PyClassType) {
437 return match(expected, (PyClassType)actual, context);
439 if (actual instanceof PyModuleType) {
440 final PyFile module = ((PyModuleType)actual).getModule();
441 if (module.getLanguageLevel().isAtLeast(LanguageLevel.PYTHON37) && definesGetAttr(module, context)) {
446 final PyResolveContext resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(context);
447 return !ContainerUtil.exists(expected.getAttributeNames(), attribute -> ContainerUtil
448 .isEmpty(actual.resolveMember(attribute, null, AccessDirection.READ, resolveContext)));
451 private static boolean match(@NotNull PyStructuralType expected, @NotNull PyStructuralType actual) {
452 if (expected.isInferredFromUsages()) {
455 return expected.getAttributeNames().containsAll(actual.getAttributeNames());
458 private static boolean match(@NotNull PyStructuralType expected, @NotNull PyClassType actual, @NotNull TypeEvalContext context) {
459 if (overridesGetAttr(actual.getPyClass(), context)) {
462 final Set<String> actualAttributes = actual.getMemberNames(true, context);
463 return actualAttributes.containsAll(expected.getAttributeNames());
467 private static Optional<Boolean> match(@NotNull PyCallableType expected,
468 @NotNull PyCallableType actual,
469 @NotNull MatchContext matchContext) {
470 if (actual instanceof PyFunctionType && expected instanceof PyClassType && FUNCTION.equals(expected.getName())
471 && expected.equals(PyBuiltinCache.getInstance(actual.getCallable()).getObjectType(FUNCTION))) {
472 return Optional.of(true);
475 if (expected instanceof PyClassLikeType) {
476 return PyTypingTypeProvider.CALLABLE.equals(((PyClassLikeType)expected).getClassQName())
477 ? Optional.of(actual.isCallable())
481 if (expected.isCallable() && actual.isCallable()) {
482 final TypeEvalContext context = matchContext.context;
483 final List<PyCallableParameter> expectedParameters = expected.getParameters(context);
484 final List<PyCallableParameter> actualParameters = actual.getParameters(context);
485 if (expectedParameters != null && actualParameters != null) {
486 final int size = Math.min(expectedParameters.size(), actualParameters.size());
487 for (int i = 0; i < size; i++) {
488 final PyCallableParameter expectedParam = expectedParameters.get(i);
489 final PyCallableParameter actualParam = actualParameters.get(i);
490 // TODO: Check named and star params, not only positional ones
491 if (expectedParam.isSelf() && actualParam.isSelf()) {
492 if (!match(expectedParam.getType(context), actualParam.getType(context), matchContext).orElse(true)) {
493 return Optional.of(false);
497 final PyType actualParamType =
498 actualParam.isPositionalContainer() && couldBeMappedOntoPositionalContainer(expectedParam)
499 ? actualParam.getArgumentType(context)
500 : actualParam.getType(context);
502 // actual callable type could accept more general parameter type
503 if (!match(actualParamType, expectedParam.getType(context), matchContext.reverseSubstitutions()).orElse(true)) {
504 return Optional.of(false);
509 if (!match(expected.getReturnType(context), getActualReturnType(actual, context), matchContext).orElse(true)) {
510 return Optional.of(false);
512 return Optional.of(true);
514 return Optional.empty();
517 private static @Nullable PyType getActualReturnType(@NotNull PyCallableType actual, @NotNull TypeEvalContext context) {
518 PyCallable callable = actual.getCallable();
519 if (callable instanceof PyFunction) {
520 return getReturnTypeToAnalyzeAsCallType((PyFunction)callable, context);
522 return actual.getReturnType(context);
525 private static boolean couldBeMappedOntoPositionalContainer(@NotNull PyCallableParameter parameter) {
526 if (parameter.isPositionalContainer() || parameter.isKeywordContainer()) return false;
528 final var psi = parameter.getParameter();
530 final var namedPsi = psi.getAsNamed();
531 if (namedPsi != null && namedPsi.isKeywordOnly()) {
539 private static boolean consistsOfSameElementNumberTuples(@NotNull PyUnionType unionType, int elementCount) {
540 for (PyType type : unionType.getMembers()) {
541 if (type instanceof PyTupleType) {
542 final PyTupleType tupleType = (PyTupleType)type;
544 if (!tupleType.isHomogeneous() && elementCount != tupleType.getElementCount()) {
556 private static boolean substituteExpectedElementsWithUnions(@NotNull PyTupleType expected,
558 @NotNull PyUnionType actual,
559 @NotNull MatchContext context) {
560 for (int i = 0; i < elementCount; i++) {
561 final int currentIndex = i;
563 final PyType elementType = PyUnionType.union(
565 .of(actual.getMembers())
566 .select(PyTupleType.class)
567 .map(type -> type.getElementType(currentIndex))
571 if (!match(expected.getElementType(i), elementType, context).orElse(true)) {
579 private static boolean matchGenerics(@NotNull PyCollectionType expected, @NotNull PyType actual, @NotNull MatchContext context) {
580 // TODO: Match generic parameters based on the correspondence between the generic parameters of subClass and its base classes
581 final List<PyType> superElementTypes = expected.getElementTypes();
582 final PyCollectionType actualCollectionType = as(actual, PyCollectionType.class);
583 final List<PyType> subElementTypes = actualCollectionType != null ?
584 actualCollectionType.getElementTypes() :
585 Collections.emptyList();
586 for (int i = 0; i < superElementTypes.size(); i++) {
587 final PyType subElementType = i < subElementTypes.size() ? subElementTypes.get(i) : null;
588 if (!match(superElementTypes.get(i), subElementType, context).orElse(true)) {
595 private static boolean matchNumericTypes(PyType expected, PyType actual) {
596 if (expected instanceof PyClassType && actual instanceof PyClassType) {
597 final String superName = ((PyClassType)expected).getPyClass().getName();
598 final String subName = ((PyClassType)actual).getPyClass().getName();
599 final boolean subIsBool = "bool".equals(subName);
600 final boolean subIsInt = PyNames.TYPE_INT.equals(subName);
601 final boolean subIsLong = PyNames.TYPE_LONG.equals(subName);
602 final boolean subIsFloat = "float".equals(subName);
603 final boolean subIsComplex = "complex".equals(subName);
604 if (superName == null || subName == null ||
605 superName.equals(subName) ||
606 (PyNames.TYPE_INT.equals(superName) && subIsBool) ||
607 ((PyNames.TYPE_LONG.equals(superName) || PyNames.ABC_INTEGRAL.equals(superName)) && (subIsBool || subIsInt)) ||
608 (("float".equals(superName) || PyNames.ABC_REAL.equals(superName)) && (subIsBool || subIsInt || subIsLong)) ||
609 (("complex".equals(superName) || PyNames.ABC_COMPLEX.equals(superName)) && (subIsBool || subIsInt || subIsLong || subIsFloat)) ||
610 (PyNames.ABC_NUMBER.equals(superName) && (subIsBool || subIsInt || subIsLong || subIsFloat || subIsComplex))) {
617 public static boolean isUnknown(@Nullable PyType type, @NotNull TypeEvalContext context) {
618 return isUnknown(type, true, context);
621 public static boolean isUnknown(@Nullable PyType type, boolean genericsAreUnknown, @NotNull TypeEvalContext context) {
622 if (type == null || (genericsAreUnknown && type instanceof PyGenericType)) {
625 if (type instanceof PyFunctionType) {
626 final PyCallable callable = ((PyFunctionType)type).getCallable();
627 if (callable instanceof PyDecoratable &&
628 PyKnownDecoratorUtil.hasUnknownOrChangingReturnTypeDecorator((PyDecoratable)callable, context)) {
632 if (type instanceof PyUnionType) {
633 final PyUnionType union = (PyUnionType)type;
634 for (PyType t : union.getMembers()) {
635 if (isUnknown(t, genericsAreUnknown, context)) {
643 public static boolean hasGenerics(@Nullable PyType type, @NotNull TypeEvalContext context) {
644 final Set<PyGenericType> collected = new HashSet<>();
645 collectGenerics(type, context, collected, new HashSet<>());
646 return !collected.isEmpty();
649 private static void collectGenerics(@Nullable PyType type, @NotNull TypeEvalContext context, @NotNull Set<? super PyGenericType> collected,
650 @NotNull Set<? super PyType> visited) {
651 if (visited.contains(type)) {
655 if (type instanceof PyGenericType) {
656 collected.add((PyGenericType)type);
658 else if (type instanceof PyUnionType) {
659 final PyUnionType union = (PyUnionType)type;
660 for (PyType t : union.getMembers()) {
661 collectGenerics(t, context, collected, visited);
664 else if (type instanceof PyTupleType) {
665 final PyTupleType tuple = (PyTupleType)type;
666 final int n = tuple.isHomogeneous() ? 1 : tuple.getElementCount();
667 for (int i = 0; i < n; i++) {
668 collectGenerics(tuple.getElementType(i), context, collected, visited);
671 else if (type instanceof PyCollectionType) {
672 final PyCollectionType collection = (PyCollectionType)type;
673 for (PyType elementType : collection.getElementTypes()) {
674 collectGenerics(elementType, context, collected, visited);
677 else if (type instanceof PyCallableType && !(type instanceof PyClassLikeType)) {
678 final PyCallableType callable = (PyCallableType)type;
679 final List<PyCallableParameter> parameters = callable.getParameters(context);
680 if (parameters != null) {
681 for (PyCallableParameter parameter : parameters) {
682 if (parameter != null) {
683 collectGenerics(parameter.getType(context), context, collected, visited);
687 collectGenerics(callable.getReturnType(context), context, collected, visited);
692 public static PyType substitute(@Nullable PyType type, @NotNull Map<PyGenericType, PyType> substitutions,
693 @NotNull TypeEvalContext context) {
694 if (hasGenerics(type, context)) {
695 if (type instanceof PyGenericType) {
696 final PyGenericType typeVar = (PyGenericType)type;
697 PyType substitution = substitutions.get(typeVar);
698 if (substitution == null) {
699 if (!typeVar.isDefinition()) {
700 final PyInstantiableType<?> classType = as(substitutions.get(typeVar.toClass()), PyInstantiableType.class);
701 if (classType != null) {
702 substitution = classType.toInstance();
706 final PyInstantiableType<?> instanceType = as(substitutions.get(typeVar.toInstance()), PyInstantiableType.class);
707 if (instanceType != null) {
708 substitution = instanceType.toClass();
712 if (substitution instanceof PyGenericType && !typeVar.equals(substitution) && substitutions.containsKey(substitution)) {
713 return substitute(substitution, substitutions, context);
717 else if (type instanceof PyUnionType) {
718 return ((PyUnionType)type).map(member -> substitute(member, substitutions, context));
720 else if (type instanceof PyCollectionTypeImpl) {
721 final PyCollectionTypeImpl collection = (PyCollectionTypeImpl)type;
722 final List<PyType> elementTypes = collection.getElementTypes();
723 final List<PyType> substitutes = new ArrayList<>();
724 for (PyType elementType : elementTypes) {
725 substitutes.add(substitute(elementType, substitutions, context));
727 return new PyCollectionTypeImpl(collection.getPyClass(), collection.isDefinition(), substitutes);
729 else if (type instanceof PyTupleType) {
730 final PyTupleType tupleType = (PyTupleType)type;
731 final PyClass tupleClass = tupleType.getPyClass();
733 final List<PyType> oldElementTypes = tupleType.isHomogeneous()
734 ? Collections.singletonList(tupleType.getIteratedItemType())
735 : tupleType.getElementTypes();
737 final List<PyType> newElementTypes =
738 ContainerUtil.map(oldElementTypes, elementType -> substitute(elementType, substitutions, context));
740 return new PyTupleType(tupleClass, newElementTypes, tupleType.isHomogeneous());
742 else if (type instanceof PyCallableType && !(type instanceof PyClassLikeType)) {
743 final PyCallableType callable = (PyCallableType)type;
744 List<PyCallableParameter> substParams = null;
745 final List<PyCallableParameter> parameters = callable.getParameters(context);
746 if (parameters != null) {
747 substParams = new ArrayList<>();
748 for (PyCallableParameter parameter : parameters) {
749 final PyType substType = substitute(parameter.getType(context), substitutions, context);
750 final PyParameter psi = parameter.getParameter();
751 final PyCallableParameter subst = psi != null ?
752 PyCallableParameterImpl.psi(psi, substType) :
753 PyCallableParameterImpl.nonPsi(parameter.getName(), substType, parameter.getDefaultValue());
754 substParams.add(subst);
757 final PyType substResult = substitute(callable.getReturnType(context), substitutions, context);
758 return new PyCallableTypeImpl(substParams, substResult);
765 public static Map<PyGenericType, PyType> unifyGenericCall(@Nullable PyExpression receiver,
766 @NotNull Map<PyExpression, PyCallableParameter> arguments,
767 @NotNull TypeEvalContext context) {
768 final Map<PyGenericType, PyType> substitutions = unifyReceiver(receiver, context);
769 for (Map.Entry<PyExpression, PyCallableParameter> entry : getRegularMappedParameters(arguments).entrySet()) {
770 final PyCallableParameter paramWrapper = entry.getValue();
771 final PyType expectedType = paramWrapper.getArgumentType(context);
772 PyType actualType = PyLiteralType.Companion.promoteToLiteral(entry.getKey(), expectedType, context);
773 if (paramWrapper.isSelf()) {
774 // TODO find out a better way to pass the corresponding function inside
775 final PyParameter param = paramWrapper.getParameter();
776 final PyFunction function = as(ScopeUtil.getScopeOwner(param), PyFunction.class);
777 if (function != null && function.getModifier() == PyFunction.Modifier.CLASSMETHOD) {
778 actualType = PyTypeUtil.toStream(actualType)
779 .select(PyClassLikeType.class)
780 .map(PyClassLikeType::toClass)
781 .select(PyType.class)
782 .foldLeft(PyUnionType::union)
785 else if (PyUtil.isInitMethod(function)) {
786 actualType = PyTypeUtil.toStream(actualType)
787 .select(PyInstantiableType.class)
788 .map(PyInstantiableType::toInstance)
789 .select(PyType.class)
790 .foldLeft(PyUnionType::union)
794 if (!match(expectedType, actualType, context, substitutions)) {
798 if (!matchContainer(getMappedPositionalContainer(arguments), getArgumentsMappedToPositionalContainer(arguments), substitutions,
802 if (!matchContainer(getMappedKeywordContainer(arguments), getArgumentsMappedToKeywordContainer(arguments), substitutions, context)) {
805 return substitutions;
808 private static boolean matchContainer(@Nullable PyCallableParameter container, @NotNull List<? extends PyExpression> arguments,
809 @NotNull Map<PyGenericType, PyType> substitutions, @NotNull TypeEvalContext context) {
810 if (container == null) {
813 final List<PyType> types = ContainerUtil.map(arguments, context::getType);
814 return match(container.getArgumentType(context), PyUnionType.union(types), context, substitutions);
818 public static Map<PyGenericType, PyType> unifyReceiver(@Nullable PyExpression receiver, @NotNull TypeEvalContext context) {
819 final Map<PyGenericType, PyType> substitutions = new LinkedHashMap<>();
820 // Collect generic params of object type
821 final Set<PyGenericType> generics = new LinkedHashSet<>();
822 final PyType qualifierType = receiver != null ? context.getType(receiver) : null;
823 collectGenerics(qualifierType, context, generics, new HashSet<>());
824 for (PyGenericType t : generics) {
825 substitutions.put(t, t);
827 if (qualifierType != null) {
828 for (PyClassType type : PyTypeUtil.toStream(qualifierType).select(PyClassType.class)) {
829 for (PyTypeProvider provider : PyTypeProvider.EP_NAME.getExtensionList()) {
830 final PyType genericType = provider.getGenericType(type.getPyClass(), context);
831 final Set<PyGenericType> providedTypeGenerics = new LinkedHashSet<>();
833 if (genericType != null) {
834 match(genericType, type, context, substitutions);
835 collectGenerics(genericType, context, providedTypeGenerics, new HashSet<>());
838 for (Map.Entry<PyType, PyType> entry : provider.getGenericSubstitutions(type.getPyClass(), context).entrySet()) {
839 final PyGenericType genericKey = as(entry.getKey(), PyGenericType.class);
840 final PyType value = entry.getValue();
842 if (genericKey != null &&
844 !substitutions.containsKey(genericKey) &&
845 !providedTypeGenerics.contains(genericKey)) {
846 substitutions.put(genericKey, value);
853 replaceUnresolvedGenericsWithAny(substitutions);
854 return substitutions;
857 private static void replaceUnresolvedGenericsWithAny(@NotNull Map<PyGenericType, PyType> substitutions) {
858 final List<PyType> unresolvedGenerics =
859 ContainerUtil.filter(substitutions.values(), type -> type instanceof PyGenericType && !substitutions.containsKey(type));
861 for (PyType unresolvedGeneric : unresolvedGenerics) {
862 substitutions.put((PyGenericType)unresolvedGeneric, null);
866 private static boolean matchClasses(@Nullable PyClass superClass, @Nullable PyClass subClass, @NotNull TypeEvalContext context) {
867 if (superClass == null ||
869 subClass.isSubclass(superClass, context) ||
870 PyABCUtil.isSubclass(subClass, superClass, context) ||
871 isStrUnicodeMatch(subClass, superClass) ||
872 isBytearrayBytesStringMatch(subClass, superClass) ||
873 PyUtil.hasUnresolvedAncestors(subClass, context)) {
877 final String superName = superClass.getName();
878 return superName != null && superName.equals(subClass.getName());
882 private static boolean isStrUnicodeMatch(@NotNull PyClass subClass, @NotNull PyClass superClass) {
883 // TODO: Check for subclasses as well
884 return PyNames.TYPE_STR.equals(subClass.getName()) && PyNames.TYPE_UNICODE.equals(superClass.getName());
887 private static boolean isBytearrayBytesStringMatch(@NotNull PyClass subClass, @NotNull PyClass superClass) {
888 if (!PyNames.TYPE_BYTEARRAY.equals(subClass.getName())) return false;
890 final PsiFile subClassFile = subClass.getContainingFile();
892 final boolean isPy2 = subClassFile instanceof PyiFile
893 ? PythonRuntimeService.getInstance().getLanguageLevelForSdk(PythonSdkUtil.findPythonSdk(subClassFile)).isPython2()
894 : LanguageLevel.forElement(subClass).isPython2();
896 final String superClassName = superClass.getName();
897 return isPy2 && PyNames.TYPE_STR.equals(superClassName) || !isPy2 && PyNames.TYPE_BYTES.equals(superClassName);
901 public static Boolean isCallable(@Nullable PyType type) {
905 if (type instanceof PyUnionType) {
906 return isUnionCallable((PyUnionType)type);
908 if (type instanceof PyCallableType) {
909 return ((PyCallableType)type).isCallable();
911 if (type instanceof PyStructuralType && ((PyStructuralType)type).isInferredFromUsages()) {
914 if (type instanceof PyGenericType) {
915 if (((PyGenericType)type).isDefinition()) {
919 return isCallable(((PyGenericType)type).getBound());
925 * If at least one is callable -- it is callable.
926 * If at least one is unknown -- it is unknown.
927 * It is false otherwise.
930 private static Boolean isUnionCallable(@NotNull final PyUnionType type) {
931 for (final PyType member : type.getMembers()) {
932 final Boolean callable = isCallable(member);
933 if (callable == null) {
943 public static boolean definesGetAttr(@NotNull PyFile file, @NotNull TypeEvalContext context) {
944 if (file instanceof PyTypedElement) {
945 final PyType type = context.getType((PyTypedElement)file);
947 return resolveTypeMember(type, PyNames.GETATTR, context) != null;
954 public static boolean overridesGetAttr(@NotNull PyClass cls, @NotNull TypeEvalContext context) {
955 final PyType type = context.getType(cls);
957 if (resolveTypeMember(type, PyNames.GETATTR, context) != null) {
960 final PsiElement method = resolveTypeMember(type, PyNames.GETATTRIBUTE, context);
961 if (method != null && !PyBuiltinCache.getInstance(cls).isBuiltin(method)) {
969 private static PsiElement resolveTypeMember(@NotNull PyType type, @NotNull String name, @NotNull TypeEvalContext context) {
970 final PyResolveContext resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(context);
971 final List<? extends RatedResolveResult> results = type.resolveMember(name, null, AccessDirection.READ, resolveContext);
972 return !ContainerUtil.isEmpty(results) ? results.get(0).getElement() : null;
976 public static PyType getTargetTypeFromTupleAssignment(@NotNull PyTargetExpression target,
977 @NotNull PyTupleExpression parentTuple,
978 @NotNull PyType assignedType,
979 @NotNull TypeEvalContext context) {
980 if (assignedType instanceof PyTupleType) {
981 return getTargetTypeFromTupleAssignment(target, parentTuple, (PyTupleType)assignedType);
983 else if (assignedType instanceof PyClassLikeType) {
985 .of(((PyClassLikeType)assignedType).getAncestorTypes(context))
986 .select(PyNamedTupleType.class)
988 .map(t -> getTargetTypeFromTupleAssignment(target, parentTuple, t))
996 public static PyType getTargetTypeFromTupleAssignment(@NotNull PyTargetExpression target, @NotNull PyTupleExpression parentTuple,
997 @NotNull PyTupleType assignedTupleType) {
998 final int count = assignedTupleType.getElementCount();
999 final PyExpression[] elements = parentTuple.getElements();
1000 if (elements.length == count || assignedTupleType.isHomogeneous()) {
1001 final int index = ArrayUtil.indexOf(elements, target);
1003 return assignedTupleType.getElementType(index);
1005 for (int i = 0; i < count; i++) {
1006 PyExpression element = elements[i];
1007 while (element instanceof PyParenthesizedExpression) {
1008 element = ((PyParenthesizedExpression)element).getContainedExpression();
1010 if (element instanceof PyTupleExpression) {
1011 final PyType elementType = assignedTupleType.getElementType(i);
1012 if (elementType instanceof PyTupleType) {
1013 final PyType result = getTargetTypeFromTupleAssignment(target, (PyTupleExpression)element, (PyTupleType)elementType);
1014 if (result != null) {
1024 private static class MatchContext {
1027 private final TypeEvalContext context;
1030 private final Map<PyGenericType, PyType> substitutions; // mutable
1032 private final boolean reversedSubstitutions;
1034 MatchContext(@NotNull TypeEvalContext context,
1035 @NotNull Map<PyGenericType, PyType> substitutions) {
1036 this(context, substitutions, false);
1039 private MatchContext(@NotNull TypeEvalContext context,
1040 @NotNull Map<PyGenericType, PyType> substitutions,
1041 boolean reversedSubstitutions) {
1042 this.context = context;
1043 this.substitutions = substitutions;
1044 this.reversedSubstitutions = reversedSubstitutions;
1048 public MatchContext reverseSubstitutions() {
1049 return new MatchContext(context, substitutions, !reversedSubstitutions);