type migration tests fixed
[idea/community.git] / python / helpers / pycharm / tcunittest.py
1 import traceback, sys
2 from unittest import TestResult
3 import datetime
4
5 from tcmessages import TeamcityServiceMessages
6
7 PYTHON_VERSION_MAJOR = sys.version_info[0]
8
9
10 def strclass(cls):
11   if not cls.__name__:
12     return cls.__module__
13   return "%s.%s" % (cls.__module__, cls.__name__)
14
15
16 def smart_str(s):
17   encoding = 'utf-8'
18   errors = 'strict'
19   if PYTHON_VERSION_MAJOR < 3:
20     is_string = isinstance(s, basestring)
21   else:
22     is_string = isinstance(s, str)
23   if not is_string:
24     try:
25       return str(s)
26     except UnicodeEncodeError:
27       if isinstance(s, Exception):
28         # An Exception subclass containing non-ASCII data that doesn't
29         # know how to print itself properly. We shouldn't raise a
30         # further exception.
31         return ' '.join([smart_str(arg) for arg in s])
32       return unicode(s).encode(encoding, errors)
33   elif isinstance(s, unicode):
34     return s.encode(encoding, errors)
35   else:
36     return s
37
38
39 class TeamcityTestResult(TestResult):
40   def __init__(self, stream=sys.stdout, *args, **kwargs):
41     TestResult.__init__(self)
42     for arg, value in kwargs.items():
43       setattr(self, arg, value)
44     self.output = stream
45     self.messages = TeamcityServiceMessages(self.output, prepend_linebreak=True)
46     self.messages.testMatrixEntered()
47     self.current_failed = False
48     self.current_suite = None
49     self.subtest_suite = None
50
51   def find_first(self, val):
52     quot = val[0]
53     count = 1
54     quote_ind = val[count:].find(quot)
55     while quote_ind != -1 and val[count + quote_ind - 1] == "\\":
56       count = count + quote_ind + 1
57       quote_ind = val[count:].find(quot)
58
59     return val[0:quote_ind + count + 1]
60
61   def find_second(self, val):
62     val_index = val.find("!=")
63     if val_index != -1:
64       count = 1
65       val = val[val_index + 2:].strip()
66       quot = val[0]
67       quote_ind = val[count:].find(quot)
68       while quote_ind != -1 and val[count + quote_ind - 1] == "\\":
69         count = count + quote_ind + 1
70         quote_ind = val[count:].find(quot)
71       return val[0:quote_ind + count + 1]
72
73     else:
74       quot = val[-1]
75       quote_ind = val[:len(val) - 1].rfind(quot)
76       while quote_ind != -1 and val[quote_ind - 1] == "\\":
77         quote_ind = val[:quote_ind - 1].rfind(quot)
78       return val[quote_ind:]
79
80   def formatErr(self, err):
81     exctype, value, tb = err
82     return ''.join(traceback.format_exception(exctype, value, tb))
83
84   def getTestName(self, test, is_subtest=False):
85     if is_subtest:
86       test_name = self.getTestName(test.test_case)
87       return "{} {}".format(test_name, test._subDescription())
88     if hasattr(test, '_testMethodName'):
89       if test._testMethodName == "runTest":
90         return str(test)
91       return test._testMethodName
92     else:
93       test_name = str(test)
94       whitespace_index = test_name.index(" ")
95       if whitespace_index != -1:
96         test_name = test_name[:whitespace_index]
97       return test_name
98
99   def getTestId(self, test):
100     return test.id
101
102   def addSuccess(self, test):
103     TestResult.addSuccess(self, test)
104
105   def addError(self, test, err):
106     location = self.init_suite(test)
107     self.current_failed = True
108     TestResult.addError(self, test, err)
109
110     err = self._exc_info_to_string(err, test)
111
112     self.messages.testStarted(self.getTestName(test), location=location)
113     self.messages.testError(self.getTestName(test),
114                             message='Error', details=err, duration=self.__getDuration(test))
115
116   def find_error_value(self, err):
117     error_value = traceback.extract_tb(err)
118     error_value = error_value[-1][-1]
119     return error_value.split('assert')[-1].strip()
120
121   def addFailure(self, test, err):
122     location = self.init_suite(test)
123     self.current_failed = True
124     TestResult.addFailure(self, test, err)
125
126     error_value = smart_str(err[1])
127     if not len(error_value):
128       # means it's test function and we have to extract value from traceback
129       error_value = self.find_error_value(err[2])
130
131     self_find_first = self.find_first(error_value)
132     self_find_second = self.find_second(error_value)
133     quotes = ["'", '"']
134     if (self_find_first[0] == self_find_first[-1] and self_find_first[0] in quotes and
135             self_find_second[0] == self_find_second[-1] and self_find_second[0] in quotes):
136       # let's unescape strings to show sexy multiline diff in PyCharm.
137       # By default all caret return chars are escaped by testing framework
138       first = self._unescape(self_find_first)
139       second = self._unescape(self_find_second)
140     else:
141       first = second = ""
142     err = self._exc_info_to_string(err, test)
143
144     self.messages.testStarted(self.getTestName(test), location=location)
145     duration = self.__getDuration(test)
146     self.messages.testFailed(self.getTestName(test),
147                              message='Failure', details=err, expected=first, actual=second, duration=duration)
148
149   def addSkip(self, test, reason):
150     self.init_suite(test)
151     self.current_failed = True
152     self.messages.testIgnored(self.getTestName(test), message=reason)
153
154   def _getSuite(self, test):
155     try:
156       suite = strclass(test.suite)
157       suite_location = test.suite.location
158       location = test.suite.abs_location
159       if hasattr(test, "lineno"):
160         location = location + ":" + str(test.lineno)
161       else:
162         location = location + ":" + str(test.test.lineno)
163     except AttributeError:
164       import inspect
165
166       try:
167         source_file = inspect.getsourcefile(test.__class__)
168         if source_file:
169           source_dir_splitted = source_file.split("/")[:-1]
170           source_dir = "/".join(source_dir_splitted) + "/"
171         else:
172           source_dir = ""
173       except TypeError:
174         source_dir = ""
175
176       suite = strclass(test.__class__)
177       suite_location = "python_uttestid://" + source_dir + suite
178       location = "python_uttestid://" + source_dir + str(test.id())
179
180     return (suite, location, suite_location)
181
182   def startTest(self, test):
183     self.current_failed = False
184     setattr(test, "startTime", datetime.datetime.now())
185
186   def init_suite(self, test):
187     suite, location, suite_location = self._getSuite(test)
188     if suite != self.current_suite:
189       if self.current_suite:
190         self.messages.testSuiteFinished(self.current_suite)
191       self.current_suite = suite
192       self.messages.testSuiteStarted(self.current_suite, location=suite_location)
193     return location
194
195   def stopTest(self, test):
196     duration = self.__getDuration(test)
197     if not self.subtest_suite:
198       if not self.current_failed:
199         location = self.init_suite(test)
200         self.messages.testStarted(self.getTestName(test), location=location)
201         self.messages.testFinished(self.getTestName(test), duration=int(duration))
202     else:
203       self.messages.testSuiteFinished(self.subtest_suite)
204       self.subtest_suite = None
205
206   def __getDuration(self, test):
207     start = getattr(test, "startTime", datetime.datetime.now())
208     d = datetime.datetime.now() - start
209     duration = d.microseconds / 1000 + d.seconds * 1000 + d.days * 86400000
210     return duration
211
212   def addSubTest(self, test, subtest, err):
213     suite_name = self.getTestName(test)  # + " (subTests)"
214     if not self.subtest_suite:
215       self.subtest_suite = suite_name
216       self.messages.testSuiteStarted(self.subtest_suite)
217     else:
218       if suite_name != self.subtest_suite:
219         self.messages.testSuiteFinished(self.subtest_suite)
220         self.subtest_suite = suite_name
221         self.messages.testSuiteStarted(self.subtest_suite)
222
223     name = self.getTestName(subtest, True)
224     if err is not None:
225       error = self._exc_info_to_string(err, test)
226       self.messages.testStarted(name)
227       self.messages.testFailed(name, message='Failure', details=error, duration=None)
228     else:
229       self.messages.testStarted(name)
230       self.messages.testFinished(name)
231
232
233   def endLastSuite(self):
234     if self.current_suite:
235       self.messages.testSuiteFinished(self.current_suite)
236       self.current_suite = None
237
238   def _unescape(self, text):
239     # do not use text.decode('string_escape'), it leads to problems with different string encodings given
240     return text.replace("\\n", "\n")
241
242
243 class TeamcityTestRunner(object):
244   def __init__(self, stream=sys.stdout):
245     self.stream = stream
246
247   def _makeResult(self, **kwargs):
248     return TeamcityTestResult(self.stream, **kwargs)
249
250   def run(self, test, **kwargs):
251     result = self._makeResult(**kwargs)
252     result.messages.testCount(test.countTestCases())
253     test(result)
254     result.endLastSuite()
255     return result