d6dd06ae2709329b0fc1d99dc8c9fd81684b939a
[idea/community.git] / python / src / com / jetbrains / python / codeInsight / intentions / PyBaseConvertCollectionLiteralIntention.java
1 /*
2  * Copyright 2000-2015 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.codeInsight.intentions;
17
18 import com.intellij.codeInsight.intention.impl.BaseIntentionAction;
19 import com.intellij.openapi.editor.Editor;
20 import com.intellij.openapi.project.Project;
21 import com.intellij.openapi.util.TextRange;
22 import com.intellij.psi.PsiElement;
23 import com.intellij.psi.PsiFile;
24 import com.intellij.psi.util.PsiTreeUtil;
25 import com.intellij.util.IncorrectOperationException;
26 import com.jetbrains.python.PyBundle;
27 import com.jetbrains.python.PyTokenTypes;
28 import com.jetbrains.python.psi.*;
29 import com.jetbrains.python.psi.impl.PyPsiUtils;
30 import org.jetbrains.annotations.Nls;
31 import org.jetbrains.annotations.NotNull;
32 import org.jetbrains.annotations.Nullable;
33
34 import static com.jetbrains.python.psi.PyUtil.as;
35
36 /**
37  * @author Mikhail Golubev
38  */
39 public abstract class PyBaseConvertCollectionLiteralIntention extends BaseIntentionAction {
40   private final Class<? extends PySequenceExpression> myTargetCollectionClass;
41   private final String myTargetCollectionName;
42   private final String myRightBrace;
43   private final String myLeftBrace;
44
45   public PyBaseConvertCollectionLiteralIntention(@NotNull Class<? extends PySequenceExpression> targetCollectionClass,
46                                                  @NotNull String targetCollectionName,
47                                                  @NotNull String leftBrace, @NotNull String rightBrace) {
48     myTargetCollectionClass = targetCollectionClass;
49     myTargetCollectionName = targetCollectionName;
50     myLeftBrace = leftBrace;
51     myRightBrace = rightBrace;
52   }
53
54   @Nls
55   @NotNull
56   @Override
57   public String getFamilyName() {
58     return PyBundle.message("INTN.convert.collection.literal.family", myTargetCollectionName);
59   }
60
61   @Override
62   public boolean isAvailable(@NotNull Project project, Editor editor, PsiFile file) {
63     if (!(file instanceof PyFile)) {
64       return false;
65     }
66     final PySequenceExpression literal = findCollectionLiteralUnderCaret(editor, file);
67     if (myTargetCollectionClass.isInstance(literal)) {
68       return false;
69     }
70     if (literal instanceof PyTupleExpression) {
71       setText(PyBundle.message("INTN.convert.collection.literal.text", "tuple", myTargetCollectionName));
72     }
73     else if (literal instanceof PyListLiteralExpression) {
74       setText(PyBundle.message("INTN.convert.collection.literal.text", "list", myTargetCollectionName));
75     }
76     else if (literal instanceof PySetLiteralExpression) {
77       setText(PyBundle.message("INTN.convert.collection.literal.text", "set", myTargetCollectionName));
78     }
79     else {
80       return false;
81     }
82     return isAvailableForCollection(literal);
83   }
84
85   protected boolean isAvailableForCollection(PySequenceExpression literal) {
86     return true;
87   }
88
89   @Override
90   public void invoke(@NotNull Project project, Editor editor, PsiFile file) throws IncorrectOperationException {
91     final PySequenceExpression literal = findCollectionLiteralUnderCaret(editor, file);
92     assert literal != null;
93
94     final PsiElement replacedElement = wrapCollection(literal);
95     final PsiElement copy = prepareOriginalElementCopy(replacedElement.copy());
96
97     final TextRange contentRange = getRangeOfContentWithoutBraces(copy);
98     final String contentToWrap = contentRange.substring(copy.getText());
99     final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
100     final PyExpression newLiteral = elementGenerator.createExpressionFromText(LanguageLevel.forElement(file),
101                                                                               myLeftBrace + contentToWrap + myRightBrace);
102     replacedElement.replace(newLiteral);
103   }
104
105   @NotNull
106   protected PsiElement prepareOriginalElementCopy(@NotNull PsiElement copy) {
107     final PySequenceExpression sequence = unwrapCollection(copy);
108     if (sequence instanceof PyTupleExpression) {
109       final PyExpression[] elements = sequence.getElements();
110       if (elements.length == 1) {
111         final PsiElement next = PyPsiUtils.getNextNonCommentSibling(elements[0], true);
112         // Strictly speaking single element tuple must contain trailing comma, but lets check explicitly nonetheless
113         if (next != null && next.getNode().getElementType() == PyTokenTypes.COMMA) {
114           next.delete();
115         }
116       }
117     }
118     return copy;
119   }
120
121   @NotNull
122   protected static PySequenceExpression unwrapCollection(@NotNull PsiElement literal) {
123     final PyParenthesizedExpression parenthesizedExpression = as(literal, PyParenthesizedExpression.class);
124     if (parenthesizedExpression != null) {
125       final PyExpression containedExpression = parenthesizedExpression.getContainedExpression();
126       assert containedExpression != null;
127       return (PyTupleExpression)containedExpression;
128     }
129     return (PySequenceExpression)literal;
130   }
131
132   @NotNull
133   protected static PsiElement wrapCollection(@NotNull PySequenceExpression literal) {
134     if (literal instanceof PyTupleExpression && literal.getParent() instanceof PyParenthesizedExpression) {
135       return literal.getParent();
136     }
137     return literal;
138   }
139
140   @NotNull
141   private static TextRange getRangeOfContentWithoutBraces(@NotNull PsiElement literal) {
142     if (literal instanceof PyTupleExpression) {
143       return TextRange.create(0, literal.getTextLength());
144     }
145
146     final String replacedText = literal.getText();
147     
148     final PsiElement firstChild = literal.getFirstChild();
149     final int contentStartOffset;
150     if (PyTokenTypes.OPEN_BRACES.contains(firstChild.getNode().getElementType())) {
151       contentStartOffset = firstChild.getTextLength();
152     }
153     else {
154       contentStartOffset = 0;
155     }
156
157     final PsiElement lastChild = literal.getLastChild();
158     final int contentEndOffset;
159     if (PyTokenTypes.CLOSE_BRACES.contains(lastChild.getNode().getElementType())) {
160       contentEndOffset = replacedText.length() - lastChild.getTextLength();
161     }
162     else {
163       contentEndOffset = replacedText.length();
164     }
165
166     return TextRange.create(contentStartOffset, contentEndOffset);
167   }
168
169   @Nullable
170   private static PySequenceExpression findCollectionLiteralUnderCaret(@NotNull Editor editor, @NotNull PsiFile psiFile) {
171     final int caretOffset = editor.getCaretModel().getOffset();
172     final PsiElement curElem = psiFile.findElementAt(caretOffset);
173     final PySequenceExpression seqExpr = PsiTreeUtil.getParentOfType(curElem, PySequenceExpression.class);
174     if (seqExpr != null) {
175       return seqExpr;
176     }
177     final PyParenthesizedExpression paren = (PyParenthesizedExpression)PsiTreeUtil.findFirstParent(curElem, element -> {
178       final PyParenthesizedExpression parenthesizedExpr = as(element, PyParenthesizedExpression.class);
179       return parenthesizedExpr != null && parenthesizedExpr.getContainedExpression() instanceof PyTupleExpression;
180     });
181     return paren != null ? ((PyTupleExpression)paren.getContainedExpression()) : null;
182   }
183 }