Add PyElementGenerator#createParameterList and PyElementGenerator#createParameterList
[idea/community.git] / python / src / com / jetbrains / python / psi / impl / PyElementGeneratorImpl.java
1 /*
2  * Copyright 2000-2014 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.jetbrains.python.psi.impl;
17
18 import com.google.common.collect.Collections2;
19 import com.google.common.collect.Queues;
20 import com.intellij.lang.ASTNode;
21 import com.intellij.openapi.project.Project;
22 import com.intellij.openapi.util.Pair;
23 import com.intellij.openapi.util.text.StringUtil;
24 import com.intellij.openapi.vfs.CharsetToolkit;
25 import com.intellij.openapi.vfs.VirtualFile;
26 import com.intellij.psi.PsiElement;
27 import com.intellij.psi.PsiFile;
28 import com.intellij.psi.PsiFileFactory;
29 import com.intellij.psi.PsiWhiteSpace;
30 import com.intellij.psi.impl.PsiFileFactoryImpl;
31 import com.intellij.psi.impl.source.tree.LeafPsiElement;
32 import com.intellij.psi.tree.TokenSet;
33 import com.intellij.psi.util.PsiTreeUtil;
34 import com.intellij.testFramework.LightVirtualFile;
35 import com.intellij.util.IncorrectOperationException;
36 import com.jetbrains.NotNullPredicate;
37 import com.jetbrains.python.PyTokenTypes;
38 import com.jetbrains.python.PythonFileType;
39 import com.jetbrains.python.PythonLanguage;
40 import com.jetbrains.python.PythonStringUtil;
41 import com.jetbrains.python.psi.*;
42 import org.jetbrains.annotations.NotNull;
43 import org.jetbrains.annotations.Nullable;
44
45 import java.nio.charset.Charset;
46 import java.nio.charset.CharsetEncoder;
47 import java.util.Arrays;
48 import java.util.Deque;
49 import java.util.Formatter;
50
51 /**
52  * @author yole
53  */
54 public class PyElementGeneratorImpl extends PyElementGenerator {
55   private static final CommasOnly COMMAS_ONLY = new CommasOnly();
56   private final Project myProject;
57
58   public PyElementGeneratorImpl(Project project) {
59     myProject = project;
60   }
61
62   public ASTNode createNameIdentifier(String name, LanguageLevel languageLevel) {
63     final PsiFile dummyFile = createDummyFile(languageLevel, name);
64     final PyExpressionStatement expressionStatement = (PyExpressionStatement)dummyFile.getFirstChild();
65     final PyReferenceExpression refExpression = (PyReferenceExpression)expressionStatement.getFirstChild();
66
67     return refExpression.getNode().getFirstChildNode();
68   }
69
70   @Override
71   public PsiFile createDummyFile(LanguageLevel langLevel, String contents) {
72     return createDummyFile(langLevel, contents, false);
73   }
74
75   public PsiFile createDummyFile(LanguageLevel langLevel, String contents, boolean physical) {
76     final PsiFileFactory factory = PsiFileFactory.getInstance(myProject);
77     final String name = "dummy." + PythonFileType.INSTANCE.getDefaultExtension();
78     final LightVirtualFile virtualFile = new LightVirtualFile(name, PythonFileType.INSTANCE, contents);
79     virtualFile.putUserData(LanguageLevel.KEY, langLevel);
80     final PsiFile psiFile = ((PsiFileFactoryImpl)factory).trySetupPsiForFile(virtualFile, PythonLanguage.getInstance(), physical, true);
81     assert psiFile != null;
82     return psiFile;
83   }
84
85   public PyStringLiteralExpression createStringLiteralAlreadyEscaped(String str) {
86     final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), "a=(" + str + ")");
87     final PyAssignmentStatement expressionStatement = (PyAssignmentStatement)dummyFile.getFirstChild();
88     return (PyStringLiteralExpression)((PyParenthesizedExpression)expressionStatement.getAssignedValue()).getContainedExpression();
89   }
90
91
92   @Override
93   public PyStringLiteralExpression createStringLiteralFromString(@NotNull String unescaped) {
94     return createStringLiteralFromString(null, unescaped, true);
95   }
96
97   public PyStringLiteralExpression createStringLiteral(@NotNull PyStringLiteralExpression oldElement, @NotNull String unescaped) {
98     Pair<String, String> quotes = PythonStringUtil.getQuotes(oldElement.getText());
99     if (quotes != null) {
100       return createStringLiteralAlreadyEscaped(quotes.first + unescaped + quotes.second);
101     }
102     else {
103       return createStringLiteralFromString(unescaped);
104     }
105   }
106
107
108   @Override
109   public PyStringLiteralExpression createStringLiteralFromString(@Nullable PsiFile destination,
110                                                                  @NotNull String unescaped,
111                                                                  final boolean preferUTF8) {
112     boolean useDouble = !unescaped.contains("\"");
113     boolean useMulti = unescaped.matches(".*(\r|\n).*");
114     String quotes;
115     if (useMulti) {
116       quotes = useDouble ? "\"\"\"" : "'''";
117     }
118     else {
119       quotes = useDouble ? "\"" : "'";
120     }
121     StringBuilder buf = new StringBuilder(unescaped.length() * 2);
122     buf.append(quotes);
123     VirtualFile vfile = destination == null ? null : destination.getVirtualFile();
124     Charset charset;
125     if (vfile == null) {
126       charset = (preferUTF8 ? CharsetToolkit.UTF8_CHARSET : Charset.forName("US-ASCII"));
127     }
128     else {
129       charset = vfile.getCharset();
130     }
131     CharsetEncoder encoder = charset.newEncoder();
132     Formatter formatter = new Formatter(buf);
133     boolean unicode = false;
134     for (int i = 0; i < unescaped.length(); i++) {
135       int c = unescaped.codePointAt(i);
136       if (c == '"' && useDouble) {
137         buf.append("\\\"");
138       }
139       else if (c == '\'' && !useDouble) {
140         buf.append("\\'");
141       }
142       else if ((c == '\r' || c == '\n') && !useMulti) {
143         if (c == '\r') {
144           buf.append("\\r");
145         }
146         else if (c == '\n') buf.append("\\n");
147       }
148       else if (!encoder.canEncode(new String(Character.toChars(c)))) {
149         if (c <= 0xff) {
150           formatter.format("\\x%02x", c);
151         }
152         else if (c < 0xffff) {
153           unicode = true;
154           formatter.format("\\u%04x", c);
155         }
156         else {
157           unicode = true;
158           formatter.format("\\U%08x", c);
159         }
160       }
161       else {
162         buf.appendCodePoint(c);
163       }
164     }
165     buf.append(quotes);
166     if (unicode) buf.insert(0, "u");
167
168     return createStringLiteralAlreadyEscaped(buf.toString());
169   }
170
171   public PyListLiteralExpression createListLiteral() {
172     final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), "[]");
173     final PyExpressionStatement expressionStatement = (PyExpressionStatement)dummyFile.getFirstChild();
174     return (PyListLiteralExpression)expressionStatement.getFirstChild();
175   }
176
177   public ASTNode createComma() {
178     final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), "[0,]");
179     final PyExpressionStatement expressionStatement = (PyExpressionStatement)dummyFile.getFirstChild();
180     ASTNode zero = expressionStatement.getFirstChild().getNode().getFirstChildNode().getTreeNext();
181     return zero.getTreeNext().copyElement();
182   }
183
184   public ASTNode createDot() {
185     final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), "a.b");
186     final PyExpressionStatement expressionStatement = (PyExpressionStatement)dummyFile.getFirstChild();
187     ASTNode dot = expressionStatement.getFirstChild().getNode().getFirstChildNode().getTreeNext();
188     return dot.copyElement();
189   }
190
191   @Override
192   @NotNull
193   public PsiElement insertItemIntoListRemoveRedundantCommas(
194     @NotNull final PyElement list,
195     @Nullable final PyExpression afterThis,
196     @NotNull final PyExpression toInsert) {
197     // TODO: #insertItemIntoList is probably buggy. In such case, fix it and get rid of this method
198     final PsiElement result = insertItemIntoList(list, afterThis, toInsert);
199     final LeafPsiElement[] leafs = PsiTreeUtil.getChildrenOfType(list, LeafPsiElement.class);
200     if (leafs != null) {
201       final Deque<LeafPsiElement> commas = Queues.newArrayDeque(Collections2.filter(Arrays.asList(leafs), COMMAS_ONLY));
202       if (!commas.isEmpty()) {
203         final LeafPsiElement lastComma = commas.getLast();
204         if (PsiTreeUtil.getNextSiblingOfType(lastComma, PyExpression.class) == null) { //Comma has no expression after it
205           lastComma.delete();
206         }
207       }
208     }
209
210     return result;
211   }
212
213   // TODO: Adds comma to empty list: adding "foo" to () will create (foo,). That is why "insertItemIntoListRemoveRedundantCommas" was created.
214   // We probably need to fix this method and delete insertItemIntoListRemoveRedundantCommas
215   public PsiElement insertItemIntoList(PyElement list, @Nullable PyExpression afterThis, PyExpression toInsert)
216     throws IncorrectOperationException {
217     ASTNode add = toInsert.getNode().copyElement();
218     if (afterThis == null) {
219       ASTNode exprNode = list.getNode();
220       ASTNode[] closingTokens = exprNode.getChildren(TokenSet.create(PyTokenTypes.LBRACKET, PyTokenTypes.LPAR));
221       if (closingTokens.length == 0) {
222         // we tried our best. let's just insert it at the end
223         exprNode.addChild(add);
224       }
225       else {
226         ASTNode next = PyPsiUtils.getNextNonWhitespaceSibling(closingTokens[closingTokens.length - 1]);
227         if (next != null) {
228           ASTNode comma = createComma();
229           exprNode.addChild(comma, next);
230           exprNode.addChild(add, comma);
231         }
232         else {
233           exprNode.addChild(add);
234         }
235       }
236     }
237     else {
238       ASTNode lastArgNode = afterThis.getNode();
239       ASTNode comma = createComma();
240       ASTNode parent = lastArgNode.getTreeParent();
241       ASTNode afterLast = lastArgNode.getTreeNext();
242       if (afterLast == null) {
243         parent.addChild(add);
244       }
245       else {
246         parent.addChild(add, afterLast);
247       }
248       parent.addChild(comma, add);
249     }
250     return add.getPsi();
251   }
252
253   public PyBinaryExpression createBinaryExpression(String s, PyExpression expr, PyExpression listLiteral) {
254     final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), "a " + s + " b");
255     final PyExpressionStatement expressionStatement = (PyExpressionStatement)dummyFile.getFirstChild();
256     PyBinaryExpression binExpr = (PyBinaryExpression)expressionStatement.getExpression();
257     ASTNode binnode = binExpr.getNode();
258     binnode.replaceChild(binExpr.getLeftExpression().getNode(), expr.getNode().copyElement());
259     binnode.replaceChild(binExpr.getRightExpression().getNode(), listLiteral.getNode().copyElement());
260     return binExpr;
261   }
262
263   public PyExpression createExpressionFromText(final String text) {
264     return createExpressionFromText(LanguageLevel.getDefault(), text);
265   }
266
267   @NotNull
268   public PyExpression createExpressionFromText(final LanguageLevel languageLevel, final String text) {
269     final PsiFile dummyFile = createDummyFile(languageLevel, text);
270     final PsiElement element = dummyFile.getFirstChild();
271     if (element instanceof PyExpressionStatement) {
272       return ((PyExpressionStatement)element).getExpression();
273     }
274     throw new IncorrectOperationException("could not parse text as expression: " + text);
275   }
276
277   @NotNull
278   public PyCallExpression createCallExpression(final LanguageLevel langLevel, String functionName) {
279     final PsiFile dummyFile = createDummyFile(langLevel, functionName + "()");
280     final PsiElement child = dummyFile.getFirstChild();
281     if (child != null) {
282       final PsiElement element = child.getFirstChild();
283       if (element instanceof PyCallExpression) {
284         return (PyCallExpression)element;
285       }
286     }
287     throw new IllegalArgumentException("Invalid call expression text " + functionName);
288   }
289
290   @Override
291   public PyImportElement createImportElement(final LanguageLevel languageLevel, String name) {
292     return createFromText(languageLevel, PyImportElement.class, "from foo import " + name, new int[]{0, 6});
293   }
294
295   @Override
296   public PyFunction createProperty(LanguageLevel languageLevel,
297                                    String propertyName,
298                                    String fieldName,
299                                    AccessDirection accessDirection) {
300     String propertyText;
301     if (accessDirection == AccessDirection.DELETE) {
302       propertyText = "@" + propertyName + ".deleter\ndef " + propertyName + "(self):\n  del self." + fieldName;
303     }
304     else if (accessDirection == AccessDirection.WRITE) {
305       propertyText = "@" + propertyName + ".setter\ndef " + propertyName + "(self, value):\n  self." + fieldName + " = value";
306     }
307     else {
308       propertyText = "@property\ndef " + propertyName + "(self):\n  return self." + fieldName;
309     }
310     return createFromText(languageLevel, PyFunction.class, propertyText);
311   }
312
313   static final int[] FROM_ROOT = new int[]{0};
314
315   @NotNull
316   public <T> T createFromText(LanguageLevel langLevel, Class<T> aClass, final String text) {
317     return createFromText(langLevel, aClass, text, FROM_ROOT);
318   }
319
320   @NotNull
321   @Override
322   public <T> T createPhysicalFromText(LanguageLevel langLevel, Class<T> aClass, String text) {
323     return createFromText(langLevel, aClass, text, FROM_ROOT, true);
324   }
325
326   static int[] PATH_PARAMETER = {0, 3, 1};
327
328   public PyNamedParameter createParameter(@NotNull String name) {
329     return createParameter(name, null, null, LanguageLevel.getDefault());
330   }
331
332   @NotNull
333   @Override
334   public PyParameterList createParameterList(@NotNull LanguageLevel languageLevel, @NotNull String text) {
335     return createFromText(languageLevel, PyParameterList.class, "def f" + text + ": pass", new int[]{0, 3});
336   }
337
338   @NotNull
339   @Override
340   public PyArgumentList createArgumentList(@NotNull LanguageLevel languageLevel, @NotNull String text) {
341     return createFromText(languageLevel, PyArgumentList.class, "f" + text, new int[]{0, 0, 1});
342   }
343
344
345   public PyNamedParameter createParameter(@NotNull String name, @Nullable String defaultValue, @Nullable String annotation,
346                                           @NotNull LanguageLevel languageLevel) {
347     String parameterText = name;
348     if (annotation != null) {
349       parameterText += ": " + annotation;
350     }
351     if (defaultValue != null) {
352       parameterText += " = " + defaultValue;
353     }
354
355     return createFromText(languageLevel, PyNamedParameter.class, "def f(" + parameterText + "): pass", PATH_PARAMETER);
356   }
357
358   @Override
359   public PyKeywordArgument createKeywordArgument(LanguageLevel languageLevel, String keyword, String value) {
360     PyCallExpression callExpression = (PyCallExpression)createExpressionFromText(languageLevel, "foo(" + keyword + "=" + value + ")");
361     return (PyKeywordArgument)callExpression.getArguments()[0];
362   }
363
364   @NotNull
365   public <T> T createFromText(LanguageLevel langLevel, Class<T> aClass, final String text, final int[] path) {
366     return createFromText(langLevel, aClass, text, path, false);
367   }
368
369   @NotNull
370   public <T> T createFromText(LanguageLevel langLevel, Class<T> aClass, final String text, final int[] path, boolean physical) {
371     PsiElement ret = createDummyFile(langLevel, text, physical);
372     for (int skip : path) {
373       if (ret != null) {
374         ret = ret.getFirstChild();
375         for (int i = 0; i < skip; i += 1) {
376           if (ret != null) {
377             ret = ret.getNextSibling();
378           }
379           else {
380             ret = null;
381             break;
382           }
383         }
384       }
385       else {
386         break;
387       }
388     }
389     if (ret == null) {
390       throw new IllegalArgumentException("Can't find element matching path " + Arrays.toString(path) + " in text '" + text + "'");
391     }
392     try {
393       //noinspection unchecked
394       return (T)ret;
395     }
396     catch (ClassCastException e) {
397       throw new IllegalArgumentException("Can't create an expression of type " + aClass + " from text '" + text + "'");
398     }
399   }
400
401   @Override
402   public PyPassStatement createPassStatement() {
403     final PyStatementList statementList = createPassStatementList();
404     return (PyPassStatement)statementList.getStatements()[0];
405   }
406
407   @NotNull
408   @Override
409   public PyDecoratorList createDecoratorList(@NotNull final String... decoratorTexts) {
410     assert decoratorTexts.length > 0;
411     StringBuilder functionText = new StringBuilder();
412     for (String decoText : decoratorTexts) {
413       functionText.append(decoText).append("\n");
414     }
415     functionText.append("def foo():\n\tpass");
416     final PyFunction function = createFromText(LanguageLevel.getDefault(), PyFunction.class,
417                                                functionText.toString());
418     final PyDecoratorList decoratorList = function.getDecoratorList();
419     assert decoratorList != null;
420     return decoratorList;
421   }
422
423   private PyStatementList createPassStatementList() {
424     final PyFunction function = createFromText(LanguageLevel.getDefault(), PyFunction.class, "def foo():\n\tpass");
425     return function.getStatementList();
426   }
427
428   public PyExpressionStatement createDocstring(String content) {
429     return createFromText(LanguageLevel.getDefault(),
430                           PyExpressionStatement.class, content + "\n");
431   }
432
433   @NotNull
434   @Override
435   public PsiElement createNewLine() {
436     return createFromText(LanguageLevel.getDefault(), PsiWhiteSpace.class, " \n\n ");
437   }
438
439   @NotNull
440   @Override
441   public PyFromImportStatement createFromImportStatement(@NotNull LanguageLevel languageLevel, @NotNull String qualifier,
442                                                          @NotNull String name, @Nullable String alias) {
443     final String asClause = StringUtil.isNotEmpty(alias) ? " as " + alias : "";
444     final String statement = "from " + qualifier + " import " + name + asClause;
445     return createFromText(languageLevel, PyFromImportStatement.class, statement);
446   }
447
448   @NotNull
449   @Override
450   public PyImportStatement createImportStatement(@NotNull LanguageLevel languageLevel, @NotNull String name, @Nullable String alias) {
451     final String asClause = StringUtil.isNotEmpty(alias) ? " as " + alias : "";
452     final String statement = "import " + name + asClause;
453     return createFromText(languageLevel, PyImportStatement.class, statement);
454   }
455
456   private static class CommasOnly extends NotNullPredicate<LeafPsiElement> {
457     @Override
458     protected boolean applyNotNull(@NotNull final LeafPsiElement input) {
459       return input.getNode().getElementType().equals(PyTokenTypes.COMMA);
460     }
461   }
462 }