PY-6637 Turn intention into refactoring
[idea/community.git] / python / src / com / jetbrains / python / refactoring / convertTopLevelFunction / PyConvertLocalFunctionToTopLevelFunctionAction.java
1 /*
2  * Copyright 2000-2015 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.refactoring.convertTopLevelFunction;
17
18 import com.google.common.annotations.VisibleForTesting;
19 import com.intellij.codeInsight.controlflow.ControlFlow;
20 import com.intellij.codeInsight.controlflow.Instruction;
21 import com.intellij.openapi.actionSystem.CommonDataKeys;
22 import com.intellij.openapi.actionSystem.DataContext;
23 import com.intellij.openapi.command.WriteCommandAction;
24 import com.intellij.openapi.editor.Editor;
25 import com.intellij.openapi.project.Project;
26 import com.intellij.openapi.ui.MessageType;
27 import com.intellij.openapi.ui.popup.Balloon;
28 import com.intellij.openapi.ui.popup.JBPopupFactory;
29 import com.intellij.openapi.util.text.StringUtil;
30 import com.intellij.psi.PsiElement;
31 import com.intellij.psi.PsiFile;
32 import com.intellij.psi.util.PsiTreeUtil;
33 import com.intellij.refactoring.RefactoringActionHandler;
34 import com.intellij.usageView.UsageInfo;
35 import com.intellij.util.IncorrectOperationException;
36 import com.intellij.util.containers.ContainerUtil;
37 import com.jetbrains.python.PyBundle;
38 import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
39 import com.jetbrains.python.codeInsight.controlflow.ReadWriteInstruction;
40 import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
41 import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
42 import com.jetbrains.python.psi.*;
43 import com.jetbrains.python.psi.impl.PyPsiUtils;
44 import com.jetbrains.python.psi.resolve.PyResolveContext;
45 import com.jetbrains.python.psi.types.TypeEvalContext;
46 import com.jetbrains.python.refactoring.PyBaseRefactoringAction;
47 import com.jetbrains.python.refactoring.PyRefactoringUtil;
48 import org.jetbrains.annotations.NotNull;
49 import org.jetbrains.annotations.Nullable;
50
51 import java.util.*;
52
53 import static com.jetbrains.python.psi.PyUtil.as;
54
55 /**
56  * @author Mikhail Golubev
57  */
58 public class PyConvertLocalFunctionToTopLevelFunctionAction extends PyBaseRefactoringAction {
59
60   @Override
61   protected boolean isAvailableInEditorOnly() {
62     return true;
63   }
64
65   @Override
66   protected boolean isEnabledOnElements(@NotNull PsiElement[] elements) {
67     return elements.length == 1 && findNestedFunction(elements[0]) != null;
68   }
69
70   @VisibleForTesting
71   @Override
72   public boolean isAvailableOnElementInEditorAndFile(@NotNull PsiElement element,
73                                                         @NotNull Editor editor,
74                                                         @NotNull PsiFile file,
75                                                         @NotNull DataContext context) {
76     return findNestedFunction(element) != null;
77   }
78
79   @Nullable
80   private static PyFunction findNestedFunction(@NotNull PsiElement element) {
81     PyFunction result = null;
82     if (isLocalFunction(element)) {
83       result = (PyFunction)element;
84     }
85     else {
86       final PyReferenceExpression refExpr = PsiTreeUtil.getParentOfType(element, PyReferenceExpression.class);
87       if (refExpr == null) {
88         return null;
89       }
90       final PsiElement resolved = refExpr.getReference().resolve();
91       if (isLocalFunction(resolved)) {
92         result = (PyFunction)resolved;
93       }
94     }
95     //if (result != null) {
96     //  final VirtualFile virtualFile = result.getContainingFile().getVirtualFile();
97     //  if (virtualFile != null && ProjectRootManager.getInstance(element.getProject()).getFileIndex().isInLibraryClasses(virtualFile)) {
98     //    return null;
99     //  }
100     //}
101     return result;
102   }
103
104   @Nullable
105   @Override
106   protected RefactoringActionHandler getHandler(@NotNull DataContext dataContext) {
107     return new RefactoringActionHandler() {
108       @Override
109       public void invoke(@NotNull Project project, Editor editor, PsiFile file, DataContext dataContext) {
110         final PsiElement element = CommonDataKeys.PSI_ELEMENT.getData(dataContext);
111         if (element != null) {
112           escalateFunction(project, file, editor, element);
113         }
114       }
115
116       @Override
117       public void invoke(@NotNull Project project, @NotNull PsiElement[] elements, DataContext dataContext) {
118         final Editor editor = CommonDataKeys.EDITOR.getData(dataContext);
119         if (editor != null && elements.length == 1) {
120           escalateFunction(project, elements[0].getContainingFile(), editor, elements[0]);
121         }
122       }
123     };
124   }
125
126   private static boolean isLocalFunction(@Nullable PsiElement resolved) {
127     if (resolved instanceof PyFunction && PsiTreeUtil.getParentOfType(resolved, ScopeOwner.class, true) instanceof PyFunction) {
128       return true;
129     }
130     return false;
131   }
132
133   @VisibleForTesting
134   public void escalateFunction(@NotNull Project project,
135                                @NotNull PsiFile file,
136                                @NotNull final Editor editor,
137                                @NotNull PsiElement targetElement) throws IncorrectOperationException {
138     final PyResolveContext context = PyResolveContext.defaultContext().withTypeEvalContext(TypeEvalContext.userInitiated(project, file));
139     final PyFunction function = findNestedFunction(targetElement);
140     assert function != null;
141     final Set<String> enclosingScopeReads = new LinkedHashSet<String>();
142     final Collection<ScopeOwner> scopeOwners = PsiTreeUtil.collectElementsOfType(function, ScopeOwner.class);
143     for (ScopeOwner owner : scopeOwners) {
144       final AnalysisResult scope = findReadsFromEnclosingScope(owner, function, context);
145       if (!scope.nonlocalWritesToEnclosingScope.isEmpty()) {
146         showErrorBalloon(editor, PyBundle.message("INTN.convert.local.function.to.top.level.function.nonlocal"));
147         return;
148       }
149       for (PsiElement element : scope.readFromEnclosingScope) {
150         if (element instanceof PyElement) {
151           ContainerUtil.addIfNotNull(enclosingScopeReads, ((PyElement)element).getName());
152         }
153       }
154     }
155
156     WriteCommandAction.runWriteCommandAction(project, new Runnable() {
157       @Override
158       public void run() {
159         updateUsagesAndFunction(editor, function, enclosingScopeReads);
160       }
161     });
162   }
163
164   private static void updateUsagesAndFunction(@NotNull Editor editor, 
165                                               @NotNull PyFunction targetFunction,
166                                               @NotNull Set<String> enclosingScopeReads) {
167     final String commaSeparatedNames = StringUtil.join(enclosingScopeReads, ", ");
168     final Project project = targetFunction.getProject();
169
170     // Update existing usages
171     final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
172     for (UsageInfo usage : PyRefactoringUtil.findUsages(targetFunction, false)) {
173       final PsiElement element = usage.getElement();
174       if (element != null) {
175         final PyCallExpression parentCall = as(element.getParent(), PyCallExpression.class);
176         if (parentCall != null) {
177           final PyArgumentList argList = parentCall.getArgumentList();
178           if (argList != null) {
179             final StringBuilder argListText = new StringBuilder(argList.getText());
180             argListText.insert(1, commaSeparatedNames + (argList.getArguments().length > 0 ? ", " : ""));
181             argList.replace(elementGenerator.createArgumentList(LanguageLevel.forElement(element), argListText.toString()));
182           }
183         }
184       }
185     }
186
187     // Replace function
188     PyFunction copiedFunction = (PyFunction)targetFunction.copy();
189     final PyParameterList paramList = copiedFunction.getParameterList();
190     final StringBuilder paramListText = new StringBuilder(paramList.getText());
191     paramListText.insert(1, commaSeparatedNames + (paramList.getParameters().length > 0 ? ", " : ""));
192     paramList.replace(elementGenerator.createParameterList(LanguageLevel.forElement(targetFunction), paramListText.toString()));
193
194     // See AddImportHelper.getFileInsertPosition()
195     final PsiFile file = targetFunction.getContainingFile();
196     final PsiElement anchor = PyPsiUtils.getParentRightBefore(targetFunction, file);
197
198     copiedFunction = (PyFunction)file.addAfter(copiedFunction, anchor);
199     targetFunction.delete();
200
201     editor.getSelectionModel().removeSelection();
202     editor.getCaretModel().moveToOffset(copiedFunction.getTextOffset());
203   }
204
205   @NotNull
206   private static AnalysisResult findReadsFromEnclosingScope(@NotNull ScopeOwner owner,
207                                                             @NotNull PyFunction targetFunction,
208                                                             @NotNull PyResolveContext context) {
209     final ControlFlow controlFlow = ControlFlowCache.getControlFlow(owner);
210     final List<PsiElement> readFromEnclosingScope = new ArrayList<PsiElement>();
211     final List<PyTargetExpression> nonlocalWrites = new ArrayList<PyTargetExpression>(); 
212     for (Instruction instruction : controlFlow.getInstructions()) {
213       if (instruction instanceof ReadWriteInstruction) {
214         final ReadWriteInstruction readWriteInstruction = (ReadWriteInstruction)instruction;
215         final PsiElement element = readWriteInstruction.getElement();
216         if (element == null) {
217           continue;
218         }
219         if (readWriteInstruction.getAccess().isReadAccess()) {
220           for (PsiElement resolved : PyUtil.multiResolveTopPriority(element, context)) {
221             if (resolved != null && isFromEnclosingScope(resolved, targetFunction)) {
222               readFromEnclosingScope.add(element);
223               break;
224             }
225           }
226         }
227         if (readWriteInstruction.getAccess().isWriteAccess()) {
228           if (element instanceof PyTargetExpression && element.getParent() instanceof PyNonlocalStatement) {
229             for (PsiElement resolved : PyUtil.multiResolveTopPriority(element, context)) {
230               if (resolved != null && isFromEnclosingScope(resolved, targetFunction)) {
231                 nonlocalWrites.add((PyTargetExpression)element);
232                 break;
233               }
234             }
235           }
236         }
237       }
238     }
239     return new AnalysisResult(readFromEnclosingScope, nonlocalWrites); 
240   }
241
242   private static class AnalysisResult {
243     final List<PsiElement> readFromEnclosingScope;
244     final List<PyTargetExpression> nonlocalWritesToEnclosingScope;
245
246     public AnalysisResult(@NotNull List<PsiElement> readFromEnclosingScope, @NotNull List<PyTargetExpression> nonlocalWrites) {
247       this.readFromEnclosingScope = readFromEnclosingScope;
248       this.nonlocalWritesToEnclosingScope = nonlocalWrites;
249     }
250   }
251
252   private static boolean isFromEnclosingScope(@NotNull PsiElement element, @NotNull PyFunction targetFunction) {
253     return !PsiTreeUtil.isAncestor(targetFunction, element, false) && !(ScopeUtil.getScopeOwner(element) instanceof PsiFile);
254   }
255
256   private static void showErrorBalloon(@NotNull Editor editor, @NotNull String message) {
257     final JBPopupFactory popupFactory = JBPopupFactory.getInstance();
258     popupFactory.createHtmlTextBalloonBuilder(message, MessageType.ERROR, null)
259                 .setDisposable(editor.getProject())
260                 .createBalloon()
261                 .show(popupFactory.guessBestPopupLocation(editor), Balloon.Position.below);
262   }
263 }