* @author anna
*/
public class LambdaUtil {
- private static final ThreadLocal<Map<PsiElement, PsiType>> ourFunctionTypes = new ThreadLocal<>();
private static final Logger LOG = Logger.getInstance(LambdaUtil.class);
@Nullable
parent = parent.getParent();
}
- final Map<PsiElement, PsiType> map = ourFunctionTypes.get();
- if (map != null) {
- final PsiType type = ObjectUtils.chooseNotNull(map.get(expression), map.get(element));
- if (type != null) {
- return type;
- }
+ PsiType type = ThreadLocalTypes.getElementType(expression);
+ if (type == null) type = ThreadLocalTypes.getElementType(element);
+ if (type != null) {
+ return type;
}
if (parent instanceof PsiArrayInitializerExpression) {
return false;
}
- @NotNull
- public static Map<PsiElement, PsiType> getFunctionalTypeMap() {
- Map<PsiElement, PsiType> map = ourFunctionTypes.get();
- if (map == null) {
- map = new HashMap<>();
- ourFunctionTypes.set(map);
- }
- return map;
- }
-
public static Map<PsiElement, String> checkReturnTypeCompatible(PsiLambdaExpression lambdaExpression, PsiType functionalInterfaceReturnType) {
Map<PsiElement, String> errors = new LinkedHashMap<>();
if (PsiType.VOID.equals(functionalInterfaceReturnType)) {
break;
}
- if (getFunctionalTypeMap().containsKey(lambdaExpression)) {
+ if (ThreadLocalTypes.hasBindingFor(lambdaExpression)) {
break;
}
}
break;
}
- if (parent instanceof PsiLambdaExpression && getFunctionalTypeMap().containsKey(parent)) {
+ if (parent instanceof PsiLambdaExpression && ThreadLocalTypes.hasBindingFor(parent)) {
break;
}
break;
}
if (MethodCandidateInfo.isOverloadCheck(psiCall.getArgumentList()) ||
- lambdaExpression != null && getFunctionalTypeMap().containsKey(lambdaExpression)) {
+ lambdaExpression != null && ThreadLocalTypes.hasBindingFor(lambdaExpression)) {
break;
}
public static <T> T performWithSubstitutedParameterBounds(final PsiTypeParameter[] typeParameters,
final PsiSubstitutor substitutor,
final Supplier<? extends T> producer) {
- try {
+ return ThreadLocalTypes.performWithTypes(map -> {
for (PsiTypeParameter parameter : typeParameters) {
final PsiClassType[] types = parameter.getExtendsListTypes();
if (types.length > 0) {
//don't glb to avoid flattening = Object&Interface would be preserved
//otherwise methods with different signatures could get same erasure
final PsiType upperBound = PsiIntersectionType.createIntersection(false, conjuncts.toArray(PsiType.EMPTY_ARRAY));
- getFunctionalTypeMap().put(parameter, upperBound);
+ map.forceType(parameter, upperBound);
}
}
return producer.get();
- }
- finally {
- for (PsiTypeParameter parameter : typeParameters) {
- getFunctionalTypeMap().remove(parameter);
- }
- }
+ });
}
public static <T> T performWithTargetType(@NotNull PsiElement element, @NotNull PsiType targetType, @NotNull Supplier<? extends T> producer) {
- Map<PsiElement, PsiType> map = getFunctionalTypeMap();
- PsiType prev = map.put(element, targetType);
- try {
+ return ThreadLocalTypes.performWithTypes(types -> {
+ types.forceType(element, targetType);
return producer.get();
- }
- finally {
- if (prev == null) {
- map.remove(element);
- } else {
- map.put(element, prev);
- }
- }
+ });
}
/**
--- /dev/null
+// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
+package com.intellij.psi;
+
+import com.intellij.openapi.util.RecursionGuard;
+import com.intellij.openapi.util.RecursionManager;
+import org.jetbrains.annotations.NotNull;
+import org.jetbrains.annotations.Nullable;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class ThreadLocalTypes {
+ private static final RecursionGuard<ThreadLocalTypes> ourGuard = RecursionManager.createGuard("ThreadLocalTypes");
+ private final Map<PsiElement, PsiType> myMap = new HashMap<>();
+
+ private ThreadLocalTypes() {}
+
+ @Nullable
+ public static PsiType getElementType(@NotNull PsiElement psi) {
+ List<? extends ThreadLocalTypes> stack = ourGuard.currentStack();
+ for (int i = stack.size() - 1; i >= 0; i--) {
+ ThreadLocalTypes types = stack.get(i);
+ PsiType type = types.myMap.get(psi);
+ if (type != null) {
+ ourGuard.prohibitResultCaching(types);
+ return type;
+ }
+ }
+ return null;
+ }
+
+ public static boolean hasBindingFor(@NotNull PsiElement psi) {
+ List<? extends ThreadLocalTypes> stack = ourGuard.currentStack();
+ for (int i = stack.size() - 1; i >= 0; i--) {
+ ThreadLocalTypes types = stack.get(i);
+ if (types.myMap.containsKey(psi)) {
+ ourGuard.prohibitResultCaching(types);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public static <T> T performWithTypes(@NotNull Function<ThreadLocalTypes, T> action) {
+ ThreadLocalTypes types = new ThreadLocalTypes();
+ return ourGuard.doPreventingRecursion(types, false, () -> action.apply(types));
+ }
+
+ public void forceType(@NotNull PsiElement psi, @Nullable PsiType type) {
+ myMap.put(psi, type);
+ }
+
+}
import org.jetbrains.annotations.Nullable;
import java.util.Arrays;
-import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
.map(expression -> PsiUtil.skipParenthesizedExprDown(expression))
.filter(expression -> expression != null && !(expression instanceof PsiFunctionalExpression))
.toArray(PsiExpression[]::new);
- Map<PsiElement, PsiType> expressionTypes = LambdaUtil.getFunctionalTypeMap();
- try {
+ return ThreadLocalTypes.performWithTypes(expressionTypes -> {
PsiMethod method = getElement();
boolean varargs = isVarargs();
for (PsiExpression context : expressions) {
- expressionTypes.put(context,
- PsiTypesUtil.getTypeByMethod(context, argumentList, method, varargs, substitutor, false));
+ expressionTypes.forceType(context,
+ PsiTypesUtil.getTypeByMethod(context, argumentList, method, varargs, substitutor, false));
}
return computable.compute();
- }
- finally {
- for (PsiExpression context : expressions) {
- expressionTypes.remove(context);
- }
- }
+ });
}
else {
return computable.compute();
}
PsiType upperBound = psiClass instanceof PsiTypeParameter ? TypeConversionUtil.getInferredUpperBoundForSynthetic((PsiTypeParameter)psiClass) : null;
if (upperBound == null && psiClass instanceof PsiTypeParameter) {
- upperBound = LambdaUtil.getFunctionalTypeMap().get(psiClass);
+ upperBound = ThreadLocalTypes.getElementType(psiClass);
}
if (upperBound instanceof PsiIntersectionType) {
final PsiType[] conjuncts = ((PsiIntersectionType)upperBound).getConjuncts();
}
PsiType upperBound = psiClass instanceof PsiTypeParameter ? TypeConversionUtil.getInferredUpperBoundForSynthetic((PsiTypeParameter)psiClass) : null;
if (upperBound == null && psiClass instanceof PsiTypeParameter) {
- upperBound = LambdaUtil.getFunctionalTypeMap().get(psiClass);
+ upperBound = ThreadLocalTypes.getElementType(psiClass);
}
if (upperBound instanceof PsiIntersectionType) {
final PsiType[] conjuncts = ((PsiIntersectionType)upperBound).getConjuncts();
}
if (element instanceof PsiMethodReferenceExpression) {
// method refs: do not cache results during parent conflict resolving, acceptable checks, etc
- if (LambdaUtil.getFunctionalTypeMap().containsKey(element)) {
+ if (ThreadLocalTypes.hasBindingFor(element)) {
return (JavaResolveResult[])resolver.resolve(element, psiFile, incompleteCode);
}
}
private PsiSubstitutor myInferenceSubstitution = PsiSubstitutor.EMPTY;
private PsiSubstitutor myRestoreNameSubstitution = PsiSubstitutor.EMPTY;
private MethodCandidateInfo myCurrentMethod;
+ private ThreadLocalTypes myTempTypes;
public InferenceSession(InitialInferenceState initialState, ParameterTypeInferencePolicy policy) {
myContext = initialState.getContext();
@Nullable PsiElement parent,
@Nullable MethodCandidateInfo currentMethod,
@NotNull PsiSubstitutor initialSubstitutor) {
- try {
- doInfer(parameters, args, parent, currentMethod, initialSubstitutor);
- return prepareSubstitution();
- }
- finally {
- for (ConstraintFormula formula : myConstraintsCopy) {
- if (formula instanceof InputOutputConstraintFormula) {
- LambdaUtil.getFunctionalTypeMap().remove(((InputOutputConstraintFormula)formula).getExpression());
- }
- }
-
- if (currentMethod != null) {
- if (myErrorMessages != null) {
- currentMethod.setApplicabilityError(StringUtil.join(myErrorMessages, "\n"));
+ return ThreadLocalTypes.performWithTypes(types -> {
+ myTempTypes = types;
+ try {
+ doInfer(parameters, args, parent, currentMethod, initialSubstitutor);
+ return prepareSubstitution();
+ }
+ finally {
+ if (currentMethod != null) {
+ if (myErrorMessages != null) {
+ currentMethod.setApplicabilityError(StringUtil.join(myErrorMessages, "\n"));
+ }
+ currentMethod.setErased(myErased);
}
- currentMethod.setErased(myErased);
+ myTempTypes = null;
}
- }
+ });
}
private void doInfer(@Nullable PsiParameter[] parameters,
}
private static boolean isPertinentToApplicabilityCheckOnContainingCall(@NotNull PsiElement parent) {
- return LambdaUtil.getFunctionalTypeMap().containsKey(parent);
+ return ThreadLocalTypes.hasBindingFor(parent);
}
private void collectAdditionalConstraints(PsiParameter[] parameters,
final PsiExpressionList argumentList = ((PsiCall)gParent).getArgumentList();
if (argumentList != null) {
if (MethodCandidateInfo.isOverloadCheck(argumentList)) {
- return LambdaUtil.getFunctionalTypeMap().get(context);
+ return ThreadLocalTypes.getElementType(context);
}
final JavaResolveResult result = PsiDiamondType.getDiamondsAwareResolveResult((PsiCall)gParent);
//at this time, types of interface method parameter types must be already calculated
// that's why walkUp in InferenceSessionContainer stops at this point and
//that's why we can reuse this type here
- final PsiType cachedLambdaType = LambdaUtil.getFunctionalTypeMap().get(lambdaExpression);
+ PsiType cachedLambdaType = ThreadLocalTypes.getElementType(lambdaExpression);
if (cachedLambdaType != null) {
return LambdaUtil.getFunctionalInterfaceReturnType(lambdaExpression.getGroundTargetType(cachedLambdaType));
}
PsiSubstitutor substitutor,
Set<ConstraintFormula> ignoredConstraints) {
formula.apply(substitutor, true);
+ if (formula instanceof InputOutputConstraintFormula) {
+ myTempTypes.forceType(((InputOutputConstraintFormula)formula).getExpression(),
+ ((InputOutputConstraintFormula)formula).getCurrentType());
+ }
addConstraint(formula);
if (!repeatInferencePhases()) {
if (formula instanceof ExpressionCompatibilityConstraint) {
PsiExpression expression = ((ExpressionCompatibilityConstraint)formula).getExpression();
if (expression instanceof PsiLambdaExpression) {
- PsiType parameterType = ((PsiLambdaExpression)expression).getGroundTargetType(((ExpressionCompatibilityConstraint)formula).getT());
+ PsiType parameterType = ((PsiLambdaExpression)expression).getGroundTargetType(((ExpressionCompatibilityConstraint)formula).getCurrentType());
collectLambdaReturnExpression(additionalConstraints, ignoredConstraints, (PsiLambdaExpression)expression, parameterType, !isProperType(parameterType), substitutor);
}
}
package com.intellij.psi.impl.source.resolve.graphInference.constraints;
import com.intellij.codeInsight.ExceptionUtil;
-import com.intellij.openapi.diagnostic.Logger;
import com.intellij.psi.*;
import com.intellij.psi.impl.source.resolve.graphInference.FunctionalInterfaceParameterizationUtil;
import com.intellij.psi.impl.source.resolve.graphInference.InferenceSession;
import java.util.Set;
public class CheckedExceptionCompatibilityConstraint extends InputOutputConstraintFormula {
- private static final Logger LOG = Logger.getInstance(CheckedExceptionCompatibilityConstraint.class);
private final PsiExpression myExpression;
- private PsiType myT;
public CheckedExceptionCompatibilityConstraint(PsiExpression expression, PsiType t) {
+ super(t);
myExpression = expression;
- myT = t;
}
@Override
if (!PsiPolyExpressionUtil.isPolyExpression(myExpression)) {
return true;
}
+ PsiType myT = getCurrentType();
if (myExpression instanceof PsiParenthesizedExpression) {
constraints.add(new CheckedExceptionCompatibilityConstraint(((PsiParenthesizedExpression)myExpression).getExpression(), myT));
return true;
final PsiElement body = myExpression instanceof PsiLambdaExpression ? ((PsiLambdaExpression)myExpression).getBody() : myExpression;
if (body != null) {
final List<PsiClassType> exceptions = ExceptionUtil.getUnhandledExceptions(new PsiElement[] {body});
- if (exceptions != null) {
- thrownTypes.addAll(ContainerUtil.filter(exceptions, type -> !ExceptionUtil.isUncheckedException(type)));
- }
+ thrownTypes.addAll(ContainerUtil.filter(exceptions, type -> !ExceptionUtil.isUncheckedException(type)));
}
if (expectedNonProperThrownTypes.isEmpty()) {
return myExpression;
}
- @Override
- protected PsiType getT() {
- return myT;
- }
-
- @Override
- protected void setT(PsiType t) {
- myT = t;
- }
-
@Override
protected InputOutputConstraintFormula createSelfConstraint(PsiType type, PsiExpression expression) {
return new CheckedExceptionCompatibilityConstraint(expression, type);
public class ExpressionCompatibilityConstraint extends InputOutputConstraintFormula {
private final PsiExpression myExpression;
- private PsiType myT;
public ExpressionCompatibilityConstraint(@NotNull PsiExpression expression, @NotNull PsiType type) {
+ super(type);
myExpression = expression;
- myT = type;
}
@Override
public boolean reduce(InferenceSession session, List<ConstraintFormula> constraints) {
+ PsiType myT = getCurrentType();
if (!PsiPolyExpressionUtil.isPolyExpression(myExpression)) {
PsiType exprType = myExpression.getType();
return myExpression;
}
- @Override
- public PsiType getT() {
- return myT;
- }
-
- @Override
- protected void setT(PsiType t) {
- myT = t;
- }
-
@Override
protected InputOutputConstraintFormula createSelfConstraint(PsiType type, PsiExpression expression) {
return new ExpressionCompatibilityConstraint(expression, type);
import java.util.stream.Stream;
public abstract class InputOutputConstraintFormula implements ConstraintFormula {
+ private PsiType myT;
+
+ protected InputOutputConstraintFormula(PsiType t) {
+ myT = t;
+ }
public abstract PsiExpression getExpression();
- protected abstract PsiType getT();
- protected abstract void setT(PsiType t);
protected abstract InputOutputConstraintFormula createSelfConstraint(PsiType type, PsiExpression expression);
protected abstract void collectReturnTypeVariables(InferenceSession session,
PsiExpression psiExpression,
public Set<InferenceVariable> getInputVariables(InferenceSession session) {
final PsiExpression psiExpression = getExpression();
- final PsiType type = getT();
+ final PsiType type = myT;
if (psiExpression instanceof PsiFunctionalExpression) {
final InferenceVariable inferenceVariable = session.getInferenceVariable(type);
if (inferenceVariable != null) {
@Nullable
public Set<InferenceVariable> getOutputVariables(Set<InferenceVariable> inputVariables, InferenceSession session) {
final HashSet<InferenceVariable> mentionedVariables = new HashSet<>();
- session.collectDependencies(getT(), mentionedVariables);
+ session.collectDependencies(myT, mentionedVariables);
if (inputVariables != null) {
mentionedVariables.removeAll(inputVariables);
}
@Override
public void apply(PsiSubstitutor substitutor, boolean cache) {
- setT(substitutor.substitute(getT()));
- if (cache) {
- LambdaUtil.getFunctionalTypeMap().put(getExpression(), getT());
- }
+ myT = substitutor.substitute(myT);
+ }
+
+ public PsiType getCurrentType() {
+ return myT;
}
@Override
public String toString() {
- return getExpression().getText() + " -> " + getT().getPresentableText();
+ return getExpression().getText() + " -> " + myT.getPresentableText();
}
}
MethodCandidateInfo.isOverloadCheck(parentArgList) &&
Arrays.stream(parentArgList.getExpressions())
.map(expression -> PsiUtil.skipParenthesizedExprDown(expression))
- .noneMatch(expression -> LambdaUtil.getFunctionalTypeMap().containsKey(expression));
+ .noneMatch(expression -> expression != null && ThreadLocalTypes.hasBindingFor(expression));
PsiType theOnly = null;
for (int i = 0; i < results.length; i++) {
class IntStream {
private void foo(IntStream s) {
- s.<error descr="Ambiguous method call: both 'IntStream.map(IntUnaryOperator)' and 'IntStream.map(ObjIntFunction<Object>)' match">map</error>(i -> 1 << i);
+ s.<error descr="Ambiguous method call: both 'IntStream.map(IntUnaryOperator)' and 'IntStream.map(ObjIntFunction<Object>)' match">map</error>(i -> <error descr="Operator '<<' cannot be applied to 'int', '<lambda parameter>'">1 << i</error>);
s.<error descr="Ambiguous method call: both 'IntStream.map(IntUnaryOperator)' and 'IntStream.map(ObjIntFunction<Object>)' match">map</error>(i -> 1);
s.<error descr="Ambiguous method call: both 'IntStream.map(IntUnaryOperator)' and 'IntStream.map(ObjIntFunction<Object>)' match">map</error>(i -> i);
}
}
void fooBar(IntStream1 instr){
- Supplier<Stream<Integer>> si = () -> instr.<error descr="Ambiguous method call: both 'IntStream1.map(IntFunction<Integer>)' and 'IntStream1.map(IntUnaryOperator)' match">map</error> ((i) -> (( <error descr="Operator '%' cannot be applied to '<lambda parameter>', 'int'">i % 2</error>) == 0) ? i : -i).boxed();
+ Supplier<Stream<Integer>> si = () -> instr.<error descr="Ambiguous method call: both 'IntStream1.map(IntFunction<Integer>)' and 'IntStream1.map(IntUnaryOperator)' match">map</error> ((i) -> (( <error descr="Operator '%' cannot be applied to '<lambda parameter>', 'int'">i % 2</error>) == 0) ? i : <error descr="Operator '-' cannot be applied to '<lambda parameter>'">-i</error>).boxed();
System.out.println(si);
Supplier<Stream<Integer>> si1 = () -> instr.map <error descr="Ambiguous method call: both 'IntStream1.map(IntFunction<Integer>)' and 'IntStream1.map(IntUnaryOperator)' match">(null)</error>.boxed();
System.out.println(si1);
foo(x -> {
return x += 1;
});
- <error descr="Ambiguous method call: both 'Test.foo(A)' and 'Test.foo(B)' match">foo</error>(x -> x += 1);
+ <error descr="Ambiguous method call: both 'Test.foo(A)' and 'Test.foo(B)' match">foo</error>(x -> <error descr="Incompatible types. Found: 'int', required: '<lambda parameter>'">x += 1</error>);
foo(x -> 1);
foo(x -> <error descr="Operator '!' cannot be applied to 'int'">!x</error>);
- <error descr="Ambiguous method call: both 'Test.foo(A)' and 'Test.foo(B)' match">foo</error>(x -> ++x);
+ <error descr="Ambiguous method call: both 'Test.foo(A)' and 'Test.foo(B)' match">foo</error>(x -> <error descr="Operator '++' cannot be applied to '<lambda parameter>'">++x</error>);
foo(x -> o instanceof String ? 1 : 0);
}
}