PY-20932 Pandas DataFrame viewer for multiindex frames raises TypeError
[idea/community.git] / python / testSrc / com / jetbrains / env / python / PythonDataViewerTest.java
index d988bb361686a6c0aa1e8116aa71e9d9c1568f03..c818157803c3f27a338ba544655084d4acee9eb3 100644 (file)
@@ -16,6 +16,8 @@
 package com.jetbrains.env.python;
 
 import com.google.common.collect.ImmutableSet;
+import com.intellij.util.Consumer;
+import com.intellij.xdebugger.XDebugSession;
 import com.intellij.xdebugger.XDebuggerTestUtil;
 import com.jetbrains.env.PyEnvTestCase;
 import com.jetbrains.env.Staging;
@@ -23,10 +25,12 @@ import com.jetbrains.env.python.debug.PyDebuggerTask;
 import com.jetbrains.python.debugger.ArrayChunk;
 import com.jetbrains.python.debugger.PyDebugValue;
 import com.jetbrains.python.debugger.PyDebuggerException;
-import com.jetbrains.python.debugger.array.AsyncArrayTableModel;
 import org.jetbrains.annotations.NotNull;
+import org.jetbrains.annotations.Nullable;
 import org.junit.Test;
 
+import java.lang.reflect.InvocationTargetException;
+import java.util.List;
 import java.util.Set;
 
 import static com.intellij.testFramework.UsefulTestCase.assertSameElements;
@@ -40,53 +44,88 @@ public class PythonDataViewerTest extends PyEnvTestCase {
   @Test
   @Staging
   public void testDataFrameChunkRetrieval() throws Exception {
-    runPythonTest(new PyDebuggerTask("/debug", "test_dataframe.py") {
+    runPythonTest(new PyDataFrameDebuggerTask(getRelativeTestDataPath(), "test_dataframe.py", ImmutableSet.of(7, 15, 22)) {
       @Override
-      public void before() throws Exception {
-        toggleBreakpoint(getScriptName(), 7);
-        toggleBreakpoint(getScriptName(), 15);
-        toggleBreakpoint(getScriptName(), 22);
+      public void testing() throws Exception {
+        doTest("df1", 3, 5, null);
+
+        doTest("df2", 3, 6, arrayChunk -> {
+          List<ArrayChunk.ColHeader> colHeaders = arrayChunk.getColHeaders();
+          assertSameElements(colHeaders.stream().map(ArrayChunk.ColHeader::getLabel).toArray(),
+                             "LABELS", "One_X", "One_Y", "Two_X", "Two_Y", "row");
+        });
+
+        doTest("df3", 7, 3, arrayChunk -> {
+          ArrayChunk.ColHeader header = arrayChunk.getColHeaders().get(2);
+          assertEquals("Sales", header.getLabel());
+          assertEquals(16, (int)Integer.valueOf(header.getMax()));
+          assertEquals(1, (int)Integer.valueOf(header.getMin()));
+        });
       }
+    });
+  }
 
+  @Test
+  @Staging
+  public void testMultiIndexDataFrame() throws Exception {
+    runPythonTest(new PyDataFrameDebuggerTask(getRelativeTestDataPath(), "test_dataframe_multiindex.py", ImmutableSet.of(5, 10)) {
       @Override
       public void testing() throws Exception {
-        waitForPause();
-        ArrayChunk df1 = getDefaultChunk("df1");
-        assertEquals(5, df1.getColumns());
-        assertEquals(3, df1.getRows());
-        resume();
-
-        waitForPause();
-        ArrayChunk df2 = getDefaultChunk("df2");
-        assertEquals(6, df2.getColumns());
-        assertEquals(3, df2.getRows());
-        assertSameElements(df2.getColHeaders().stream().map((header -> header.getLabel())).toArray(),
-                           new String[]{"LABELS", "One_X", "One_Y", "Two_X", "Two_Y", "row"});
-        resume();
-
-        waitForPause();
-        ArrayChunk df3 = getDefaultChunk("df3");
-        assertEquals(3, df3.getColumns());
-        assertEquals(7, df3.getRows());
-        ArrayChunk.ColHeader header = df3.getColHeaders().get(2);
-        assertEquals("Sales", header.getLabel());
-        assertEquals(16, (int)Integer.valueOf(header.getMax()));
-        assertEquals(1, (int)Integer.valueOf(header.getMin()));
-        resume();
+        doTest("frame1", 4, 2, arrayChunk -> assertSameElements(arrayChunk.getRowLabels(),
+                                                                "s/2", "s/3", "d/2", "d/3"));
+        doTest("frame2", 4, 4, arrayChunk -> {
+          List<ArrayChunk.ColHeader> headers = arrayChunk.getColHeaders();
+          assertSameElements(headers.stream().map(ArrayChunk.ColHeader::getLabel).toArray(), "1/1", "1/B", "2/1", "2/B");
+        });
       }
+    });
+  }
 
-      private ArrayChunk getDefaultChunk(String varName) throws PyDebuggerException {
-        PyDebugValue dbgVal = (PyDebugValue)XDebuggerTestUtil.evaluate(mySession, varName).first;
-        return dbgVal.getFrameAccessor()
-          .getArrayItems(dbgVal, 0, 0, -1, -1, ".%5f");
-      }
+  private static class PyDataFrameDebuggerTask extends PyDebuggerTask {
 
+    private Set<Integer> myLines;
 
-      @NotNull
-      @Override
-      public Set<String> getTags() {
-        return ImmutableSet.of("pandas");
+    public PyDataFrameDebuggerTask(@Nullable String relativeTestDataPath, String scriptName, Set<Integer> lines) {
+      super(relativeTestDataPath, scriptName);
+      myLines = lines;
+    }
+
+    protected void testShape(ArrayChunk arrayChunk, int expectedRows, int expectedColumns) {
+      assertEquals(expectedRows, arrayChunk.getRows());
+      assertEquals(expectedColumns, arrayChunk.getColumns());
+    }
+
+    protected void doTest(String name, int expectedRows, int expectedColumns, @Nullable Consumer<ArrayChunk> test)
+      throws InvocationTargetException, InterruptedException, PyDebuggerException {
+      waitForPause();
+      ArrayChunk arrayChunk = getDefaultChunk(name, mySession);
+      testShape(arrayChunk, expectedRows, expectedColumns);
+      if (test != null) {
+        test.consume(arrayChunk);
       }
-    });
+      resume();
+    }
+
+    @Override
+    public void before() throws Exception {
+      for (Integer line : myLines) {
+        toggleBreakpoint(getScriptName(), line);
+      }
+    }
+
+    @NotNull
+    @Override
+    public Set<String> getTags() {
+      return ImmutableSet.of("pandas");
+    }
+  }
+
+  private static ArrayChunk getDefaultChunk(String varName, XDebugSession session) throws PyDebuggerException {
+    PyDebugValue dbgVal = (PyDebugValue)XDebuggerTestUtil.evaluate(session, varName).first;
+    return dbgVal.getFrameAccessor().getArrayItems(dbgVal, 0, 0, -1, -1, ".%5f");
+  }
+
+  private static String getRelativeTestDataPath() {
+    return "/debug";
   }
 }