Commit bbf9daf2 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

CSE optimizations (#4067)

* CSE optimizations

* Get TI working with clang 6
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 0fc8f4d8
......@@ -75,36 +75,56 @@ static bool cse_constant(shared_ptr<Node> a, shared_ptr<Node> b)
return false;
}
auto ca = static_pointer_cast<op::Constant>(a);
auto cb = static_pointer_cast<op::Constant>(b);
const op::Constant* ca = static_cast<op::Constant*>(a.get());
const op::Constant* cb = static_cast<op::Constant*>(b.get());
size_t size = shape_size(a->get_shape()) * a->get_element_type().size();
return !memcmp(ca->get_data_ptr(), cb->get_data_ptr(), size);
if (ca->get_all_data_elements_bitwise_identical() ||
cb->get_all_data_elements_bitwise_identical())
{
if (ca->get_all_data_elements_bitwise_identical() &&
cb->get_all_data_elements_bitwise_identical())
{
// Since both Constants are uniform we only need to compare a single element
return !memcmp(ca->get_data_ptr(), cb->get_data_ptr(), a->get_element_type().size());
}
else
{
return false;
}
}
else
{
// Neither Constant is uniform so compare all elements
return !memcmp(ca->get_data_ptr(), cb->get_data_ptr(), size);
}
}
static bool cse_reshape(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_reshape for " << a->get_name() << " and " << b->get_name();
auto reshape_a = static_pointer_cast<ngraph::op::Reshape>(a);
auto reshape_b = static_pointer_cast<ngraph::op::Reshape>(b);
const op::Reshape* reshape_a = static_cast<ngraph::op::Reshape*>(a.get());
const op::Reshape* reshape_b = static_cast<ngraph::op::Reshape*>(b.get());
return (a->input(0).get_source_output() == b->input(0).get_source_output()) &&
(reshape_a->get_input_order() == reshape_b->get_input_order()) &&
(reshape_a->get_output_shape() == reshape_b->get_output_shape());
}
static bool cse_broadcast(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_broadcast for " << a->get_name() << " and " << b->get_name();
auto broadcast_a = static_pointer_cast<ngraph::op::Broadcast>(a);
auto broadcast_b = static_pointer_cast<ngraph::op::Broadcast>(b);
const op::Broadcast* broadcast_a = static_cast<ngraph::op::Broadcast*>(a.get());
const op::Broadcast* broadcast_b = static_cast<ngraph::op::Broadcast*>(b.get());
return (a->input(0).get_source_output() == b->input(0).get_source_output()) &&
(broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) &&
(broadcast_a->get_broadcast_shape() == broadcast_b->get_broadcast_shape());
}
static bool cse_unarywise(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_unarywise for " << a->get_name() << " and " << b->get_name();
......@@ -126,8 +146,10 @@ static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_reduction for " << a->get_name() << " and " << b->get_name();
auto ar_a = static_pointer_cast<op::util::ArithmeticReduction>(a);
auto ar_b = static_pointer_cast<op::util::ArithmeticReduction>(b);
const op::util::ArithmeticReduction* ar_a =
static_cast<op::util::ArithmeticReduction*>(a.get());
const op::util::ArithmeticReduction* ar_b =
static_cast<op::util::ArithmeticReduction*>(b.get());
return ar_a->input(0).get_source_output() == ar_b->input(0).get_source_output() &&
ar_a->get_reduction_axes() == ar_b->get_reduction_axes();
......@@ -137,13 +159,14 @@ static bool cse_one_hot(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_one_hot for " << a->get_name() << " and " << b->get_name();
auto one_hot_a = static_pointer_cast<ngraph::op::OneHot>(a);
auto one_hot_b = static_pointer_cast<ngraph::op::OneHot>(b);
const op::OneHot* one_hot_a = static_cast<ngraph::op::OneHot*>(a.get());
const op::OneHot* one_hot_b = static_cast<ngraph::op::OneHot*>(b.get());
return (a->input(0).get_source_output() == b->input(0).get_source_output()) &&
(one_hot_a->get_one_hot_axis() == one_hot_b->get_one_hot_axis()) &&
(a->get_shape() == b->get_shape());
}
static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>
initialize_ops_to_cse_handlers()
{
......@@ -190,10 +213,12 @@ static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node
class NodeKey
{
public:
NodeKey(shared_ptr<Node> n,
NodeKey(const shared_ptr<Node>& n,
unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>&
backend_handlers)
: m_node(n)
, m_node_ref(*n)
, m_ti(TI(m_node_ref))
, m_backend_handlers(backend_handlers)
{
}
......@@ -201,27 +226,18 @@ public:
shared_ptr<Node> get_node() const { return m_node; }
bool operator==(const NodeKey& other) const
{
Node& p_this = *m_node.get();
Node& p_other = *other.get_node().get();
if (TI(p_this) != TI(p_other))
{
return false;
}
if (m_ti == other.m_ti)
{
auto eh = ops_to_cse_handlers.find(TI(p_this));
auto eh = ops_to_cse_handlers.find(m_ti);
if (eh != ops_to_cse_handlers.end())
{
return eh->second(m_node, other.get_node());
return eh->second(m_node, other.m_node);
}
}
{
auto eh = m_backend_handlers.find(TI(p_this));
eh = m_backend_handlers.find(m_ti);
if (eh != m_backend_handlers.end())
{
return eh->second(m_node, other.get_node());
return eh->second(m_node, other.m_node);
}
}
......@@ -230,6 +246,9 @@ public:
private:
shared_ptr<Node> m_node;
// m_node_ref is only to allow getting the type_index in the ctor
Node& m_node_ref;
std::type_index m_ti;
unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>&
m_backend_handlers;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment