2 * Copyright 2000-2015 JetBrains s.r.o.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 package com.jetbrains.python.refactoring.move.makeFunctionTopLevel;
18 import com.google.common.collect.Lists;
19 import com.intellij.openapi.util.Comparing;
20 import com.intellij.openapi.util.text.StringUtil;
21 import com.intellij.psi.PsiElement;
22 import com.intellij.psi.util.PsiTreeUtil;
23 import com.intellij.usageView.UsageInfo;
24 import com.intellij.util.IncorrectOperationException;
25 import com.intellij.util.containers.ContainerUtil;
26 import com.intellij.util.containers.HashSet;
27 import com.intellij.util.containers.MultiMap;
28 import com.jetbrains.python.PyBundle;
29 import com.jetbrains.python.PyNames;
30 import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
31 import com.jetbrains.python.psi.*;
32 import com.jetbrains.python.refactoring.PyRefactoringUtil;
33 import org.jetbrains.annotations.NotNull;
37 import static com.jetbrains.python.psi.PyUtil.as;
40 * @author Mikhail Golubev
42 public class PyMakeMethodTopLevelProcessor extends PyBaseMakeFunctionTopLevelProcessor {
44 private final LinkedHashMap<String, String> myAttributeToParameterName = new LinkedHashMap<>();
45 private final MultiMap<String, PyReferenceExpression> myAttributeReferences = MultiMap.create();
46 private final Set<PsiElement> myReadsOfSelfParam = new HashSet<>();
48 public PyMakeMethodTopLevelProcessor(@NotNull PyFunction targetFunction, @NotNull String destination) {
49 super(targetFunction, destination);
54 protected String getRefactoringName() {
55 return PyBundle.message("refactoring.make.method.top.level.dialog.title");
59 protected void updateUsages(@NotNull Collection<String> newParamNames, @NotNull UsageInfo[] usages) {
61 for (String attrName : myAttributeReferences.keySet()) {
62 final Collection<PyReferenceExpression> reads = myAttributeReferences.get(attrName);
63 final String paramName = myAttributeToParameterName.get(attrName);
64 if (!attrName.equals(paramName)) {
65 for (PyReferenceExpression read : reads) {
66 read.replace(myGenerator.createExpressionFromText(LanguageLevel.forElement(read), paramName));
70 for (PyReferenceExpression read : reads) {
71 removeQualifier(read);
77 final Collection<String> attrNames = myAttributeToParameterName.keySet();
78 for (UsageInfo usage : usages) {
79 final PsiElement usageElem = usage.getElement();
80 if (usageElem == null) {
84 if (usageElem instanceof PyReferenceExpression) {
85 final PyExpression qualifier = ((PyReferenceExpression)usageElem).getQualifier();
86 final PyCallExpression callExpr = as(usageElem.getParent(), PyCallExpression.class);
87 if (qualifier != null && callExpr != null && callExpr.getArgumentList() != null) {
88 PyExpression instanceExpr = qualifier;
89 final PyArgumentList argumentList = callExpr.getArgumentList();
91 // Class.method(instance) -> method(instance.attr)
92 if (resolvesToClass(qualifier)) {
93 final PyExpression[] arguments = argumentList.getArguments();
94 if (arguments.length > 0) {
95 instanceExpr = arguments[0];
96 instanceExpr.delete();
99 // It's not clear how to handle usages like Class.method(), since there is no suitable instance
104 if (instanceExpr != null) {
105 // module.inst.method() -> method(module.inst.foo, module.inst.bar)
106 if (isPureReferenceExpression(instanceExpr)) {
107 // recursive call inside the method
108 if (myReadsOfSelfParam.contains(instanceExpr)) {
109 addArguments(argumentList, newParamNames);
112 final String instanceExprText = instanceExpr.getText();
113 addArguments(argumentList, ContainerUtil.map(attrNames, attribute -> instanceExprText + "." + attribute));
116 // Class().method() -> method(Class().foo)
117 else if (newParamNames.size() == 1) {
118 addArguments(argumentList, Collections.singleton(instanceExpr.getText() + "." + ContainerUtil.getFirstItem(attrNames)));
120 // Class().method() -> inst = Class(); method(inst.foo, inst.bar)
121 else if (!newParamNames.isEmpty()) {
122 final PyStatement anchor = PsiTreeUtil.getParentOfType(callExpr, PyStatement.class);
123 //noinspection ConstantConditions
124 final String className = StringUtil.notNullize(myFunction.getContainingClass().getName(), PyNames.OBJECT);
125 final String targetName = PyRefactoringUtil.selectUniqueNameFromType(className, usageElem);
126 final String assignmentText = targetName + " = " + instanceExpr.getText();
127 final PyAssignmentStatement assignment = myGenerator.createFromText(LanguageLevel.forElement(callExpr),
128 PyAssignmentStatement.class,
130 //noinspection ConstantConditions
131 anchor.getParent().addBefore(assignment, anchor);
132 addArguments(argumentList, ContainerUtil.map(attrNames, attribute -> targetName + "." + attribute));
137 // Will replace/invalidate entire expression
138 removeQualifier((PyReferenceExpression)usageElem);
143 private boolean resolvesToClass(@NotNull PyExpression qualifier) {
144 for (PsiElement element : PyUtil.multiResolveTopPriority(qualifier, myResolveContext)) {
145 if (element == myFunction.getContainingClass()) {
152 private static boolean isPureReferenceExpression(@NotNull PyExpression expr) {
153 if (!(expr instanceof PyReferenceExpression)) {
156 final PyExpression qualifier = ((PyReferenceExpression)expr).getQualifier();
157 return qualifier == null || isPureReferenceExpression(qualifier);
161 private PyReferenceExpression removeQualifier(@NotNull PyReferenceExpression expr) {
162 if (!expr.isQualified()) {
165 final PyExpression newExpression = myGenerator.createExpressionFromText(LanguageLevel.forElement(expr), expr.getLastChild().getText());
166 return (PyReferenceExpression)expr.replace(newExpression);
171 protected PyFunction createNewFunction(@NotNull Collection<String> newParams) {
172 final PyFunction copied = (PyFunction)myFunction.copy();
173 final PyParameter[] params = copied.getParameterList().getParameters();
174 if (params.length > 0) {
177 addParameters(copied.getParameterList(), newParams);
183 protected List<String> collectNewParameterNames() {
184 final Set<String> attributeNames = new LinkedHashSet<>();
185 for (ScopeOwner owner : PsiTreeUtil.collectElementsOfType(myFunction, ScopeOwner.class)) {
186 final AnalysisResult result = analyseScope(owner);
187 if (!result.nonlocalWritesToEnclosingScope.isEmpty()) {
188 throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.nonlocal.writes"));
190 if (!result.readsOfSelfParametersFromEnclosingScope.isEmpty()) {
191 throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.self.reads"));
193 if (!result.readsFromEnclosingScope.isEmpty()) {
194 throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.outer.scope.reads"));
196 if (!result.writesToSelfParameter.isEmpty()) {
197 throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.special.usage.of.self"));
199 myReadsOfSelfParam.addAll(result.readsOfSelfParameter);
200 for (PsiElement usage : result.readsOfSelfParameter) {
201 if (usage.getParent() instanceof PyTargetExpression) {
202 throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.attribute.writes"));
204 final PyReferenceExpression parentReference = as(usage.getParent(), PyReferenceExpression.class);
205 if (parentReference != null) {
206 final String attrName = parentReference.getName();
207 if (attrName != null && PyUtil.isClassPrivateName(attrName)) {
208 throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.private.attributes"));
210 if (parentReference.getParent() instanceof PyCallExpression) {
211 if (!(Comparing.equal(myFunction.getName(), parentReference.getName()))) {
212 throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.method.calls"));
215 // do not add method itself to its parameters
219 attributeNames.add(attrName);
220 myAttributeReferences.putValue(attrName, parentReference);
223 throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.special.usage.of.self"));
227 for (String name : attributeNames) {
228 final Collection<PyReferenceExpression> reads = myAttributeReferences.get(name);
229 final PsiElement anchor = ContainerUtil.getFirstItem(reads);
230 //noinspection ConstantConditions
231 if (!PyRefactoringUtil.isValidNewName(name, anchor)) {
232 final String indexedName = PyRefactoringUtil.appendNumberUntilValid(name, anchor);
233 myAttributeToParameterName.put(name, indexedName);
236 myAttributeToParameterName.put(name, name);
239 return Lists.newArrayList(myAttributeToParameterName.values());