2e42facb8d6902b928eb2cb0d42e7a4777d45272
[idea/community.git] / python / src / com / jetbrains / python / refactoring / move / makeFunctionTopLevel / PyMakeMethodTopLevelProcessor.java
1 /*
2  * Copyright 2000-2015 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.refactoring.move.makeFunctionTopLevel;
17
18 import com.google.common.collect.Lists;
19 import com.intellij.openapi.util.Comparing;
20 import com.intellij.openapi.util.text.StringUtil;
21 import com.intellij.psi.PsiElement;
22 import com.intellij.psi.util.PsiTreeUtil;
23 import com.intellij.usageView.UsageInfo;
24 import com.intellij.util.Function;
25 import com.intellij.util.IncorrectOperationException;
26 import com.intellij.util.containers.ContainerUtil;
27 import com.intellij.util.containers.HashSet;
28 import com.intellij.util.containers.MultiMap;
29 import com.jetbrains.python.PyBundle;
30 import com.jetbrains.python.PyNames;
31 import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
32 import com.jetbrains.python.psi.*;
33 import com.jetbrains.python.refactoring.PyRefactoringUtil;
34 import org.jetbrains.annotations.NotNull;
35
36 import java.util.*;
37
38 import static com.jetbrains.python.psi.PyUtil.as;
39
40 /**
41  * @author Mikhail Golubev
42  */
43 public class PyMakeMethodTopLevelProcessor extends PyBaseMakeFunctionTopLevelProcessor {
44
45   private final LinkedHashMap<String, String> myAttributeToParameterName = new LinkedHashMap<>();
46   private final MultiMap<String, PyReferenceExpression> myAttributeReferences = MultiMap.create();
47   private final Set<PsiElement> myReadsOfSelfParam = new HashSet<>();
48
49   public PyMakeMethodTopLevelProcessor(@NotNull PyFunction targetFunction, @NotNull String destination) {
50     super(targetFunction, destination);
51   }
52
53   @NotNull
54   @Override
55   protected String getRefactoringName() {
56     return PyBundle.message("refactoring.make.method.top.level.dialog.title");
57   }
58
59   @Override
60   protected void updateUsages(@NotNull Collection<String> newParamNames, @NotNull UsageInfo[] usages) {
61     // Field usages
62     for (String attrName : myAttributeReferences.keySet()) {
63       final Collection<PyReferenceExpression> reads = myAttributeReferences.get(attrName);
64       final String paramName = myAttributeToParameterName.get(attrName);
65       if (!attrName.equals(paramName)) {
66         for (PyReferenceExpression read : reads) {
67           read.replace(myGenerator.createExpressionFromText(LanguageLevel.forElement(read), paramName));
68         }
69       }
70       else {
71         for (PyReferenceExpression read : reads) {
72           removeQualifier(read);
73         }
74       }
75     }
76
77     // Function usages
78     final Collection<String> attrNames = myAttributeToParameterName.keySet();
79     for (UsageInfo usage : usages) {
80       final PsiElement usageElem = usage.getElement();
81       if (usageElem == null) {
82         continue;
83       }
84
85       if (usageElem instanceof PyReferenceExpression) {
86         final PyExpression qualifier = ((PyReferenceExpression)usageElem).getQualifier();
87         final PyCallExpression callExpr = as(usageElem.getParent(), PyCallExpression.class);
88         if (qualifier != null && callExpr != null && callExpr.getArgumentList() != null) {
89           PyExpression instanceExpr = qualifier;
90           final PyArgumentList argumentList = callExpr.getArgumentList();
91           
92           // Class.method(instance) -> method(instance)
93           if (resolvesToClass(qualifier)) {
94             final PyExpression[] arguments = argumentList.getArguments();
95             if (arguments.length > 0) {
96               instanceExpr = arguments[0];
97               instanceExpr.delete();
98             }
99             else {
100               // It's not clear how to handle usages like Class.method(), since there is no suitable instance
101               instanceExpr = null;
102             }
103           }
104
105           if (instanceExpr != null) {
106             // module.inst.method() -> method(module.inst.foo, module.inst.bar)
107             if (isPureReferenceExpression(instanceExpr)) {
108               // recursive call inside the method
109               if (myReadsOfSelfParam.contains(instanceExpr)) {
110                 addArguments(argumentList, newParamNames);
111               }
112               else {
113                 final String instanceExprText = instanceExpr.getText();
114                 addArguments(argumentList, ContainerUtil.map(attrNames, attribute -> instanceExprText + "." + attribute));
115               }
116             }
117             // Class().method() -> method(Class().foo)
118             else if (newParamNames.size() == 1) {
119               addArguments(argumentList, Collections.singleton(instanceExpr.getText() + "." + ContainerUtil.getFirstItem(attrNames)));
120             }
121             // Class().method() -> inst = Class(); method(inst.foo, inst.bar)
122             else if (!newParamNames.isEmpty()) {
123               final PyStatement anchor = PsiTreeUtil.getParentOfType(callExpr, PyStatement.class);
124               //noinspection ConstantConditions
125               final String className = StringUtil.notNullize(myFunction.getContainingClass().getName(), PyNames.OBJECT);
126               final String targetName = PyRefactoringUtil.selectUniqueNameFromType(className, usageElem);
127               final String assignmentText = targetName + " = " + instanceExpr.getText();
128               final PyAssignmentStatement assignment = myGenerator.createFromText(LanguageLevel.forElement(callExpr),
129                                                                                   PyAssignmentStatement.class,
130                                                                                   assignmentText);
131               //noinspection ConstantConditions
132               anchor.getParent().addBefore(assignment, anchor);
133               addArguments(argumentList, ContainerUtil.map(attrNames, attribute -> targetName + "." + attribute));
134             }
135           }
136         }
137         
138         // Will replace/invalidate entire expression
139         removeQualifier((PyReferenceExpression)usageElem);
140       }
141     }
142   }
143
144   private boolean resolvesToClass(@NotNull PyExpression qualifier) {
145     for (PsiElement element : PyUtil.multiResolveTopPriority(qualifier, myResolveContext)) {
146       if (element == myFunction.getContainingClass()) {
147         return true;
148       }
149     }
150     return false;
151   }
152
153   private static boolean isPureReferenceExpression(@NotNull PyExpression expr) {
154     if (!(expr instanceof PyReferenceExpression)) {
155       return false;
156     }
157     final PyExpression qualifier = ((PyReferenceExpression)expr).getQualifier();
158     return qualifier == null || isPureReferenceExpression(qualifier);
159   }
160
161   @NotNull
162   private PyReferenceExpression removeQualifier(@NotNull PyReferenceExpression expr) {
163     if (!expr.isQualified()) {
164       return expr;
165     }
166     final PyExpression newExpression = myGenerator.createExpressionFromText(LanguageLevel.forElement(expr), expr.getLastChild().getText());
167     return (PyReferenceExpression)expr.replace(newExpression);
168   }
169
170   @NotNull
171   @Override
172   protected PyFunction createNewFunction(@NotNull Collection<String> newParams) {
173     final PyFunction copied = (PyFunction)myFunction.copy();
174     final PyParameter[] params = copied.getParameterList().getParameters();
175     if (params.length > 0) {
176       params[0].delete();
177     }
178     addParameters(copied.getParameterList(), newParams);
179     return copied;
180   }
181
182   @NotNull
183   @Override
184   protected List<String> collectNewParameterNames() {
185     final Set<String> attributeNames = new LinkedHashSet<>();
186     for (ScopeOwner owner : PsiTreeUtil.collectElementsOfType(myFunction, ScopeOwner.class)) {
187       final AnalysisResult result =  analyseScope(owner);
188       if (!result.nonlocalWritesToEnclosingScope.isEmpty()) {
189         throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.nonlocal.writes"));
190       }
191       if (!result.readsOfSelfParametersFromEnclosingScope.isEmpty()) {
192         throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.self.reads"));
193       }
194       if (!result.readsFromEnclosingScope.isEmpty()) {
195         throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.outer.scope.reads"));
196       }
197       if (!result.writesToSelfParameter.isEmpty()) {
198         throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.special.usage.of.self"));
199       }
200       myReadsOfSelfParam.addAll(result.readsOfSelfParameter);
201       for (PsiElement usage : result.readsOfSelfParameter) {
202         if (usage.getParent() instanceof PyTargetExpression) {
203           throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.attribute.writes"));
204         }
205         final PyReferenceExpression parentReference = as(usage.getParent(), PyReferenceExpression.class);
206         if (parentReference != null) {
207           final String attrName = parentReference.getName();
208           if (attrName != null && PyUtil.isClassPrivateName(attrName)) {
209             throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.private.attributes"));
210           }
211           if (parentReference.getParent() instanceof PyCallExpression) {
212             if (!(Comparing.equal(myFunction.getName(), parentReference.getName()))) {
213               throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.method.calls"));
214             }
215             else {
216               // do not add method itself to its parameters
217               continue;
218             }
219           }
220           attributeNames.add(attrName);
221           myAttributeReferences.putValue(attrName, parentReference);
222         }
223         else {
224           throw new IncorrectOperationException(PyBundle.message("refactoring.make.function.top.level.error.special.usage.of.self"));
225         }
226       }
227     }
228     for (String name : attributeNames) {
229       final Collection<PyReferenceExpression> reads = myAttributeReferences.get(name);
230       final PsiElement anchor = ContainerUtil.getFirstItem(reads);
231       //noinspection ConstantConditions
232       if (!PyRefactoringUtil.isValidNewName(name, anchor)) {
233         final String indexedName = PyRefactoringUtil.appendNumberUntilValid(name, anchor);
234         myAttributeToParameterName.put(name, indexedName);
235       }
236       else {
237         myAttributeToParameterName.put(name, name);
238       }
239     }
240     return Lists.newArrayList(myAttributeToParameterName.values());
241   }
242 }