c990839fc06fc6c2eff9bb5526ce193fb0ee5799
[idea/community.git] / python / src / com / jetbrains / python / refactoring / convertTopLevelFunction / PyConvertLocalFunctionToTopLevelFunctionAction.java
1 /*
2  * Copyright 2000-2015 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.refactoring.convertTopLevelFunction;
17
18 import com.google.common.annotations.VisibleForTesting;
19 import com.intellij.codeInsight.controlflow.ControlFlow;
20 import com.intellij.codeInsight.controlflow.Instruction;
21 import com.intellij.openapi.actionSystem.CommonDataKeys;
22 import com.intellij.openapi.actionSystem.DataContext;
23 import com.intellij.openapi.command.WriteCommandAction;
24 import com.intellij.openapi.editor.Editor;
25 import com.intellij.openapi.project.Project;
26 import com.intellij.openapi.util.text.StringUtil;
27 import com.intellij.psi.PsiElement;
28 import com.intellij.psi.PsiFile;
29 import com.intellij.psi.util.PsiTreeUtil;
30 import com.intellij.refactoring.RefactoringActionHandler;
31 import com.intellij.refactoring.util.CommonRefactoringUtil;
32 import com.intellij.usageView.UsageInfo;
33 import com.intellij.util.IncorrectOperationException;
34 import com.intellij.util.containers.ContainerUtil;
35 import com.jetbrains.python.PyBundle;
36 import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
37 import com.jetbrains.python.codeInsight.controlflow.ReadWriteInstruction;
38 import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
39 import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
40 import com.jetbrains.python.psi.*;
41 import com.jetbrains.python.psi.impl.PyPsiUtils;
42 import com.jetbrains.python.psi.resolve.PyResolveContext;
43 import com.jetbrains.python.psi.types.TypeEvalContext;
44 import com.jetbrains.python.refactoring.PyBaseRefactoringAction;
45 import com.jetbrains.python.refactoring.PyRefactoringUtil;
46 import org.jetbrains.annotations.NotNull;
47 import org.jetbrains.annotations.Nullable;
48
49 import java.util.*;
50
51 import static com.jetbrains.python.psi.PyUtil.as;
52
53 /**
54  * @author Mikhail Golubev
55  */
56 public class PyConvertLocalFunctionToTopLevelFunctionAction extends PyBaseRefactoringAction {
57   public static final String ID = "py.convert.local.function.to.top.level.function";
58
59   @Override
60   protected boolean isAvailableInEditorOnly() {
61     return true;
62   }
63
64   @Override
65   protected boolean isEnabledOnElementInsideEditor(@NotNull PsiElement element,
66                                                    @NotNull Editor editor,
67                                                    @NotNull PsiFile file,
68                                                    @NotNull DataContext context) {
69     return findNestedFunction(element) != null;
70   }
71
72   @Override
73   protected boolean isEnabledOnElementsOutsideEditor(@NotNull PsiElement[] elements) {
74     return false;
75   }
76
77   @Nullable
78   private static PyFunction findNestedFunction(@NotNull PsiElement element) {
79     PyFunction result = null;
80     if (isLocalFunction(element)) {
81       result = (PyFunction)element;
82     }
83     else {
84       final PyReferenceExpression refExpr = PsiTreeUtil.getParentOfType(element, PyReferenceExpression.class);
85       if (refExpr == null) {
86         return null;
87       }
88       final PsiElement resolved = refExpr.getReference().resolve();
89       if (isLocalFunction(resolved)) {
90         result = (PyFunction)resolved;
91       }
92     }
93     return result;
94   }
95
96   @Nullable
97   @Override
98   protected RefactoringActionHandler getHandler(@NotNull DataContext dataContext) {
99     return new RefactoringActionHandler() {
100       @Override
101       public void invoke(@NotNull Project project, Editor editor, PsiFile file, DataContext dataContext) {
102         final PsiElement element = CommonDataKeys.PSI_ELEMENT.getData(dataContext);
103         if (element != null) {
104           escalateFunction(project, file, editor, element);
105         }
106       }
107
108       @Override
109       public void invoke(@NotNull Project project, @NotNull PsiElement[] elements, DataContext dataContext) {
110         final Editor editor = CommonDataKeys.EDITOR.getData(dataContext);
111         if (editor != null && elements.length == 1) {
112           escalateFunction(project, elements[0].getContainingFile(), editor, elements[0]);
113         }
114       }
115     };
116   }
117
118   private static boolean isLocalFunction(@Nullable PsiElement resolved) {
119     if (resolved instanceof PyFunction && PsiTreeUtil.getParentOfType(resolved, ScopeOwner.class, true) instanceof PyFunction) {
120       return true;
121     }
122     return false;
123   }
124
125   @VisibleForTesting
126   public void escalateFunction(@NotNull Project project,
127                                @NotNull PsiFile file,
128                                @NotNull final Editor editor,
129                                @NotNull PsiElement targetElement) throws IncorrectOperationException {
130     final PyResolveContext context = PyResolveContext.defaultContext().withTypeEvalContext(TypeEvalContext.userInitiated(project, file));
131     final PyFunction function = findNestedFunction(targetElement);
132     assert function != null;
133     final Set<String> enclosingScopeReads = new LinkedHashSet<String>();
134     final Collection<ScopeOwner> scopeOwners = PsiTreeUtil.collectElementsOfType(function, ScopeOwner.class);
135     for (ScopeOwner owner : scopeOwners) {
136       final AnalysisResult scope = findReadsFromEnclosingScope(owner, function, context);
137       if (!scope.nonlocalWritesToEnclosingScope.isEmpty()) {
138         final String errMsg = PyBundle.message("INTN.convert.local.function.to.top.level.function.nonlocal");
139         CommonRefactoringUtil.showErrorHint(project, editor, errMsg, null, ID);
140         return;
141       }
142       for (PsiElement element : scope.readFromEnclosingScope) {
143         if (element instanceof PyElement) {
144           ContainerUtil.addIfNotNull(enclosingScopeReads, ((PyElement)element).getName());
145         }
146       }
147     }
148
149     WriteCommandAction.runWriteCommandAction(project, new Runnable() {
150       @Override
151       public void run() {
152         updateUsagesAndFunction(editor, function, enclosingScopeReads);
153       }
154     });
155   }
156
157   private static void updateUsagesAndFunction(@NotNull Editor editor, 
158                                               @NotNull PyFunction targetFunction,
159                                               @NotNull Set<String> enclosingScopeReads) {
160     final String commaSeparatedNames = StringUtil.join(enclosingScopeReads, ", ");
161     final Project project = targetFunction.getProject();
162
163     // Update existing usages
164     final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
165     for (UsageInfo usage : PyRefactoringUtil.findUsages(targetFunction, false)) {
166       final PsiElement element = usage.getElement();
167       if (element != null) {
168         final PyCallExpression parentCall = as(element.getParent(), PyCallExpression.class);
169         if (parentCall != null) {
170           final PyArgumentList argList = parentCall.getArgumentList();
171           if (argList != null) {
172             final StringBuilder argListText = new StringBuilder(argList.getText());
173             argListText.insert(1, commaSeparatedNames + (argList.getArguments().length > 0 ? ", " : ""));
174             argList.replace(elementGenerator.createArgumentList(LanguageLevel.forElement(element), argListText.toString()));
175           }
176         }
177       }
178     }
179
180     // Replace function
181     PyFunction copiedFunction = (PyFunction)targetFunction.copy();
182     final PyParameterList paramList = copiedFunction.getParameterList();
183     final StringBuilder paramListText = new StringBuilder(paramList.getText());
184     paramListText.insert(1, commaSeparatedNames + (paramList.getParameters().length > 0 ? ", " : ""));
185     paramList.replace(elementGenerator.createParameterList(LanguageLevel.forElement(targetFunction), paramListText.toString()));
186
187     // See AddImportHelper.getFileInsertPosition()
188     final PsiFile file = targetFunction.getContainingFile();
189     final PsiElement anchor = PyPsiUtils.getParentRightBefore(targetFunction, file);
190
191     copiedFunction = (PyFunction)file.addAfter(copiedFunction, anchor);
192     targetFunction.delete();
193
194     editor.getSelectionModel().removeSelection();
195     editor.getCaretModel().moveToOffset(copiedFunction.getTextOffset());
196   }
197
198   @NotNull
199   private static AnalysisResult findReadsFromEnclosingScope(@NotNull ScopeOwner owner,
200                                                             @NotNull PyFunction targetFunction,
201                                                             @NotNull PyResolveContext context) {
202     final ControlFlow controlFlow = ControlFlowCache.getControlFlow(owner);
203     final List<PsiElement> readFromEnclosingScope = new ArrayList<PsiElement>();
204     final List<PyTargetExpression> nonlocalWrites = new ArrayList<PyTargetExpression>(); 
205     for (Instruction instruction : controlFlow.getInstructions()) {
206       if (instruction instanceof ReadWriteInstruction) {
207         final ReadWriteInstruction readWriteInstruction = (ReadWriteInstruction)instruction;
208         final PsiElement element = readWriteInstruction.getElement();
209         if (element == null) {
210           continue;
211         }
212         if (readWriteInstruction.getAccess().isReadAccess()) {
213           for (PsiElement resolved : PyUtil.multiResolveTopPriority(element, context)) {
214             if (resolved != null && isFromEnclosingScope(resolved, targetFunction)) {
215               readFromEnclosingScope.add(element);
216               break;
217             }
218           }
219         }
220         if (readWriteInstruction.getAccess().isWriteAccess()) {
221           if (element instanceof PyTargetExpression && element.getParent() instanceof PyNonlocalStatement) {
222             for (PsiElement resolved : PyUtil.multiResolveTopPriority(element, context)) {
223               if (resolved != null && isFromEnclosingScope(resolved, targetFunction)) {
224                 nonlocalWrites.add((PyTargetExpression)element);
225                 break;
226               }
227             }
228           }
229         }
230       }
231     }
232     return new AnalysisResult(readFromEnclosingScope, nonlocalWrites); 
233   }
234
235   private static class AnalysisResult {
236     final List<PsiElement> readFromEnclosingScope;
237     final List<PyTargetExpression> nonlocalWritesToEnclosingScope;
238
239     public AnalysisResult(@NotNull List<PsiElement> readFromEnclosingScope, @NotNull List<PyTargetExpression> nonlocalWrites) {
240       this.readFromEnclosingScope = readFromEnclosingScope;
241       this.nonlocalWritesToEnclosingScope = nonlocalWrites;
242     }
243   }
244
245   private static boolean isFromEnclosingScope(@NotNull PsiElement element, @NotNull PyFunction targetFunction) {
246     return !PsiTreeUtil.isAncestor(targetFunction, element, false) && !(ScopeUtil.getScopeOwner(element) instanceof PsiFile);
247   }
248 }