cdfffec888d0efefd487dbeaa708702c5d396f29
[idea/community.git] / python / testSrc / com / jetbrains / python / PyTypeFromUsedAttributesTest.java
1 package com.jetbrains.python;
2
3 import com.intellij.psi.PsiElement;
4 import com.intellij.psi.util.PsiElementFilter;
5 import com.intellij.psi.util.PsiTreeUtil;
6 import com.intellij.psi.util.QualifiedName;
7 import com.intellij.util.ArrayUtil;
8 import com.intellij.util.Function;
9 import com.intellij.util.containers.ContainerUtil;
10 import com.jetbrains.python.documentation.PythonDocumentationProvider;
11 import com.jetbrains.python.fixtures.PyTestCase;
12 import com.jetbrains.python.psi.PyReferenceExpression;
13 import com.jetbrains.python.psi.types.PyType;
14 import com.jetbrains.python.psi.types.PyTypeFromUsedAttributesHelper;
15 import com.jetbrains.python.psi.types.TypeEvalContext;
16 import org.jetbrains.annotations.NotNull;
17 import org.jetbrains.annotations.Nullable;
18
19 import java.util.Set;
20
21 /**
22  * @author Mikhail Golubev
23  */
24 public class PyTypeFromUsedAttributesTest extends PyTestCase {
25
26   public void testCollectAttributeOfParameter() {
27     doTestUsedAttributes("def func(x):\n" +
28                          "    if x.baz():\n" +
29                          "        x.foo = x.bar",
30                          "foo", "bar", "baz");
31   }
32
33   public void testCollectAttributesInOuterScopes() {
34     doTestUsedAttributes("x = undefined()" +
35                          "x.quux\n" +
36                          "def func2():\n" +
37                          "    x.bar\n" +
38                          "    def nested():\n" +
39                          "        x.baz",
40                          "quux", "baz", "bar");
41   }
42
43   public void testIgnoreAttributesOfParameterInOtherFunctions() {
44     doTestUsedAttributes("def func1(x):\n" +
45                          "    x.foo\n" +
46                          "    \n" +
47                          "def func2(x):\n" +
48                          "    x.bar",
49                          "bar");
50   }
51
52
53   public void testCollectSpecialMethodNames() {
54     doTestUsedAttributes("x = undefined()\n" +
55                          "x[0] = x[...]\n" +
56                          "x + 42",
57                          "__getitem__", "__setitem__", "__add__");
58   }
59
60   public void testOnlyBaseClassesRetained() {
61     doTestType("class Base(object):\n" +
62                "    attr_a = None\n" +
63                "    attr_b = None\n" +
64                "\n" +
65                "class C2(Base):\n" +
66                "    pass\n" +
67                "\n" +
68                "class C3(Base):\n" +
69                "    attr_a = None\n" +
70                "\n" +
71                "class C4(Base):\n" +
72                "    attr_a = None\n" +
73                "    attr_b = None\n" +
74                "\n" +
75                "x = undefined()\n" +
76                "x.attr_a\n" +
77                "x.attr_b",
78                "Base | unknown");
79   }
80
81   public void testDiamondHierarchyBottom() {
82     doTestType("class D(object):\n" +
83                "    pass\n" +
84                "class B(D):\n" +
85                "    pass\n" +
86                "class C(D):\n" +
87                "    pass\n" +
88                "class A(B, C):\n" +
89                "    foo = None\n" +
90                "    bar = None\n" +
91                "\n" +
92                "def func(x):\n" +
93                "    x.foo\n" +
94                "    x.bar",
95                "A | unknown");
96   }
97
98   public void testDiamondHierarchySiblings() {
99     doTestType("class D(object):\n" +
100                "    bar = None\n" +
101                "class B(D):\n" +
102                "    foo = None\n" +
103                "class C(D):\n" +
104                "    foo = None\n" +
105                "    bar = None\n" +
106                "class A(B, C):\n" +
107                "    foo = None\n" +
108                "    bar = None\n" +
109                "\n" +
110                "def func(x):\n" +
111                "    x.foo()\n" +
112                "    x.bar()\n",
113                "B | C | unknown");
114   }
115
116   public void testDiamondHierarchyTop() {
117     doTestType("class D(object):\n" +
118                "    foo = None\n" +
119                "    bar = None\n" +
120                "\n" +
121                "class B(D):\n" +
122                "    foo = None\n" +
123                "    bar = None\n" +
124                "\n" +
125                "class C(D):\n" +
126                "    foo = None\n" +
127                "\n" +
128                "class A(B, C):\n" +
129                "    foo = None\n" +
130                "    bar = None\n" +
131                "\n" +
132                "def func(x):\n" +
133                "    x.foo()\n" +
134                "    x.bar()",
135                "D | unknown");
136   }
137
138   public void testDiamondHierarchyLeft() {
139     doTestType("class D(object):\n" +
140                "    foo = None\n" +
141                "class B(D):\n" +
142                "    bar = None\n" +
143                "class C(D):\n" +
144                "    pass\n" +
145                "class A(B, C):\n" +
146                "    foo = None\n" +
147                "    bar = None\n" +
148                "def func(x):\n" +
149                "    x.foo()\n" +
150                "    x.bar()",
151                "B | unknown");
152   }
153
154   public void testBuiltinTypes() {
155     doTestType("def func(x):\n" +
156                "    x.upper()\n" +
157                "    x.decode()",
158                "bytearray | str | unicode | unknown");
159
160     doTestType("def func(x):\n" +
161                "    x.pop() and x.update()",
162                "dict | set | unknown");
163   }
164
165   public void testFunctionType() {
166     doTestType("class A:\n" +
167                "    def method_a(self):\n" +
168                "        pass\n" +
169                "class B:\n" +
170                "    def method_b(self):\n" +
171                "        pass\n" +
172                "def func(a, b):\n" +
173                "    a.method_a\n" +
174                "    b.method_b\n" +
175                "    return a\n" +
176                "x = func\n" +
177                "x",
178                "(a: A | unknown, b: B | unknown) -> A | unknown");
179   }
180
181   public void testFastInferenceForObjectAttributes() {
182     doTestType("x = undefined()\n" +
183                "x.__init__(1)\n" +
184                "x",
185                "object | unknown");
186   }
187
188   public void testResultsOrdering() {
189     myFixture.copyDirectoryToProject(getTestName(true), "");
190     doTestType("import module\n" +
191                "class MySortable(object):\n" +
192                "    def sort(self):\n" +
193                "        pass\n" +
194                "def f(x):\n" +
195                "    x.sort()",
196                "list | MySortable | OtherClassA | OtherClassB | unknown");
197   }
198
199   public void testCyclicInheritance() {
200     myFixture.copyDirectoryToProject(getTestName(true), "");
201     myFixture.configureByFile("main.py");
202     final PyReferenceExpression referenceExpression = findLastReferenceByText("x");
203     assertNotNull(referenceExpression);
204     final TypeEvalContext context =
205       TypeEvalContext.userInitiated(myFixture.getProject(), referenceExpression.getContainingFile()).withTracing();
206     final PyType actual = context.getType(referenceExpression);
207     final String actualType = PythonDocumentationProvider.getTypeName(actual, context);
208     assertEquals("B | unknown", actualType);
209   }
210
211   public void testLongInheritanceChain() {
212     // This obvious test is needed because of custom method for resolution of class ancestors
213     doTestType("class C1(object):\n" +
214                "    attr = 'top'\n" +
215                "class C2(C1):\n" +
216                "    pass\n" +
217                "class C3(C2):\n" +
218                "    pass\n" +
219                "class C4(C3):\n" +
220                "    pass\n" +
221                "class C5(C4):\n" +
222                "    attr = 'bottom'\n" +
223                "def f(x):\n" +
224                "    x.attr\n",
225                "C1 | unknown");
226   }
227
228   public void testImportQualifiers() {
229     myFixture.copyDirectoryToProject(getTestName(true), "");
230     myFixture.configureByFile("pkg1/pkg2/main.py");
231     final Set<QualifiedName> qualifiers = PyTypeFromUsedAttributesHelper.collectImportQualifiers(myFixture.getFile());
232     final Set<String> qualifiedNames = ContainerUtil.map2Set(qualifiers, new Function<QualifiedName, String>() {
233       @Override
234       public String fun(QualifiedName name) {
235         return name.toString();
236       }
237     });
238     assertSameElements(qualifiedNames,
239                        "root",
240                        "pkg1.pkg2.module2a",
241                        "pkg1.module1a",
242                        "pkg1.pkg2.module2b.C1",
243                        "pkg1.module1b.B1",
244                        "pkg1.pkg2.module2b",
245                        "pkg1.pkg2");
246   }
247
248   private void doTestType(@NotNull String text, @NotNull String expectedType) {
249     myFixture.configureByText(PythonFileType.INSTANCE, text);
250     final PyReferenceExpression referenceExpression = findLastReferenceByText("x");
251     assertNotNull(referenceExpression);
252     final TypeEvalContext context = TypeEvalContext.userInitiated(myFixture.getProject(), referenceExpression.getContainingFile());
253     final PyType actual = context.getType(referenceExpression);
254     final String actualType = PythonDocumentationProvider.getTypeName(actual, context);
255     assertEquals(expectedType, actualType);
256   }
257
258   private void doTestUsedAttributes(@NotNull String text, @NotNull String... attributesExpected) {
259     myFixture.configureByText(PythonFileType.INSTANCE, text);
260     final PyReferenceExpression referenceExpression = findLastReferenceByText("x");
261     assertNotNull(referenceExpression);
262     assertSameElements(PyTypeFromUsedAttributesHelper.collectUsedAttributes(referenceExpression), attributesExpected);
263   }
264
265   @Nullable
266   private PyReferenceExpression findLastReferenceByText(@NotNull final String text) {
267     final PsiElement[] elements = PsiTreeUtil.collectElements(myFixture.getFile(), new PsiElementFilter() {
268       @Override
269       public boolean isAccepted(PsiElement element) {
270         return element instanceof PyReferenceExpression && element.getText().equals(text);
271       }
272     });
273     return (PyReferenceExpression)ArrayUtil.getLastElement(elements);
274   }
275
276   @Override
277   protected String getTestDataPath() {
278     return super.getTestDataPath() + "/typesFromAttributes/";
279   }
280 }