PY-6637 Use standard refactoring processor to collect and show usages
[idea/community.git] / python / src / com / jetbrains / python / refactoring / convertTopLevelFunction / PyMakeFunctionTopLevelProcessor.java
1 /*
2  * Copyright 2000-2015 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.refactoring.convertTopLevelFunction;
17
18 import com.intellij.codeInsight.controlflow.ControlFlow;
19 import com.intellij.codeInsight.controlflow.Instruction;
20 import com.intellij.openapi.application.ApplicationManager;
21 import com.intellij.openapi.editor.Editor;
22 import com.intellij.openapi.project.Project;
23 import com.intellij.openapi.util.text.StringUtil;
24 import com.intellij.psi.PsiElement;
25 import com.intellij.psi.PsiFile;
26 import com.intellij.psi.util.PsiTreeUtil;
27 import com.intellij.refactoring.BaseRefactoringProcessor;
28 import com.intellij.refactoring.ui.UsageViewDescriptorAdapter;
29 import com.intellij.usageView.UsageInfo;
30 import com.intellij.usageView.UsageViewDescriptor;
31 import com.intellij.util.ArrayUtil;
32 import com.intellij.util.IncorrectOperationException;
33 import com.intellij.util.containers.ContainerUtil;
34 import com.jetbrains.python.PyBundle;
35 import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
36 import com.jetbrains.python.codeInsight.controlflow.ReadWriteInstruction;
37 import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
38 import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
39 import com.jetbrains.python.psi.*;
40 import com.jetbrains.python.psi.impl.PyPsiUtils;
41 import com.jetbrains.python.psi.resolve.PyResolveContext;
42 import com.jetbrains.python.psi.types.TypeEvalContext;
43 import com.jetbrains.python.refactoring.PyRefactoringUtil;
44 import org.jetbrains.annotations.NotNull;
45
46 import java.util.*;
47
48 import static com.jetbrains.python.psi.PyUtil.as;
49
50 /**
51  * @author Mikhail Golubev
52  */
53 public class PyMakeFunctionTopLevelProcessor extends BaseRefactoringProcessor {
54   private final PyFunction myFunction;
55   private final PyResolveContext myContext;
56   private final Editor myEditor;
57
58   protected PyMakeFunctionTopLevelProcessor(@NotNull PyFunction targetFunction, @NotNull Editor editor) {
59     super(targetFunction.getProject());
60     myFunction = targetFunction;
61     myEditor = editor;
62     final TypeEvalContext typeEvalContext = TypeEvalContext.userInitiated(myProject, targetFunction.getContainingFile());
63     myContext = PyResolveContext.defaultContext().withTypeEvalContext(typeEvalContext);
64     setPreviewUsages(isForMethod());
65   }
66
67   private boolean isForMethod() {
68     return myFunction.getContainingClass() != null;
69   }
70
71   @NotNull
72   @Override
73   protected UsageViewDescriptor createUsageViewDescriptor(@NotNull UsageInfo[] usages) {
74     return new UsageViewDescriptorAdapter() {
75       @NotNull
76       @Override
77       public PsiElement[] getElements() {
78         return new PsiElement[] {myFunction};
79       }
80
81       @Override
82       public String getProcessedElementsHeader() {
83         return getRefactoringName();
84       }
85     };
86   }
87
88   @NotNull
89   @Override
90   protected UsageInfo[] findUsages() {
91     return ArrayUtil.toObjectArray(PyRefactoringUtil.findUsages(myFunction, false), UsageInfo.class);
92   }
93
94   @Override
95   protected String getCommandName() {
96     return getRefactoringName();
97   }
98
99   @NotNull
100   private String getRefactoringName() {
101     return isForMethod() ? PyBundle.message("refactoring.make.method.top.level")
102                          : PyBundle.message("refactoring.make.local.function.top.level");
103   }
104
105   @Override
106   protected void performRefactoring(@NotNull UsageInfo[] usages) {
107     if (isForMethod()) {
108       // TODO escalate method
109     }
110     else {
111       escalateLocalFunction(usages);
112     }
113
114   }
115
116   private void escalateLocalFunction(@NotNull UsageInfo[] usages) {
117     final Set<String> enclosingScopeReads = new LinkedHashSet<String>();
118     final Collection<ScopeOwner> scopeOwners = PsiTreeUtil.collectElementsOfType(myFunction, ScopeOwner.class);
119     for (ScopeOwner owner : scopeOwners) {
120       final PyMakeFunctionTopLevelProcessor.AnalysisResult scope = findReadsFromEnclosingScope(owner);
121       if (!scope.nonlocalWritesToEnclosingScope.isEmpty()) {
122         throw new IncorrectOperationException(PyBundle.message("refactoring.make.method.top.level.error.nonlocal.writes"));
123       }
124       for (PsiElement element : scope.readFromEnclosingScope) {
125         if (element instanceof PyElement) {
126           ContainerUtil.addIfNotNull(enclosingScopeReads, ((PyElement)element).getName());
127         }
128       }
129     }
130
131     assert ApplicationManager.getApplication().isWriteAccessAllowed();
132     updateLocalFunctionAndUsages(myEditor, enclosingScopeReads, usages);
133   }
134
135   private void updateLocalFunctionAndUsages(@NotNull Editor editor, @NotNull Set<String> enclosingScopeReads, UsageInfo[] usages) {
136     final String commaSeparatedNames = StringUtil.join(enclosingScopeReads, ", ");
137     final Project project = myFunction.getProject();
138
139     // Update existing usages
140     final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
141     for (UsageInfo usage : usages) {
142       final PsiElement element = usage.getElement();
143       if (element != null) {
144         final PyCallExpression parentCall = as(element.getParent(), PyCallExpression.class);
145         if (parentCall != null) {
146           final PyArgumentList argList = parentCall.getArgumentList();
147           if (argList != null) {
148             final StringBuilder argListText = new StringBuilder(argList.getText());
149             argListText.insert(1, commaSeparatedNames + (argList.getArguments().length > 0 ? ", " : ""));
150             argList.replace(elementGenerator.createArgumentList(LanguageLevel.forElement(element), argListText.toString()));
151           }
152         }
153       }
154     }
155
156     // Replace function
157     PyFunction copiedFunction = (PyFunction)myFunction.copy();
158     final PyParameterList paramList = copiedFunction.getParameterList();
159     final StringBuilder paramListText = new StringBuilder(paramList.getText());
160     paramListText.insert(1, commaSeparatedNames + (paramList.getParameters().length > 0 ? ", " : ""));
161     paramList.replace(elementGenerator.createParameterList(LanguageLevel.forElement(myFunction), paramListText.toString()));
162
163     // See AddImportHelper.getFileInsertPosition()
164     final PsiFile file = myFunction.getContainingFile();
165     final PsiElement anchor = PyPsiUtils.getParentRightBefore(myFunction, file);
166
167     copiedFunction = (PyFunction)file.addAfter(copiedFunction, anchor);
168     myFunction.delete();
169
170     editor.getSelectionModel().removeSelection();
171     editor.getCaretModel().moveToOffset(copiedFunction.getTextOffset());
172   }
173
174   @NotNull
175   private AnalysisResult findReadsFromEnclosingScope(@NotNull ScopeOwner owner) {
176     final ControlFlow controlFlow = ControlFlowCache.getControlFlow(owner);
177     final List<PsiElement> readFromEnclosingScope = new ArrayList<PsiElement>();
178     final List<PyTargetExpression> nonlocalWrites = new ArrayList<PyTargetExpression>();
179     for (Instruction instruction : controlFlow.getInstructions()) {
180       if (instruction instanceof ReadWriteInstruction) {
181         final ReadWriteInstruction readWriteInstruction = (ReadWriteInstruction)instruction;
182         final PsiElement element = readWriteInstruction.getElement();
183         if (element == null) {
184           continue;
185         }
186         if (readWriteInstruction.getAccess().isReadAccess()) {
187           for (PsiElement resolved : PyUtil.multiResolveTopPriority(element, myContext)) {
188             if (resolved != null && isFromEnclosingScope(resolved)) {
189               readFromEnclosingScope.add(element);
190               break;
191             }
192           }
193         }
194         if (readWriteInstruction.getAccess().isWriteAccess()) {
195           if (element instanceof PyTargetExpression && element.getParent() instanceof PyNonlocalStatement) {
196             for (PsiElement resolved : PyUtil.multiResolveTopPriority(element, myContext)) {
197               if (resolved != null && isFromEnclosingScope(resolved)) {
198                 nonlocalWrites.add((PyTargetExpression)element);
199                 break;
200               }
201             }
202           }
203         }
204       }
205     }
206     return new AnalysisResult(readFromEnclosingScope, nonlocalWrites);
207   }
208
209   private boolean isFromEnclosingScope(@NotNull PsiElement element) {
210     return !PsiTreeUtil.isAncestor(myFunction, element, false) && !(ScopeUtil.getScopeOwner(element) instanceof PsiFile);
211   }
212
213   static class AnalysisResult {
214     final List<PsiElement> readFromEnclosingScope;
215     final List<PyTargetExpression> nonlocalWritesToEnclosingScope;
216
217     public AnalysisResult(@NotNull List<PsiElement> readFromEnclosingScope, @NotNull List<PyTargetExpression> nonlocalWrites) {
218       this.readFromEnclosingScope = readFromEnclosingScope;
219       this.nonlocalWritesToEnclosingScope = nonlocalWrites;
220     }
221   }
222 }