Merge branch 'east825/py-move-to-toplevel'
[idea/community.git] / python / src / com / jetbrains / python / psi / impl / PyElementGeneratorImpl.java
1 /*
2  * Copyright 2000-2014 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.psi.impl;
17
18 import com.google.common.collect.Collections2;
19 import com.google.common.collect.Queues;
20 import com.intellij.lang.ASTNode;
21 import com.intellij.openapi.project.Project;
22 import com.intellij.openapi.util.Pair;
23 import com.intellij.openapi.util.text.StringUtil;
24 import com.intellij.openapi.vfs.CharsetToolkit;
25 import com.intellij.openapi.vfs.VirtualFile;
26 import com.intellij.psi.PsiElement;
27 import com.intellij.psi.PsiFile;
28 import com.intellij.psi.PsiFileFactory;
29 import com.intellij.psi.PsiWhiteSpace;
30 import com.intellij.psi.impl.PsiFileFactoryImpl;
31 import com.intellij.psi.impl.source.tree.LeafPsiElement;
32 import com.intellij.psi.tree.TokenSet;
33 import com.intellij.psi.util.PsiTreeUtil;
34 import com.intellij.testFramework.LightVirtualFile;
35 import com.intellij.util.IncorrectOperationException;
36 import com.jetbrains.NotNullPredicate;
37 import com.jetbrains.python.PyTokenTypes;
38 import com.jetbrains.python.PythonFileType;
39 import com.jetbrains.python.PythonLanguage;
40 import com.jetbrains.python.PythonStringUtil;
41 import com.jetbrains.python.psi.*;
42 import org.jetbrains.annotations.NotNull;
43 import org.jetbrains.annotations.Nullable;
44
45 import java.nio.charset.Charset;
46 import java.nio.charset.CharsetEncoder;
47 import java.util.Arrays;
48 import java.util.Deque;
49 import java.util.Formatter;
50
51 /**
52  * @author yole
53  */
54 public class PyElementGeneratorImpl extends PyElementGenerator {
55   private static final CommasOnly COMMAS_ONLY = new CommasOnly();
56   private final Project myProject;
57
58   public PyElementGeneratorImpl(Project project) {
59     myProject = project;
60   }
61
62   public ASTNode createNameIdentifier(String name, LanguageLevel languageLevel) {
63     final PsiFile dummyFile = createDummyFile(languageLevel, name);
64     final PyExpressionStatement expressionStatement = (PyExpressionStatement)dummyFile.getFirstChild();
65     final PyReferenceExpression refExpression = (PyReferenceExpression)expressionStatement.getFirstChild();
66
67     return refExpression.getNode().getFirstChildNode();
68   }
69
70   @Override
71   public PsiFile createDummyFile(LanguageLevel langLevel, String contents) {
72     return createDummyFile(langLevel, contents, false);
73   }
74
75   public PsiFile createDummyFile(LanguageLevel langLevel, String contents, boolean physical) {
76     final PsiFileFactory factory = PsiFileFactory.getInstance(myProject);
77     final String name = "dummy." + PythonFileType.INSTANCE.getDefaultExtension();
78     final LightVirtualFile virtualFile = new LightVirtualFile(name, PythonFileType.INSTANCE, contents);
79     virtualFile.putUserData(LanguageLevel.KEY, langLevel);
80     final PsiFile psiFile = ((PsiFileFactoryImpl)factory).trySetupPsiForFile(virtualFile, PythonLanguage.getInstance(), physical, true);
81     assert psiFile != null;
82     return psiFile;
83   }
84
85   public PyStringLiteralExpression createStringLiteralAlreadyEscaped(String str) {
86     final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), "a=(" + str + ")");
87     final PyAssignmentStatement expressionStatement = (PyAssignmentStatement)dummyFile.getFirstChild();
88     final PyExpression assignedValue = expressionStatement.getAssignedValue();
89     if (assignedValue != null) {
90       return (PyStringLiteralExpression)((PyParenthesizedExpression)assignedValue).getContainedExpression();
91     }
92     return createStringLiteralFromString(str);
93   }
94
95
96   @Override
97   public PyStringLiteralExpression createStringLiteralFromString(@NotNull String unescaped) {
98     return createStringLiteralFromString(null, unescaped, true);
99   }
100
101   public PyStringLiteralExpression createStringLiteral(@NotNull PyStringLiteralExpression oldElement, @NotNull String unescaped) {
102     Pair<String, String> quotes = PythonStringUtil.getQuotes(oldElement.getText());
103     if (quotes != null) {
104       return createStringLiteralAlreadyEscaped(quotes.first + unescaped + quotes.second);
105     }
106     else {
107       return createStringLiteralFromString(unescaped);
108     }
109   }
110
111
112   @Override
113   public PyStringLiteralExpression createStringLiteralFromString(@Nullable PsiFile destination,
114                                                                  @NotNull String unescaped,
115                                                                  final boolean preferUTF8) {
116     boolean useDouble = !unescaped.contains("\"");
117     boolean useMulti = unescaped.matches(".*(\r|\n).*");
118     String quotes;
119     if (useMulti) {
120       quotes = useDouble ? "\"\"\"" : "'''";
121     }
122     else {
123       quotes = useDouble ? "\"" : "'";
124     }
125     StringBuilder buf = new StringBuilder(unescaped.length() * 2);
126     buf.append(quotes);
127     VirtualFile vfile = destination == null ? null : destination.getVirtualFile();
128     Charset charset;
129     if (vfile == null) {
130       charset = (preferUTF8 ? CharsetToolkit.UTF8_CHARSET : Charset.forName("US-ASCII"));
131     }
132     else {
133       charset = vfile.getCharset();
134     }
135     CharsetEncoder encoder = charset.newEncoder();
136     Formatter formatter = new Formatter(buf);
137     boolean unicode = false;
138     for (int i = 0; i < unescaped.length(); i++) {
139       int c = unescaped.codePointAt(i);
140       if (c == '"' && useDouble) {
141         buf.append("\\\"");
142       }
143       else if (c == '\'' && !useDouble) {
144         buf.append("\\'");
145       }
146       else if ((c == '\r' || c == '\n') && !useMulti) {
147         if (c == '\r') {
148           buf.append("\\r");
149         }
150         else if (c == '\n') buf.append("\\n");
151       }
152       else if (!encoder.canEncode(new String(Character.toChars(c)))) {
153         if (c <= 0xff) {
154           formatter.format("\\x%02x", c);
155         }
156         else if (c < 0xffff) {
157           unicode = true;
158           formatter.format("\\u%04x", c);
159         }
160         else {
161           unicode = true;
162           formatter.format("\\U%08x", c);
163         }
164       }
165       else {
166         buf.appendCodePoint(c);
167       }
168     }
169     buf.append(quotes);
170     if (unicode) buf.insert(0, "u");
171
172     return createStringLiteralAlreadyEscaped(buf.toString());
173   }
174
175   public PyListLiteralExpression createListLiteral() {
176     final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), "[]");
177     final PyExpressionStatement expressionStatement = (PyExpressionStatement)dummyFile.getFirstChild();
178     return (PyListLiteralExpression)expressionStatement.getFirstChild();
179   }
180
181   public ASTNode createComma() {
182     final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), "[0,]");
183     final PyExpressionStatement expressionStatement = (PyExpressionStatement)dummyFile.getFirstChild();
184     ASTNode zero = expressionStatement.getFirstChild().getNode().getFirstChildNode().getTreeNext();
185     return zero.getTreeNext().copyElement();
186   }
187
188   public ASTNode createDot() {
189     final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), "a.b");
190     final PyExpressionStatement expressionStatement = (PyExpressionStatement)dummyFile.getFirstChild();
191     ASTNode dot = expressionStatement.getFirstChild().getNode().getFirstChildNode().getTreeNext();
192     return dot.copyElement();
193   }
194
195   @Override
196   @NotNull
197   public PsiElement insertItemIntoListRemoveRedundantCommas(
198     @NotNull final PyElement list,
199     @Nullable final PyExpression afterThis,
200     @NotNull final PyExpression toInsert) {
201     // TODO: #insertItemIntoList is probably buggy. In such case, fix it and get rid of this method
202     final PsiElement result = insertItemIntoList(list, afterThis, toInsert);
203     final LeafPsiElement[] leafs = PsiTreeUtil.getChildrenOfType(list, LeafPsiElement.class);
204     if (leafs != null) {
205       final Deque<LeafPsiElement> commas = Queues.newArrayDeque(Collections2.filter(Arrays.asList(leafs), COMMAS_ONLY));
206       if (!commas.isEmpty()) {
207         final LeafPsiElement lastComma = commas.getLast();
208         if (PsiTreeUtil.getNextSiblingOfType(lastComma, PyExpression.class) == null) { //Comma has no expression after it
209           lastComma.delete();
210         }
211       }
212     }
213
214     return result;
215   }
216
217   // TODO: Adds comma to empty list: adding "foo" to () will create (foo,). That is why "insertItemIntoListRemoveRedundantCommas" was created.
218   // We probably need to fix this method and delete insertItemIntoListRemoveRedundantCommas
219   public PsiElement insertItemIntoList(PyElement list, @Nullable PyExpression afterThis, PyExpression toInsert)
220     throws IncorrectOperationException {
221     ASTNode add = toInsert.getNode().copyElement();
222     if (afterThis == null) {
223       ASTNode exprNode = list.getNode();
224       ASTNode[] closingTokens = exprNode.getChildren(TokenSet.create(PyTokenTypes.LBRACKET, PyTokenTypes.LPAR));
225       if (closingTokens.length == 0) {
226         // we tried our best. let's just insert it at the end
227         exprNode.addChild(add);
228       }
229       else {
230         ASTNode next = PyPsiUtils.getNextNonWhitespaceSibling(closingTokens[closingTokens.length - 1]);
231         if (next != null) {
232           ASTNode comma = createComma();
233           exprNode.addChild(comma, next);
234           exprNode.addChild(add, comma);
235         }
236         else {
237           exprNode.addChild(add);
238         }
239       }
240     }
241     else {
242       ASTNode lastArgNode = afterThis.getNode();
243       ASTNode comma = createComma();
244       ASTNode parent = lastArgNode.getTreeParent();
245       ASTNode afterLast = lastArgNode.getTreeNext();
246       if (afterLast == null) {
247         parent.addChild(add);
248       }
249       else {
250         parent.addChild(add, afterLast);
251       }
252       parent.addChild(comma, add);
253     }
254     return add.getPsi();
255   }
256
257   public PyBinaryExpression createBinaryExpression(String s, PyExpression expr, PyExpression listLiteral) {
258     final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), "a " + s + " b");
259     final PyExpressionStatement expressionStatement = (PyExpressionStatement)dummyFile.getFirstChild();
260     PyBinaryExpression binExpr = (PyBinaryExpression)expressionStatement.getExpression();
261     ASTNode binnode = binExpr.getNode();
262     binnode.replaceChild(binExpr.getLeftExpression().getNode(), expr.getNode().copyElement());
263     binnode.replaceChild(binExpr.getRightExpression().getNode(), listLiteral.getNode().copyElement());
264     return binExpr;
265   }
266
267   public PyExpression createExpressionFromText(final String text) {
268     return createExpressionFromText(LanguageLevel.getDefault(), text);
269   }
270
271   @NotNull
272   public PyExpression createExpressionFromText(final LanguageLevel languageLevel, final String text) {
273     final PsiFile dummyFile = createDummyFile(languageLevel, text);
274     final PsiElement element = dummyFile.getFirstChild();
275     if (element instanceof PyExpressionStatement) {
276       return ((PyExpressionStatement)element).getExpression();
277     }
278     throw new IncorrectOperationException("could not parse text as expression: " + text);
279   }
280
281   @NotNull
282   public PyCallExpression createCallExpression(final LanguageLevel langLevel, String functionName) {
283     final PsiFile dummyFile = createDummyFile(langLevel, functionName + "()");
284     final PsiElement child = dummyFile.getFirstChild();
285     if (child != null) {
286       final PsiElement element = child.getFirstChild();
287       if (element instanceof PyCallExpression) {
288         return (PyCallExpression)element;
289       }
290     }
291     throw new IllegalArgumentException("Invalid call expression text " + functionName);
292   }
293
294   @Override
295   public PyImportElement createImportElement(final LanguageLevel languageLevel, String name) {
296     return createFromText(languageLevel, PyImportElement.class, "from foo import " + name, new int[]{0, 6});
297   }
298
299   @Override
300   public PyFunction createProperty(LanguageLevel languageLevel,
301                                    String propertyName,
302                                    String fieldName,
303                                    AccessDirection accessDirection) {
304     String propertyText;
305     if (accessDirection == AccessDirection.DELETE) {
306       propertyText = "@" + propertyName + ".deleter\ndef " + propertyName + "(self):\n  del self." + fieldName;
307     }
308     else if (accessDirection == AccessDirection.WRITE) {
309       propertyText = "@" + propertyName + ".setter\ndef " + propertyName + "(self, value):\n  self." + fieldName + " = value";
310     }
311     else {
312       propertyText = "@property\ndef " + propertyName + "(self):\n  return self." + fieldName;
313     }
314     return createFromText(languageLevel, PyFunction.class, propertyText);
315   }
316
317   static final int[] FROM_ROOT = new int[]{0};
318
319   @NotNull
320   public <T> T createFromText(LanguageLevel langLevel, Class<T> aClass, final String text) {
321     return createFromText(langLevel, aClass, text, FROM_ROOT);
322   }
323
324   @NotNull
325   @Override
326   public <T> T createPhysicalFromText(LanguageLevel langLevel, Class<T> aClass, String text) {
327     return createFromText(langLevel, aClass, text, FROM_ROOT, true);
328   }
329
330   static int[] PATH_PARAMETER = {0, 3, 1};
331
332   public PyNamedParameter createParameter(@NotNull String name) {
333     return createParameter(name, null, null, LanguageLevel.getDefault());
334   }
335
336   @NotNull
337   @Override
338   public PyParameterList createParameterList(@NotNull LanguageLevel languageLevel, @NotNull String text) {
339     return createFromText(languageLevel, PyParameterList.class, "def f" + text + ": pass", new int[]{0, 3});
340   }
341
342   @NotNull
343   @Override
344   public PyArgumentList createArgumentList(@NotNull LanguageLevel languageLevel, @NotNull String text) {
345     return createFromText(languageLevel, PyArgumentList.class, "f" + text, new int[]{0, 0, 1});
346   }
347
348
349   public PyNamedParameter createParameter(@NotNull String name, @Nullable String defaultValue, @Nullable String annotation,
350                                           @NotNull LanguageLevel languageLevel) {
351     String parameterText = name;
352     if (annotation != null) {
353       parameterText += ": " + annotation;
354     }
355     if (defaultValue != null) {
356       parameterText += " = " + defaultValue;
357     }
358
359     return createFromText(languageLevel, PyNamedParameter.class, "def f(" + parameterText + "): pass", PATH_PARAMETER);
360   }
361
362   @Override
363   public PyKeywordArgument createKeywordArgument(LanguageLevel languageLevel, String keyword, String value) {
364     PyCallExpression callExpression = (PyCallExpression)createExpressionFromText(languageLevel, "foo(" + keyword + "=" + value + ")");
365     return (PyKeywordArgument)callExpression.getArguments()[0];
366   }
367
368   @NotNull
369   public <T> T createFromText(LanguageLevel langLevel, Class<T> aClass, final String text, final int[] path) {
370     return createFromText(langLevel, aClass, text, path, false);
371   }
372
373   @NotNull
374   public <T> T createFromText(LanguageLevel langLevel, Class<T> aClass, final String text, final int[] path, boolean physical) {
375     PsiElement ret = createDummyFile(langLevel, text, physical);
376     for (int skip : path) {
377       if (ret != null) {
378         ret = ret.getFirstChild();
379         for (int i = 0; i < skip; i += 1) {
380           if (ret != null) {
381             ret = ret.getNextSibling();
382           }
383           else {
384             ret = null;
385             break;
386           }
387         }
388       }
389       else {
390         break;
391       }
392     }
393     if (ret == null) {
394       throw new IllegalArgumentException("Can't find element matching path " + Arrays.toString(path) + " in text '" + text + "'");
395     }
396     try {
397       //noinspection unchecked
398       return (T)ret;
399     }
400     catch (ClassCastException e) {
401       throw new IllegalArgumentException("Can't create an expression of type " + aClass + " from text '" + text + "'");
402     }
403   }
404
405   @Override
406   public PyPassStatement createPassStatement() {
407     final PyStatementList statementList = createPassStatementList();
408     return (PyPassStatement)statementList.getStatements()[0];
409   }
410
411   @NotNull
412   @Override
413   public PyDecoratorList createDecoratorList(@NotNull final String... decoratorTexts) {
414     assert decoratorTexts.length > 0;
415     StringBuilder functionText = new StringBuilder();
416     for (String decoText : decoratorTexts) {
417       functionText.append(decoText).append("\n");
418     }
419     functionText.append("def foo():\n\tpass");
420     final PyFunction function = createFromText(LanguageLevel.getDefault(), PyFunction.class,
421                                                functionText.toString());
422     final PyDecoratorList decoratorList = function.getDecoratorList();
423     assert decoratorList != null;
424     return decoratorList;
425   }
426
427   private PyStatementList createPassStatementList() {
428     final PyFunction function = createFromText(LanguageLevel.getDefault(), PyFunction.class, "def foo():\n\tpass");
429     return function.getStatementList();
430   }
431
432   public PyExpressionStatement createDocstring(String content) {
433     return createFromText(LanguageLevel.getDefault(),
434                           PyExpressionStatement.class, content + "\n");
435   }
436
437   @NotNull
438   @Override
439   public PsiElement createNewLine() {
440     return createFromText(LanguageLevel.getDefault(), PsiWhiteSpace.class, " \n\n ");
441   }
442
443   @NotNull
444   @Override
445   public PyFromImportStatement createFromImportStatement(@NotNull LanguageLevel languageLevel, @NotNull String qualifier,
446                                                          @NotNull String name, @Nullable String alias) {
447     final String asClause = StringUtil.isNotEmpty(alias) ? " as " + alias : "";
448     final String statement = "from " + qualifier + " import " + name + asClause;
449     return createFromText(languageLevel, PyFromImportStatement.class, statement);
450   }
451
452   @NotNull
453   @Override
454   public PyImportStatement createImportStatement(@NotNull LanguageLevel languageLevel, @NotNull String name, @Nullable String alias) {
455     final String asClause = StringUtil.isNotEmpty(alias) ? " as " + alias : "";
456     final String statement = "import " + name + asClause;
457     return createFromText(languageLevel, PyImportStatement.class, statement);
458   }
459
460   private static class CommasOnly extends NotNullPredicate<LeafPsiElement> {
461     @Override
462     protected boolean applyNotNull(@NotNull final LeafPsiElement input) {
463       return input.getNode().getElementType().equals(PyTokenTypes.COMMA);
464     }
465   }
466 }