Commit b0ed81b4 authored by Louis Feng's avatar Louis Feng Committed by Robert Kimball

[Dynamic Shape] Add Properties to PassBase (#2768)

* Added EnumMask.

* added unit tests for EnumMask.

* fix.

* forgot to save file.

* add pass property to matcher.

* Clean up and made interface more restrictive, easier to use.

* addressed PR feedbacks.
parent db594969
......@@ -20,6 +20,11 @@
using namespace std;
using namespace ngraph;
pass::PassBase::PassBase()
{
set_property(PassProperty::REGULAR_FUSIONS, true);
}
pass::ManagerState& pass::PassBase::get_state()
{
return *m_state;
......@@ -29,3 +34,20 @@ void pass::PassBase::set_state(ManagerState& state)
{
m_state = &state;
}
bool pass::PassBase::get_property(const PassPropertyMask& prop) const
{
return m_property.is_set(prop);
}
void pass::PassBase::set_property(const PassPropertyMask& prop, bool value)
{
if (value)
{
m_property.set(prop);
}
else
{
m_property.clear(prop);
}
}
......@@ -23,6 +23,7 @@
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager_state.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
......@@ -45,6 +46,13 @@ namespace ngraph
FOP_FUSIONS = 0x4,
ALL_FUSIONS = 0xFFFFFFFF
};
enum class PassProperty : uint32_t
{
REGULAR_FUSIONS = 1 << 1,
REQUIRE_STATIC_SHAPE = 1 << 2,
CHANGE_FUNCTION_STATE = 1 << 3
};
typedef EnumMask<PassProperty> PassPropertyMask;
}
}
......@@ -53,12 +61,18 @@ class ngraph::pass::PassBase
friend class Manager;
public:
PassBase();
virtual ~PassBase() {}
/// Check if this pass has all the pass properties.
bool get_property(const PassPropertyMask& prop_mask) const;
protected:
ManagerState& get_state();
void set_state(ManagerState&);
void set_property(const PassPropertyMask& prop, bool value);
private:
PassPropertyMask m_property;
ManagerState* m_state;
};
......
......@@ -25,6 +25,7 @@
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
......@@ -68,11 +69,13 @@ namespace ngraph
Matcher(const std::shared_ptr<Node> pattern_node = nullptr,
graph_rewrite_callback callback = nullptr,
const std::string& name = "Unnamed",
pass::PassPropertyMask property = pass::PassProperty::REGULAR_FUSIONS,
bool strict_mode = false)
: m_pattern_node(pattern_node)
, m_callback(callback)
, m_depth(0)
, m_name(name)
, m_property(property)
, m_strict_mode(strict_mode)
{
}
......@@ -110,6 +113,7 @@ namespace ngraph
return matched;
}
bool get_property(const pass::PassPropertyMask& prop) const;
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
bool process_match(graph_rewrite_callback callback = nullptr);
......@@ -172,6 +176,7 @@ namespace ngraph
graph_rewrite_callback m_callback;
size_t m_depth;
std::string m_name;
pass::PassPropertyMask m_property;
bool m_strict_mode;
};
......
......@@ -236,6 +236,105 @@ namespace ngraph
OUTPUT,
INTERMEDIATE
};
/**
* EnumMask is intended to work with a scoped enum type. It's used to store
* a combination of enum values and provides easy access and manipulation
* of these enum values as a mask.
*
* EnumMask does not provide a set_all() or invert() operator because they
* could do things unexpected by the user, i.e. for enum with 4 bit values,
* invert(001000...) != 110100..., due to the extra bits.
*/
template <typename T>
class EnumMask
{
public:
/// Make sure the template type is an enum.
static_assert(std::is_enum<T>::value, "EnumMask template type must be an enum");
/// Extract the underlying type of the enum.
typedef typename std::underlying_type<T>::type value_type;
/// Some bit operations are not safe for signed values, we require enum
/// type to use unsigned underlying type.
static_assert(std::is_unsigned<value_type>::value, "EnumMask enum must use unsigned type.");
EnumMask()
: m_value{0}
{
}
EnumMask(const T& enum_value)
: m_value{static_cast<value_type>(enum_value)}
{
}
EnumMask(const EnumMask& other)
: m_value{other.m_value}
{
}
EnumMask(std::initializer_list<T> enum_values)
: m_value{0}
{
for (auto& v : enum_values)
{
m_value |= static_cast<value_type>(v);
}
}
value_type value() const { return m_value; }
/// Check if any of the enum bit mask match
bool is_any_set(const EnumMask& p) const { return m_value & p.m_value; }
/// Check if all of the enum bit mask match
bool is_set(const EnumMask& p) const { return (m_value & p.m_value) == p.m_value; }
/// Check if any of the enum bit mask does not match
bool is_any_clear(const EnumMask& p) const { return !is_set(p); }
/// Check if all of the enum bit mask do not match
bool is_clear(const EnumMask& p) const { return !is_any_set(p); }
void set(const EnumMask& p) { m_value |= p.m_value; }
void clear(const EnumMask& p) { m_value &= ~p.m_value; }
void clear_all() { m_value = 0; }
bool operator[](const EnumMask& p) const { return is_set(p); }
bool operator==(const EnumMask& other) const { return m_value == other.m_value; }
bool operator!=(const EnumMask& other) const { return m_value != other.m_value; }
EnumMask& operator=(const EnumMask& other)
{
m_value = other.m_value;
return *this;
}
EnumMask& operator&=(const EnumMask& other)
{
m_value &= other.m_value;
return *this;
}
EnumMask& operator|=(const EnumMask& other)
{
m_value |= other.m_value;
return *this;
}
EnumMask operator&(const EnumMask& other) const
{
return EnumMask(m_value & other.m_value);
}
EnumMask operator|(const EnumMask& other) const
{
return EnumMask(m_value | other.m_value);
}
friend std::ostream& operator<<(std::ostream& os, const EnumMask& m)
{
os << m.m_value;
return os;
}
private:
/// Only used internally
explicit EnumMask(const value_type& value)
: m_value{value}
{
}
value_type m_value;
};
} // end namespace ngraph
std::ostream& operator<<(std::ostream& os, const ngraph::NodeVector& nv);
......@@ -446,7 +446,7 @@ TEST(pattern, matcher)
// strict mode
{
TestMatcher sm(nullptr, nullptr, "TestMatcher", true);
TestMatcher sm(nullptr, nullptr, "TestMatcher", pass::PassProperty::REGULAR_FUSIONS, true);
// exact shape and type
auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
auto label_dynamic_shape =
......
......@@ -406,3 +406,147 @@ TEST(pass, visualize_tree)
pm.register_pass<pass::VisualizeTree>("test_viz.png");
pm.run_passes(f);
}
TEST(util, enum_mask_construction)
{
enum class Type : uint32_t
{
a = 0x1,
b = 1 << 1,
c = 1 << 2,
d = 1 << 3
};
{
EnumMask<Type> m;
EXPECT_EQ(0, m.value());
}
{
EnumMask<Type> m(Type::c);
EXPECT_EQ(static_cast<uint32_t>(Type::c), m.value());
}
{
EnumMask<Type> a(Type::c);
EnumMask<Type> b{a};
EXPECT_EQ(a.value(), b.value());
}
{
EnumMask<Type> a{Type::a, Type::c, Type::d};
EXPECT_EQ((static_cast<uint32_t>(Type::a) | static_cast<uint32_t>(Type::c) |
static_cast<uint32_t>(Type::d)),
a.value());
}
}
TEST(util, enum_mask_set_clear)
{
enum class Type : uint32_t
{
a = 0x1,
b = 1 << 1,
c = 1 << 2,
d = 1 << 3
};
EnumMask<Type> m;
m.set(Type::b);
EXPECT_EQ(static_cast<uint32_t>(Type::b), m.value());
m.set(Type::c);
EXPECT_EQ(static_cast<uint32_t>(Type::b) | static_cast<uint32_t>(Type::c), m.value());
m.clear(Type::b);
EXPECT_EQ(static_cast<uint32_t>(Type::c), m.value());
m.clear_all();
EXPECT_EQ(0, m.value());
m.set(Type::d);
m.set(Type::b);
EXPECT_EQ(true, m.is_set(Type::d));
EXPECT_EQ(false, m.is_set(Type::a));
EXPECT_EQ(true, m.is_set(Type::b));
EXPECT_EQ(false, m.is_set(Type::c));
EXPECT_EQ(false, m.is_set({Type::a, Type::b}));
EXPECT_EQ(false, m.is_set({Type::c, Type::d}));
EXPECT_EQ(false, m.is_set({Type::a, Type::c}));
EXPECT_EQ(true, m.is_set({Type::b, Type::d}));
EXPECT_EQ(false, m.is_clear(Type::d));
EXPECT_EQ(true, m.is_clear(Type::a));
EXPECT_EQ(false, m.is_clear(Type::b));
EXPECT_EQ(true, m.is_clear(Type::c));
EXPECT_EQ(false, m.is_clear({Type::c, Type::d}));
EXPECT_EQ(false, m.is_clear({Type::a, Type::b}));
EXPECT_EQ(true, m.is_clear({Type::a, Type::c}));
EXPECT_EQ(false, m.is_clear({Type::b, Type::d}));
EXPECT_EQ(true, m.is_any_set({Type::a, Type::b}));
EXPECT_EQ(true, m.is_any_set({Type::a, Type::d}));
EXPECT_EQ(true, m.is_any_set({Type::b, Type::c}));
EXPECT_EQ(true, m.is_any_set({Type::c, Type::d}));
EXPECT_EQ(false, m.is_any_set({Type::a, Type::c}));
EXPECT_EQ(true, m.is_any_clear({Type::c, Type::d}));
EXPECT_EQ(true, m.is_any_clear({Type::a, Type::b}));
EXPECT_EQ(true, m.is_any_clear({Type::a, Type::c}));
EXPECT_EQ(true, m.is_any_clear({Type::b, Type::c}));
EXPECT_EQ(false, m.is_any_clear({Type::b, Type::d}));
m.set(Type::a);
EXPECT_EQ(false, m.is_clear(Type::a));
EXPECT_EQ(false, m.is_clear(Type::b));
EXPECT_EQ(true, m.is_clear(Type::c));
EXPECT_EQ(false, m.is_clear(Type::d));
}
TEST(util, enum_mask_operators)
{
enum class Type : uint32_t
{
a = 0x1,
b = 1 << 1,
c = 1 << 2,
d = 1 << 3
};
EnumMask<Type> m;
m = Type::b;
EXPECT_EQ(static_cast<uint32_t>(Type::b), m.value());
EXPECT_EQ(true, m[Type::b]);
EXPECT_EQ(false, m[Type::a]);
EXPECT_EQ(false, m[Type::c]);
m |= Type::c;
EXPECT_EQ(static_cast<uint32_t>(Type::b) | static_cast<uint32_t>(Type::c), m.value());
m &= Type::d;
EXPECT_EQ(0, m.value());
m |= Type::a;
m |= Type::c;
EXPECT_EQ(true, m.is_set(Type::a));
EXPECT_EQ(false, m.is_set(Type::b));
EXPECT_EQ(true, m.is_set(Type::c));
EXPECT_EQ(false, m.is_set(Type::d));
EXPECT_EQ(true, m.is_any_set(Type::a));
EXPECT_EQ(false, m.is_any_set(Type::b));
EXPECT_EQ(true, m.is_any_set(Type::c));
EXPECT_EQ(false, m.is_any_set(Type::d));
EXPECT_EQ(true, m.is_any_set({Type::a, Type::c}));
EXPECT_EQ(false, m.is_any_set({Type::b, Type::d}));
EnumMask<Type> n;
n = m | n;
EXPECT_EQ(m, n);
n = m & n;
EXPECT_EQ(m, n);
bool r = (n == m);
EXPECT_EQ(true, r);
r = (n != m);
EXPECT_EQ(false, r);
n.clear_all();
n = {Type::a, Type::b};
r = (n == m);
EXPECT_EQ(false, r);
r = (n != m);
EXPECT_EQ(true, r);
n = m & n;
EXPECT_EQ(static_cast<uint32_t>(Type::a), n.value());
n = m | Type::b;
EXPECT_EQ(true, n.is_set(Type::a));
EXPECT_EQ(true, n.is_set(Type::b));
EXPECT_EQ(true, n.is_set(Type::c));
EXPECT_EQ(false, n.is_set(Type::d));
EXPECT_EQ(false, n[Type::d]);
EXPECT_EQ(true, n[Type::b]);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment