Resolve conflicts
[idea/community.git] / python / educational-python / resources / fileTemplates / internal / test_helper.py.ft
index 687bdd580a95062e980762d0f5ceae79357df4df..250f253d8decbd6f57be9180d7a296b08df184f5 100644 (file)
@@ -9,7 +9,7 @@ def get_file_text(path):
     return text
 
 
-def get_file_output(encoding="utf-8", path=sys.argv[-1]):
+def get_file_output(encoding="utf-8", path=sys.argv[-1], arg_string=""):
     """
     Returns answer file output
     :param encoding: to decode output in python3
@@ -18,7 +18,13 @@ def get_file_output(encoding="utf-8", path=sys.argv[-1]):
     """
     import subprocess
 
-    proc = subprocess.Popen([sys.executable, path], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+    proc = subprocess.Popen([sys.executable, path], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
+                            stderr=subprocess.STDOUT)
+    if arg_string:
+        for arg in arg_string.split("\n"):
+            proc.stdin.write(bytearray(str(arg) + "\n", encoding))
+            proc.stdin.flush()
+
     return list(map(lambda x: str(x.decode(encoding)), proc.communicate()[0].splitlines()))
 
 
@@ -31,7 +37,8 @@ def test_file_importable():
         parent = os.path.abspath(os.path.join(path, os.pardir))
         python_files = [f for f in os.listdir(parent) if os.path.isfile(os.path.join(parent, f)) and f.endswith(".py")]
         for python_file in python_files:
-            if python_file == "tests.py": continue
+            if python_file == "tests.py":
+                continue
             check_importable_path(os.path.join(parent, python_file))
         return
     check_importable_path(path)
@@ -187,6 +194,30 @@ def get_answer_placeholders():
     return windows
 
 
+def check_samples(samples=()):
+    """
+      Check script output for all samples. Sample is a two element list, where the first is input and
+      the second is output.
+    """
+    for sample in samples:
+        if len(sample) == 2:
+            output = get_file_output(arg_string=str(sample[0]))
+            if "\n".join(output) != sample[1]:
+                failed(
+                    "Test from samples failed: \n \n"
+                    "Input:\n {}"
+                    "\n \n"
+                    "Expected:\n {}"
+                    "\n \n"
+                    "Your result:\n {}".format(sample[0], 
+                                                        sample[1], 
+                                                        "\n".join(output)))
+                return
+        set_congratulation_message("All test from samples passed. Now we are checking your solution on Stepic server.")
+
+    passed()
+
+
 def run_common_tests(error_text="Please, reload file and try again"):
     test_is_initial_text()
     test_is_not_empty()