IDEA-285172 - [decompiler] - StrongConnectivityHelper refactoring
[idea/community.git] / plugins / java-decompiler / engine / src / org / jetbrains / java / decompiler / modules / decompiler / DomHelper.java
1 // Copyright 2000-2021 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
2 package org.jetbrains.java.decompiler.modules.decompiler;
3
4 import org.jetbrains.java.decompiler.code.cfg.BasicBlock;
5 import org.jetbrains.java.decompiler.code.cfg.ControlFlowGraph;
6 import org.jetbrains.java.decompiler.code.cfg.ExceptionRangeCFG;
7 import org.jetbrains.java.decompiler.main.DecompilerContext;
8 import org.jetbrains.java.decompiler.main.extern.IFernflowerLogger;
9 import org.jetbrains.java.decompiler.modules.decompiler.decompose.FastExtendedPostdominanceHelper;
10 import org.jetbrains.java.decompiler.modules.decompiler.deobfuscator.IrreducibleCFGDeobfuscator;
11 import org.jetbrains.java.decompiler.modules.decompiler.stats.*;
12 import org.jetbrains.java.decompiler.util.FastFixedSetFactory;
13 import org.jetbrains.java.decompiler.util.FastFixedSetFactory.FastFixedSet;
14 import org.jetbrains.java.decompiler.util.VBStyleCollection;
15
16 import java.util.*;
17
18 public final class DomHelper {
19
20
21   private static RootStatement graphToStatement(ControlFlowGraph graph) {
22
23     VBStyleCollection<Statement, Integer> stats = new VBStyleCollection<>();
24     VBStyleCollection<BasicBlock, Integer> blocks = graph.getBlocks();
25
26     for (BasicBlock block : blocks) {
27       stats.addWithKey(new BasicBlockStatement(block), block.id);
28     }
29
30     BasicBlock firstblock = graph.getFirst();
31     // head statement
32     Statement firstst = stats.getWithKey(firstblock.id);
33     // dummy exit statement
34     DummyExitStatement dummyexit = new DummyExitStatement();
35
36     Statement general;
37     if (stats.size() > 1 || firstblock.isSuccessor(firstblock)) { // multiple basic blocks or an infinite loop of one block
38       general = new GeneralStatement(firstst, stats, null);
39     }
40     else { // one straightforward basic block
41       RootStatement root = new RootStatement(firstst, dummyexit);
42       firstst.addSuccessor(new StatEdge(StatEdge.TYPE_BREAK, firstst, dummyexit, root));
43
44       return root;
45     }
46
47     for (BasicBlock block : blocks) {
48       Statement stat = stats.getWithKey(block.id);
49
50       for (BasicBlock succ : block.getSuccessors()) {
51         Statement stsucc = stats.getWithKey(succ.id);
52
53         int type;
54         if (stsucc == firstst) {
55           type = StatEdge.TYPE_CONTINUE;
56         }
57         else if (graph.getFinallyExits().contains(block)) {
58           type = StatEdge.TYPE_FINALLYEXIT;
59           stsucc = dummyexit;
60         }
61         else if (succ.id == graph.getLast().id) {
62           type = StatEdge.TYPE_BREAK;
63           stsucc = dummyexit;
64         }
65         else {
66           type = StatEdge.TYPE_REGULAR;
67         }
68
69         stat.addSuccessor(new StatEdge(type, stat, (type == StatEdge.TYPE_CONTINUE) ? general : stsucc,
70                                        (type == StatEdge.TYPE_REGULAR) ? null : general));
71       }
72
73       // exceptions edges
74       for (BasicBlock succex : block.getSuccessorExceptions()) {
75         Statement stsuccex = stats.getWithKey(succex.id);
76
77         ExceptionRangeCFG range = graph.getExceptionRange(succex, block);
78         if (!range.isCircular()) {
79           stat.addSuccessor(new StatEdge(stat, stsuccex, range.getExceptionTypes()));
80         }
81       }
82     }
83
84     general.buildContinueSet();
85     general.buildMonitorFlags();
86     return new RootStatement(general, dummyexit);
87   }
88
89   public static VBStyleCollection<List<Integer>, Integer> calcPostDominators(Statement container) {
90
91     HashMap<Statement, FastFixedSet<Statement>> lists = new HashMap<>();
92
93     StrongConnectivityHelper connectivityHelper = new StrongConnectivityHelper(container);
94
95     List<Statement> lstStats = container.getPostReversePostOrderList(connectivityHelper.getExitReps());
96
97     FastFixedSetFactory<Statement> factory = new FastFixedSetFactory<>(lstStats);
98
99     FastFixedSet<Statement> setFlagNodes = factory.spawnEmptySet();
100     setFlagNodes.setAllElements();
101
102     FastFixedSet<Statement> initSet = factory.spawnEmptySet();
103     initSet.setAllElements();
104
105     for (List<Statement> lst : connectivityHelper.getComponents()) {
106       FastFixedSet<Statement> tmpSet;
107
108       if (StrongConnectivityHelper.isExitComponent(lst)) {
109         tmpSet = factory.spawnEmptySet();
110         tmpSet.addAll(lst);
111       }
112       else {
113         tmpSet = initSet.getCopy();
114       }
115
116       for (Statement stat : lst) {
117         lists.put(stat, tmpSet);
118       }
119     }
120
121     do {
122
123       for (Statement stat : lstStats) {
124
125         if (!setFlagNodes.contains(stat)) {
126           continue;
127         }
128         setFlagNodes.remove(stat);
129
130         FastFixedSet<Statement> doms = lists.get(stat);
131         FastFixedSet<Statement> domsSuccs = factory.spawnEmptySet();
132
133         List<Statement> lstSuccs = stat.getNeighbours(StatEdge.TYPE_REGULAR, Statement.DIRECTION_FORWARD);
134         for (int j = 0; j < lstSuccs.size(); j++) {
135           Statement succ = lstSuccs.get(j);
136           FastFixedSet<Statement> succlst = lists.get(succ);
137
138           if (j == 0) {
139             domsSuccs.union(succlst);
140           }
141           else {
142             domsSuccs.intersection(succlst);
143           }
144         }
145
146         if (!domsSuccs.contains(stat)) {
147           domsSuccs.add(stat);
148         }
149
150         if (!Objects.equals(domsSuccs, doms)) {
151
152           lists.put(stat, domsSuccs);
153
154           List<Statement> lstPreds = stat.getNeighbours(StatEdge.TYPE_REGULAR, Statement.DIRECTION_BACKWARD);
155           for (Statement pred : lstPreds) {
156             setFlagNodes.add(pred);
157           }
158         }
159       }
160     }
161     while (!setFlagNodes.isEmpty());
162
163     VBStyleCollection<List<Integer>, Integer> ret = new VBStyleCollection<>();
164     List<Statement> lstRevPost = container.getReversePostOrderList(); // sort order crucial!
165
166     final HashMap<Integer, Integer> mapSortOrder = new HashMap<>();
167     for (int i = 0; i < lstRevPost.size(); i++) {
168       mapSortOrder.put(lstRevPost.get(i).id, i);
169     }
170
171     for (Statement st : lstStats) {
172
173       List<Integer> lstPosts = new ArrayList<>();
174       for (Statement stt : lists.get(st)) {
175         lstPosts.add(stt.id);
176       }
177
178       lstPosts.sort(Comparator.comparing(mapSortOrder::get));
179
180       if (lstPosts.size() > 1 && lstPosts.get(0).intValue() == st.id) {
181         lstPosts.add(lstPosts.remove(0));
182       }
183
184       ret.addWithKey(lstPosts, st.id);
185     }
186
187     return ret;
188   }
189
190   public static RootStatement parseGraph(ControlFlowGraph graph) {
191
192     RootStatement root = graphToStatement(graph);
193
194     if (!processStatement(root, new HashMap<>())) {
195
196       //                        try {
197       //                                DotExporter.toDotFile(root.getFirst().getStats().get(13), new File("c:\\Temp\\stat1.dot"));
198       //                        } catch (Exception ex) {
199       //                                ex.printStackTrace();
200       //                        }
201       throw new RuntimeException("parsing failure!");
202     }
203
204     LabelHelper.lowContinueLabels(root, new HashSet<>());
205
206     SequenceHelper.condenseSequences(root);
207     root.buildMonitorFlags();
208
209     // build synchronized statements
210     buildSynchronized(root);
211
212     return root;
213   }
214
215   public static void removeSynchronizedHandler(Statement stat) {
216
217     for (Statement st : stat.getStats()) {
218       removeSynchronizedHandler(st);
219     }
220
221     if (stat.type == Statement.TYPE_SYNCHRONIZED) {
222       ((SynchronizedStatement)stat).removeExc();
223     }
224   }
225
226
227   private static void buildSynchronized(Statement stat) {
228
229     for (Statement st : stat.getStats()) {
230       buildSynchronized(st);
231     }
232
233     if (stat.type == Statement.TYPE_SEQUENCE) {
234
235       while (true) {
236
237         boolean found = false;
238
239         List<Statement> lst = stat.getStats();
240         for (int i = 0; i < lst.size() - 1; i++) {
241           Statement current = lst.get(i);  // basic block
242
243           if (current.isMonitorEnter()) {
244
245             Statement next = lst.get(i + 1);
246             Statement nextDirect = next;
247
248             while (next.type == Statement.TYPE_SEQUENCE) {
249               next = next.getFirst();
250             }
251
252             if (next.type == Statement.TYPE_CATCH_ALL) {
253
254               CatchAllStatement ca = (CatchAllStatement)next;
255
256               if (ca.getFirst().isContainsMonitorExit() && ca.getHandler().isContainsMonitorExit()) {
257
258                 // remove the head block from sequence
259                 current.removeSuccessor(current.getSuccessorEdges(Statement.STATEDGE_DIRECT_ALL).get(0));
260
261                 for (StatEdge edge : current.getPredecessorEdges(Statement.STATEDGE_DIRECT_ALL)) {
262                   current.removePredecessor(edge);
263                   edge.getSource().changeEdgeNode(Statement.DIRECTION_FORWARD, edge, nextDirect);
264                   nextDirect.addPredecessor(edge);
265                 }
266
267                 stat.getStats().removeWithKey(current.id);
268                 stat.setFirst(stat.getStats().get(0));
269
270                 // new statement
271                 SynchronizedStatement sync = new SynchronizedStatement(current, ca.getFirst(), ca.getHandler());
272                 sync.setAllParent();
273
274                 for (StatEdge edge : new HashSet<>(ca.getLabelEdges())) {
275                   sync.addLabeledEdge(edge);
276                 }
277
278                 current.addSuccessor(new StatEdge(StatEdge.TYPE_REGULAR, current, ca.getFirst()));
279
280                 ca.getParent().replaceStatement(ca, sync);
281                 found = true;
282                 break;
283               }
284             }
285           }
286         }
287
288         if (!found) {
289           break;
290         }
291       }
292     }
293   }
294
295   private static boolean processStatement(Statement general, HashMap<Integer, Set<Integer>> mapExtPost) {
296
297     if (general.type == Statement.TYPE_ROOT) {
298       Statement stat = general.getFirst();
299       if (stat.type != Statement.TYPE_GENERAL) {
300         return true;
301       }
302       else {
303         boolean complete = processStatement(stat, mapExtPost);
304         if (complete) {
305           // replace general purpose statement with simple one
306           general.replaceStatement(stat, stat.getFirst());
307         }
308         return complete;
309       }
310     }
311
312     boolean mapRefreshed = mapExtPost.isEmpty();
313
314     for (int mapstage = 0; mapstage < 2; mapstage++) {
315
316       for (int reducibility = 0;
317            reducibility < 5;
318            reducibility++) { // FIXME: implement proper node splitting. For now up to 5 nodes in sequence are splitted.
319
320         if (reducibility > 0) {
321
322           //                                    try {
323           //                                            DotExporter.toDotFile(general, new File("c:\\Temp\\stat1.dot"));
324           //                                    } catch(Exception ex) {ex.printStackTrace();}
325
326           // take care of irreducible control flow graphs
327           if (IrreducibleCFGDeobfuscator.isStatementIrreducible(general)) {
328             if (!IrreducibleCFGDeobfuscator.splitIrreducibleNode(general)) {
329               DecompilerContext.getLogger().writeMessage("Irreducible statement cannot be decomposed!", IFernflowerLogger.Severity.ERROR);
330               break;
331             }
332           }
333           else {
334             if (mapstage == 2 || mapRefreshed) { // last chance lost
335               DecompilerContext.getLogger().writeMessage("Statement cannot be decomposed although reducible!", IFernflowerLogger.Severity.ERROR);
336             }
337             break;
338           }
339
340           //                                    try {
341           //                                            DotExporter.toDotFile(general, new File("c:\\Temp\\stat1.dot"));
342           //                                    } catch(Exception ex) {ex.printStackTrace();}
343
344           mapExtPost = new HashMap<>();
345           mapRefreshed = true;
346         }
347
348         for (int i = 0; i < 2; i++) {
349
350           boolean forceall = i != 0;
351
352           while (true) {
353
354             if (findSimpleStatements(general, mapExtPost)) {
355               reducibility = 0;
356             }
357
358             if (general.type == Statement.TYPE_PLACEHOLDER) {
359               return true;
360             }
361
362             Statement stat = findGeneralStatement(general, forceall, mapExtPost);
363
364             if (stat != null) {
365               boolean complete = processStatement(stat, general.getFirst() == stat ? mapExtPost : new HashMap<>());
366
367               if (complete) {
368                 // replace general purpose statement with simple one
369                 general.replaceStatement(stat, stat.getFirst());
370               }
371               else {
372                 return false;
373               }
374
375               mapExtPost = new HashMap<>();
376               mapRefreshed = true;
377               reducibility = 0;
378             }
379             else {
380               break;
381             }
382           }
383         }
384
385         //                              try {
386         //                                      DotExporter.toDotFile(general, new File("c:\\Temp\\stat1.dot"));
387         //                              } catch (Exception ex) {
388         //                                      ex.printStackTrace();
389         //                              }
390       }
391
392       if (mapRefreshed) {
393         break;
394       }
395       else {
396         mapExtPost = new HashMap<>();
397       }
398     }
399
400     return false;
401   }
402
403   private static Statement findGeneralStatement(Statement stat, boolean forceall, HashMap<Integer, Set<Integer>> mapExtPost) {
404
405     VBStyleCollection<Statement, Integer> stats = stat.getStats();
406     VBStyleCollection<List<Integer>, Integer> vbPost;
407
408     if (mapExtPost.isEmpty()) {
409       FastExtendedPostdominanceHelper extpost = new FastExtendedPostdominanceHelper();
410       mapExtPost.putAll(extpost.getExtendedPostdominators(stat));
411     }
412
413     if (forceall) {
414       vbPost = new VBStyleCollection<>();
415       List<Statement> lstAll = stat.getPostReversePostOrderList();
416
417       for (Statement st : lstAll) {
418         Set<Integer> set = mapExtPost.get(st.id);
419         if (set != null) {
420           vbPost.addWithKey(new ArrayList<>(set), st.id); // FIXME: sort order!!
421         }
422       }
423
424       // tail statements
425       Set<Integer> setFirst = mapExtPost.get(stat.getFirst().id);
426       if (setFirst != null) {
427         for (Integer id : setFirst) {
428           List<Integer> lst = vbPost.getWithKey(id);
429           if (lst == null) {
430             vbPost.addWithKey(lst = new ArrayList<>(), id);
431           }
432           lst.add(id);
433         }
434       }
435     }
436     else {
437       vbPost = calcPostDominators(stat);
438     }
439
440     for (int k = 0; k < vbPost.size(); k++) {
441
442       Integer headid = vbPost.getKey(k);
443       List<Integer> posts = vbPost.get(k);
444
445       if (!mapExtPost.containsKey(headid) &&
446           !(posts.size() == 1 && posts.get(0).equals(headid))) {
447         continue;
448       }
449
450       Statement head = stats.getWithKey(headid);
451
452       Set<Integer> setExtPosts = mapExtPost.get(headid);
453
454       for (Integer postId : posts) {
455         if (!postId.equals(headid) && !setExtPosts.contains(postId)) {
456           continue;
457         }
458
459         Statement post = stats.getWithKey(postId);
460
461         if (post == null) { // possible in case of an inherited postdominance set
462           continue;
463         }
464
465         boolean same = (post == head);
466
467         HashSet<Statement> setNodes = new HashSet<>();
468         HashSet<Statement> setPreds = new HashSet<>();
469
470         // collect statement nodes
471         HashSet<Statement> setHandlers = new HashSet<>();
472         setHandlers.add(head);
473         while (true) {
474
475           boolean hdfound = false;
476           for (Statement handler : setHandlers) {
477             if (setNodes.contains(handler)) {
478               continue;
479             }
480
481             boolean addhd = (setNodes.size() == 0); // first handler == head
482             if (!addhd) {
483               List<Statement> hdsupp = handler.getNeighbours(StatEdge.TYPE_EXCEPTION, Statement.DIRECTION_BACKWARD);
484               addhd = (setNodes.containsAll(hdsupp) && (setNodes.size() > hdsupp.size()
485                                                         || setNodes.size() == 1)); // strict subset
486             }
487
488             if (addhd) {
489               LinkedList<Statement> lstStack = new LinkedList<>();
490               lstStack.add(handler);
491
492               while (!lstStack.isEmpty()) {
493                 Statement st = lstStack.remove(0);
494
495                 if (!(setNodes.contains(st) || (!same && st == post))) {
496                   setNodes.add(st);
497                   if (st != head) {
498                     // record predeccessors except for the head
499                     setPreds.addAll(st.getNeighbours(StatEdge.TYPE_REGULAR, Statement.DIRECTION_BACKWARD));
500                   }
501
502                   // put successors on the stack
503                   lstStack.addAll(st.getNeighbours(StatEdge.TYPE_REGULAR, Statement.DIRECTION_FORWARD));
504
505                   // exception edges
506                   setHandlers.addAll(st.getNeighbours(StatEdge.TYPE_EXCEPTION, Statement.DIRECTION_FORWARD));
507                 }
508               }
509
510               hdfound = true;
511               setHandlers.remove(handler);
512               break;
513             }
514           }
515
516           if (!hdfound) {
517             break;
518           }
519         }
520
521         // check exception handlers
522         setHandlers.clear();
523         for (Statement st : setNodes) {
524           setHandlers.addAll(st.getNeighbours(StatEdge.TYPE_EXCEPTION, Statement.DIRECTION_FORWARD));
525         }
526         setHandlers.removeAll(setNodes);
527
528         boolean excok = true;
529         for (Statement handler : setHandlers) {
530           if (!handler.getNeighbours(StatEdge.TYPE_EXCEPTION, Statement.DIRECTION_BACKWARD).containsAll(setNodes)) {
531             excok = false;
532             break;
533           }
534         }
535
536         // build statement and return
537         if (excok) {
538           Statement res;
539
540           setPreds.removeAll(setNodes);
541           if (setPreds.size() == 0) {
542             if ((setNodes.size() > 1 ||
543                  head.getNeighbours(StatEdge.TYPE_REGULAR, Statement.DIRECTION_BACKWARD).contains(head))
544                 && setNodes.size() < stats.size()) {
545               if (checkSynchronizedCompleteness(setNodes)) {
546                 res = new GeneralStatement(head, setNodes, same ? null : post);
547                 stat.collapseNodesToStatement(res);
548
549                 return res;
550               }
551             }
552           }
553         }
554       }
555     }
556
557     return null;
558   }
559
560   private static boolean checkSynchronizedCompleteness(Set<Statement> setNodes) {
561     // check exit nodes
562     for (Statement stat : setNodes) {
563       if (stat.isMonitorEnter()) {
564         List<StatEdge> lstSuccs = stat.getSuccessorEdges(Statement.STATEDGE_DIRECT_ALL);
565         if (lstSuccs.size() != 1 || lstSuccs.get(0).getType() != StatEdge.TYPE_REGULAR) {
566           return false;
567         }
568
569         if (!setNodes.contains(lstSuccs.get(0).getDestination())) {
570           return false;
571         }
572       }
573     }
574
575     return true;
576   }
577
578   private static boolean findSimpleStatements(Statement stat, HashMap<Integer, Set<Integer>> mapExtPost) {
579
580     boolean found, success = false;
581
582     do {
583       found = false;
584
585       List<Statement> lstStats = stat.getPostReversePostOrderList();
586       for (Statement st : lstStats) {
587
588         Statement result = detectStatement(st);
589
590         if (result != null) {
591
592           if (stat.type == Statement.TYPE_GENERAL && result.getFirst() == stat.getFirst() &&
593               stat.getStats().size() == result.getStats().size()) {
594             // mark general statement
595             stat.type = Statement.TYPE_PLACEHOLDER;
596           }
597
598           stat.collapseNodesToStatement(result);
599
600           // update the postdominator map
601           if (!mapExtPost.isEmpty()) {
602             HashSet<Integer> setOldNodes = new HashSet<>();
603             for (Statement old : result.getStats()) {
604               setOldNodes.add(old.id);
605             }
606
607             Integer newid = result.id;
608
609             for (Integer key : new ArrayList<>(mapExtPost.keySet())) {
610               Set<Integer> set = mapExtPost.get(key);
611
612               int oldsize = set.size();
613               set.removeAll(setOldNodes);
614
615               if (setOldNodes.contains(key)) {
616                 mapExtPost.computeIfAbsent(newid, k -> new HashSet<>()).addAll(set);
617                 mapExtPost.remove(key);
618               }
619               else {
620                 if (set.size() < oldsize) {
621                   set.add(newid);
622                 }
623               }
624             }
625           }
626
627
628           found = true;
629           break;
630         }
631       }
632
633       if (found) {
634         success = true;
635       }
636     }
637     while (found);
638
639     return success;
640   }
641
642
643   private static Statement detectStatement(Statement head) {
644
645     Statement res;
646
647     if ((res = DoStatement.isHead(head)) != null) {
648       return res;
649     }
650
651     if ((res = SwitchStatement.isHead(head)) != null) {
652       return res;
653     }
654
655     if ((res = IfStatement.isHead(head)) != null) {
656       return res;
657     }
658
659     // synchronized statements will be identified later
660     // right now they are recognized as catchall
661
662     if ((res = SequenceStatement.isHead2Block(head)) != null) {
663       return res;
664     }
665
666     if ((res = CatchStatement.isHead(head)) != null) {
667       return res;
668     }
669
670     if ((res = CatchAllStatement.isHead(head)) != null) {
671       return res;
672     }
673
674     return null;
675   }
676 }