Commit 13fc556e authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add ngraph and std namespaces to c++ files (#2549)

* add ngraph and std namespaces to c++ files

* style
parent ef3378c1
...@@ -39,29 +39,26 @@ ...@@ -39,29 +39,26 @@
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
#define TI(x) std::type_index(typeid(x)) #define TI(x) type_index(typeid(x))
extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input, extern template Shape ngraph::apply_permutation<Shape>(Shape input, AxisVector order);
ngraph::AxisVector order);
template <typename T> template <typename T>
static std::shared_ptr<pattern::Matcher> static shared_ptr<pattern::Matcher>
create_binary_matcher(std::shared_ptr<pattern::op::Label> label, create_binary_matcher(shared_ptr<pattern::op::Label> label,
std::shared_ptr<pattern::op::Label> const_label) shared_ptr<pattern::op::Label> const_label)
{ {
auto bcst = auto bcst = make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::Broadcast>());
std::make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::Broadcast>()); auto bcst_label = make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst}); auto matcher = make_shared<pattern::Matcher>(make_shared<T>(label, bcst_label), nullptr);
auto matcher =
std::make_shared<pattern::Matcher>(std::make_shared<T>(label, bcst_label), nullptr);
return matcher; return matcher;
} }
static std::shared_ptr<pattern::op::Label> static shared_ptr<pattern::op::Label> get_broadcast_label(shared_ptr<pattern::Matcher> matcher)
get_broadcast_label(std::shared_ptr<pattern::Matcher> matcher)
{ {
return std::dynamic_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1)); return dynamic_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1));
} }
//`simplify_concat` identifies slices-concat sequences //`simplify_concat` identifies slices-concat sequences
...@@ -75,23 +72,21 @@ static std::shared_ptr<pattern::op::Label> ...@@ -75,23 +72,21 @@ static std::shared_ptr<pattern::op::Label>
// +-------+ | +----------+ | +-----------+ // +-------+ | +----------+ | +-----------+
// +----+slice(0..n/2)---+ // +----+slice(0..n/2)---+
// +----------+ // +----------+
static bool simplify_concat(std::shared_ptr<Node> n) static bool simplify_concat(shared_ptr<Node> n)
{ {
NGRAPH_DEBUG << "In simplify_concat for " << n->get_name(); NGRAPH_DEBUG << "In simplify_concat for " << n->get_name();
std::shared_ptr<Node> branch_tip; shared_ptr<Node> branch_tip;
auto ltip = std::make_shared<pattern::op::Label>(element::i32, Shape{2, 1}); auto ltip = make_shared<pattern::op::Label>(element::i32, Shape{2, 1});
auto pslice = auto pslice = make_shared<op::Slice>(ltip, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
std::make_shared<op::Slice>(ltip, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
auto lslice = std::make_shared<pattern::op::Label>(pslice, nullptr, NodeVector{pslice}); auto lslice = make_shared<pattern::op::Label>(pslice, nullptr, NodeVector{pslice});
auto skip_reshape = auto skip_reshape = make_shared<pattern::op::Skip>(lslice, pattern::has_class<op::Reshape>());
std::make_shared<pattern::op::Skip>(lslice, pattern::has_class<op::Reshape>());
auto matcher = std::make_shared<pattern::Matcher>(skip_reshape, nullptr); auto matcher = make_shared<pattern::Matcher>(skip_reshape, nullptr);
Coordinate prev_lower_bounds; Coordinate prev_lower_bounds;
Shape prev_slice_shape; Shape prev_slice_shape;
...@@ -104,7 +99,7 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -104,7 +99,7 @@ static bool simplify_concat(std::shared_ptr<Node> n)
return false; return false;
} }
auto slice = std::static_pointer_cast<op::Slice>(matcher->get_pattern_map()[lslice]); auto slice = static_pointer_cast<op::Slice>(matcher->get_pattern_map()[lslice]);
if (branch_tip) if (branch_tip)
{ {
if (branch_tip != matcher->get_pattern_map()[ltip]) if (branch_tip != matcher->get_pattern_map()[ltip])
...@@ -153,9 +148,9 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -153,9 +148,9 @@ static bool simplify_concat(std::shared_ptr<Node> n)
} }
//check that no other node uses slices and reshapes //check that no other node uses slices and reshapes
if (auto rcarg = std::dynamic_pointer_cast<op::Reshape>(carg)) if (auto rcarg = dynamic_pointer_cast<op::Reshape>(carg))
{ {
auto default_shape = ngraph::get_default_order(rcarg->get_argument(0)->get_shape()); auto default_shape = get_default_order(rcarg->get_argument(0)->get_shape());
if (default_shape != rcarg->get_input_order()) if (default_shape != rcarg->get_input_order())
{ {
NGRAPH_DEBUG << carg->get_name() << " reshape also does transposes"; NGRAPH_DEBUG << carg->get_name() << " reshape also does transposes";
...@@ -170,11 +165,11 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -170,11 +165,11 @@ static bool simplify_concat(std::shared_ptr<Node> n)
} }
} }
auto concat = std::static_pointer_cast<op::Concat>(n); auto concat = static_pointer_cast<op::Concat>(n);
size_t concat_axis = concat->get_concatenation_axis(); size_t concat_axis = concat->get_concatenation_axis();
auto slice_shape = branch_tip->get_users(true).at(0)->get_shape(); auto slice_shape = branch_tip->get_users(true).at(0)->get_shape();
size_t slice_axis = std::numeric_limits<size_t>::max(); size_t slice_axis = numeric_limits<size_t>::max();
auto btip_shape = branch_tip->get_shape(); auto btip_shape = branch_tip->get_shape();
...@@ -191,7 +186,7 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -191,7 +186,7 @@ static bool simplify_concat(std::shared_ptr<Node> n)
{ {
if (btip_shape[i] != slice_shape[i]) if (btip_shape[i] != slice_shape[i])
{ {
if (slice_axis != std::numeric_limits<size_t>::max()) if (slice_axis != numeric_limits<size_t>::max())
{ {
// multi-axis slice + concat do not cancel // multi-axis slice + concat do not cancel
return false; return false;
...@@ -200,19 +195,18 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -200,19 +195,18 @@ static bool simplify_concat(std::shared_ptr<Node> n)
} }
} }
if (slice_axis == std::numeric_limits<size_t>::max()) if (slice_axis == numeric_limits<size_t>::max())
{ {
return false; return false;
} }
auto replacement = branch_tip; auto replacement = branch_tip;
if (btip_shape != n->get_shape()) if (btip_shape != n->get_shape())
{ {
auto default_order = ngraph::get_default_order(btip_shape); auto default_order = get_default_order(btip_shape);
if (concat_axis == slice_axis) if (concat_axis == slice_axis)
{ {
// logical reshape only // logical reshape only
replacement = replacement = make_shared<op::Reshape>(branch_tip, default_order, concat->get_shape());
std::make_shared<op::Reshape>(branch_tip, default_order, concat->get_shape());
} }
else else
{ {
...@@ -221,30 +215,29 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -221,30 +215,29 @@ static bool simplify_concat(std::shared_ptr<Node> n)
if (btip_shape.size() >= transposed_shape.size()) if (btip_shape.size() >= transposed_shape.size())
{ {
AxisVector order = ngraph::get_default_order(btip_shape); AxisVector order = get_default_order(btip_shape);
auto ax = order[slice_axis]; auto ax = order[slice_axis];
order[slice_axis] = order[concat_axis]; order[slice_axis] = order[concat_axis];
order[concat_axis] = ax; order[concat_axis] = ax;
replacement = std::make_shared<op::Reshape>(branch_tip, order, transposed_shape); replacement = make_shared<op::Reshape>(branch_tip, order, transposed_shape);
} }
else if (btip_shape.size() < transposed_shape.size()) else if (btip_shape.size() < transposed_shape.size())
{ {
// intermediate logical reshape // intermediate logical reshape
AxisVector order = ngraph::get_default_order(transposed_shape); AxisVector order = get_default_order(transposed_shape);
auto ax = order[slice_axis]; auto ax = order[slice_axis];
order[slice_axis] = order[concat_axis]; order[slice_axis] = order[concat_axis];
order[concat_axis] = ax; order[concat_axis] = ax;
auto output_shape = ngraph::apply_permutation(transposed_shape, order); auto output_shape = apply_permutation(transposed_shape, order);
auto logical_reshape = auto logical_reshape =
std::make_shared<op::Reshape>(branch_tip, default_order, output_shape); make_shared<op::Reshape>(branch_tip, default_order, output_shape);
// transpose to final concatenated shape // transpose to final concatenated shape
replacement = replacement = make_shared<op::Reshape>(logical_reshape, order, transposed_shape);
std::make_shared<op::Reshape>(logical_reshape, order, transposed_shape);
} }
} }
} }
ngraph::replace_node(n, replacement); replace_node(n, replacement);
return true; return true;
} }
...@@ -255,15 +248,13 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -255,15 +248,13 @@ static bool simplify_concat(std::shared_ptr<Node> n)
//a * broadcast(0) -> broadcast(0) //a * broadcast(0) -> broadcast(0)
//a * 1 -> a //a * 1 -> a
//a * broadcast(1) -> a //a * broadcast(1) -> a
static bool simplify_multiply(std::shared_ptr<Node> n) static bool simplify_multiply(shared_ptr<Node> n)
{ {
NGRAPH_DEBUG << "In simplify_multiply for " << n->get_name(); NGRAPH_DEBUG << "In simplify_multiply for " << n->get_name();
auto iconst = ngraph::make_zero(element::i32, Shape{}); auto iconst = make_zero(element::i32, Shape{});
auto label = std::make_shared<pattern::op::Label>(iconst); auto label = make_shared<pattern::op::Label>(iconst);
auto const_label_zero = auto const_label_zero = make_shared<pattern::op::Label>(iconst, is_zero, NodeVector{iconst});
std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst}); auto const_label_one = make_shared<pattern::op::Label>(iconst, is_one, NodeVector{iconst});
auto const_label_one =
std::make_shared<pattern::op::Label>(iconst, ngraph::is_one, NodeVector{iconst});
auto matcher_const_zero = create_binary_matcher<op::Multiply>(label, const_label_zero); auto matcher_const_zero = create_binary_matcher<op::Multiply>(label, const_label_zero);
auto matcher_const_one = create_binary_matcher<op::Multiply>(label, const_label_one); auto matcher_const_one = create_binary_matcher<op::Multiply>(label, const_label_one);
...@@ -273,7 +264,7 @@ static bool simplify_multiply(std::shared_ptr<Node> n) ...@@ -273,7 +264,7 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
auto bcst_label = get_broadcast_label(matcher_const_zero); auto bcst_label = get_broadcast_label(matcher_const_zero);
auto bcst_or_cnst = matcher_const_zero->get_pattern_map()[bcst_label]; auto bcst_or_cnst = matcher_const_zero->get_pattern_map()[bcst_label];
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << bcst_or_cnst->get_name(); NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << bcst_or_cnst->get_name();
ngraph::replace_node(n, bcst_or_cnst); replace_node(n, bcst_or_cnst);
return true; return true;
} }
...@@ -281,7 +272,7 @@ static bool simplify_multiply(std::shared_ptr<Node> n) ...@@ -281,7 +272,7 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
{ {
auto x = matcher_const_one->get_pattern_map()[label]; auto x = matcher_const_one->get_pattern_map()[label];
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << x->get_name(); NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << x->get_name();
ngraph::replace_node(n, x); replace_node(n, x);
return true; return true;
} }
...@@ -293,12 +284,12 @@ static bool simplify_multiply(std::shared_ptr<Node> n) ...@@ -293,12 +284,12 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
// //
//a + 0 -> a //a + 0 -> a
//a + broadcast(0) -> a //a + broadcast(0) -> a
static bool simplify_add(std::shared_ptr<Node> n) static bool simplify_add(shared_ptr<Node> n)
{ {
NGRAPH_DEBUG << "In simplify_add for " << n->get_name(); NGRAPH_DEBUG << "In simplify_add for " << n->get_name();
auto iconst = ngraph::make_zero(element::i32, Shape{}); auto iconst = make_zero(element::i32, Shape{});
auto label = std::make_shared<pattern::op::Label>(iconst); auto label = make_shared<pattern::op::Label>(iconst);
auto const_label = std::make_shared<pattern::op::Label>(iconst, nullptr, NodeVector{iconst}); auto const_label = make_shared<pattern::op::Label>(iconst, nullptr, NodeVector{iconst});
auto matcher = create_binary_matcher<op::Add>(label, const_label); auto matcher = create_binary_matcher<op::Add>(label, const_label);
if (matcher->match(n)) if (matcher->match(n))
...@@ -309,10 +300,10 @@ static bool simplify_add(std::shared_ptr<Node> n) ...@@ -309,10 +300,10 @@ static bool simplify_add(std::shared_ptr<Node> n)
NGRAPH_DEBUG << "Node " << n->get_name() << " matched \" arg + 0 \" \n" NGRAPH_DEBUG << "Node " << n->get_name() << " matched \" arg + 0 \" \n"
<< " arg : " << x->get_name() << " , const : " << cnst->get_name(); << " arg : " << x->get_name() << " , const : " << cnst->get_name();
if (ngraph::is_zero(cnst)) if (is_zero(cnst))
{ {
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << x->get_name(); NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << x->get_name();
ngraph::replace_node(n, x); replace_node(n, x);
return true; return true;
} }
else else
...@@ -324,16 +315,16 @@ static bool simplify_add(std::shared_ptr<Node> n) ...@@ -324,16 +315,16 @@ static bool simplify_add(std::shared_ptr<Node> n)
} }
//`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)` //`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)`
static bool simplify_log(std::shared_ptr<Node> n) static bool simplify_log(shared_ptr<Node> n)
{ {
if (auto div = std::dynamic_pointer_cast<op::Divide>(n->get_argument(0))) if (auto div = dynamic_pointer_cast<op::Divide>(n->get_argument(0)))
{ {
if (auto exp = std::dynamic_pointer_cast<op::Exp>(div->get_argument(0))) if (auto exp = dynamic_pointer_cast<op::Exp>(div->get_argument(0)))
{ {
auto denom = div->get_argument(1); auto denom = div->get_argument(1);
auto diff = std::make_shared<op::Subtract>(exp->get_argument(0), auto diff =
std::make_shared<op::Log>(denom)); make_shared<op::Subtract>(exp->get_argument(0), make_shared<op::Log>(denom));
ngraph::replace_node(n, diff); replace_node(n, diff);
return true; return true;
} }
} }
...@@ -353,16 +344,15 @@ static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape) ...@@ -353,16 +344,15 @@ static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape)
} }
template <typename T> template <typename T>
static std::shared_ptr<Node> static shared_ptr<Node>
multiply_by(element::Type type, size_t multiplier, std::shared_ptr<op::Constant> cnst) multiply_by(element::Type type, size_t multiplier, shared_ptr<op::Constant> cnst)
{ {
T sum_cnst = static_cast<T>(cnst->get_vector<T>().at(0) * multiplier); T sum_cnst = static_cast<T>(cnst->get_vector<T>().at(0) * multiplier);
return op::Constant::create<T>(type, Shape{}, {sum_cnst}); return op::Constant::create<T>(type, Shape{}, {sum_cnst});
} }
template <typename T> template <typename T>
static std::shared_ptr<Node> static shared_ptr<Node> pow_by(element::Type type, size_t multiplier, shared_ptr<op::Constant> cnst)
pow_by(element::Type type, size_t multiplier, std::shared_ptr<op::Constant> cnst)
{ {
T prod = static_cast<T>(1); T prod = static_cast<T>(1);
T val = cnst->get_vector<T>().at(0); T val = cnst->get_vector<T>().at(0);
...@@ -373,7 +363,7 @@ static std::shared_ptr<Node> ...@@ -373,7 +363,7 @@ static std::shared_ptr<Node>
return op::Constant::create<T>(type, Shape{}, {prod}); return op::Constant::create<T>(type, Shape{}, {prod});
} }
static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst, size_t multiplier) static shared_ptr<Node> get_sum_constant(shared_ptr<op::Constant> cnst, size_t multiplier)
{ {
if (cnst->get_element_type() == element::i32) if (cnst->get_element_type() == element::i32)
{ {
...@@ -395,8 +385,7 @@ static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst ...@@ -395,8 +385,7 @@ static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst
return nullptr; return nullptr;
} }
static std::shared_ptr<Node> get_prod_constant(std::shared_ptr<op::Constant> cnst, static shared_ptr<Node> get_prod_constant(shared_ptr<op::Constant> cnst, size_t multiplier)
size_t multiplier)
{ {
if (cnst->get_element_type() == element::i32) if (cnst->get_element_type() == element::i32)
{ {
...@@ -423,21 +412,20 @@ static std::shared_ptr<Node> get_prod_constant(std::shared_ptr<op::Constant> cns ...@@ -423,21 +412,20 @@ static std::shared_ptr<Node> get_prod_constant(std::shared_ptr<op::Constant> cns
//where constant2's values are equal to scalar_constant * shape_size(reduction_axes) //where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
//product(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant) //product(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant ^ shape_size(reduction_axes) //where constant2's values are equal to scalar_constant ^ shape_size(reduction_axes)
template <typename T, template <typename T, shared_ptr<Node> (*F)(shared_ptr<op::Constant> cnst, size_t multiplier)>
std::shared_ptr<Node> (*F)(std::shared_ptr<op::Constant> cnst, size_t multiplier)> static bool simplify_reduction(shared_ptr<Node> n)
static bool simplify_reduction(std::shared_ptr<Node> n)
{ {
NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name(); NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name();
auto reduction = std::static_pointer_cast<T>(n); auto reduction = static_pointer_cast<T>(n);
auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(n->get_argument(0)); auto broadcast = dynamic_pointer_cast<op::Broadcast>(n->get_argument(0));
if (!broadcast) if (!broadcast)
{ {
NGRAPH_DEBUG << n->get_name() << " isn't Broadcast"; NGRAPH_DEBUG << n->get_name() << " isn't Broadcast";
return false; return false;
} }
auto cnst = std::dynamic_pointer_cast<op::Constant>(broadcast->get_argument(0)); auto cnst = dynamic_pointer_cast<op::Constant>(broadcast->get_argument(0));
if (!cnst || cnst->get_shape().size() > 0 /*not a scalar*/) if (!cnst || cnst->get_shape().size() > 0 /*not a scalar*/)
{ {
NGRAPH_DEBUG << broadcast->get_argument(0)->get_name() << " isn't a scalar constant"; NGRAPH_DEBUG << broadcast->get_argument(0)->get_name() << " isn't a scalar constant";
...@@ -456,39 +444,35 @@ static bool simplify_reduction(std::shared_ptr<Node> n) ...@@ -456,39 +444,35 @@ static bool simplify_reduction(std::shared_ptr<Node> n)
if (reduction->get_shape().size() > 0) if (reduction->get_shape().size() > 0)
{ {
ngraph::AxisSet axes{}; AxisSet axes{};
for (size_t i = 0; i < reduction->get_shape().size(); i++) for (size_t i = 0; i < reduction->get_shape().size(); i++)
{ {
axes.insert(i); axes.insert(i);
} }
reduction_cnst = reduction_cnst = make_shared<op::Broadcast>(reduction_cnst, reduction->get_shape(), axes);
std::make_shared<op::Broadcast>(reduction_cnst, reduction->get_shape(), axes);
} }
ngraph::replace_node(n, reduction_cnst); replace_node(n, reduction_cnst);
return true; return true;
} }
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>> static unordered_map<type_index, function<bool(shared_ptr<Node>)>> initialize_ops_to_simplifiers()
initialize_ops_to_simplifiers()
{ {
return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>( return unordered_map<type_index, function<bool(shared_ptr<Node>)>>(
{{TI(op::Add), simplify_add}, {{TI(op::Add), simplify_add},
{TI(op::Multiply), simplify_multiply}, {TI(op::Multiply), simplify_multiply},
{TI(op::Concat), simplify_concat}, {TI(op::Concat), simplify_concat},
{TI(op::Sum), {TI(op::Sum),
std::function<bool(std::shared_ptr<Node>)>{ function<bool(shared_ptr<Node>)>{simplify_reduction<op::Sum, get_sum_constant>}},
simplify_reduction<op::Sum, get_sum_constant>}},
{TI(op::Product), {TI(op::Product),
std::function<bool(std::shared_ptr<Node>)>{ function<bool(shared_ptr<Node>)>{simplify_reduction<op::Product, get_prod_constant>}},
simplify_reduction<op::Product, get_prod_constant>}},
{TI(op::Log), simplify_log}}); {TI(op::Log), simplify_log}});
} }
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>> static unordered_map<type_index, function<bool(shared_ptr<Node>)>> ops_to_simplifiers =
ops_to_simplifiers = initialize_ops_to_simplifiers(); initialize_ops_to_simplifiers();
bool ngraph::pass::AlgebraicSimplification::run_on_function(std::shared_ptr<ngraph::Function> f) bool pass::AlgebraicSimplification::run_on_function(shared_ptr<Function> f)
{ {
bool replaced = false; bool replaced = false;
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
......
...@@ -89,7 +89,7 @@ shared_ptr<op::Constant> make_constant_pad(shared_ptr<op::Constant> constant, ...@@ -89,7 +89,7 @@ shared_ptr<op::Constant> make_constant_pad(shared_ptr<op::Constant> constant,
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
void ngraph::pass::ConstantFolding::construct_constant_pad() void pass::ConstantFolding::construct_constant_pad()
{ {
auto is_constant = pattern::has_class<op::Constant>(); auto is_constant = pattern::has_class<op::Constant>();
auto constant_label = make_shared<pattern::op::Label>(element::f32, Shape{6}, is_constant); auto constant_label = make_shared<pattern::op::Label>(element::f32, Shape{6}, is_constant);
...@@ -142,7 +142,7 @@ void ngraph::pass::ConstantFolding::construct_constant_pad() ...@@ -142,7 +142,7 @@ void ngraph::pass::ConstantFolding::construct_constant_pad()
this->add_matcher(pad_matcher); this->add_matcher(pad_matcher);
} }
void ngraph::pass::ConstantFolding::construct_constant_reshape() void pass::ConstantFolding::construct_constant_reshape()
{ {
auto constant_label = make_shared<pattern::op::Label>( auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>()); element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
...@@ -207,7 +207,7 @@ shared_ptr<op::Constant> make_constant_broadcast(shared_ptr<op::Constant> consta ...@@ -207,7 +207,7 @@ shared_ptr<op::Constant> make_constant_broadcast(shared_ptr<op::Constant> consta
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
void ngraph::pass::ConstantFolding::construct_constant_broadcast() void pass::ConstantFolding::construct_constant_broadcast()
{ {
auto constant_label = auto constant_label =
make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>()); make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>());
...@@ -324,7 +324,7 @@ bool is_supported_binary_op(std::shared_ptr<Node> n) ...@@ -324,7 +324,7 @@ bool is_supported_binary_op(std::shared_ptr<Node> n)
std::dynamic_pointer_cast<op::Minimum>(n)); std::dynamic_pointer_cast<op::Minimum>(n));
} }
void ngraph::pass::ConstantFolding::construct_constant_binary() void pass::ConstantFolding::construct_constant_binary()
{ {
auto a = make_shared<pattern::op::Label>( auto a = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>()); element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
...@@ -418,7 +418,7 @@ shared_ptr<op::Constant> make_constant_unary(shared_ptr<op::Constant> constant, ...@@ -418,7 +418,7 @@ shared_ptr<op::Constant> make_constant_unary(shared_ptr<op::Constant> constant,
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
void ngraph::pass::ConstantFolding::construct_constant_unary() void pass::ConstantFolding::construct_constant_unary()
{ {
auto constant_label = make_shared<pattern::op::Label>( auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>()); element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
...@@ -493,7 +493,7 @@ shared_ptr<op::Constant> make_constant_dequantize(shared_ptr<op::Constant> const ...@@ -493,7 +493,7 @@ shared_ptr<op::Constant> make_constant_dequantize(shared_ptr<op::Constant> const
return make_shared<op::Constant>(dequant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(dequant->get_element_type(), out_shape, out_vec);
} }
void ngraph::pass::ConstantFolding::construct_constant_dequantize() void pass::ConstantFolding::construct_constant_dequantize()
{ {
auto constant_label = auto constant_label =
make_shared<pattern::op::Label>(element::u8, Shape{2}, pattern::has_class<op::Constant>()); make_shared<pattern::op::Label>(element::u8, Shape{2}, pattern::has_class<op::Constant>());
...@@ -567,7 +567,7 @@ shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constan ...@@ -567,7 +567,7 @@ shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constan
return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec);
} }
void ngraph::pass::ConstantFolding::construct_constant_quantize() void pass::ConstantFolding::construct_constant_quantize()
{ {
auto constant_label = auto constant_label =
make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>()); make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>());
......
...@@ -69,7 +69,7 @@ void pass::CoreFusion::construct_relu() ...@@ -69,7 +69,7 @@ void pass::CoreFusion::construct_relu()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto mzero = m.get_pattern_map()[zero]; auto mzero = m.get_pattern_map()[zero];
if (!ngraph::is_zero(mzero)) if (!is_zero(mzero))
{ {
NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n"; NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n";
return false; return false;
...@@ -77,7 +77,7 @@ void pass::CoreFusion::construct_relu() ...@@ -77,7 +77,7 @@ void pass::CoreFusion::construct_relu()
auto mpattern = m.get_match_root(); auto mpattern = m.get_match_root();
auto cg = shared_ptr<Node>(new op::Relu(pattern_map[val])); auto cg = shared_ptr<Node>(new op::Relu(pattern_map[val]));
ngraph::replace_node(m.get_match_root(), cg); replace_node(m.get_match_root(), cg);
return true; return true;
}; };
...@@ -100,7 +100,7 @@ void pass::CoreFusion::construct_sigmoid() ...@@ -100,7 +100,7 @@ void pass::CoreFusion::construct_sigmoid()
auto divide_1_over_exp = std::make_shared<op::Divide>(skip_broadcast, add_exp); auto divide_1_over_exp = std::make_shared<op::Divide>(skip_broadcast, add_exp);
// Define a call back that needs to called once the DFG matches the pattern // Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input, constant](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [input, constant](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -125,12 +125,11 @@ void pass::CoreFusion::construct_sigmoid() ...@@ -125,12 +125,11 @@ void pass::CoreFusion::construct_sigmoid()
return false; return false;
} }
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]); auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
ngraph::replace_node(m.get_match_root(), sigmoid_node); replace_node(m.get_match_root(), sigmoid_node);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>( auto m = std::make_shared<pattern::Matcher>(divide_1_over_exp, callback, "CoreFusion.Sigmoid");
divide_1_over_exp, callback, "CoreFusion.Sigmoid");
this->add_matcher(m); this->add_matcher(m);
} }
...@@ -159,7 +158,7 @@ void pass::CoreFusion::construct_sigmoid_bprop() ...@@ -159,7 +158,7 @@ void pass::CoreFusion::construct_sigmoid_bprop()
auto negative_2 = std::make_shared<op::Negative>(multiply_2); auto negative_2 = std::make_shared<op::Negative>(multiply_2);
// Define a call back that needs to called once the DFG matches the pattern // Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_bprop_sigmoid pattern against " NGRAPH_DEBUG << "In a callback for construct_bprop_sigmoid pattern against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -178,12 +177,11 @@ void pass::CoreFusion::construct_sigmoid_bprop() ...@@ -178,12 +177,11 @@ void pass::CoreFusion::construct_sigmoid_bprop()
} }
auto dsigmoid = auto dsigmoid =
std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]); std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]);
ngraph::replace_node(m.get_match_root(), dsigmoid); replace_node(m.get_match_root(), dsigmoid);
return true; return true;
}; };
auto m = auto m = std::make_shared<pattern::Matcher>(negative_2, callback, "CoreFusion.SigmoidBprop");
std::make_shared<ngraph::pattern::Matcher>(negative_2, callback, "CoreFusion.SigmoidBprop");
this->add_matcher(m); this->add_matcher(m);
} }
...@@ -212,7 +210,7 @@ void pass::CoreFusion::construct_folded_batch_norm() ...@@ -212,7 +210,7 @@ void pass::CoreFusion::construct_folded_batch_norm()
auto shape_r = Shape{1, 2, 2, 2}; auto shape_r = Shape{1, 2, 2, 2};
auto bn = std::make_shared<op::BatchNormInference>(eps, gamma, beta, pconv, mean, var); auto bn = std::make_shared<op::BatchNormInference>(eps, gamma, beta, pconv, mean, var);
ngraph::pattern::graph_rewrite_callback callback = [input, filters, mean, var, gamma, beta]( pattern::graph_rewrite_callback callback = [input, filters, mean, var, gamma, beta](
pattern::Matcher& m) { pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for folded batch norm against node = " NGRAPH_DEBUG << "In callback for folded batch norm against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -258,13 +256,13 @@ void pass::CoreFusion::construct_folded_batch_norm() ...@@ -258,13 +256,13 @@ void pass::CoreFusion::construct_folded_batch_norm()
m_conv->get_data_dilation_strides()); m_conv->get_data_dilation_strides());
auto conv_bias = auto conv_bias =
conv + std::make_shared<op::Broadcast>(new_biases, conv->get_shape(), AxisSet{0, 2, 3}); conv + std::make_shared<op::Broadcast>(new_biases, conv->get_shape(), AxisSet{0, 2, 3});
ngraph::replace_node(m.get_match_root(), conv_bias); replace_node(m.get_match_root(), conv_bias);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, callback, "CoreFusion.FoldedBatchNorm"); auto m = std::make_shared<pattern::Matcher>(bn, callback, "CoreFusion.FoldedBatchNorm");
this->add_matcher(m); this->add_matcher(m);
} }
...@@ -293,7 +291,7 @@ void pass::CoreFusion::construct_conv_affine_folding() ...@@ -293,7 +291,7 @@ void pass::CoreFusion::construct_conv_affine_folding()
auto multiply = std::make_shared<op::Multiply>(conv_label, A_label); auto multiply = std::make_shared<op::Multiply>(conv_label, A_label);
auto add = std::make_shared<op::Add>(multiply, B_label); auto add = std::make_shared<op::Add>(multiply, B_label);
ngraph::pattern::graph_rewrite_callback callback = pattern::graph_rewrite_callback callback =
[input, filters, conv_label, A_label, B_label](pattern::Matcher& m) { [input, filters, conv_label, A_label, B_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for conv affine folding against node = " NGRAPH_DEBUG << "In callback for conv affine folding against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -345,7 +343,7 @@ void pass::CoreFusion::construct_conv_affine_folding() ...@@ -345,7 +343,7 @@ void pass::CoreFusion::construct_conv_affine_folding()
if (bcast->get_argument(0)->get_shape().size() == 2) if (bcast->get_argument(0)->get_shape().size() == 2)
{ {
Shape bshape{bcast->get_argument(0)->get_shape()[1]}; Shape bshape{bcast->get_argument(0)->get_shape()[1]};
return static_pointer_cast<ngraph::Node>(std::make_shared<op::Reshape>( return static_pointer_cast<Node>(std::make_shared<op::Reshape>(
bcast->get_argument(0), AxisVector{0, 1}, bshape)); bcast->get_argument(0), AxisVector{0, 1}, bshape));
} }
throw ngraph_error("Unexpected shape for bcast input"); throw ngraph_error("Unexpected shape for bcast input");
...@@ -369,14 +367,13 @@ void pass::CoreFusion::construct_conv_affine_folding() ...@@ -369,14 +367,13 @@ void pass::CoreFusion::construct_conv_affine_folding()
conv_m->get_padding_above(), conv_m->get_padding_above(),
conv_m->get_data_dilation_strides()); conv_m->get_data_dilation_strides());
auto convbias_n = conv_n + B_m; auto convbias_n = conv_n + B_m;
ngraph::replace_node(m.get_match_root(), convbias_n); replace_node(m.get_match_root(), convbias_n);
return true; return true;
}; };
auto m = auto m = std::make_shared<pattern::Matcher>(add, callback, "CoreFusion.ConvAffineFolding");
std::make_shared<ngraph::pattern::Matcher>(add, callback, "CoreFusion.ConvAffineFolding");
this->add_matcher(m); this->add_matcher(m);
} }
...@@ -440,7 +437,7 @@ static size_t shape_to_index(Shape shape) ...@@ -440,7 +437,7 @@ static size_t shape_to_index(Shape shape)
} }
} }
void ngraph::pass::CoreFusion::construct_reshape_broadcast() void pass::CoreFusion::construct_reshape_broadcast()
{ {
Shape input_shape{10}; Shape input_shape{10};
auto input = make_shared<pattern::op::Label>(element::f32, input_shape); auto input = make_shared<pattern::op::Label>(element::f32, input_shape);
...@@ -473,7 +470,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast() ...@@ -473,7 +470,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast()
if (d != 1 && d != dim) if (d != 1 && d != dim)
{ {
NGRAPH_DEBUG << "Input is reshaped in a way we can't directly broadcast ( shape = " NGRAPH_DEBUG << "Input is reshaped in a way we can't directly broadcast ( shape = "
<< ngraph::vector_to_string(reshape1_m->get_shape()) << ")"; << vector_to_string(reshape1_m->get_shape()) << ")";
return false; return false;
} }
...@@ -502,7 +499,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast() ...@@ -502,7 +499,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast()
auto new_broadcast = auto new_broadcast =
make_shared<op::Broadcast>(input_m, broadcast_m->get_shape(), new_axes); make_shared<op::Broadcast>(input_m, broadcast_m->get_shape(), new_axes);
ngraph::replace_node(m.get_match_root(), new_broadcast); replace_node(m.get_match_root(), new_broadcast);
return true; return true;
}; };
...@@ -520,7 +517,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast() ...@@ -520,7 +517,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast()
void pass::CoreFusion::construct_optimized_strided_conv() void pass::CoreFusion::construct_optimized_strided_conv()
{ {
Shape win_size_1{1, 1, 1, 1}; Shape win_size_1{1, 1, 1, 1};
auto is_bc = ngraph::pattern::has_class<op::Broadcast>(); auto is_bc = pattern::has_class<op::Broadcast>();
auto data_stride3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 128, 128}); auto data_stride3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 128, 128});
auto weights_stride3 = std::make_shared<pattern::op::Label>(element::f32, win_size_1); auto weights_stride3 = std::make_shared<pattern::op::Label>(element::f32, win_size_1);
...@@ -689,7 +686,7 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -689,7 +686,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
new_relu_two_convs, sconv->get_argument(1), stride_1, stride_1); new_relu_two_convs, sconv->get_argument(1), stride_1, stride_1);
NGRAPH_DEBUG << "Replacing " << sconv->get_name() << " with " NGRAPH_DEBUG << "Replacing " << sconv->get_name() << " with "
<< sconv_28w1s1->get_name(); << sconv_28w1s1->get_name();
ngraph::replace_node(sconv, sconv_28w1s1); replace_node(sconv, sconv_28w1s1);
} }
return true; return true;
}; };
...@@ -699,7 +696,7 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -699,7 +696,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
this->add_matcher(m); this->add_matcher(m);
} }
void ngraph::pass::CoreFusion::construct_reshape_softmax_reshape() void pass::CoreFusion::construct_reshape_softmax_reshape()
{ {
Shape input_shape{10, 20}; Shape input_shape{10, 20};
AxisVector io{1, 0}; AxisVector io{1, 0};
...@@ -738,7 +735,7 @@ void ngraph::pass::CoreFusion::construct_reshape_softmax_reshape() ...@@ -738,7 +735,7 @@ void ngraph::pass::CoreFusion::construct_reshape_softmax_reshape()
} }
auto new_softmax = make_shared<op::Softmax>(input_m, new_axes); auto new_softmax = make_shared<op::Softmax>(input_m, new_axes);
ngraph::replace_node(m.get_match_root(), new_softmax); replace_node(m.get_match_root(), new_softmax);
return true; return true;
}; };
......
...@@ -59,11 +59,12 @@ ...@@ -59,11 +59,12 @@
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
#define TI(x) std::type_index(typeid(x)) #define TI(x) type_index(typeid(x))
static bool cse_constant(std::shared_ptr<Node> a, std::shared_ptr<Node> b) static bool cse_constant(shared_ptr<Node> a, shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_constant for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_constant for " << a->get_name() << " and " << b->get_name();
...@@ -72,44 +73,44 @@ static bool cse_constant(std::shared_ptr<Node> a, std::shared_ptr<Node> b) ...@@ -72,44 +73,44 @@ static bool cse_constant(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
return false; return false;
} }
auto ca = std::static_pointer_cast<op::Constant>(a); auto ca = static_pointer_cast<op::Constant>(a);
auto cb = std::static_pointer_cast<op::Constant>(b); auto cb = static_pointer_cast<op::Constant>(b);
size_t size = shape_size(a->get_shape()) * a->get_element_type().size(); size_t size = shape_size(a->get_shape()) * a->get_element_type().size();
return !memcmp(ca->get_data_ptr(), cb->get_data_ptr(), size); return !memcmp(ca->get_data_ptr(), cb->get_data_ptr(), size);
} }
static bool cse_reshape(std::shared_ptr<Node> a, std::shared_ptr<Node> b) static bool cse_reshape(shared_ptr<Node> a, shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_reshape for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_reshape for " << a->get_name() << " and " << b->get_name();
auto reshape_a = std::static_pointer_cast<ngraph::op::Reshape>(a); auto reshape_a = static_pointer_cast<ngraph::op::Reshape>(a);
auto reshape_b = std::static_pointer_cast<ngraph::op::Reshape>(b); auto reshape_b = static_pointer_cast<ngraph::op::Reshape>(b);
return (a->get_argument(0) == b->get_argument(0)) && return (a->get_argument(0) == b->get_argument(0)) &&
(reshape_a->get_input_order() == reshape_b->get_input_order()) && (reshape_a->get_input_order() == reshape_b->get_input_order()) &&
(reshape_a->get_output_shape() == reshape_b->get_output_shape()); (reshape_a->get_output_shape() == reshape_b->get_output_shape());
} }
static bool cse_broadcast(std::shared_ptr<Node> a, std::shared_ptr<Node> b) static bool cse_broadcast(shared_ptr<Node> a, shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_broadcast for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_broadcast for " << a->get_name() << " and " << b->get_name();
auto broadcast_a = std::static_pointer_cast<ngraph::op::Broadcast>(a); auto broadcast_a = static_pointer_cast<ngraph::op::Broadcast>(a);
auto broadcast_b = std::static_pointer_cast<ngraph::op::Broadcast>(b); auto broadcast_b = static_pointer_cast<ngraph::op::Broadcast>(b);
return (a->get_argument(0) == b->get_argument(0)) && return (a->get_argument(0) == b->get_argument(0)) &&
(broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) && (broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) &&
(broadcast_a->get_broadcast_shape() == broadcast_b->get_broadcast_shape()); (broadcast_a->get_broadcast_shape() == broadcast_b->get_broadcast_shape());
} }
static bool cse_unarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b) static bool cse_unarywise(shared_ptr<Node> a, shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_unarywise for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_unarywise for " << a->get_name() << " and " << b->get_name();
return a->get_argument(0) == b->get_argument(0); return a->get_argument(0) == b->get_argument(0);
} }
static bool cse_binarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b) static bool cse_binarywise(shared_ptr<Node> a, shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_binary for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_binary for " << a->get_name() << " and " << b->get_name();
...@@ -117,23 +118,21 @@ static bool cse_binarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b) ...@@ -117,23 +118,21 @@ static bool cse_binarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
(a->get_argument(1) == b->get_argument(0) && a->get_argument(0) == b->get_argument(1)); (a->get_argument(1) == b->get_argument(0) && a->get_argument(0) == b->get_argument(1));
} }
static bool cse_reduction(std::shared_ptr<Node> a, std::shared_ptr<Node> b) static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_reduction for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_reduction for " << a->get_name() << " and " << b->get_name();
auto ar_a = std::static_pointer_cast<op::util::ArithmeticReduction>(a); auto ar_a = static_pointer_cast<op::util::ArithmeticReduction>(a);
auto ar_b = std::static_pointer_cast<op::util::ArithmeticReduction>(b); auto ar_b = static_pointer_cast<op::util::ArithmeticReduction>(b);
return ar_a->get_argument(0) == ar_b->get_argument(0) && return ar_a->get_argument(0) == ar_b->get_argument(0) &&
ar_a->get_reduction_axes() == ar_b->get_reduction_axes(); ar_a->get_reduction_axes() == ar_b->get_reduction_axes();
} }
static std::unordered_map<std::type_index, static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>
initialize_ops_to_cse_handlers() initialize_ops_to_cse_handlers()
{ {
return std::unordered_map<std::type_index, return unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>(
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>(
{{TI(op::Abs), cse_unarywise}, {{TI(op::Abs), cse_unarywise},
{TI(op::Acos), cse_unarywise}, {TI(op::Acos), cse_unarywise},
{TI(op::Asin), cse_unarywise}, {TI(op::Asin), cse_unarywise},
...@@ -168,23 +167,21 @@ static std::unordered_map<std::type_index, ...@@ -168,23 +167,21 @@ static std::unordered_map<std::type_index,
{TI(op::Broadcast), cse_broadcast}}); {TI(op::Broadcast), cse_broadcast}});
} }
static std::unordered_map<std::type_index, static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>
ops_to_cse_handlers = initialize_ops_to_cse_handlers(); ops_to_cse_handlers = initialize_ops_to_cse_handlers();
class NodeKey class NodeKey
{ {
public: public:
NodeKey(std::shared_ptr<Node> n, NodeKey(shared_ptr<Node> n,
std::unordered_map<std::type_index, unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>&
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>&
backend_handlers) backend_handlers)
: m_node(n) : m_node(n)
, m_backend_handlers(backend_handlers) , m_backend_handlers(backend_handlers)
{ {
} }
std::shared_ptr<Node> get_node() const { return m_node; } shared_ptr<Node> get_node() const { return m_node; }
bool operator==(const NodeKey& other) const bool operator==(const NodeKey& other) const
{ {
Node& p_this = *m_node.get(); Node& p_this = *m_node.get();
...@@ -215,9 +212,8 @@ public: ...@@ -215,9 +212,8 @@ public:
} }
private: private:
std::shared_ptr<Node> m_node; shared_ptr<Node> m_node;
std::unordered_map<std::type_index, unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>&
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>&
m_backend_handlers; m_backend_handlers;
}; };
...@@ -226,15 +222,15 @@ namespace std ...@@ -226,15 +222,15 @@ namespace std
template <> template <>
struct hash<NodeKey> struct hash<NodeKey>
{ {
std::size_t operator()(const NodeKey& k) const size_t operator()(const NodeKey& k) const
{ {
Node& p_this = *k.get_node().get(); Node& p_this = *k.get_node().get();
auto ti = TI(p_this); auto ti = TI(p_this);
std::hash<std::type_index> type_hash_compute{}; hash<type_index> type_hash_compute{};
auto type_hash = type_hash_compute(ti); auto type_hash = type_hash_compute(ti);
std::vector<size_t> arg_ids; vector<size_t> arg_ids;
arg_ids.push_back(type_hash); arg_ids.push_back(type_hash);
...@@ -244,7 +240,7 @@ namespace std ...@@ -244,7 +240,7 @@ namespace std
// specify how to compute hash for each op? // specify how to compute hash for each op?
if (p_this.is_commutative()) if (p_this.is_commutative())
{ {
std::sort(begin(cargs), end(cargs)); sort(begin(cargs), end(cargs));
} }
for (auto arg : cargs) for (auto arg : cargs)
...@@ -258,11 +254,10 @@ namespace std ...@@ -258,11 +254,10 @@ namespace std
}; };
} }
bool ngraph::pass::CommonSubexpressionElimination::run_on_function( bool ngraph::pass::CommonSubexpressionElimination::run_on_function(shared_ptr<ngraph::Function> f)
std::shared_ptr<ngraph::Function> f)
{ {
bool replaced = false; bool replaced = false;
std::unordered_map<NodeKey, std::shared_ptr<Node>> expressions{}; unordered_map<NodeKey, shared_ptr<Node>> expressions{};
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
...@@ -279,7 +274,7 @@ bool ngraph::pass::CommonSubexpressionElimination::run_on_function( ...@@ -279,7 +274,7 @@ bool ngraph::pass::CommonSubexpressionElimination::run_on_function(
} }
else else
{ {
expressions.insert(std::make_pair(n_key, n)); expressions.insert(make_pair(n_key, n));
} }
} }
......
...@@ -24,6 +24,9 @@ ...@@ -24,6 +24,9 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
// GraphRewrite algorithm: // GraphRewrite algorithm:
// GraphRewrite processes an input graph in an topological order(i.e. args before users) // GraphRewrite processes an input graph in an topological order(i.e. args before users)
// Given the following graph: Abs2 // Given the following graph: Abs2
...@@ -56,16 +59,16 @@ ...@@ -56,16 +59,16 @@
// c) there's no linear order of fusions which will give // c) there's no linear order of fusions which will give
// the correct final fusion. i.e. the same fusion needs to occur before and after some other fusion // the correct final fusion. i.e. the same fusion needs to occur before and after some other fusion
bool ngraph::pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f) bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
{ {
bool rewritten = false; bool rewritten = false;
const size_t NUM_TRIES = 10; const size_t NUM_TRIES = 10;
size_t tries = NUM_TRIES; size_t tries = NUM_TRIES;
std::vector<std::shared_ptr<pattern::Matcher>> original_matchers{m_matchers}; vector<shared_ptr<pattern::Matcher>> original_matchers{m_matchers};
do do
{ {
rewritten = false; rewritten = false;
std::vector<std::shared_ptr<pattern::Matcher>> matchers{m_matchers}; vector<shared_ptr<pattern::Matcher>> matchers{m_matchers};
m_matchers.clear(); m_matchers.clear();
for (auto node : f->get_ordered_ops()) for (auto node : f->get_ordered_ops())
{ {
...@@ -92,31 +95,31 @@ bool ngraph::pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Functio ...@@ -92,31 +95,31 @@ bool ngraph::pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Functio
return (NUM_TRIES - tries) > 1; //this means a graph was transformed return (NUM_TRIES - tries) > 1; //this means a graph was transformed
} }
static const std::vector<std::regex> initialize_fusion_regexes() static const vector<regex> initialize_fusion_regexes()
{ {
const char* cnsf = std::getenv("NGRAPH_DISABLED_FUSIONS"); const char* cnsf = getenv("NGRAPH_DISABLED_FUSIONS");
std::vector<std::regex> regexes; vector<regex> regexes;
if (cnsf) if (cnsf)
{ {
const std::string nsf = cnsf; const string nsf = cnsf;
const auto sregexes = ngraph::split(nsf, ';'); const auto sregexes = split(nsf, ';');
std::transform(sregexes.begin(), transform(sregexes.begin(),
sregexes.end(), sregexes.end(),
std::back_inserter(regexes), back_inserter(regexes),
[](const std::string& c) -> std::regex { return std::regex(c); }); [](const string& c) -> regex { return regex(c); });
} }
return regexes; return regexes;
} }
bool ngraph::pass::GraphRewrite::is_enabled(std::shared_ptr<pattern::Matcher> m) bool pass::GraphRewrite::is_enabled(shared_ptr<pattern::Matcher> m)
{ {
//note, regexes are static to avoid re-initialization //note, regexes are static to avoid re-initialization
static const auto regexes = initialize_fusion_regexes(); static const auto regexes = initialize_fusion_regexes();
for (const auto& regex : regexes) for (const auto& regex : regexes)
{ {
if (std::regex_match(m->get_name(), regex)) if (regex_match(m->get_name(), regex))
{ {
NGRAPH_DEBUG << "Disabling matcher " << m->get_name(); NGRAPH_DEBUG << "Disabling matcher " << m->get_name();
return false; return false;
...@@ -126,7 +129,7 @@ bool ngraph::pass::GraphRewrite::is_enabled(std::shared_ptr<pattern::Matcher> m) ...@@ -126,7 +129,7 @@ bool ngraph::pass::GraphRewrite::is_enabled(std::shared_ptr<pattern::Matcher> m)
return true; return true;
} }
void ngraph::pass::GraphRewrite::add_matcher(std::shared_ptr<pattern::Matcher> m) void pass::GraphRewrite::add_matcher(shared_ptr<pattern::Matcher> m)
{ {
if (is_enabled(m)) if (is_enabled(m))
{ {
...@@ -134,7 +137,7 @@ void ngraph::pass::GraphRewrite::add_matcher(std::shared_ptr<pattern::Matcher> m ...@@ -134,7 +137,7 @@ void ngraph::pass::GraphRewrite::add_matcher(std::shared_ptr<pattern::Matcher> m
} }
} }
bool ngraph::pass::RecurrentGraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f) bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
{ {
bool changed = false; bool changed = false;
size_t i = 0; size_t i = 0;
......
...@@ -30,27 +30,28 @@ ...@@ -30,27 +30,28 @@
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#define TI(x) std::type_index(typeid(x)) using namespace std;
using namespace ngraph;
#define HANDLER_DECL(x) static bool x(const std::shared_ptr<ngraph::Node>& node) #define TI(x) type_index(typeid(x))
#define HANDLER_DECL(x) static bool x(const shared_ptr<Node>& node)
HANDLER_DECL(replace_broadcast_like) HANDLER_DECL(replace_broadcast_like)
{ {
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument // Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument
auto broadcast_like = std::static_pointer_cast<ngraph::op::BroadcastLike>(node); auto broadcast_like = static_pointer_cast<op::BroadcastLike>(node);
ngraph::replace_node( replace_node(node,
node, make_shared<op::Broadcast>(broadcast_like->get_argument(0),
std::make_shared<ngraph::op::Broadcast>(broadcast_like->get_argument(0), broadcast_like->get_broadcast_shape(),
broadcast_like->get_broadcast_shape(), broadcast_like->get_broadcast_axes()));
broadcast_like->get_broadcast_axes()));
return true; return true;
} }
static const std::unordered_map<std::type_index, static const unordered_map<type_index, function<bool(const shared_ptr<Node>&)>> dispatcher{
std::function<bool(const std::shared_ptr<ngraph::Node>&)>> {TI(op::BroadcastLike), &replace_broadcast_like}};
dispatcher{{TI(ngraph::op::BroadcastLike), &replace_broadcast_like}};
bool ngraph::pass::LikeReplacement::run_on_function(std::shared_ptr<ngraph::Function> function) bool pass::LikeReplacement::run_on_function(shared_ptr<Function> function)
{ {
bool clobbered = false; bool clobbered = false;
...@@ -66,10 +67,10 @@ bool ngraph::pass::LikeReplacement::run_on_function(std::shared_ptr<ngraph::Func ...@@ -66,10 +67,10 @@ bool ngraph::pass::LikeReplacement::run_on_function(std::shared_ptr<ngraph::Func
// Here we're checking on a common base class of a family of template classes, // Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle. // which is more than type info can handle.
auto sclb = std::dynamic_pointer_cast<ngraph::op::ScalarConstantLikeBase>(n); auto sclb = dynamic_pointer_cast<op::ScalarConstantLikeBase>(n);
if (sclb != nullptr) if (sclb != nullptr)
{ {
ngraph::replace_node(sclb, sclb->as_constant()); replace_node(sclb, sclb->as_constant());
clobbered = true; clobbered = true;
} }
} }
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function) bool pass::Liveness::run_on_function(shared_ptr<Function> function)
{ {
list<shared_ptr<Node>> ops = function->get_ordered_ops(); list<shared_ptr<Node>> ops = function->get_ordered_ops();
......
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
ngraph::pass::Manager::Manager() pass::Manager::Manager()
{ {
static const auto nevt = std::getenv("NGRAPH_ENABLE_VISUALIZE_TRACING"); static const auto nevt = std::getenv("NGRAPH_ENABLE_VISUALIZE_TRACING");
if (nevt) if (nevt)
...@@ -49,15 +49,15 @@ ngraph::pass::Manager::Manager() ...@@ -49,15 +49,15 @@ ngraph::pass::Manager::Manager()
} }
} }
ngraph::pass::Manager::~Manager() pass::Manager::~Manager()
{ {
} }
void ngraph::pass::Manager::initialize_default_passes() void pass::Manager::initialize_default_passes()
{ {
} }
void ngraph::pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
{ {
bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr; bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr;
...@@ -167,7 +167,7 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func, bool transitiv ...@@ -167,7 +167,7 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func, bool transitiv
} }
} }
ngraph::pass::ManagerState& ngraph::pass::Manager::get_state() pass::ManagerState& pass::Manager::get_state()
{ {
return m_state; return m_state;
} }
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const vector<shared_ptr<Function>>& ngraph::pass::ManagerState::get_functions() const vector<shared_ptr<Function>>& pass::ManagerState::get_functions()
{ {
return m_function_list; return m_function_list;
} }
...@@ -40,7 +40,7 @@ pass::MemoryLayout::MemoryLayout(size_t alignment, bool disable_memory_sharing) ...@@ -40,7 +40,7 @@ pass::MemoryLayout::MemoryLayout(size_t alignment, bool disable_memory_sharing)
} }
} }
bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function) bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
{ {
MemoryManager mm(m_alignment, m_disable_memory_sharing); MemoryManager mm(m_alignment, m_disable_memory_sharing);
for (shared_ptr<Node> node : function->get_ordered_ops()) for (shared_ptr<Node> node : function->get_ordered_ops())
......
...@@ -34,7 +34,7 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename) ...@@ -34,7 +34,7 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{ {
} }
bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>& functions) bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<Function>>& functions)
{ {
ofstream file(m_filename); ofstream file(m_filename);
{ {
......
...@@ -30,94 +30,93 @@ ...@@ -30,94 +30,93 @@
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "nop_elimination.hpp" #include "nop_elimination.hpp"
#define TI(x) std::type_index(typeid(x)) using namespace std;
using namespace ngraph;
#define HANDLER_DECL(x) static bool x(const std::shared_ptr<ngraph::Node>& node) #define TI(x) std::type_index(typeid(x))
HANDLER_DECL(eliminate_pad) static bool eliminate_pad(const std::shared_ptr<Node>& node)
{ {
auto pad = std::static_pointer_cast<ngraph::op::Pad>(node); auto pad = std::static_pointer_cast<op::Pad>(node);
if (pad->get_input_shape(0) == pad->get_output_shape(0)) if (pad->get_input_shape(0) == pad->get_output_shape(0))
{ {
ngraph::replace_node(node, node->get_argument(0)); replace_node(node, node->get_argument(0));
return true; return true;
} }
return false; return false;
} }
HANDLER_DECL(eliminate_sum) static bool eliminate_sum(const std::shared_ptr<Node>& node)
{ {
auto sum = std::static_pointer_cast<ngraph::op::Sum>(node); auto sum = std::static_pointer_cast<op::Sum>(node);
if (sum->get_reduction_axes().empty()) if (sum->get_reduction_axes().empty())
{ {
ngraph::replace_node(node, node->get_argument(0)); replace_node(node, node->get_argument(0));
return true; return true;
} }
return false; return false;
} }
HANDLER_DECL(eliminate_convert) static bool eliminate_convert(const std::shared_ptr<Node>& node)
{ {
auto convert = std::static_pointer_cast<ngraph::op::Convert>(node); auto convert = std::static_pointer_cast<op::Convert>(node);
if (convert->get_convert_element_type() == convert->get_argument(0)->get_element_type()) if (convert->get_convert_element_type() == convert->get_argument(0)->get_element_type())
{ {
ngraph::replace_node(node, node->get_argument(0)); replace_node(node, node->get_argument(0));
return true; return true;
} }
return false; return false;
} }
HANDLER_DECL(eliminate_slice) static bool eliminate_slice(const std::shared_ptr<Node>& node)
{ {
auto slice = std::static_pointer_cast<ngraph::op::Slice>(node); auto slice = std::static_pointer_cast<op::Slice>(node);
if (slice->get_input_shape(0) == slice->get_output_shape(0)) if (slice->get_input_shape(0) == slice->get_output_shape(0))
{ {
ngraph::replace_node(node, node->get_argument(0)); replace_node(node, node->get_argument(0));
return true; return true;
} }
return false; return false;
} }
HANDLER_DECL(replace_broadcast_like) static bool replace_broadcast_like(const std::shared_ptr<Node>& node)
{ {
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument // Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument
auto broadcast_like = std::static_pointer_cast<ngraph::op::BroadcastLike>(node); auto broadcast_like = std::static_pointer_cast<op::BroadcastLike>(node);
ngraph::replace_node( replace_node(node,
node, std::make_shared<op::Broadcast>(broadcast_like->get_argument(0),
std::make_shared<ngraph::op::Broadcast>(broadcast_like->get_argument(0), broadcast_like->get_broadcast_shape(),
broadcast_like->get_broadcast_shape(), broadcast_like->get_broadcast_axes()));
broadcast_like->get_broadcast_axes()));
return true; return true;
} }
HANDLER_DECL(eliminate_broadcast) static bool eliminate_broadcast(const std::shared_ptr<Node>& node)
{ {
auto broadcast = std::static_pointer_cast<ngraph::op::Broadcast>(node); auto broadcast = std::static_pointer_cast<op::Broadcast>(node);
if (broadcast->get_input_shape(0) == broadcast->get_output_shape(0)) if (broadcast->get_input_shape(0) == broadcast->get_output_shape(0))
{ {
ngraph::replace_node(node, node->get_argument(0)); replace_node(node, node->get_argument(0));
return true; return true;
} }
return false; return false;
} }
HANDLER_DECL(eliminate_stop_gradient) static bool eliminate_stop_gradient(const std::shared_ptr<Node>& node)
{ {
ngraph::replace_node(node, node->get_argument(0)); replace_node(node, node->get_argument(0));
return true; return true;
} }
static const std::unordered_map<std::type_index, static const std::unordered_map<std::type_index, std::function<bool(const std::shared_ptr<Node>&)>>
std::function<bool(const std::shared_ptr<ngraph::Node>&)>> dispatcher{{TI(op::Pad), &eliminate_pad},
dispatcher{{TI(ngraph::op::Pad), &eliminate_pad}, {TI(op::Sum), &eliminate_sum},
{TI(ngraph::op::Sum), &eliminate_sum}, {TI(op::Convert), &eliminate_convert},
{TI(ngraph::op::Convert), &eliminate_convert}, {TI(op::Slice), &eliminate_slice},
{TI(ngraph::op::Slice), &eliminate_slice}, {TI(op::StopGradient), &eliminate_stop_gradient},
{TI(ngraph::op::StopGradient), &eliminate_stop_gradient}, {TI(op::BroadcastLike), &replace_broadcast_like},
{TI(ngraph::op::BroadcastLike), &replace_broadcast_like}, {TI(op::Broadcast), &eliminate_broadcast}};
{TI(ngraph::op::Broadcast), &eliminate_broadcast}};
bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function)
bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Function> function)
{ {
bool clobbered = false; bool clobbered = false;
...@@ -133,10 +132,10 @@ bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -133,10 +132,10 @@ bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Funct
// Here we're checking on a common base class of a family of template classes, // Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle. // which is more than type info can handle.
auto sclb = std::dynamic_pointer_cast<ngraph::op::ScalarConstantLikeBase>(n); auto sclb = std::dynamic_pointer_cast<op::ScalarConstantLikeBase>(n);
if (sclb != nullptr) if (sclb != nullptr)
{ {
ngraph::replace_node(sclb, sclb->as_constant()); replace_node(sclb, sclb->as_constant());
clobbered = true; clobbered = true;
} }
} }
......
...@@ -17,12 +17,15 @@ ...@@ -17,12 +17,15 @@
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
ngraph::pass::ManagerState& ngraph::pass::PassBase::get_state() using namespace std;
using namespace ngraph;
pass::ManagerState& pass::PassBase::get_state()
{ {
return *m_state; return *m_state;
} }
void ngraph::pass::PassBase::set_state(ManagerState& state) void pass::PassBase::set_state(ManagerState& state)
{ {
m_state = &state; m_state = &state;
} }
...@@ -19,10 +19,11 @@ ...@@ -19,10 +19,11 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
// TODO: Add file-based configuration support // TODO: Add file-based configuration support
ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode) pass::PassConfig::PassConfig(pass::CompilationMode mode)
: m_compilation_mode(mode) : m_compilation_mode(mode)
{ {
/** /**
...@@ -32,15 +33,15 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode) ...@@ -32,15 +33,15 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
* E.g., NGRAPH_PASS_ENABLES="CoreFusion:0;LikeReplacement:1;CPUCollapseDims" would * E.g., NGRAPH_PASS_ENABLES="CoreFusion:0;LikeReplacement:1;CPUCollapseDims" would
* set disables on CoreFusion and enables on LikeReplacement and CPUCollapseDims * set disables on CoreFusion and enables on LikeReplacement and CPUCollapseDims
**/ **/
const char* env_str = std::getenv("NGRAPH_PASS_ENABLES"); const char* env_str = getenv("NGRAPH_PASS_ENABLES");
if (env_str) if (env_str)
{ {
std::stringstream ss; stringstream ss;
ss << env_str; ss << env_str;
while (ss.good()) while (ss.good())
{ {
std::string substr; string substr;
std::getline(ss, substr, ';'); getline(ss, substr, ';');
auto split_str = split(substr, ':', false); auto split_str = split(substr, ':', false);
switch (split_str.size()) switch (split_str.size())
{ {
...@@ -58,15 +59,15 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode) ...@@ -58,15 +59,15 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
* would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on * would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on
* "UseDefaultLayouts" * "UseDefaultLayouts"
**/ **/
env_str = std::getenv("NGRAPH_PASS_ATTRIBUTES"); env_str = getenv("NGRAPH_PASS_ATTRIBUTES");
if (env_str) if (env_str)
{ {
std::stringstream ss; stringstream ss;
ss << env_str; ss << env_str;
while (ss.good()) while (ss.good())
{ {
std::string substr; string substr;
std::getline(ss, substr, ';'); getline(ss, substr, ';');
auto split_str = split(substr, '=', false); auto split_str = split(substr, '=', false);
switch (split_str.size()) switch (split_str.size())
{ {
...@@ -80,12 +81,12 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode) ...@@ -80,12 +81,12 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
} }
} }
void ngraph::pass::PassConfig::set_pass_enable(std::string name, bool enable) void pass::PassConfig::set_pass_enable(string name, bool enable)
{ {
m_pass_enables[name] = enable; m_pass_enables[name] = enable;
} }
bool ngraph::pass::PassConfig::get_pass_enable(std::string name) bool pass::PassConfig::get_pass_enable(string name)
{ {
if (m_pass_enables.find(name) == m_pass_enables.end()) if (m_pass_enables.find(name) == m_pass_enables.end())
{ {
...@@ -94,12 +95,12 @@ bool ngraph::pass::PassConfig::get_pass_enable(std::string name) ...@@ -94,12 +95,12 @@ bool ngraph::pass::PassConfig::get_pass_enable(std::string name)
return m_pass_enables[name]; return m_pass_enables[name];
} }
void ngraph::pass::PassConfig::set_pass_attribute(std::string name, bool enable) void pass::PassConfig::set_pass_attribute(string name, bool enable)
{ {
m_pass_attributes[name] = enable; m_pass_attributes[name] = enable;
} }
bool ngraph::pass::PassConfig::get_pass_attribute(std::string name) bool pass::PassConfig::get_pass_attribute(string name)
{ {
if (m_pass_attributes.find(name) == m_pass_attributes.end()) if (m_pass_attributes.find(name) == m_pass_attributes.end())
{ {
......
...@@ -24,14 +24,17 @@ ...@@ -24,14 +24,17 @@
#include "ngraph/pattern/op/any_of.hpp" #include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination() using namespace std;
using namespace ngraph;
pass::PrefixReshapeElimination::PrefixReshapeElimination()
{ {
auto src_op = std::make_shared<pattern::op::Label>( auto src_op = make_shared<pattern::op::Label>(
element::i8, Shape{}, [](std::shared_ptr<Node>) { return true; }); element::i8, Shape{}, [](shared_ptr<Node>) { return true; });
auto reshape_op = std::make_shared<pattern::op::Any>( auto reshape_op = make_shared<pattern::op::Any>(
element::i8, element::i8,
Shape{}, Shape{},
[](std::shared_ptr<Node> node) { [](shared_ptr<Node> node) {
op::Reshape* reshape = dynamic_cast<op::Reshape*>(node.get()); op::Reshape* reshape = dynamic_cast<op::Reshape*>(node.get());
if (!reshape) if (!reshape)
{ {
...@@ -46,14 +49,14 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination() ...@@ -46,14 +49,14 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
// Make sure that logical dimension sizes match. // Make sure that logical dimension sizes match.
const Shape& src_shape = reshape->get_input_shape(0); const Shape& src_shape = reshape->get_input_shape(0);
for (std::size_t idx = 0; idx < reshape->get_output_shape().size(); ++idx) for (size_t idx = 0; idx < reshape->get_output_shape().size(); ++idx)
{ {
std::size_t src_size = 1; size_t src_size = 1;
if (idx < src_shape.size()) if (idx < src_shape.size())
{ {
src_size = src_shape.at(src_shape.size() - 1 - idx); src_size = src_shape.at(src_shape.size() - 1 - idx);
} }
std::size_t dest_size = size_t dest_size =
reshape->get_output_shape().at(reshape->get_output_shape().size() - 1 - idx); reshape->get_output_shape().at(reshape->get_output_shape().size() - 1 - idx);
if (dest_size != src_size) if (dest_size != src_size)
{ {
...@@ -64,10 +67,10 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination() ...@@ -64,10 +67,10 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
return true; return true;
}, },
NodeVector{src_op}); NodeVector{src_op});
auto target_op = std::make_shared<pattern::op::AnyOf>( auto target_op = make_shared<pattern::op::AnyOf>(
element::i8, element::i8,
Shape{}, Shape{},
[](std::shared_ptr<Node> node) { [](shared_ptr<Node> node) {
return pattern::has_class<op::Reshape>()(node) || return pattern::has_class<op::Reshape>()(node) ||
pattern::has_class<op::util::UnaryElementwiseArithmetic>()(node) || pattern::has_class<op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<op::util::BinaryElementwiseArithmetic>()(node); pattern::has_class<op::util::BinaryElementwiseArithmetic>()(node);
...@@ -78,5 +81,5 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination() ...@@ -78,5 +81,5 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2)); replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2));
return true; return true;
}; };
add_matcher(std::make_shared<pattern::Matcher>(target_op, callback)); add_matcher(make_shared<pattern::Matcher>(target_op, callback));
} }
...@@ -22,15 +22,16 @@ ...@@ -22,15 +22,16 @@
#include "ngraph/op/util/op_annotations.hpp" #include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Function> function) bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
{ {
for (auto& node : function->get_ordered_ops()) for (auto& node : function->get_ordered_ops())
{ {
if (node->is_op()) if (node->is_op())
{ {
auto op = std::static_pointer_cast<op::Op>(node); auto op = static_pointer_cast<op::Op>(node);
NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name(); NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name();
auto op_annotations = op->get_op_annotations(); auto op_annotations = op->get_op_annotations();
if (!op_annotations) if (!op_annotations)
...@@ -41,7 +42,7 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi ...@@ -41,7 +42,7 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi
} }
if (node->is_parameter()) if (node->is_parameter())
{ {
auto parameter = std::static_pointer_cast<op::Parameter>(node); auto parameter = static_pointer_cast<op::Parameter>(node);
op_annotations->set_cacheable(parameter->get_cacheable()); op_annotations->set_cacheable(parameter->get_cacheable());
NGRAPH_DEBUG << "propagate cacheability: cacheability is " NGRAPH_DEBUG << "propagate cacheability: cacheability is "
<< parameter->get_cacheable(); << parameter->get_cacheable();
...@@ -54,7 +55,7 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi ...@@ -54,7 +55,7 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi
NGRAPH_DEBUG << "propagate cacheability: arg is " << arg->get_name(); NGRAPH_DEBUG << "propagate cacheability: arg is " << arg->get_name();
if (arg->is_op()) if (arg->is_op())
{ {
auto arg_op = std::static_pointer_cast<op::Op>(arg); auto arg_op = static_pointer_cast<op::Op>(arg);
auto arg_op_annotations = arg_op->get_op_annotations(); auto arg_op_annotations = arg_op->get_op_annotations();
NGRAPH_ASSERT(arg_op_annotations); NGRAPH_ASSERT(arg_op_annotations);
if (!arg_op_annotations->is_cacheable()) if (!arg_op_annotations->is_cacheable())
......
...@@ -33,17 +33,19 @@ ...@@ -33,17 +33,19 @@
#include "ngraph/pattern/op/skip.hpp" #include "ngraph/pattern/op/skip.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
extern template ngraph::AxisVector using namespace std;
ngraph::apply_permutation<ngraph::AxisVector>(ngraph::AxisVector input, using namespace ngraph;
ngraph::AxisVector order);
void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() extern template AxisVector ngraph::apply_permutation<AxisVector>(AxisVector input,
AxisVector order);
void pass::ReshapeElimination::construct_identity_reshape_pattern()
{ {
Shape shape_op{3}; Shape shape_op{3};
Shape shape_r1{1, 3}; Shape shape_r1{1, 3};
auto op = std::make_shared<pattern::op::Label>(element::f32, shape_op); auto op = make_shared<pattern::op::Label>(element::f32, shape_op);
auto reshape1 = std::make_shared<op::Reshape>(op, AxisVector{0}, shape_r1); auto reshape1 = make_shared<op::Reshape>(op, AxisVector{0}, shape_r1);
auto callback = [op](pattern::Matcher& m) { auto callback = [op](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = " NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = "
...@@ -51,7 +53,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -51,7 +53,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto gop = pattern_map[op]; auto gop = pattern_map[op];
auto r1 = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root()); auto r1 = dynamic_pointer_cast<op::Reshape>(m.get_match_root());
if (r1->get_shape() != gop->get_shape()) if (r1->get_shape() != gop->get_shape())
{ {
...@@ -59,7 +61,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -59,7 +61,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
return false; return false;
} }
auto do_r1 = ngraph::get_default_order(r1->get_shape()); auto do_r1 = get_default_order(r1->get_shape());
if (do_r1 != r1->get_input_order()) if (do_r1 != r1->get_input_order())
{ {
...@@ -67,22 +69,22 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -67,22 +69,22 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
return false; return false;
} }
ngraph::replace_node(m.get_match_root(), gop); replace_node(m.get_match_root(), gop);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape1, callback); auto m = make_shared<pattern::Matcher>(reshape1, callback);
this->add_matcher(m); this->add_matcher(m);
} }
void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() void pass::ReshapeElimination::construct_reshapex2_pattern()
{ {
Shape shape_op{3}; Shape shape_op{3};
Shape shape_r1{1, 3}; Shape shape_r1{1, 3};
auto op = std::make_shared<pattern::op::Label>(element::f32, shape_op); auto op = make_shared<pattern::op::Label>(element::f32, shape_op);
auto reshape1 = std::make_shared<op::Reshape>(op, AxisVector{0}, shape_r1); auto reshape1 = make_shared<op::Reshape>(op, AxisVector{0}, shape_r1);
auto reshape2 = std::make_shared<op::Reshape>(reshape1, AxisVector{0, 1}, shape_op); auto reshape2 = make_shared<op::Reshape>(reshape1, AxisVector{0, 1}, shape_op);
auto callback = [op](pattern::Matcher& m) { auto callback = [op](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_reshapex2_pattern against node = " NGRAPH_DEBUG << "In callback for construct_reshapex2_pattern against node = "
...@@ -101,11 +103,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -101,11 +103,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
return false; return false;
} }
auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root()); auto r2 = dynamic_pointer_cast<op::Reshape>(m.get_match_root());
auto r1 = std::dynamic_pointer_cast<op::Reshape>(r2->get_argument(0)); auto r1 = dynamic_pointer_cast<op::Reshape>(r2->get_argument(0));
auto do_r2 = ngraph::get_default_order(r1->get_shape()); auto do_r2 = get_default_order(r1->get_shape());
auto do_r1 = ngraph::get_default_order(gop->get_shape()); auto do_r1 = get_default_order(gop->get_shape());
NGRAPH_DEBUG << "r1's i/o = " << vector_to_string(r1->get_input_order()) NGRAPH_DEBUG << "r1's i/o = " << vector_to_string(r1->get_input_order())
<< "do_r1 = " << vector_to_string(do_r1); << "do_r1 = " << vector_to_string(do_r1);
...@@ -115,40 +117,40 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -115,40 +117,40 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2) if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2)
{ {
NGRAPH_DEBUG << "Two reshapes were removed!"; NGRAPH_DEBUG << "Two reshapes were removed!";
ngraph::replace_node(m.get_match_root(), gop); replace_node(m.get_match_root(), gop);
return true; return true;
} }
auto perm1 = ngraph::apply_permutation(do_r1, r1->get_input_order()); auto perm1 = apply_permutation(do_r1, r1->get_input_order());
auto perm2 = ngraph::apply_permutation(perm1, r2->get_input_order()); auto perm2 = apply_permutation(perm1, r2->get_input_order());
if (perm2 == do_r1) if (perm2 == do_r1)
{ {
NGRAPH_DEBUG << "Two transposes were removed!"; NGRAPH_DEBUG << "Two transposes were removed!";
ngraph::replace_node(m.get_match_root(), gop); replace_node(m.get_match_root(), gop);
return true; return true;
} }
return false; return false;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape2, callback); auto m = make_shared<pattern::Matcher>(reshape2, callback);
this->add_matcher(m); this->add_matcher(m);
} }
void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern() void pass::ReshapeElimination::construct_dot_transpose_pattern()
{ {
// dot(A,B).T = dot (B.T, A.T) // dot(A,B).T = dot (B.T, A.T)
auto dot_pred = [](std::shared_ptr<Node> n) { auto dot_pred = [](shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Dot>(n)); return static_cast<bool>(dynamic_pointer_cast<op::Dot>(n));
}; };
auto pdot = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred); auto pdot = make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = std::make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2}); auto preshape = make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = " NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto mtranspose = std::static_pointer_cast<op::Reshape>(m.get_match_root()); auto mtranspose = static_pointer_cast<op::Reshape>(m.get_match_root());
// this also checks the rank // this also checks the rank
if (mtranspose->get_input_order() != AxisVector{1, 0}) if (mtranspose->get_input_order() != AxisVector{1, 0})
{ {
...@@ -171,7 +173,7 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -171,7 +173,7 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
return false; return false;
} }
auto reshape0_shape = Shape{arg0->get_shape().at(1), arg0->get_shape().at(0)}; auto reshape0_shape = Shape{arg0->get_shape().at(1), arg0->get_shape().at(0)};
auto reshape0 = std::make_shared<op::Reshape>(arg0, AxisVector{1, 0}, reshape0_shape); auto reshape0 = make_shared<op::Reshape>(arg0, AxisVector{1, 0}, reshape0_shape);
auto arg1 = mdot->get_argument(1); auto arg1 = mdot->get_argument(1);
if (arg1->get_shape().size() != 2) if (arg1->get_shape().size() != 2)
...@@ -180,13 +182,13 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -180,13 +182,13 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
return false; return false;
} }
auto reshape1_shape = Shape{arg1->get_shape().at(1), arg1->get_shape().at(0)}; auto reshape1_shape = Shape{arg1->get_shape().at(1), arg1->get_shape().at(0)};
auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape); auto reshape1 = make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
auto tdot = std::shared_ptr<Node>(new op::Dot(reshape1, reshape0)); auto tdot = shared_ptr<Node>(new op::Dot(reshape1, reshape0));
ngraph::replace_node(m.get_match_root(), tdot); replace_node(m.get_match_root(), tdot);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(preshape, callback); auto m = make_shared<pattern::Matcher>(preshape, callback);
this->add_matcher(m); this->add_matcher(m);
} }
...@@ -39,14 +39,15 @@ ...@@ -39,14 +39,15 @@
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
using ReshapeMap = std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>>; using ReshapeMap = unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>;
static std::string describe_reshape(std::shared_ptr<Node> node) static string describe_reshape(shared_ptr<Node> node)
{ {
std::stringstream ss; stringstream ss;
auto reshape = std::dynamic_pointer_cast<op::Reshape>(node); auto reshape = dynamic_pointer_cast<op::Reshape>(node);
ss << reshape->get_name() ss << reshape->get_name()
<< " ( axis order = " << ngraph::vector_to_string(reshape->get_input_order()) << " ( axis order = " << ngraph::vector_to_string(reshape->get_input_order())
<< " , shape = " << vector_to_string(reshape->get_shape()) << " ) " << " , shape = " << vector_to_string(reshape->get_shape()) << " ) "
...@@ -55,25 +56,24 @@ static std::string describe_reshape(std::shared_ptr<Node> node) ...@@ -55,25 +56,24 @@ static std::string describe_reshape(std::shared_ptr<Node> node)
return ss.str(); return ss.str();
} }
static std::shared_ptr<op::Reshape> combine_reshapes(std::shared_ptr<op::Reshape> r1, static shared_ptr<op::Reshape> combine_reshapes(shared_ptr<op::Reshape> r1,
std::shared_ptr<op::Reshape> r2) shared_ptr<op::Reshape> r2)
{ {
auto default_order = ngraph::get_default_order(r1->get_shape()); auto default_order = ngraph::get_default_order(r1->get_shape());
auto perm_r1 = apply_permutation(default_order, r1->get_input_order()); auto perm_r1 = apply_permutation(default_order, r1->get_input_order());
auto perm_r2 = apply_permutation(perm_r1, r2->get_input_order()); auto perm_r2 = apply_permutation(perm_r1, r2->get_input_order());
auto rreshape = std::make_shared<op::Reshape>(r2->get_argument(0), perm_r2, r2->get_shape()); auto rreshape = make_shared<op::Reshape>(r2->get_argument(0), perm_r2, r2->get_shape());
return rreshape; return rreshape;
} }
static void static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, size_t input_index)
insert_reshape(std::shared_ptr<Node> target, std::shared_ptr<Node> reshape, size_t input_index)
{ {
auto arg = target->get_inputs().at(input_index).get_output().get_node(); auto arg = target->get_inputs().at(input_index).get_output().get_node();
auto new_reshape = reshape->copy_with_new_args({arg}); auto new_reshape = reshape->copy_with_new_args({arg});
target->get_inputs().at(input_index).replace_output(new_reshape->get_outputs().at(0)); target->get_inputs().at(input_index).replace_output(new_reshape->get_outputs().at(0));
} }
static void delete_reshape(std::shared_ptr<Node> reshape) static void delete_reshape(shared_ptr<Node> reshape)
{ {
NGRAPH_DEBUG << "Removing reshape " << reshape->get_name(); NGRAPH_DEBUG << "Removing reshape " << reshape->get_name();
if (!reshape->get_users().empty()) if (!reshape->get_users().empty())
...@@ -82,22 +82,22 @@ static void delete_reshape(std::shared_ptr<Node> reshape) ...@@ -82,22 +82,22 @@ static void delete_reshape(std::shared_ptr<Node> reshape)
} }
} }
static void mark_reshape_for_deletion(std::shared_ptr<Node> reshape, static void mark_reshape_for_deletion(shared_ptr<Node> reshape,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
NGRAPH_DEBUG << "Marking reshape " << reshape->get_name() << " for deletion"; NGRAPH_DEBUG << "Marking reshape " << reshape->get_name() << " for deletion";
reshapes_to_delete.insert(reshape); reshapes_to_delete.insert(reshape);
} }
static std::shared_ptr<op::Reshape> create_default_reshape(std::shared_ptr<Node> n) static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n)
{ {
auto default_order = ngraph::get_default_order(n->get_shape()); auto default_order = ngraph::get_default_order(n->get_shape());
auto default_reshape = std::make_shared<op::Reshape>(n, default_order, n->get_shape()); auto default_reshape = make_shared<op::Reshape>(n, default_order, n->get_shape());
return default_reshape; return default_reshape;
} }
//compute an axis order that converts the given axis order to default //compute an axis order that converts the given axis order to default
static AxisSet get_quantization_axes_in_default_order(std::shared_ptr<op::Reshape> arg_reshape, static AxisSet get_quantization_axes_in_default_order(shared_ptr<op::Reshape> arg_reshape,
const AxisSet& old_axis_set) const AxisSet& old_axis_set)
{ {
auto perm_to_def = ngraph::get_permutation_to_default_order(arg_reshape->get_input_order()); auto perm_to_def = ngraph::get_permutation_to_default_order(arg_reshape->get_input_order());
...@@ -112,7 +112,7 @@ static AxisSet get_quantization_axes_in_default_order(std::shared_ptr<op::Reshap ...@@ -112,7 +112,7 @@ static AxisSet get_quantization_axes_in_default_order(std::shared_ptr<op::Reshap
struct Swimmer struct Swimmer
{ {
descriptor::Input* input; descriptor::Input* input;
std::shared_ptr<op::Reshape> reshape; shared_ptr<op::Reshape> reshape;
}; };
//Swim is used to push/"swim" reshapes towards paramaters. //Swim is used to push/"swim" reshapes towards paramaters.
...@@ -121,10 +121,10 @@ struct Swimmer ...@@ -121,10 +121,10 @@ struct Swimmer
//we prefer nchw since a lot of ngraph ops require this format, //we prefer nchw since a lot of ngraph ops require this format,
//so keeping things in nchw allows us to eliminate as many reshapes //so keeping things in nchw allows us to eliminate as many reshapes
//as possible //as possible
void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape) void swim(descriptor::Input* input, shared_ptr<op::Reshape> reshape)
{ {
Swimmer sw{input, reshape}; Swimmer sw{input, reshape};
std::list<Swimmer> work_queue; list<Swimmer> work_queue;
work_queue.push_back(sw); work_queue.push_back(sw);
//TODO: if we support more ops (especially, with >1 args) //TODO: if we support more ops (especially, with >1 args)
...@@ -135,21 +135,21 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape) ...@@ -135,21 +135,21 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
work_queue.pop_front(); work_queue.pop_front();
auto n = csw.input->get_output().get_node(); auto n = csw.input->get_output().get_node();
NGRAPH_DEBUG << "Processing (swimming) " << n->get_name(); NGRAPH_DEBUG << "Processing (swimming) " << n->get_name();
if (auto unary = std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n)) if (auto unary = dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
{ {
Swimmer nsw{&unary->get_inputs().at(0), csw.reshape}; Swimmer nsw{&unary->get_inputs().at(0), csw.reshape};
work_queue.push_back(nsw); work_queue.push_back(nsw);
NGRAPH_DEBUG << "Propagating reshape " << describe_reshape(csw.reshape) << " for " NGRAPH_DEBUG << "Propagating reshape " << describe_reshape(csw.reshape) << " for "
<< n->get_name() << " to " << unary->get_argument(0); << n->get_name() << " to " << unary->get_argument(0);
} }
else if (std::dynamic_pointer_cast<op::Broadcast>(n)) else if (dynamic_pointer_cast<op::Broadcast>(n))
{ {
auto old_broadcast = std::static_pointer_cast<op::Broadcast>(n); auto old_broadcast = static_pointer_cast<op::Broadcast>(n);
auto broadcast_axes = old_broadcast->get_broadcast_axes(); auto broadcast_axes = old_broadcast->get_broadcast_axes();
auto broadcast_reshape = csw.reshape; auto broadcast_reshape = csw.reshape;
bool in_order = true; bool in_order = true;
AxisSet new_broadcast_axes; AxisSet new_broadcast_axes;
std::vector<size_t> new_source_axes; vector<size_t> new_source_axes;
auto input_order = broadcast_reshape->get_input_order(); auto input_order = broadcast_reshape->get_input_order();
for (size_t i = 0; i < input_order.size(); i++) for (size_t i = 0; i < input_order.size(); i++)
{ {
...@@ -171,8 +171,8 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape) ...@@ -171,8 +171,8 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
if (!in_order) if (!in_order)
{ {
AxisVector new_source_axes_sorted{new_source_axes}; AxisVector new_source_axes_sorted{new_source_axes};
std::sort(new_source_axes_sorted.begin(), new_source_axes_sorted.end()); sort(new_source_axes_sorted.begin(), new_source_axes_sorted.end());
std::map<size_t, size_t> old_new_source_axes; map<size_t, size_t> old_new_source_axes;
for (size_t i = 0; new_source_axes_sorted.size(); i++) for (size_t i = 0; new_source_axes_sorted.size(); i++)
{ {
old_new_source_axes.insert({new_source_axes.at(i), i}); old_new_source_axes.insert({new_source_axes.at(i), i});
...@@ -186,11 +186,11 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape) ...@@ -186,11 +186,11 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
auto new_arg_shape = auto new_arg_shape =
ngraph::apply_permutation(broadcast_input->get_shape(), new_source_axis_order); ngraph::apply_permutation(broadcast_input->get_shape(), new_source_axis_order);
broadcast_input = std::make_shared<op::Reshape>( broadcast_input =
broadcast_input, new_source_axis_order, new_arg_shape); make_shared<op::Reshape>(broadcast_input, new_source_axis_order, new_arg_shape);
} }
auto new_broadcast = std::make_shared<op::Broadcast>( auto new_broadcast = make_shared<op::Broadcast>(
broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes); broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes);
csw.input->replace_output(new_broadcast->get_outputs().at(0)); csw.input->replace_output(new_broadcast->get_outputs().at(0));
} }
...@@ -210,11 +210,11 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape) ...@@ -210,11 +210,11 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
//We have to normalize this other argument to nchw by swimming nchw towards parameters //We have to normalize this other argument to nchw by swimming nchw towards parameters
//as far as we can //as far as we can
static void convert_binary_to_default_order( static void convert_binary_to_default_order(
std::shared_ptr<Node> binary, shared_ptr<Node> binary,
descriptor::Input& input, descriptor::Input& input,
std::shared_ptr<Node> right, shared_ptr<Node> right,
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>>& reorders, unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
auto left = input.get_output().get_node(); auto left = input.get_output().get_node();
auto perm_to_def = auto perm_to_def =
...@@ -222,7 +222,7 @@ static void convert_binary_to_default_order( ...@@ -222,7 +222,7 @@ static void convert_binary_to_default_order(
auto new_shape = apply_permutation(left->get_shape(), perm_to_def); auto new_shape = apply_permutation(left->get_shape(), perm_to_def);
NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right->get_shape()) << ", " NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right->get_shape()) << ", "
<< right->get_name(); << right->get_name();
auto new_reshape = std::make_shared<op::Reshape>(left, perm_to_def, new_shape); auto new_reshape = make_shared<op::Reshape>(left, perm_to_def, new_shape);
NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to " NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to "
<< left->get_name(); << left->get_name();
//this should now insert and swim reshape on right //this should now insert and swim reshape on right
...@@ -231,9 +231,9 @@ static void convert_binary_to_default_order( ...@@ -231,9 +231,9 @@ static void convert_binary_to_default_order(
reorders[binary] = reorders.at(right); reorders[binary] = reorders.at(right);
} }
static void materialize_shapes(std::shared_ptr<Node> n, static void materialize_shapes(shared_ptr<Node> n,
ReshapeMap& reorders, ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
//skip multiple output nodes and deal with GOEs exclusively //skip multiple output nodes and deal with GOEs exclusively
if (n->get_outputs().size() > 1) if (n->get_outputs().size() > 1)
...@@ -257,9 +257,9 @@ static void materialize_shapes(std::shared_ptr<Node> n, ...@@ -257,9 +257,9 @@ static void materialize_shapes(std::shared_ptr<Node> n,
reorders[n] = create_default_reshape(n); reorders[n] = create_default_reshape(n);
} }
static void sink_reshape(std::shared_ptr<op::Reshape> reshape, static void sink_reshape(shared_ptr<op::Reshape> reshape,
ReshapeMap& reorders, ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
auto orig_reshape = reorders.at(reshape->get_argument(0)); auto orig_reshape = reorders.at(reshape->get_argument(0));
if (!reshape->get_is_transpose()) if (!reshape->get_is_transpose())
...@@ -286,18 +286,18 @@ static void sink_reshape(std::shared_ptr<op::Reshape> reshape, ...@@ -286,18 +286,18 @@ static void sink_reshape(std::shared_ptr<op::Reshape> reshape,
} }
} }
static void sink_unary(std::shared_ptr<op::util::UnaryElementwiseArithmetic> n, static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n,
ReshapeMap& reorders, ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
auto arg_reshape = reorders.at(n->get_argument(0)); auto arg_reshape = reorders.at(n->get_argument(0));
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name(); NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name();
reorders[n] = reorders[n->get_argument(0)]; reorders[n] = reorders[n->get_argument(0)];
} }
static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> binary, static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
ReshapeMap& reorders, ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
auto left = binary->get_argument(0); auto left = binary->get_argument(0);
auto right = binary->get_argument(1); auto right = binary->get_argument(1);
...@@ -333,9 +333,9 @@ static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> b ...@@ -333,9 +333,9 @@ static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> b
} }
} }
static void sink_slice(std::shared_ptr<op::Slice> n, static void sink_slice(shared_ptr<op::Slice> n,
ReshapeMap& reorders, ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
auto arg_reshape = reorders.at(n->get_argument(0)); auto arg_reshape = reorders.at(n->get_argument(0));
auto order = arg_reshape->get_input_order(); auto order = arg_reshape->get_input_order();
...@@ -346,25 +346,23 @@ static void sink_slice(std::shared_ptr<op::Slice> n, ...@@ -346,25 +346,23 @@ static void sink_slice(std::shared_ptr<op::Slice> n,
auto def_order = ngraph::get_permutation_to_default_order(order); auto def_order = ngraph::get_permutation_to_default_order(order);
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order); auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
auto dummy_correct_shape = auto dummy_correct_shape =
std::make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape); make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
auto new_lower = ngraph::apply_permutation(n->get_lower_bounds(), def_order); auto new_lower = ngraph::apply_permutation(n->get_lower_bounds(), def_order);
auto new_upper = ngraph::apply_permutation(n->get_upper_bounds(), def_order); auto new_upper = ngraph::apply_permutation(n->get_upper_bounds(), def_order);
auto new_strides = ngraph::apply_permutation(n->get_strides(), def_order); auto new_strides = ngraph::apply_permutation(n->get_strides(), def_order);
auto new_slice = auto new_slice = make_shared<op::Slice>(dummy_correct_shape, new_lower, new_upper, new_strides);
std::make_shared<op::Slice>(dummy_correct_shape, new_lower, new_upper, new_strides);
ngraph::replace_node(dummy_correct_shape, n->get_argument(0)); ngraph::replace_node(dummy_correct_shape, n->get_argument(0));
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_slice->get_name(); NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_slice->get_name();
ngraph::replace_node(n, new_slice); ngraph::replace_node(n, new_slice);
auto new_reshape = std::make_shared<op::Reshape>(new_slice, order, n->get_shape()); auto new_reshape = make_shared<op::Reshape>(new_slice, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name(); NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_slice] = new_reshape; reorders[new_slice] = new_reshape;
} }
static void sink_pad(std::shared_ptr<op::Pad> n, static void
ReshapeMap& reorders, sink_pad(shared_ptr<op::Pad> n, ReshapeMap& reorders, set<shared_ptr<Node>>& reshapes_to_delete)
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{ {
auto arg_reshape = reorders.at(n->get_argument(0)); auto arg_reshape = reorders.at(n->get_argument(0));
auto order = arg_reshape->get_input_order(); auto order = arg_reshape->get_input_order();
...@@ -374,41 +372,41 @@ static void sink_pad(std::shared_ptr<op::Pad> n, ...@@ -374,41 +372,41 @@ static void sink_pad(std::shared_ptr<op::Pad> n,
auto def_order = ngraph::get_permutation_to_default_order(order); auto def_order = ngraph::get_permutation_to_default_order(order);
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order); auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
auto dummy_correct_shape = auto dummy_correct_shape =
std::make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape); make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
auto new_lower = ngraph::apply_permutation(n->get_padding_below(), def_order); auto new_lower = ngraph::apply_permutation(n->get_padding_below(), def_order);
auto new_upper = ngraph::apply_permutation(n->get_padding_above(), def_order); auto new_upper = ngraph::apply_permutation(n->get_padding_above(), def_order);
auto new_interior = ngraph::apply_permutation(n->get_padding_interior(), def_order); auto new_interior = ngraph::apply_permutation(n->get_padding_interior(), def_order);
auto new_pad = std::make_shared<op::Pad>( auto new_pad = make_shared<op::Pad>(
dummy_correct_shape, n->get_argument(1), new_lower, new_upper, new_interior); dummy_correct_shape, n->get_argument(1), new_lower, new_upper, new_interior);
ngraph::replace_node(dummy_correct_shape, n->get_argument(0)); ngraph::replace_node(dummy_correct_shape, n->get_argument(0));
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name(); NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
ngraph::replace_node(n, new_pad); ngraph::replace_node(n, new_pad);
auto new_reshape = std::make_shared<op::Reshape>(new_pad, order, n->get_shape()); auto new_reshape = make_shared<op::Reshape>(new_pad, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name(); NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_pad] = new_reshape; reorders[new_pad] = new_reshape;
} }
static void sink_quantize(std::shared_ptr<op::Quantize> quantize, static void sink_quantize(shared_ptr<op::Quantize> quantize,
ReshapeMap& reorders, ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
auto arg_reshape = reorders.at(quantize->get_argument(0)); auto arg_reshape = reorders.at(quantize->get_argument(0));
AxisSet axes_in_def_order = AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, quantize->get_axes()); get_quantization_axes_in_default_order(arg_reshape, quantize->get_axes());
auto new_quantize = std::make_shared<op::Quantize>(quantize->get_argument(0), auto new_quantize = make_shared<op::Quantize>(quantize->get_argument(0),
quantize->get_argument(1), quantize->get_argument(1),
quantize->get_argument(2), quantize->get_argument(2),
quantize->get_element_type(), quantize->get_element_type(),
axes_in_def_order, axes_in_def_order,
quantize->get_round_mode()); quantize->get_round_mode());
ngraph::replace_node(quantize, new_quantize); ngraph::replace_node(quantize, new_quantize);
reorders[new_quantize] = arg_reshape; reorders[new_quantize] = arg_reshape;
} }
static void sink_concat(std::shared_ptr<op::Concat> n, static void sink_concat(shared_ptr<op::Concat> n,
ReshapeMap& reorders, ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
auto arg_reshape = reorders.at(n->get_argument(0)); auto arg_reshape = reorders.at(n->get_argument(0));
auto order = arg_reshape->get_input_order(); auto order = arg_reshape->get_input_order();
...@@ -418,7 +416,7 @@ static void sink_concat(std::shared_ptr<op::Concat> n, ...@@ -418,7 +416,7 @@ static void sink_concat(std::shared_ptr<op::Concat> n,
auto def_order = ngraph::get_permutation_to_default_order(order); auto def_order = ngraph::get_permutation_to_default_order(order);
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order); auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
auto dummy_correct_shape = auto dummy_correct_shape =
std::make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape); make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
NodeVector new_args; NodeVector new_args;
new_args.push_back(dummy_correct_shape); new_args.push_back(dummy_correct_shape);
...@@ -436,12 +434,12 @@ static void sink_concat(std::shared_ptr<op::Concat> n, ...@@ -436,12 +434,12 @@ static void sink_concat(std::shared_ptr<op::Concat> n,
auto iinput_shape = ngraph::apply_permutation(iarg_reshape->get_shape(), def_order); auto iinput_shape = ngraph::apply_permutation(iarg_reshape->get_shape(), def_order);
auto idummy_correct_shape = auto idummy_correct_shape =
std::make_shared<pattern::op::Label>(iarg_reshape->get_element_type(), iinput_shape); make_shared<pattern::op::Label>(iarg_reshape->get_element_type(), iinput_shape);
new_args.push_back(idummy_correct_shape); new_args.push_back(idummy_correct_shape);
} }
auto new_axis = order.at(n->get_concatenation_axis()); auto new_axis = order.at(n->get_concatenation_axis());
auto new_concat = std::make_shared<op::Concat>(new_args, new_axis); auto new_concat = make_shared<op::Concat>(new_args, new_axis);
//put back the original arguments //put back the original arguments
for (size_t i = 0; i < new_concat->get_input_size(); i++) for (size_t i = 0; i < new_concat->get_input_size(); i++)
{ {
...@@ -450,23 +448,23 @@ static void sink_concat(std::shared_ptr<op::Concat> n, ...@@ -450,23 +448,23 @@ static void sink_concat(std::shared_ptr<op::Concat> n,
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name(); NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
ngraph::replace_node(n, new_concat); ngraph::replace_node(n, new_concat);
auto new_reshape = std::make_shared<op::Reshape>(new_concat, order, n->get_shape()); auto new_reshape = make_shared<op::Reshape>(new_concat, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name(); NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_concat] = new_reshape; reorders[new_concat] = new_reshape;
} }
static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize, static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
ReshapeMap& reorders, ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
auto arg_reshape = reorders.at(dequantize->get_argument(0)); auto arg_reshape = reorders.at(dequantize->get_argument(0));
AxisSet axes_in_def_order = AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, dequantize->get_axes()); get_quantization_axes_in_default_order(arg_reshape, dequantize->get_axes());
auto new_dequantize = std::make_shared<op::Dequantize>(dequantize->get_argument(0), auto new_dequantize = make_shared<op::Dequantize>(dequantize->get_argument(0),
dequantize->get_argument(1), dequantize->get_argument(1),
dequantize->get_argument(2), dequantize->get_argument(2),
dequantize->get_element_type(), dequantize->get_element_type(),
axes_in_def_order); axes_in_def_order);
ngraph::replace_node(dequantize, new_dequantize); ngraph::replace_node(dequantize, new_dequantize);
reorders[new_dequantize] = arg_reshape; reorders[new_dequantize] = arg_reshape;
...@@ -481,11 +479,11 @@ static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize, ...@@ -481,11 +479,11 @@ static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
//For each op type we support we can either combine //For each op type we support we can either combine
//two reshapes by replacing the existing Reshape, //two reshapes by replacing the existing Reshape,
//materialize pending reshapes if they can't be propagated through op //materialize pending reshapes if they can't be propagated through op
bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Function> f) bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> f)
{ {
ReshapeMap reorders; ReshapeMap reorders;
NodeVector results; NodeVector results;
std::set<std::shared_ptr<Node>> reshapes_to_delete; set<shared_ptr<Node>> reshapes_to_delete;
//STEP 1 : Sink or Swim reshapes away for op clusters //STEP 1 : Sink or Swim reshapes away for op clusters
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
...@@ -497,31 +495,31 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -497,31 +495,31 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
results.push_back(n); results.push_back(n);
} }
if (auto reshape = std::dynamic_pointer_cast<op::Reshape>(n)) if (auto reshape = dynamic_pointer_cast<op::Reshape>(n))
{ {
sink_reshape(reshape, reorders, reshapes_to_delete); sink_reshape(reshape, reorders, reshapes_to_delete);
} }
else if (auto unary = std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n)) else if (auto unary = dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
{ {
sink_unary(unary, reorders, reshapes_to_delete); sink_unary(unary, reorders, reshapes_to_delete);
} }
else if (auto binary = std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n)) else if (auto binary = dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n))
{ {
sink_binary(binary, reorders, reshapes_to_delete); sink_binary(binary, reorders, reshapes_to_delete);
} }
else if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(n)) else if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(n))
{ {
reorders[goe] = create_default_reshape(goe); reorders[goe] = create_default_reshape(goe);
} }
else if (auto quantize = std::dynamic_pointer_cast<op::Quantize>(n)) else if (auto quantize = dynamic_pointer_cast<op::Quantize>(n))
{ {
sink_quantize(quantize, reorders, reshapes_to_delete); sink_quantize(quantize, reorders, reshapes_to_delete);
} }
else if (auto dequantize = std::dynamic_pointer_cast<op::Dequantize>(n)) else if (auto dequantize = dynamic_pointer_cast<op::Dequantize>(n))
{ {
sink_dequantize(dequantize, reorders, reshapes_to_delete); sink_dequantize(dequantize, reorders, reshapes_to_delete);
} }
else if (auto slice = std::dynamic_pointer_cast<op::Slice>(n)) else if (auto slice = dynamic_pointer_cast<op::Slice>(n))
{ {
// A heuristic. If Reshape has multiple slice users, if sunk // A heuristic. If Reshape has multiple slice users, if sunk
// it will be replicated by the number of its users // it will be replicated by the number of its users
...@@ -542,11 +540,11 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -542,11 +540,11 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
materialize_shapes(n, reorders, reshapes_to_delete); materialize_shapes(n, reorders, reshapes_to_delete);
} }
} }
else if (auto pad = std::dynamic_pointer_cast<op::Pad>(n)) else if (auto pad = dynamic_pointer_cast<op::Pad>(n))
{ {
sink_pad(pad, reorders, reshapes_to_delete); sink_pad(pad, reorders, reshapes_to_delete);
} }
else if (auto concat = std::dynamic_pointer_cast<op::Concat>(n)) else if (auto concat = dynamic_pointer_cast<op::Concat>(n))
{ {
sink_concat(concat, reorders, reshapes_to_delete); sink_concat(concat, reorders, reshapes_to_delete);
} }
......
...@@ -27,9 +27,9 @@ ...@@ -27,9 +27,9 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
#define TI(x) std::type_index(typeid(x)) #define TI(x) type_index(typeid(x))
bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& functions) bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
{ {
for (shared_ptr<Function> f : functions) for (shared_ptr<Function> f : functions)
{ {
...@@ -42,10 +42,10 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu ...@@ -42,10 +42,10 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu
m_ss << add_attributes(node); m_ss << add_attributes(node);
m_ss << " " << arg->get_name() << " -> " << node->get_name(); m_ss << " " << arg->get_name() << " -> " << node->get_name();
if (std::getenv("NGRAPH_VISUALIZE_EDGE_LABELS") != nullptr) if (getenv("NGRAPH_VISUALIZE_EDGE_LABELS") != nullptr)
{ {
size_t output = 0; size_t output = 0;
if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(node)) if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(node))
{ {
output = goe->get_n(); output = goe->get_n();
} }
...@@ -71,7 +71,7 @@ pass::VisualizeTree::VisualizeTree(const string& file_name, node_modifiers_t nm) ...@@ -71,7 +71,7 @@ pass::VisualizeTree::VisualizeTree(const string& file_name, node_modifiers_t nm)
{ {
} }
std::string pass::VisualizeTree::add_attributes(shared_ptr<Node> node) string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
{ {
string rc; string rc;
if (m_nodes_with_attributes.find(node) == m_nodes_with_attributes.end()) if (m_nodes_with_attributes.find(node) == m_nodes_with_attributes.end())
...@@ -82,7 +82,7 @@ std::string pass::VisualizeTree::add_attributes(shared_ptr<Node> node) ...@@ -82,7 +82,7 @@ std::string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
return rc; return rc;
} }
std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node) string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
{ {
vector<string> attributes; vector<string> attributes;
if (node->is_parameter() || node->is_output()) if (node->is_parameter() || node->is_output())
...@@ -110,22 +110,22 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node) ...@@ -110,22 +110,22 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
stringstream label; stringstream label;
label << "label=\"" << node->get_friendly_name(); label << "label=\"" << node->get_friendly_name();
static const char* nvtos = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES"); static const char* nvtos = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES");
if (nvtos != nullptr) if (nvtos != nullptr)
{ {
// The shapes of the Outputs of a multi-output op // The shapes of the Outputs of a multi-output op
// will be printed for its corresponding `GetOutputElement`s // will be printed for its corresponding `GetOutputElement`s
label << " " << (node->get_outputs().size() != 1 ? std::string("[skipped]") label << " " << (node->get_outputs().size() != 1 ? string("[skipped]")
: vector_to_string(node->get_shape())); : vector_to_string(node->get_shape()));
} }
static const char* nvtot = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_TYPES"); static const char* nvtot = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_TYPES");
if (nvtot != nullptr) if (nvtot != nullptr)
{ {
// The types of the Outputs of a multi-output op // The types of the Outputs of a multi-output op
// will be printed for its corresponding `GetOutputElement`s // will be printed for its corresponding `GetOutputElement`s
label << " " label << " "
<< ((node->get_outputs().size() != 1) ? std::string("[skipped]") << ((node->get_outputs().size() != 1) ? string("[skipped]")
: node->get_element_type().c_type_string()); : node->get_element_type().c_type_string());
} }
...@@ -150,9 +150,9 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node) ...@@ -150,9 +150,9 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
return ss.str(); return ss.str();
} }
std::string pass::VisualizeTree::get_file_ext() string pass::VisualizeTree::get_file_ext()
{ {
const char* format = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT"); const char* format = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT");
if (!format) if (!format)
{ {
format = "png"; format = "png";
...@@ -163,7 +163,7 @@ std::string pass::VisualizeTree::get_file_ext() ...@@ -163,7 +163,7 @@ std::string pass::VisualizeTree::get_file_ext()
format += 1; format += 1;
} }
return std::string(format); return string(format);
} }
void pass::VisualizeTree::render() const void pass::VisualizeTree::render() const
......
...@@ -30,9 +30,10 @@ ...@@ -30,9 +30,10 @@
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "zero_dim_tensor_elimination.hpp" #include "zero_dim_tensor_elimination.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
static bool has_zero_dim(std::shared_ptr<Node> node) static bool has_zero_dim(shared_ptr<Node> node)
{ {
if (node->get_output_size() != 1) if (node->get_output_size() != 1)
{ {
...@@ -40,12 +41,12 @@ static bool has_zero_dim(std::shared_ptr<Node> node) ...@@ -40,12 +41,12 @@ static bool has_zero_dim(std::shared_ptr<Node> node)
} }
const auto& shape = node->get_shape(); const auto& shape = node->get_shape();
return std::find(shape.begin(), shape.end(), 0) != shape.end(); return find(shape.begin(), shape.end(), 0) != shape.end();
} }
static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function> f) static bool verify_no_internal_zero_length_ops(shared_ptr<Function> f)
{ {
std::set<std::shared_ptr<Node>> zero_length_nodes; set<shared_ptr<Node>> zero_length_nodes;
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
if (n->is_output() || n->is_parameter() || n->get_outputs().size() > 1) if (n->is_output() || n->is_parameter() || n->get_outputs().size() > 1)
...@@ -76,10 +77,10 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function> ...@@ -76,10 +77,10 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function>
return zero_length_nodes.size() > 0; return zero_length_nodes.size() > 0;
} }
bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngraph::Function> f) bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
{ {
bool replaced = false; bool replaced = false;
auto cvals = std::vector<std::string>(0); auto cvals = vector<string>(0);
// we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op // we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op
// as an internal node (i.e. a node that isn't an argument to `op::Result`) // as an internal node (i.e. a node that isn't an argument to `op::Result`)
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
...@@ -98,8 +99,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr ...@@ -98,8 +99,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
{ {
// we don't have to create constants every time but this is the easiest // we don't have to create constants every time but this is the easiest
// and it's CSE's job to eliminate the same ones // and it's CSE's job to eliminate the same ones
auto constant = auto constant = make_shared<op::Constant>(n->get_element_type(), n->get_shape(), cvals);
std::make_shared<op::Constant>(n->get_element_type(), n->get_shape(), cvals);
replace_node(n, constant); replace_node(n, constant);
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << constant->get_name(); NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << constant->get_name();
replaced = true; replaced = true;
...@@ -111,7 +111,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr ...@@ -111,7 +111,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
continue; continue;
} }
if (auto concat = std::dynamic_pointer_cast<op::Concat>(n)) if (auto concat = dynamic_pointer_cast<op::Concat>(n))
{ {
NodeVector non_zero_dim_args; NodeVector non_zero_dim_args;
for (auto arg : concat->get_arguments()) for (auto arg : concat->get_arguments())
...@@ -127,7 +127,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr ...@@ -127,7 +127,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
auto new_concat = concat->copy_with_new_args(non_zero_dim_args); auto new_concat = concat->copy_with_new_args(non_zero_dim_args);
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " NGRAPH_DEBUG << " Replacing " << n->get_name() << " with "
<< new_concat->get_name(); << new_concat->get_name();
ngraph::replace_node(concat, new_concat); replace_node(concat, new_concat);
continue; continue;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment