PY-6637 Extract helper superclass PyBaseRefactoringAction
[idea/community.git] / python / src / com / jetbrains / python / refactoring / convertModulePackage / PyConvertPackageToModuleAction.java
1 package com.jetbrains.python.refactoring.convertModulePackage;
2
3 import com.google.common.annotations.VisibleForTesting;
4 import com.intellij.openapi.actionSystem.DataContext;
5 import com.intellij.openapi.command.WriteCommandAction;
6 import com.intellij.openapi.diagnostic.Logger;
7 import com.intellij.openapi.editor.Editor;
8 import com.intellij.openapi.module.Module;
9 import com.intellij.openapi.module.ModuleUtilCore;
10 import com.intellij.openapi.project.Project;
11 import com.intellij.openapi.vfs.VirtualFile;
12 import com.intellij.psi.PsiDirectory;
13 import com.intellij.psi.PsiElement;
14 import com.intellij.psi.PsiFile;
15 import com.intellij.refactoring.RefactoringActionHandler;
16 import com.intellij.refactoring.RefactoringBundle;
17 import com.intellij.refactoring.util.CommonRefactoringUtil;
18 import com.jetbrains.python.PyBundle;
19 import com.jetbrains.python.PyNames;
20 import com.jetbrains.python.psi.PyFile;
21 import com.jetbrains.python.psi.PyUtil;
22 import org.jetbrains.annotations.NotNull;
23 import org.jetbrains.annotations.Nullable;
24
25 import java.io.IOException;
26
27 import static com.jetbrains.python.psi.PyUtil.as;
28
29 /**
30  * @author Mikhail Golubev
31  */
32 public class PyConvertPackageToModuleAction extends PyBaseConvertModulePackageAction {
33   private static final Logger LOG = Logger.getInstance(PyConvertPackageToModuleAction.class);
34   private static final String ID = "py.refactoring.convert.package.to.module";
35
36   @Override
37   protected boolean isEnabledOnElements(@NotNull PsiElement[] elements) {
38     if (elements.length == 1) {
39       final PsiDirectory pyPackage = getPackageDir(elements[0]);
40       return pyPackage != null && !isSpecialDirectory(pyPackage);
41
42     }
43     return false;
44   }
45
46   @Nullable
47   private static PsiDirectory getPackageDir(@NotNull PsiElement elem) {
48     if (elem instanceof PsiDirectory && PyUtil.isPackage(((PsiDirectory)elem), null)) {
49       return (PsiDirectory)elem;
50     }
51     else if (elem instanceof PsiFile && PyUtil.isPackage(((PsiFile)elem))) {
52       return ((PsiFile)elem).getParent();
53     }
54     return null;
55   }
56
57   private static boolean isSpecialDirectory(@NotNull PsiDirectory element) {
58     final Module module = ModuleUtilCore.findModuleForPsiElement(element);
59     return module == null || (PyUtil.getSourceRoots(module).contains(element.getVirtualFile()));
60   }
61
62   @Nullable
63   @Override
64   protected RefactoringActionHandler getHandler(@NotNull DataContext dataContext) {
65     return new RefactoringActionHandler() {
66       @Override
67       public void invoke(@NotNull Project project, Editor editor, PsiFile file, DataContext dataContext) {
68         final PsiDirectory pyPackage = getPackageDir(file);
69         if (pyPackage != null) {
70           createModuleFromPackage(pyPackage);
71         }
72       }
73
74       @Override
75       public void invoke(@NotNull Project project, @NotNull PsiElement[] elements, DataContext dataContext) {
76         if (elements.length == 1) {
77           final PsiDirectory pyPackage = getPackageDir(elements[0]);
78           if (pyPackage != null) {
79             createModuleFromPackage(pyPackage);
80           }
81         }
82       }
83     };
84   }
85
86   @VisibleForTesting
87   public void createModuleFromPackage(@NotNull final PsiDirectory pyPackage) {
88     if (pyPackage.getParent() == null) {
89       return;
90     }
91
92     final String packageName = pyPackage.getName();
93     if (!isEmptyPackage(pyPackage)) {
94       CommonRefactoringUtil.showErrorMessage(RefactoringBundle.message("error.title"),
95                                              PyBundle.message("refactoring.convert.package.to.module.error.not.empty.package", packageName),
96                                              ID, pyPackage.getProject());
97       return;
98     }
99     final VirtualFile parentDirVFile = pyPackage.getParent().getVirtualFile();
100     final String moduleName = packageName + PyNames.DOT_PY;
101     final VirtualFile existing = parentDirVFile.findChild(moduleName);
102     if (existing != null) {
103       showFileExistsErrorMessage(existing, ID, pyPackage.getProject());
104       return;
105     }
106     final PsiFile initPy = pyPackage.findFile(PyNames.INIT_DOT_PY);
107     WriteCommandAction.runWriteCommandAction(pyPackage.getProject(), new Runnable() {
108       public void run() {
109         try {
110           if (initPy != null) {
111             final VirtualFile initPyVFile = initPy.getVirtualFile();
112             initPyVFile.rename(PyConvertPackageToModuleAction.this, moduleName);
113             initPyVFile.move(PyConvertPackageToModuleAction.this, parentDirVFile);
114           }
115           else {
116             PyUtil.getOrCreateFile(parentDirVFile.getPath() + "/" + moduleName, pyPackage.getProject());
117           }
118           pyPackage.getVirtualFile().delete(PyConvertPackageToModuleAction.this);
119         }
120         catch (IOException e) {
121           LOG.error(e);
122         }
123       }
124     });
125   }
126
127   private static boolean isEmptyPackage(@NotNull PsiDirectory pyPackage) {
128     final PsiElement[] children = pyPackage.getChildren();
129     if (children.length == 1) {
130       final PyFile onlyFile = as(children[0], PyFile.class);
131       return onlyFile != null && onlyFile.getName().equals(PyNames.INIT_DOT_PY);
132     }
133     return children.length == 0;
134   }
135 }