summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--utils/TableGen/DAGISelMatcher.h49
-rw-r--r--utils/TableGen/DAGISelMatcherOpt.cpp68
2 files changed, 117 insertions, 0 deletions
diff --git a/utils/TableGen/DAGISelMatcher.h b/utils/TableGen/DAGISelMatcher.h
index ec61fcd1da..9af98f77f3 100644
--- a/utils/TableGen/DAGISelMatcher.h
+++ b/utils/TableGen/DAGISelMatcher.h
@@ -104,6 +104,12 @@ public:
return ((getHashImpl() << 4) ^ getKind()) & (~0U>>1);
}
+ /// isSafeToReorderWithPatternPredicate - Return true if it is safe to sink a
+ /// PatternPredicate node past this one.
+ virtual bool isSafeToReorderWithPatternPredicate() const {
+ return false;
+ }
+
void print(raw_ostream &OS, unsigned indent = 0) const;
void dump() const;
protected:
@@ -173,6 +179,7 @@ public:
return N->getKind() == RecordNode;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const { return true; }
@@ -199,6 +206,8 @@ public:
return N->getKind() == RecordChild;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -216,6 +225,8 @@ public:
return N->getKind() == RecordMemRef;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const { return true; }
@@ -233,6 +244,8 @@ public:
return N->getKind() == CaptureFlagInput;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const { return true; }
@@ -252,6 +265,8 @@ public:
return N->getKind() == MoveChild;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -270,6 +285,8 @@ public:
return N->getKind() == MoveParent;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const { return true; }
@@ -291,6 +308,8 @@ public:
return N->getKind() == CheckSame;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -314,6 +333,8 @@ public:
return N->getKind() == CheckPatternPredicate;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -336,6 +357,9 @@ public:
return N->getKind() == CheckPredicate;
}
+ // TODO: Ok?
+ //virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -359,6 +383,8 @@ public:
return N->getKind() == CheckOpcode;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -382,6 +408,8 @@ public:
return N->getKind() == CheckMultiOpcode;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -406,6 +434,8 @@ public:
return N->getKind() == CheckType;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -430,6 +460,8 @@ public:
return N->getKind() == CheckChildType;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -454,6 +486,8 @@ public:
return N->getKind() == CheckInteger;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -476,6 +510,8 @@ public:
return N->getKind() == CheckCondCode;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -498,6 +534,8 @@ public:
return N->getKind() == CheckValueType;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -522,6 +560,9 @@ public:
return N->getKind() == CheckComplexPat;
}
+ // Not safe to move a pattern predicate past a complex pattern.
+ virtual bool isSafeToReorderWithPatternPredicate() const { return false; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -546,6 +587,8 @@ public:
return N->getKind() == CheckAndImm;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -568,6 +611,8 @@ public:
return N->getKind() == CheckOrImm;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
@@ -587,6 +632,8 @@ public:
return N->getKind() == CheckFoldableChainNode;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const { return true; }
@@ -607,6 +654,8 @@ public:
return N->getKind() == CheckChainCompatible;
}
+ virtual bool isSafeToReorderWithPatternPredicate() const { return true; }
+
private:
virtual void printImpl(raw_ostream &OS, unsigned indent) const;
virtual bool isEqualImpl(const Matcher *M) const {
diff --git a/utils/TableGen/DAGISelMatcherOpt.cpp b/utils/TableGen/DAGISelMatcherOpt.cpp
index 5aaa51f97c..55c2538933 100644
--- a/utils/TableGen/DAGISelMatcherOpt.cpp
+++ b/utils/TableGen/DAGISelMatcherOpt.cpp
@@ -16,6 +16,8 @@
#include <vector>
using namespace llvm;
+/// ContractNodes - Turn multiple matcher node patterns like 'MoveChild+Record'
+/// into single compound nodes like RecordChild.
static void ContractNodes(OwningPtr<Matcher> &MatcherPtr) {
// If we reached the end of the chain, we're done.
Matcher *N = MatcherPtr.get();
@@ -61,6 +63,71 @@ static void ContractNodes(OwningPtr<Matcher> &MatcherPtr) {
ContractNodes(N->getNextPtr());
}
+/// SinkPatternPredicates - Pattern predicates can be checked at any level of
+/// the matching tree. The generator dumps them at the top level of the pattern
+/// though, which prevents factoring from being able to see past them. This
+/// optimization sinks them as far down into the pattern as possible.
+///
+/// Conceptually, we'd like to sink these predicates all the way to the last
+/// matcher predicate in the series. However, it turns out that some
+/// ComplexPatterns have side effects on the graph, so we really don't want to
+/// run a the complex pattern if the pattern predicate will fail. For this
+/// reason, we refuse to sink the pattern predicate past a ComplexPattern.
+///
+static void SinkPatternPredicates(OwningPtr<Matcher> &MatcherPtr) {
+ // Recursively scan for a PatternPredicate.
+ // If we reached the end of the chain, we're done.
+ Matcher *N = MatcherPtr.get();
+ if (N == 0) return;
+
+ // Walk down all members of a scope node.
+ if (ScopeMatcher *Scope = dyn_cast<ScopeMatcher>(N)) {
+ for (unsigned i = 0, e = Scope->getNumChildren(); i != e; ++i) {
+ OwningPtr<Matcher> Child(Scope->takeChild(i));
+ SinkPatternPredicates(Child);
+ Scope->resetChild(i, Child.take());
+ }
+ return;
+ }
+
+ // If this node isn't a CheckPatternPredicateMatcher we keep scanning until
+ // we find one.
+ CheckPatternPredicateMatcher *CPPM =dyn_cast<CheckPatternPredicateMatcher>(N);
+ if (CPPM == 0)
+ return SinkPatternPredicates(N->getNextPtr());
+
+ // Ok, we found one, lets try to sink it. Check if we can sink it past the
+ // next node in the chain. If not, we won't be able to change anything and
+ // might as well bail.
+ if (!CPPM->getNext()->isSafeToReorderWithPatternPredicate())
+ return;
+
+ // Okay, we know we can sink it past at least one node. Unlink it from the
+ // chain and scan for the new insertion point.
+ MatcherPtr.take(); // Don't delete CPPM.
+ MatcherPtr.reset(CPPM->takeNext());
+
+ N = MatcherPtr.get();
+ while (N->getNext()->isSafeToReorderWithPatternPredicate())
+ N = N->getNext();
+
+ // At this point, we want to insert CPPM after N.
+ CPPM->setNext(N->takeNext());
+ N->setNext(CPPM);
+}
+
+/// FactorNodes - Turn matches like this:
+/// Scope
+/// OPC_CheckType i32
+/// ABC
+/// OPC_CheckType i32
+/// XYZ
+/// into:
+/// OPC_CheckType i32
+/// Scope
+/// ABC
+/// XYZ
+///
static void FactorNodes(OwningPtr<Matcher> &MatcherPtr) {
// If we reached the end of the chain, we're done.
Matcher *N = MatcherPtr.get();
@@ -145,6 +212,7 @@ static void FactorNodes(OwningPtr<Matcher> &MatcherPtr) {
Matcher *llvm::OptimizeMatcher(Matcher *TheMatcher) {
OwningPtr<Matcher> MatcherPtr(TheMatcher);
ContractNodes(MatcherPtr);
+ SinkPatternPredicates(MatcherPtr);
FactorNodes(MatcherPtr);
return MatcherPtr.take();
}