PY-6637 Use CommonRefactoringUtil#showErrorHint to display error
[idea/community.git] / python / src / com / jetbrains / python / refactoring / convertTopLevelFunction / PyConvertLocalFunctionToTopLevelFunctionAction.java
1 /*
2  * Copyright 2000-2015 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.refactoring.convertTopLevelFunction;
17
18 import com.google.common.annotations.VisibleForTesting;
19 import com.intellij.codeInsight.controlflow.ControlFlow;
20 import com.intellij.codeInsight.controlflow.Instruction;
21 import com.intellij.openapi.actionSystem.CommonDataKeys;
22 import com.intellij.openapi.actionSystem.DataContext;
23 import com.intellij.openapi.command.WriteCommandAction;
24 import com.intellij.openapi.editor.Editor;
25 import com.intellij.openapi.project.Project;
26 import com.intellij.openapi.util.text.StringUtil;
27 import com.intellij.psi.PsiElement;
28 import com.intellij.psi.PsiFile;
29 import com.intellij.psi.util.PsiTreeUtil;
30 import com.intellij.refactoring.RefactoringActionHandler;
31 import com.intellij.refactoring.util.CommonRefactoringUtil;
32 import com.intellij.usageView.UsageInfo;
33 import com.intellij.util.IncorrectOperationException;
34 import com.intellij.util.containers.ContainerUtil;
35 import com.jetbrains.python.PyBundle;
36 import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
37 import com.jetbrains.python.codeInsight.controlflow.ReadWriteInstruction;
38 import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
39 import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
40 import com.jetbrains.python.psi.*;
41 import com.jetbrains.python.psi.impl.PyPsiUtils;
42 import com.jetbrains.python.psi.resolve.PyResolveContext;
43 import com.jetbrains.python.psi.types.TypeEvalContext;
44 import com.jetbrains.python.refactoring.PyBaseRefactoringAction;
45 import com.jetbrains.python.refactoring.PyRefactoringUtil;
46 import org.jetbrains.annotations.NotNull;
47 import org.jetbrains.annotations.Nullable;
48
49 import java.util.*;
50
51 import static com.jetbrains.python.psi.PyUtil.as;
52
53 /**
54  * @author Mikhail Golubev
55  */
56 public class PyConvertLocalFunctionToTopLevelFunctionAction extends PyBaseRefactoringAction {
57   public static final String ID = "py.convert.local.function.to.top.level.function";
58
59   @Override
60   protected boolean isAvailableInEditorOnly() {
61     return true;
62   }
63
64   @Override
65   protected boolean isEnabledOnElementInsideEditor(@NotNull PsiElement element,
66                                                    @NotNull Editor editor,
67                                                    @NotNull PsiFile file,
68                                                    @NotNull DataContext context) {
69     return findNestedFunction(element) != null;
70   }
71
72   @Override
73   protected boolean isEnabledOnElementsOutsideEditor(@NotNull PsiElement[] elements) {
74     return false;
75   }
76
77   @Nullable
78   private static PyFunction findNestedFunction(@NotNull PsiElement element) {
79     PyFunction result = null;
80     if (isLocalFunction(element)) {
81       result = (PyFunction)element;
82     }
83     else {
84       final PyReferenceExpression refExpr = PsiTreeUtil.getParentOfType(element, PyReferenceExpression.class);
85       if (refExpr == null) {
86         return null;
87       }
88       final PsiElement resolved = refExpr.getReference().resolve();
89       if (isLocalFunction(resolved)) {
90         result = (PyFunction)resolved;
91       }
92     }
93     //if (result != null) {
94     //  final VirtualFile virtualFile = result.getContainingFile().getVirtualFile();
95     //  if (virtualFile != null && ProjectRootManager.getInstance(element.getProject()).getFileIndex().isInLibraryClasses(virtualFile)) {
96     //    return null;
97     //  }
98     //}
99     return result;
100   }
101
102   @Nullable
103   @Override
104   protected RefactoringActionHandler getHandler(@NotNull DataContext dataContext) {
105     return new RefactoringActionHandler() {
106       @Override
107       public void invoke(@NotNull Project project, Editor editor, PsiFile file, DataContext dataContext) {
108         final PsiElement element = CommonDataKeys.PSI_ELEMENT.getData(dataContext);
109         if (element != null) {
110           escalateFunction(project, file, editor, element);
111         }
112       }
113
114       @Override
115       public void invoke(@NotNull Project project, @NotNull PsiElement[] elements, DataContext dataContext) {
116         final Editor editor = CommonDataKeys.EDITOR.getData(dataContext);
117         if (editor != null && elements.length == 1) {
118           escalateFunction(project, elements[0].getContainingFile(), editor, elements[0]);
119         }
120       }
121     };
122   }
123
124   private static boolean isLocalFunction(@Nullable PsiElement resolved) {
125     if (resolved instanceof PyFunction && PsiTreeUtil.getParentOfType(resolved, ScopeOwner.class, true) instanceof PyFunction) {
126       return true;
127     }
128     return false;
129   }
130
131   @VisibleForTesting
132   public void escalateFunction(@NotNull Project project,
133                                @NotNull PsiFile file,
134                                @NotNull final Editor editor,
135                                @NotNull PsiElement targetElement) throws IncorrectOperationException {
136     final PyResolveContext context = PyResolveContext.defaultContext().withTypeEvalContext(TypeEvalContext.userInitiated(project, file));
137     final PyFunction function = findNestedFunction(targetElement);
138     assert function != null;
139     final Set<String> enclosingScopeReads = new LinkedHashSet<String>();
140     final Collection<ScopeOwner> scopeOwners = PsiTreeUtil.collectElementsOfType(function, ScopeOwner.class);
141     for (ScopeOwner owner : scopeOwners) {
142       final AnalysisResult scope = findReadsFromEnclosingScope(owner, function, context);
143       if (!scope.nonlocalWritesToEnclosingScope.isEmpty()) {
144         final String errMsg = PyBundle.message("INTN.convert.local.function.to.top.level.function.nonlocal");
145         CommonRefactoringUtil.showErrorHint(project, editor, errMsg, null, ID);
146         return;
147       }
148       for (PsiElement element : scope.readFromEnclosingScope) {
149         if (element instanceof PyElement) {
150           ContainerUtil.addIfNotNull(enclosingScopeReads, ((PyElement)element).getName());
151         }
152       }
153     }
154
155     WriteCommandAction.runWriteCommandAction(project, new Runnable() {
156       @Override
157       public void run() {
158         updateUsagesAndFunction(editor, function, enclosingScopeReads);
159       }
160     });
161   }
162
163   private static void updateUsagesAndFunction(@NotNull Editor editor, 
164                                               @NotNull PyFunction targetFunction,
165                                               @NotNull Set<String> enclosingScopeReads) {
166     final String commaSeparatedNames = StringUtil.join(enclosingScopeReads, ", ");
167     final Project project = targetFunction.getProject();
168
169     // Update existing usages
170     final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
171     for (UsageInfo usage : PyRefactoringUtil.findUsages(targetFunction, false)) {
172       final PsiElement element = usage.getElement();
173       if (element != null) {
174         final PyCallExpression parentCall = as(element.getParent(), PyCallExpression.class);
175         if (parentCall != null) {
176           final PyArgumentList argList = parentCall.getArgumentList();
177           if (argList != null) {
178             final StringBuilder argListText = new StringBuilder(argList.getText());
179             argListText.insert(1, commaSeparatedNames + (argList.getArguments().length > 0 ? ", " : ""));
180             argList.replace(elementGenerator.createArgumentList(LanguageLevel.forElement(element), argListText.toString()));
181           }
182         }
183       }
184     }
185
186     // Replace function
187     PyFunction copiedFunction = (PyFunction)targetFunction.copy();
188     final PyParameterList paramList = copiedFunction.getParameterList();
189     final StringBuilder paramListText = new StringBuilder(paramList.getText());
190     paramListText.insert(1, commaSeparatedNames + (paramList.getParameters().length > 0 ? ", " : ""));
191     paramList.replace(elementGenerator.createParameterList(LanguageLevel.forElement(targetFunction), paramListText.toString()));
192
193     // See AddImportHelper.getFileInsertPosition()
194     final PsiFile file = targetFunction.getContainingFile();
195     final PsiElement anchor = PyPsiUtils.getParentRightBefore(targetFunction, file);
196
197     copiedFunction = (PyFunction)file.addAfter(copiedFunction, anchor);
198     targetFunction.delete();
199
200     editor.getSelectionModel().removeSelection();
201     editor.getCaretModel().moveToOffset(copiedFunction.getTextOffset());
202   }
203
204   @NotNull
205   private static AnalysisResult findReadsFromEnclosingScope(@NotNull ScopeOwner owner,
206                                                             @NotNull PyFunction targetFunction,
207                                                             @NotNull PyResolveContext context) {
208     final ControlFlow controlFlow = ControlFlowCache.getControlFlow(owner);
209     final List<PsiElement> readFromEnclosingScope = new ArrayList<PsiElement>();
210     final List<PyTargetExpression> nonlocalWrites = new ArrayList<PyTargetExpression>(); 
211     for (Instruction instruction : controlFlow.getInstructions()) {
212       if (instruction instanceof ReadWriteInstruction) {
213         final ReadWriteInstruction readWriteInstruction = (ReadWriteInstruction)instruction;
214         final PsiElement element = readWriteInstruction.getElement();
215         if (element == null) {
216           continue;
217         }
218         if (readWriteInstruction.getAccess().isReadAccess()) {
219           for (PsiElement resolved : PyUtil.multiResolveTopPriority(element, context)) {
220             if (resolved != null && isFromEnclosingScope(resolved, targetFunction)) {
221               readFromEnclosingScope.add(element);
222               break;
223             }
224           }
225         }
226         if (readWriteInstruction.getAccess().isWriteAccess()) {
227           if (element instanceof PyTargetExpression && element.getParent() instanceof PyNonlocalStatement) {
228             for (PsiElement resolved : PyUtil.multiResolveTopPriority(element, context)) {
229               if (resolved != null && isFromEnclosingScope(resolved, targetFunction)) {
230                 nonlocalWrites.add((PyTargetExpression)element);
231                 break;
232               }
233             }
234           }
235         }
236       }
237     }
238     return new AnalysisResult(readFromEnclosingScope, nonlocalWrites); 
239   }
240
241   private static class AnalysisResult {
242     final List<PsiElement> readFromEnclosingScope;
243     final List<PyTargetExpression> nonlocalWritesToEnclosingScope;
244
245     public AnalysisResult(@NotNull List<PsiElement> readFromEnclosingScope, @NotNull List<PyTargetExpression> nonlocalWrites) {
246       this.readFromEnclosingScope = readFromEnclosingScope;
247       this.nonlocalWritesToEnclosingScope = nonlocalWrites;
248     }
249   }
250
251   private static boolean isFromEnclosingScope(@NotNull PsiElement element, @NotNull PyFunction targetFunction) {
252     return !PsiTreeUtil.isAncestor(targetFunction, element, false) && !(ScopeUtil.getScopeOwner(element) instanceof PsiFile);
253   }
254 }