Commit 13fc556e authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add ngraph and std namespaces to c++ files (#2549)

* add ngraph and std namespaces to c++ files

* style
parent ef3378c1
......@@ -39,29 +39,26 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
#define TI(x) std::type_index(typeid(x))
#define TI(x) type_index(typeid(x))
extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input,
ngraph::AxisVector order);
extern template Shape ngraph::apply_permutation<Shape>(Shape input, AxisVector order);
template <typename T>
static std::shared_ptr<pattern::Matcher>
create_binary_matcher(std::shared_ptr<pattern::op::Label> label,
std::shared_ptr<pattern::op::Label> const_label)
static shared_ptr<pattern::Matcher>
create_binary_matcher(shared_ptr<pattern::op::Label> label,
shared_ptr<pattern::op::Label> const_label)
{
auto bcst =
std::make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::Broadcast>());
auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher =
std::make_shared<pattern::Matcher>(std::make_shared<T>(label, bcst_label), nullptr);
auto bcst = make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::Broadcast>());
auto bcst_label = make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher = make_shared<pattern::Matcher>(make_shared<T>(label, bcst_label), nullptr);
return matcher;
}
static std::shared_ptr<pattern::op::Label>
get_broadcast_label(std::shared_ptr<pattern::Matcher> matcher)
static shared_ptr<pattern::op::Label> get_broadcast_label(shared_ptr<pattern::Matcher> matcher)
{
return std::dynamic_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1));
return dynamic_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1));
}
//`simplify_concat` identifies slices-concat sequences
......@@ -75,23 +72,21 @@ static std::shared_ptr<pattern::op::Label>
// +-------+ | +----------+ | +-----------+
// +----+slice(0..n/2)---+
// +----------+
static bool simplify_concat(std::shared_ptr<Node> n)
static bool simplify_concat(shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_concat for " << n->get_name();
std::shared_ptr<Node> branch_tip;
shared_ptr<Node> branch_tip;
auto ltip = std::make_shared<pattern::op::Label>(element::i32, Shape{2, 1});
auto ltip = make_shared<pattern::op::Label>(element::i32, Shape{2, 1});
auto pslice =
std::make_shared<op::Slice>(ltip, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
auto pslice = make_shared<op::Slice>(ltip, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
auto lslice = std::make_shared<pattern::op::Label>(pslice, nullptr, NodeVector{pslice});
auto lslice = make_shared<pattern::op::Label>(pslice, nullptr, NodeVector{pslice});
auto skip_reshape =
std::make_shared<pattern::op::Skip>(lslice, pattern::has_class<op::Reshape>());
auto skip_reshape = make_shared<pattern::op::Skip>(lslice, pattern::has_class<op::Reshape>());
auto matcher = std::make_shared<pattern::Matcher>(skip_reshape, nullptr);
auto matcher = make_shared<pattern::Matcher>(skip_reshape, nullptr);
Coordinate prev_lower_bounds;
Shape prev_slice_shape;
......@@ -104,7 +99,7 @@ static bool simplify_concat(std::shared_ptr<Node> n)
return false;
}
auto slice = std::static_pointer_cast<op::Slice>(matcher->get_pattern_map()[lslice]);
auto slice = static_pointer_cast<op::Slice>(matcher->get_pattern_map()[lslice]);
if (branch_tip)
{
if (branch_tip != matcher->get_pattern_map()[ltip])
......@@ -153,9 +148,9 @@ static bool simplify_concat(std::shared_ptr<Node> n)
}
//check that no other node uses slices and reshapes
if (auto rcarg = std::dynamic_pointer_cast<op::Reshape>(carg))
if (auto rcarg = dynamic_pointer_cast<op::Reshape>(carg))
{
auto default_shape = ngraph::get_default_order(rcarg->get_argument(0)->get_shape());
auto default_shape = get_default_order(rcarg->get_argument(0)->get_shape());
if (default_shape != rcarg->get_input_order())
{
NGRAPH_DEBUG << carg->get_name() << " reshape also does transposes";
......@@ -170,11 +165,11 @@ static bool simplify_concat(std::shared_ptr<Node> n)
}
}
auto concat = std::static_pointer_cast<op::Concat>(n);
auto concat = static_pointer_cast<op::Concat>(n);
size_t concat_axis = concat->get_concatenation_axis();
auto slice_shape = branch_tip->get_users(true).at(0)->get_shape();
size_t slice_axis = std::numeric_limits<size_t>::max();
size_t slice_axis = numeric_limits<size_t>::max();
auto btip_shape = branch_tip->get_shape();
......@@ -191,7 +186,7 @@ static bool simplify_concat(std::shared_ptr<Node> n)
{
if (btip_shape[i] != slice_shape[i])
{
if (slice_axis != std::numeric_limits<size_t>::max())
if (slice_axis != numeric_limits<size_t>::max())
{
// multi-axis slice + concat do not cancel
return false;
......@@ -200,19 +195,18 @@ static bool simplify_concat(std::shared_ptr<Node> n)
}
}
if (slice_axis == std::numeric_limits<size_t>::max())
if (slice_axis == numeric_limits<size_t>::max())
{
return false;
}
auto replacement = branch_tip;
if (btip_shape != n->get_shape())
{
auto default_order = ngraph::get_default_order(btip_shape);
auto default_order = get_default_order(btip_shape);
if (concat_axis == slice_axis)
{
// logical reshape only
replacement =
std::make_shared<op::Reshape>(branch_tip, default_order, concat->get_shape());
replacement = make_shared<op::Reshape>(branch_tip, default_order, concat->get_shape());
}
else
{
......@@ -221,30 +215,29 @@ static bool simplify_concat(std::shared_ptr<Node> n)
if (btip_shape.size() >= transposed_shape.size())
{
AxisVector order = ngraph::get_default_order(btip_shape);
AxisVector order = get_default_order(btip_shape);
auto ax = order[slice_axis];
order[slice_axis] = order[concat_axis];
order[concat_axis] = ax;
replacement = std::make_shared<op::Reshape>(branch_tip, order, transposed_shape);
replacement = make_shared<op::Reshape>(branch_tip, order, transposed_shape);
}
else if (btip_shape.size() < transposed_shape.size())
{
// intermediate logical reshape
AxisVector order = ngraph::get_default_order(transposed_shape);
AxisVector order = get_default_order(transposed_shape);
auto ax = order[slice_axis];
order[slice_axis] = order[concat_axis];
order[concat_axis] = ax;
auto output_shape = ngraph::apply_permutation(transposed_shape, order);
auto output_shape = apply_permutation(transposed_shape, order);
auto logical_reshape =
std::make_shared<op::Reshape>(branch_tip, default_order, output_shape);
make_shared<op::Reshape>(branch_tip, default_order, output_shape);
// transpose to final concatenated shape
replacement =
std::make_shared<op::Reshape>(logical_reshape, order, transposed_shape);
replacement = make_shared<op::Reshape>(logical_reshape, order, transposed_shape);
}
}
}
ngraph::replace_node(n, replacement);
replace_node(n, replacement);
return true;
}
......@@ -255,15 +248,13 @@ static bool simplify_concat(std::shared_ptr<Node> n)
//a * broadcast(0) -> broadcast(0)
//a * 1 -> a
//a * broadcast(1) -> a
static bool simplify_multiply(std::shared_ptr<Node> n)
static bool simplify_multiply(shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_multiply for " << n->get_name();
auto iconst = ngraph::make_zero(element::i32, Shape{});
auto label = std::make_shared<pattern::op::Label>(iconst);
auto const_label_zero =
std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
auto const_label_one =
std::make_shared<pattern::op::Label>(iconst, ngraph::is_one, NodeVector{iconst});
auto iconst = make_zero(element::i32, Shape{});
auto label = make_shared<pattern::op::Label>(iconst);
auto const_label_zero = make_shared<pattern::op::Label>(iconst, is_zero, NodeVector{iconst});
auto const_label_one = make_shared<pattern::op::Label>(iconst, is_one, NodeVector{iconst});
auto matcher_const_zero = create_binary_matcher<op::Multiply>(label, const_label_zero);
auto matcher_const_one = create_binary_matcher<op::Multiply>(label, const_label_one);
......@@ -273,7 +264,7 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
auto bcst_label = get_broadcast_label(matcher_const_zero);
auto bcst_or_cnst = matcher_const_zero->get_pattern_map()[bcst_label];
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << bcst_or_cnst->get_name();
ngraph::replace_node(n, bcst_or_cnst);
replace_node(n, bcst_or_cnst);
return true;
}
......@@ -281,7 +272,7 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
{
auto x = matcher_const_one->get_pattern_map()[label];
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << x->get_name();
ngraph::replace_node(n, x);
replace_node(n, x);
return true;
}
......@@ -293,12 +284,12 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
//
//a + 0 -> a
//a + broadcast(0) -> a
static bool simplify_add(std::shared_ptr<Node> n)
static bool simplify_add(shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_add for " << n->get_name();
auto iconst = ngraph::make_zero(element::i32, Shape{});
auto label = std::make_shared<pattern::op::Label>(iconst);
auto const_label = std::make_shared<pattern::op::Label>(iconst, nullptr, NodeVector{iconst});
auto iconst = make_zero(element::i32, Shape{});
auto label = make_shared<pattern::op::Label>(iconst);
auto const_label = make_shared<pattern::op::Label>(iconst, nullptr, NodeVector{iconst});
auto matcher = create_binary_matcher<op::Add>(label, const_label);
if (matcher->match(n))
......@@ -309,10 +300,10 @@ static bool simplify_add(std::shared_ptr<Node> n)
NGRAPH_DEBUG << "Node " << n->get_name() << " matched \" arg + 0 \" \n"
<< " arg : " << x->get_name() << " , const : " << cnst->get_name();
if (ngraph::is_zero(cnst))
if (is_zero(cnst))
{
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << x->get_name();
ngraph::replace_node(n, x);
replace_node(n, x);
return true;
}
else
......@@ -324,16 +315,16 @@ static bool simplify_add(std::shared_ptr<Node> n)
}
//`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)`
static bool simplify_log(std::shared_ptr<Node> n)
static bool simplify_log(shared_ptr<Node> n)
{
if (auto div = std::dynamic_pointer_cast<op::Divide>(n->get_argument(0)))
if (auto div = dynamic_pointer_cast<op::Divide>(n->get_argument(0)))
{
if (auto exp = std::dynamic_pointer_cast<op::Exp>(div->get_argument(0)))
if (auto exp = dynamic_pointer_cast<op::Exp>(div->get_argument(0)))
{
auto denom = div->get_argument(1);
auto diff = std::make_shared<op::Subtract>(exp->get_argument(0),
std::make_shared<op::Log>(denom));
ngraph::replace_node(n, diff);
auto diff =
make_shared<op::Subtract>(exp->get_argument(0), make_shared<op::Log>(denom));
replace_node(n, diff);
return true;
}
}
......@@ -353,16 +344,15 @@ static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape)
}
template <typename T>
static std::shared_ptr<Node>
multiply_by(element::Type type, size_t multiplier, std::shared_ptr<op::Constant> cnst)
static shared_ptr<Node>
multiply_by(element::Type type, size_t multiplier, shared_ptr<op::Constant> cnst)
{
T sum_cnst = static_cast<T>(cnst->get_vector<T>().at(0) * multiplier);
return op::Constant::create<T>(type, Shape{}, {sum_cnst});
}
template <typename T>
static std::shared_ptr<Node>
pow_by(element::Type type, size_t multiplier, std::shared_ptr<op::Constant> cnst)
static shared_ptr<Node> pow_by(element::Type type, size_t multiplier, shared_ptr<op::Constant> cnst)
{
T prod = static_cast<T>(1);
T val = cnst->get_vector<T>().at(0);
......@@ -373,7 +363,7 @@ static std::shared_ptr<Node>
return op::Constant::create<T>(type, Shape{}, {prod});
}
static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst, size_t multiplier)
static shared_ptr<Node> get_sum_constant(shared_ptr<op::Constant> cnst, size_t multiplier)
{
if (cnst->get_element_type() == element::i32)
{
......@@ -395,8 +385,7 @@ static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst
return nullptr;
}
static std::shared_ptr<Node> get_prod_constant(std::shared_ptr<op::Constant> cnst,
size_t multiplier)
static shared_ptr<Node> get_prod_constant(shared_ptr<op::Constant> cnst, size_t multiplier)
{
if (cnst->get_element_type() == element::i32)
{
......@@ -423,21 +412,20 @@ static std::shared_ptr<Node> get_prod_constant(std::shared_ptr<op::Constant> cns
//where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
//product(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant ^ shape_size(reduction_axes)
template <typename T,
std::shared_ptr<Node> (*F)(std::shared_ptr<op::Constant> cnst, size_t multiplier)>
static bool simplify_reduction(std::shared_ptr<Node> n)
template <typename T, shared_ptr<Node> (*F)(shared_ptr<op::Constant> cnst, size_t multiplier)>
static bool simplify_reduction(shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name();
auto reduction = std::static_pointer_cast<T>(n);
auto reduction = static_pointer_cast<T>(n);
auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(n->get_argument(0));
auto broadcast = dynamic_pointer_cast<op::Broadcast>(n->get_argument(0));
if (!broadcast)
{
NGRAPH_DEBUG << n->get_name() << " isn't Broadcast";
return false;
}
auto cnst = std::dynamic_pointer_cast<op::Constant>(broadcast->get_argument(0));
auto cnst = dynamic_pointer_cast<op::Constant>(broadcast->get_argument(0));
if (!cnst || cnst->get_shape().size() > 0 /*not a scalar*/)
{
NGRAPH_DEBUG << broadcast->get_argument(0)->get_name() << " isn't a scalar constant";
......@@ -456,39 +444,35 @@ static bool simplify_reduction(std::shared_ptr<Node> n)
if (reduction->get_shape().size() > 0)
{
ngraph::AxisSet axes{};
AxisSet axes{};
for (size_t i = 0; i < reduction->get_shape().size(); i++)
{
axes.insert(i);
}
reduction_cnst =
std::make_shared<op::Broadcast>(reduction_cnst, reduction->get_shape(), axes);
reduction_cnst = make_shared<op::Broadcast>(reduction_cnst, reduction->get_shape(), axes);
}
ngraph::replace_node(n, reduction_cnst);
replace_node(n, reduction_cnst);
return true;
}
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>
initialize_ops_to_simplifiers()
static unordered_map<type_index, function<bool(shared_ptr<Node>)>> initialize_ops_to_simplifiers()
{
return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>(
return unordered_map<type_index, function<bool(shared_ptr<Node>)>>(
{{TI(op::Add), simplify_add},
{TI(op::Multiply), simplify_multiply},
{TI(op::Concat), simplify_concat},
{TI(op::Sum),
std::function<bool(std::shared_ptr<Node>)>{
simplify_reduction<op::Sum, get_sum_constant>}},
function<bool(shared_ptr<Node>)>{simplify_reduction<op::Sum, get_sum_constant>}},
{TI(op::Product),
std::function<bool(std::shared_ptr<Node>)>{
simplify_reduction<op::Product, get_prod_constant>}},
function<bool(shared_ptr<Node>)>{simplify_reduction<op::Product, get_prod_constant>}},
{TI(op::Log), simplify_log}});
}
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>
ops_to_simplifiers = initialize_ops_to_simplifiers();
static unordered_map<type_index, function<bool(shared_ptr<Node>)>> ops_to_simplifiers =
initialize_ops_to_simplifiers();
bool ngraph::pass::AlgebraicSimplification::run_on_function(std::shared_ptr<ngraph::Function> f)
bool pass::AlgebraicSimplification::run_on_function(shared_ptr<Function> f)
{
bool replaced = false;
for (auto n : f->get_ordered_ops())
......
......@@ -89,7 +89,7 @@ shared_ptr<op::Constant> make_constant_pad(shared_ptr<op::Constant> constant,
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_pad()
void pass::ConstantFolding::construct_constant_pad()
{
auto is_constant = pattern::has_class<op::Constant>();
auto constant_label = make_shared<pattern::op::Label>(element::f32, Shape{6}, is_constant);
......@@ -142,7 +142,7 @@ void ngraph::pass::ConstantFolding::construct_constant_pad()
this->add_matcher(pad_matcher);
}
void ngraph::pass::ConstantFolding::construct_constant_reshape()
void pass::ConstantFolding::construct_constant_reshape()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
......@@ -207,7 +207,7 @@ shared_ptr<op::Constant> make_constant_broadcast(shared_ptr<op::Constant> consta
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_broadcast()
void pass::ConstantFolding::construct_constant_broadcast()
{
auto constant_label =
make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>());
......@@ -324,7 +324,7 @@ bool is_supported_binary_op(std::shared_ptr<Node> n)
std::dynamic_pointer_cast<op::Minimum>(n));
}
void ngraph::pass::ConstantFolding::construct_constant_binary()
void pass::ConstantFolding::construct_constant_binary()
{
auto a = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
......@@ -418,7 +418,7 @@ shared_ptr<op::Constant> make_constant_unary(shared_ptr<op::Constant> constant,
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_unary()
void pass::ConstantFolding::construct_constant_unary()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
......@@ -493,7 +493,7 @@ shared_ptr<op::Constant> make_constant_dequantize(shared_ptr<op::Constant> const
return make_shared<op::Constant>(dequant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_dequantize()
void pass::ConstantFolding::construct_constant_dequantize()
{
auto constant_label =
make_shared<pattern::op::Label>(element::u8, Shape{2}, pattern::has_class<op::Constant>());
......@@ -567,7 +567,7 @@ shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constan
return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_quantize()
void pass::ConstantFolding::construct_constant_quantize()
{
auto constant_label =
make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>());
......
......@@ -69,7 +69,7 @@ void pass::CoreFusion::construct_relu()
auto pattern_map = m.get_pattern_map();
auto mzero = m.get_pattern_map()[zero];
if (!ngraph::is_zero(mzero))
if (!is_zero(mzero))
{
NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n";
return false;
......@@ -77,7 +77,7 @@ void pass::CoreFusion::construct_relu()
auto mpattern = m.get_match_root();
auto cg = shared_ptr<Node>(new op::Relu(pattern_map[val]));
ngraph::replace_node(m.get_match_root(), cg);
replace_node(m.get_match_root(), cg);
return true;
};
......@@ -100,7 +100,7 @@ void pass::CoreFusion::construct_sigmoid()
auto divide_1_over_exp = std::make_shared<op::Divide>(skip_broadcast, add_exp);
// Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input, constant](pattern::Matcher& m) {
pattern::graph_rewrite_callback callback = [input, constant](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -125,12 +125,11 @@ void pass::CoreFusion::construct_sigmoid()
return false;
}
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
ngraph::replace_node(m.get_match_root(), sigmoid_node);
replace_node(m.get_match_root(), sigmoid_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(
divide_1_over_exp, callback, "CoreFusion.Sigmoid");
auto m = std::make_shared<pattern::Matcher>(divide_1_over_exp, callback, "CoreFusion.Sigmoid");
this->add_matcher(m);
}
......@@ -159,7 +158,7 @@ void pass::CoreFusion::construct_sigmoid_bprop()
auto negative_2 = std::make_shared<op::Negative>(multiply_2);
// Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) {
pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_bprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -178,12 +177,11 @@ void pass::CoreFusion::construct_sigmoid_bprop()
}
auto dsigmoid =
std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]);
ngraph::replace_node(m.get_match_root(), dsigmoid);
replace_node(m.get_match_root(), dsigmoid);
return true;
};
auto m =
std::make_shared<ngraph::pattern::Matcher>(negative_2, callback, "CoreFusion.SigmoidBprop");
auto m = std::make_shared<pattern::Matcher>(negative_2, callback, "CoreFusion.SigmoidBprop");
this->add_matcher(m);
}
......@@ -212,7 +210,7 @@ void pass::CoreFusion::construct_folded_batch_norm()
auto shape_r = Shape{1, 2, 2, 2};
auto bn = std::make_shared<op::BatchNormInference>(eps, gamma, beta, pconv, mean, var);
ngraph::pattern::graph_rewrite_callback callback = [input, filters, mean, var, gamma, beta](
pattern::graph_rewrite_callback callback = [input, filters, mean, var, gamma, beta](
pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for folded batch norm against node = "
<< m.get_match_root()->get_name();
......@@ -258,13 +256,13 @@ void pass::CoreFusion::construct_folded_batch_norm()
m_conv->get_data_dilation_strides());
auto conv_bias =
conv + std::make_shared<op::Broadcast>(new_biases, conv->get_shape(), AxisSet{0, 2, 3});
ngraph::replace_node(m.get_match_root(), conv_bias);
replace_node(m.get_match_root(), conv_bias);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, callback, "CoreFusion.FoldedBatchNorm");
auto m = std::make_shared<pattern::Matcher>(bn, callback, "CoreFusion.FoldedBatchNorm");
this->add_matcher(m);
}
......@@ -293,7 +291,7 @@ void pass::CoreFusion::construct_conv_affine_folding()
auto multiply = std::make_shared<op::Multiply>(conv_label, A_label);
auto add = std::make_shared<op::Add>(multiply, B_label);
ngraph::pattern::graph_rewrite_callback callback =
pattern::graph_rewrite_callback callback =
[input, filters, conv_label, A_label, B_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for conv affine folding against node = "
<< m.get_match_root()->get_name();
......@@ -345,7 +343,7 @@ void pass::CoreFusion::construct_conv_affine_folding()
if (bcast->get_argument(0)->get_shape().size() == 2)
{
Shape bshape{bcast->get_argument(0)->get_shape()[1]};
return static_pointer_cast<ngraph::Node>(std::make_shared<op::Reshape>(
return static_pointer_cast<Node>(std::make_shared<op::Reshape>(
bcast->get_argument(0), AxisVector{0, 1}, bshape));
}
throw ngraph_error("Unexpected shape for bcast input");
......@@ -369,14 +367,13 @@ void pass::CoreFusion::construct_conv_affine_folding()
conv_m->get_padding_above(),
conv_m->get_data_dilation_strides());
auto convbias_n = conv_n + B_m;
ngraph::replace_node(m.get_match_root(), convbias_n);
replace_node(m.get_match_root(), convbias_n);
return true;
};
auto m =
std::make_shared<ngraph::pattern::Matcher>(add, callback, "CoreFusion.ConvAffineFolding");
auto m = std::make_shared<pattern::Matcher>(add, callback, "CoreFusion.ConvAffineFolding");
this->add_matcher(m);
}
......@@ -440,7 +437,7 @@ static size_t shape_to_index(Shape shape)
}
}
void ngraph::pass::CoreFusion::construct_reshape_broadcast()
void pass::CoreFusion::construct_reshape_broadcast()
{
Shape input_shape{10};
auto input = make_shared<pattern::op::Label>(element::f32, input_shape);
......@@ -473,7 +470,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast()
if (d != 1 && d != dim)
{
NGRAPH_DEBUG << "Input is reshaped in a way we can't directly broadcast ( shape = "
<< ngraph::vector_to_string(reshape1_m->get_shape()) << ")";
<< vector_to_string(reshape1_m->get_shape()) << ")";
return false;
}
......@@ -502,7 +499,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast()
auto new_broadcast =
make_shared<op::Broadcast>(input_m, broadcast_m->get_shape(), new_axes);
ngraph::replace_node(m.get_match_root(), new_broadcast);
replace_node(m.get_match_root(), new_broadcast);
return true;
};
......@@ -520,7 +517,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast()
void pass::CoreFusion::construct_optimized_strided_conv()
{
Shape win_size_1{1, 1, 1, 1};
auto is_bc = ngraph::pattern::has_class<op::Broadcast>();
auto is_bc = pattern::has_class<op::Broadcast>();
auto data_stride3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 128, 128});
auto weights_stride3 = std::make_shared<pattern::op::Label>(element::f32, win_size_1);
......@@ -689,7 +686,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
new_relu_two_convs, sconv->get_argument(1), stride_1, stride_1);
NGRAPH_DEBUG << "Replacing " << sconv->get_name() << " with "
<< sconv_28w1s1->get_name();
ngraph::replace_node(sconv, sconv_28w1s1);
replace_node(sconv, sconv_28w1s1);
}
return true;
};
......@@ -699,7 +696,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
this->add_matcher(m);
}
void ngraph::pass::CoreFusion::construct_reshape_softmax_reshape()
void pass::CoreFusion::construct_reshape_softmax_reshape()
{
Shape input_shape{10, 20};
AxisVector io{1, 0};
......@@ -738,7 +735,7 @@ void ngraph::pass::CoreFusion::construct_reshape_softmax_reshape()
}
auto new_softmax = make_shared<op::Softmax>(input_m, new_axes);
ngraph::replace_node(m.get_match_root(), new_softmax);
replace_node(m.get_match_root(), new_softmax);
return true;
};
......
......@@ -59,11 +59,12 @@
#include "ngraph/op/tanh.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
#define TI(x) std::type_index(typeid(x))
#define TI(x) type_index(typeid(x))
static bool cse_constant(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
static bool cse_constant(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_constant for " << a->get_name() << " and " << b->get_name();
......@@ -72,44 +73,44 @@ static bool cse_constant(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
return false;
}
auto ca = std::static_pointer_cast<op::Constant>(a);
auto cb = std::static_pointer_cast<op::Constant>(b);
auto ca = static_pointer_cast<op::Constant>(a);
auto cb = static_pointer_cast<op::Constant>(b);
size_t size = shape_size(a->get_shape()) * a->get_element_type().size();
return !memcmp(ca->get_data_ptr(), cb->get_data_ptr(), size);
}
static bool cse_reshape(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
static bool cse_reshape(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_reshape for " << a->get_name() << " and " << b->get_name();
auto reshape_a = std::static_pointer_cast<ngraph::op::Reshape>(a);
auto reshape_b = std::static_pointer_cast<ngraph::op::Reshape>(b);
auto reshape_a = static_pointer_cast<ngraph::op::Reshape>(a);
auto reshape_b = static_pointer_cast<ngraph::op::Reshape>(b);
return (a->get_argument(0) == b->get_argument(0)) &&
(reshape_a->get_input_order() == reshape_b->get_input_order()) &&
(reshape_a->get_output_shape() == reshape_b->get_output_shape());
}
static bool cse_broadcast(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
static bool cse_broadcast(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_broadcast for " << a->get_name() << " and " << b->get_name();
auto broadcast_a = std::static_pointer_cast<ngraph::op::Broadcast>(a);
auto broadcast_b = std::static_pointer_cast<ngraph::op::Broadcast>(b);
auto broadcast_a = static_pointer_cast<ngraph::op::Broadcast>(a);
auto broadcast_b = static_pointer_cast<ngraph::op::Broadcast>(b);
return (a->get_argument(0) == b->get_argument(0)) &&
(broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) &&
(broadcast_a->get_broadcast_shape() == broadcast_b->get_broadcast_shape());
}
static bool cse_unarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
static bool cse_unarywise(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_unarywise for " << a->get_name() << " and " << b->get_name();
return a->get_argument(0) == b->get_argument(0);
}
static bool cse_binarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
static bool cse_binarywise(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_binary for " << a->get_name() << " and " << b->get_name();
......@@ -117,23 +118,21 @@ static bool cse_binarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
(a->get_argument(1) == b->get_argument(0) && a->get_argument(0) == b->get_argument(1));
}
static bool cse_reduction(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_reduction for " << a->get_name() << " and " << b->get_name();
auto ar_a = std::static_pointer_cast<op::util::ArithmeticReduction>(a);
auto ar_b = std::static_pointer_cast<op::util::ArithmeticReduction>(b);
auto ar_a = static_pointer_cast<op::util::ArithmeticReduction>(a);
auto ar_b = static_pointer_cast<op::util::ArithmeticReduction>(b);
return ar_a->get_argument(0) == ar_b->get_argument(0) &&
ar_a->get_reduction_axes() == ar_b->get_reduction_axes();
}
static std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>
static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>
initialize_ops_to_cse_handlers()
{
return std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>(
return unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>(
{{TI(op::Abs), cse_unarywise},
{TI(op::Acos), cse_unarywise},
{TI(op::Asin), cse_unarywise},
......@@ -168,23 +167,21 @@ static std::unordered_map<std::type_index,
{TI(op::Broadcast), cse_broadcast}});
}
static std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>
static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>
ops_to_cse_handlers = initialize_ops_to_cse_handlers();
class NodeKey
{
public:
NodeKey(std::shared_ptr<Node> n,
std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>&
NodeKey(shared_ptr<Node> n,
unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>&
backend_handlers)
: m_node(n)
, m_backend_handlers(backend_handlers)
{
}
std::shared_ptr<Node> get_node() const { return m_node; }
shared_ptr<Node> get_node() const { return m_node; }
bool operator==(const NodeKey& other) const
{
Node& p_this = *m_node.get();
......@@ -215,9 +212,8 @@ public:
}
private:
std::shared_ptr<Node> m_node;
std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>&
shared_ptr<Node> m_node;
unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>&
m_backend_handlers;
};
......@@ -226,15 +222,15 @@ namespace std
template <>
struct hash<NodeKey>
{
std::size_t operator()(const NodeKey& k) const
size_t operator()(const NodeKey& k) const
{
Node& p_this = *k.get_node().get();
auto ti = TI(p_this);
std::hash<std::type_index> type_hash_compute{};
hash<type_index> type_hash_compute{};
auto type_hash = type_hash_compute(ti);
std::vector<size_t> arg_ids;
vector<size_t> arg_ids;
arg_ids.push_back(type_hash);
......@@ -244,7 +240,7 @@ namespace std
// specify how to compute hash for each op?
if (p_this.is_commutative())
{
std::sort(begin(cargs), end(cargs));
sort(begin(cargs), end(cargs));
}
for (auto arg : cargs)
......@@ -258,11 +254,10 @@ namespace std
};
}
bool ngraph::pass::CommonSubexpressionElimination::run_on_function(
std::shared_ptr<ngraph::Function> f)
bool ngraph::pass::CommonSubexpressionElimination::run_on_function(shared_ptr<ngraph::Function> f)
{
bool replaced = false;
std::unordered_map<NodeKey, std::shared_ptr<Node>> expressions{};
unordered_map<NodeKey, shared_ptr<Node>> expressions{};
for (auto n : f->get_ordered_ops())
{
......@@ -279,7 +274,7 @@ bool ngraph::pass::CommonSubexpressionElimination::run_on_function(
}
else
{
expressions.insert(std::make_pair(n_key, n));
expressions.insert(make_pair(n_key, n));
}
}
......
......@@ -24,6 +24,9 @@
#include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
// GraphRewrite algorithm:
// GraphRewrite processes an input graph in an topological order(i.e. args before users)
// Given the following graph: Abs2
......@@ -56,16 +59,16 @@
// c) there's no linear order of fusions which will give
// the correct final fusion. i.e. the same fusion needs to occur before and after some other fusion
bool ngraph::pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
{
bool rewritten = false;
const size_t NUM_TRIES = 10;
size_t tries = NUM_TRIES;
std::vector<std::shared_ptr<pattern::Matcher>> original_matchers{m_matchers};
vector<shared_ptr<pattern::Matcher>> original_matchers{m_matchers};
do
{
rewritten = false;
std::vector<std::shared_ptr<pattern::Matcher>> matchers{m_matchers};
vector<shared_ptr<pattern::Matcher>> matchers{m_matchers};
m_matchers.clear();
for (auto node : f->get_ordered_ops())
{
......@@ -92,31 +95,31 @@ bool ngraph::pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Functio
return (NUM_TRIES - tries) > 1; //this means a graph was transformed
}
static const std::vector<std::regex> initialize_fusion_regexes()
static const vector<regex> initialize_fusion_regexes()
{
const char* cnsf = std::getenv("NGRAPH_DISABLED_FUSIONS");
std::vector<std::regex> regexes;
const char* cnsf = getenv("NGRAPH_DISABLED_FUSIONS");
vector<regex> regexes;
if (cnsf)
{
const std::string nsf = cnsf;
const auto sregexes = ngraph::split(nsf, ';');
const string nsf = cnsf;
const auto sregexes = split(nsf, ';');
std::transform(sregexes.begin(),
sregexes.end(),
std::back_inserter(regexes),
[](const std::string& c) -> std::regex { return std::regex(c); });
transform(sregexes.begin(),
sregexes.end(),
back_inserter(regexes),
[](const string& c) -> regex { return regex(c); });
}
return regexes;
}
bool ngraph::pass::GraphRewrite::is_enabled(std::shared_ptr<pattern::Matcher> m)
bool pass::GraphRewrite::is_enabled(shared_ptr<pattern::Matcher> m)
{
//note, regexes are static to avoid re-initialization
static const auto regexes = initialize_fusion_regexes();
for (const auto& regex : regexes)
{
if (std::regex_match(m->get_name(), regex))
if (regex_match(m->get_name(), regex))
{
NGRAPH_DEBUG << "Disabling matcher " << m->get_name();
return false;
......@@ -126,7 +129,7 @@ bool ngraph::pass::GraphRewrite::is_enabled(std::shared_ptr<pattern::Matcher> m)
return true;
}
void ngraph::pass::GraphRewrite::add_matcher(std::shared_ptr<pattern::Matcher> m)
void pass::GraphRewrite::add_matcher(shared_ptr<pattern::Matcher> m)
{
if (is_enabled(m))
{
......@@ -134,7 +137,7 @@ void ngraph::pass::GraphRewrite::add_matcher(std::shared_ptr<pattern::Matcher> m
}
}
bool ngraph::pass::RecurrentGraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
{
bool changed = false;
size_t i = 0;
......
......@@ -30,27 +30,28 @@
#include "ngraph/op/sum.hpp"
#include "ngraph/util.hpp"
#define TI(x) std::type_index(typeid(x))
using namespace std;
using namespace ngraph;
#define HANDLER_DECL(x) static bool x(const std::shared_ptr<ngraph::Node>& node)
#define TI(x) type_index(typeid(x))
#define HANDLER_DECL(x) static bool x(const shared_ptr<Node>& node)
HANDLER_DECL(replace_broadcast_like)
{
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument
auto broadcast_like = std::static_pointer_cast<ngraph::op::BroadcastLike>(node);
ngraph::replace_node(
node,
std::make_shared<ngraph::op::Broadcast>(broadcast_like->get_argument(0),
broadcast_like->get_broadcast_shape(),
broadcast_like->get_broadcast_axes()));
auto broadcast_like = static_pointer_cast<op::BroadcastLike>(node);
replace_node(node,
make_shared<op::Broadcast>(broadcast_like->get_argument(0),
broadcast_like->get_broadcast_shape(),
broadcast_like->get_broadcast_axes()));
return true;
}
static const std::unordered_map<std::type_index,
std::function<bool(const std::shared_ptr<ngraph::Node>&)>>
dispatcher{{TI(ngraph::op::BroadcastLike), &replace_broadcast_like}};
static const unordered_map<type_index, function<bool(const shared_ptr<Node>&)>> dispatcher{
{TI(op::BroadcastLike), &replace_broadcast_like}};
bool ngraph::pass::LikeReplacement::run_on_function(std::shared_ptr<ngraph::Function> function)
bool pass::LikeReplacement::run_on_function(shared_ptr<Function> function)
{
bool clobbered = false;
......@@ -66,10 +67,10 @@ bool ngraph::pass::LikeReplacement::run_on_function(std::shared_ptr<ngraph::Func
// Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle.
auto sclb = std::dynamic_pointer_cast<ngraph::op::ScalarConstantLikeBase>(n);
auto sclb = dynamic_pointer_cast<op::ScalarConstantLikeBase>(n);
if (sclb != nullptr)
{
ngraph::replace_node(sclb, sclb->as_constant());
replace_node(sclb, sclb->as_constant());
clobbered = true;
}
}
......
......@@ -33,7 +33,7 @@
using namespace std;
using namespace ngraph;
bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function)
bool pass::Liveness::run_on_function(shared_ptr<Function> function)
{
list<shared_ptr<Node>> ops = function->get_ordered_ops();
......
......@@ -35,7 +35,7 @@
using namespace std;
using namespace ngraph;
ngraph::pass::Manager::Manager()
pass::Manager::Manager()
{
static const auto nevt = std::getenv("NGRAPH_ENABLE_VISUALIZE_TRACING");
if (nevt)
......@@ -49,15 +49,15 @@ ngraph::pass::Manager::Manager()
}
}
ngraph::pass::Manager::~Manager()
pass::Manager::~Manager()
{
}
void ngraph::pass::Manager::initialize_default_passes()
void pass::Manager::initialize_default_passes()
{
}
void ngraph::pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
{
bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr;
......@@ -167,7 +167,7 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func, bool transitiv
}
}
ngraph::pass::ManagerState& ngraph::pass::Manager::get_state()
pass::ManagerState& pass::Manager::get_state()
{
return m_state;
}
......@@ -25,7 +25,7 @@
using namespace std;
using namespace ngraph;
const vector<shared_ptr<Function>>& ngraph::pass::ManagerState::get_functions()
const vector<shared_ptr<Function>>& pass::ManagerState::get_functions()
{
return m_function_list;
}
......@@ -40,7 +40,7 @@ pass::MemoryLayout::MemoryLayout(size_t alignment, bool disable_memory_sharing)
}
}
bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
{
MemoryManager mm(m_alignment, m_disable_memory_sharing);
for (shared_ptr<Node> node : function->get_ordered_ops())
......
......@@ -34,7 +34,7 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{
}
bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>& functions)
bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<Function>>& functions)
{
ofstream file(m_filename);
{
......
......@@ -30,94 +30,93 @@
#include "ngraph/util.hpp"
#include "nop_elimination.hpp"
#define TI(x) std::type_index(typeid(x))
using namespace std;
using namespace ngraph;
#define HANDLER_DECL(x) static bool x(const std::shared_ptr<ngraph::Node>& node)
#define TI(x) std::type_index(typeid(x))
HANDLER_DECL(eliminate_pad)
static bool eliminate_pad(const std::shared_ptr<Node>& node)
{
auto pad = std::static_pointer_cast<ngraph::op::Pad>(node);
auto pad = std::static_pointer_cast<op::Pad>(node);
if (pad->get_input_shape(0) == pad->get_output_shape(0))
{
ngraph::replace_node(node, node->get_argument(0));
replace_node(node, node->get_argument(0));
return true;
}
return false;
}
HANDLER_DECL(eliminate_sum)
static bool eliminate_sum(const std::shared_ptr<Node>& node)
{
auto sum = std::static_pointer_cast<ngraph::op::Sum>(node);
auto sum = std::static_pointer_cast<op::Sum>(node);
if (sum->get_reduction_axes().empty())
{
ngraph::replace_node(node, node->get_argument(0));
replace_node(node, node->get_argument(0));
return true;
}
return false;
}
HANDLER_DECL(eliminate_convert)
static bool eliminate_convert(const std::shared_ptr<Node>& node)
{
auto convert = std::static_pointer_cast<ngraph::op::Convert>(node);
auto convert = std::static_pointer_cast<op::Convert>(node);
if (convert->get_convert_element_type() == convert->get_argument(0)->get_element_type())
{
ngraph::replace_node(node, node->get_argument(0));
replace_node(node, node->get_argument(0));
return true;
}
return false;
}
HANDLER_DECL(eliminate_slice)
static bool eliminate_slice(const std::shared_ptr<Node>& node)
{
auto slice = std::static_pointer_cast<ngraph::op::Slice>(node);
auto slice = std::static_pointer_cast<op::Slice>(node);
if (slice->get_input_shape(0) == slice->get_output_shape(0))
{
ngraph::replace_node(node, node->get_argument(0));
replace_node(node, node->get_argument(0));
return true;
}
return false;
}
HANDLER_DECL(replace_broadcast_like)
static bool replace_broadcast_like(const std::shared_ptr<Node>& node)
{
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument
auto broadcast_like = std::static_pointer_cast<ngraph::op::BroadcastLike>(node);
ngraph::replace_node(
node,
std::make_shared<ngraph::op::Broadcast>(broadcast_like->get_argument(0),
broadcast_like->get_broadcast_shape(),
broadcast_like->get_broadcast_axes()));
auto broadcast_like = std::static_pointer_cast<op::BroadcastLike>(node);
replace_node(node,
std::make_shared<op::Broadcast>(broadcast_like->get_argument(0),
broadcast_like->get_broadcast_shape(),
broadcast_like->get_broadcast_axes()));
return true;
}
HANDLER_DECL(eliminate_broadcast)
static bool eliminate_broadcast(const std::shared_ptr<Node>& node)
{
auto broadcast = std::static_pointer_cast<ngraph::op::Broadcast>(node);
auto broadcast = std::static_pointer_cast<op::Broadcast>(node);
if (broadcast->get_input_shape(0) == broadcast->get_output_shape(0))
{
ngraph::replace_node(node, node->get_argument(0));
replace_node(node, node->get_argument(0));
return true;
}
return false;
}
HANDLER_DECL(eliminate_stop_gradient)
static bool eliminate_stop_gradient(const std::shared_ptr<Node>& node)
{
ngraph::replace_node(node, node->get_argument(0));
replace_node(node, node->get_argument(0));
return true;
}
static const std::unordered_map<std::type_index,
std::function<bool(const std::shared_ptr<ngraph::Node>&)>>
dispatcher{{TI(ngraph::op::Pad), &eliminate_pad},
{TI(ngraph::op::Sum), &eliminate_sum},
{TI(ngraph::op::Convert), &eliminate_convert},
{TI(ngraph::op::Slice), &eliminate_slice},
{TI(ngraph::op::StopGradient), &eliminate_stop_gradient},
{TI(ngraph::op::BroadcastLike), &replace_broadcast_like},
{TI(ngraph::op::Broadcast), &eliminate_broadcast}};
bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Function> function)
static const std::unordered_map<std::type_index, std::function<bool(const std::shared_ptr<Node>&)>>
dispatcher{{TI(op::Pad), &eliminate_pad},
{TI(op::Sum), &eliminate_sum},
{TI(op::Convert), &eliminate_convert},
{TI(op::Slice), &eliminate_slice},
{TI(op::StopGradient), &eliminate_stop_gradient},
{TI(op::BroadcastLike), &replace_broadcast_like},
{TI(op::Broadcast), &eliminate_broadcast}};
bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function)
{
bool clobbered = false;
......@@ -133,10 +132,10 @@ bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Funct
// Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle.
auto sclb = std::dynamic_pointer_cast<ngraph::op::ScalarConstantLikeBase>(n);
auto sclb = std::dynamic_pointer_cast<op::ScalarConstantLikeBase>(n);
if (sclb != nullptr)
{
ngraph::replace_node(sclb, sclb->as_constant());
replace_node(sclb, sclb->as_constant());
clobbered = true;
}
}
......
......@@ -17,12 +17,15 @@
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/manager.hpp"
ngraph::pass::ManagerState& ngraph::pass::PassBase::get_state()
using namespace std;
using namespace ngraph;
pass::ManagerState& pass::PassBase::get_state()
{
return *m_state;
}
void ngraph::pass::PassBase::set_state(ManagerState& state)
void pass::PassBase::set_state(ManagerState& state)
{
m_state = &state;
}
......@@ -19,10 +19,11 @@
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
// TODO: Add file-based configuration support
ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
pass::PassConfig::PassConfig(pass::CompilationMode mode)
: m_compilation_mode(mode)
{
/**
......@@ -32,15 +33,15 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
* E.g., NGRAPH_PASS_ENABLES="CoreFusion:0;LikeReplacement:1;CPUCollapseDims" would
* set disables on CoreFusion and enables on LikeReplacement and CPUCollapseDims
**/
const char* env_str = std::getenv("NGRAPH_PASS_ENABLES");
const char* env_str = getenv("NGRAPH_PASS_ENABLES");
if (env_str)
{
std::stringstream ss;
stringstream ss;
ss << env_str;
while (ss.good())
{
std::string substr;
std::getline(ss, substr, ';');
string substr;
getline(ss, substr, ';');
auto split_str = split(substr, ':', false);
switch (split_str.size())
{
......@@ -58,15 +59,15 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
* would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on
* "UseDefaultLayouts"
**/
env_str = std::getenv("NGRAPH_PASS_ATTRIBUTES");
env_str = getenv("NGRAPH_PASS_ATTRIBUTES");
if (env_str)
{
std::stringstream ss;
stringstream ss;
ss << env_str;
while (ss.good())
{
std::string substr;
std::getline(ss, substr, ';');
string substr;
getline(ss, substr, ';');
auto split_str = split(substr, '=', false);
switch (split_str.size())
{
......@@ -80,12 +81,12 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
}
}
void ngraph::pass::PassConfig::set_pass_enable(std::string name, bool enable)
void pass::PassConfig::set_pass_enable(string name, bool enable)
{
m_pass_enables[name] = enable;
}
bool ngraph::pass::PassConfig::get_pass_enable(std::string name)
bool pass::PassConfig::get_pass_enable(string name)
{
if (m_pass_enables.find(name) == m_pass_enables.end())
{
......@@ -94,12 +95,12 @@ bool ngraph::pass::PassConfig::get_pass_enable(std::string name)
return m_pass_enables[name];
}
void ngraph::pass::PassConfig::set_pass_attribute(std::string name, bool enable)
void pass::PassConfig::set_pass_attribute(string name, bool enable)
{
m_pass_attributes[name] = enable;
}
bool ngraph::pass::PassConfig::get_pass_attribute(std::string name)
bool pass::PassConfig::get_pass_attribute(string name)
{
if (m_pass_attributes.find(name) == m_pass_attributes.end())
{
......
......@@ -24,14 +24,17 @@
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
using namespace std;
using namespace ngraph;
pass::PrefixReshapeElimination::PrefixReshapeElimination()
{
auto src_op = std::make_shared<pattern::op::Label>(
element::i8, Shape{}, [](std::shared_ptr<Node>) { return true; });
auto reshape_op = std::make_shared<pattern::op::Any>(
auto src_op = make_shared<pattern::op::Label>(
element::i8, Shape{}, [](shared_ptr<Node>) { return true; });
auto reshape_op = make_shared<pattern::op::Any>(
element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
[](shared_ptr<Node> node) {
op::Reshape* reshape = dynamic_cast<op::Reshape*>(node.get());
if (!reshape)
{
......@@ -46,14 +49,14 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
// Make sure that logical dimension sizes match.
const Shape& src_shape = reshape->get_input_shape(0);
for (std::size_t idx = 0; idx < reshape->get_output_shape().size(); ++idx)
for (size_t idx = 0; idx < reshape->get_output_shape().size(); ++idx)
{
std::size_t src_size = 1;
size_t src_size = 1;
if (idx < src_shape.size())
{
src_size = src_shape.at(src_shape.size() - 1 - idx);
}
std::size_t dest_size =
size_t dest_size =
reshape->get_output_shape().at(reshape->get_output_shape().size() - 1 - idx);
if (dest_size != src_size)
{
......@@ -64,10 +67,10 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
return true;
},
NodeVector{src_op});
auto target_op = std::make_shared<pattern::op::AnyOf>(
auto target_op = make_shared<pattern::op::AnyOf>(
element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
[](shared_ptr<Node> node) {
return pattern::has_class<op::Reshape>()(node) ||
pattern::has_class<op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<op::util::BinaryElementwiseArithmetic>()(node);
......@@ -78,5 +81,5 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2));
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(target_op, callback));
add_matcher(make_shared<pattern::Matcher>(target_op, callback));
}
......@@ -22,15 +22,16 @@
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
using namespace std;
using namespace ngraph;
bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Function> function)
bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
{
for (auto& node : function->get_ordered_ops())
{
if (node->is_op())
{
auto op = std::static_pointer_cast<op::Op>(node);
auto op = static_pointer_cast<op::Op>(node);
NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name();
auto op_annotations = op->get_op_annotations();
if (!op_annotations)
......@@ -41,7 +42,7 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi
}
if (node->is_parameter())
{
auto parameter = std::static_pointer_cast<op::Parameter>(node);
auto parameter = static_pointer_cast<op::Parameter>(node);
op_annotations->set_cacheable(parameter->get_cacheable());
NGRAPH_DEBUG << "propagate cacheability: cacheability is "
<< parameter->get_cacheable();
......@@ -54,7 +55,7 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi
NGRAPH_DEBUG << "propagate cacheability: arg is " << arg->get_name();
if (arg->is_op())
{
auto arg_op = std::static_pointer_cast<op::Op>(arg);
auto arg_op = static_pointer_cast<op::Op>(arg);
auto arg_op_annotations = arg_op->get_op_annotations();
NGRAPH_ASSERT(arg_op_annotations);
if (!arg_op_annotations->is_cacheable())
......
......@@ -33,17 +33,19 @@
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/util.hpp"
extern template ngraph::AxisVector
ngraph::apply_permutation<ngraph::AxisVector>(ngraph::AxisVector input,
ngraph::AxisVector order);
using namespace std;
using namespace ngraph;
void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
extern template AxisVector ngraph::apply_permutation<AxisVector>(AxisVector input,
AxisVector order);
void pass::ReshapeElimination::construct_identity_reshape_pattern()
{
Shape shape_op{3};
Shape shape_r1{1, 3};
auto op = std::make_shared<pattern::op::Label>(element::f32, shape_op);
auto reshape1 = std::make_shared<op::Reshape>(op, AxisVector{0}, shape_r1);
auto op = make_shared<pattern::op::Label>(element::f32, shape_op);
auto reshape1 = make_shared<op::Reshape>(op, AxisVector{0}, shape_r1);
auto callback = [op](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = "
......@@ -51,7 +53,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
auto pattern_map = m.get_pattern_map();
auto gop = pattern_map[op];
auto r1 = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root());
auto r1 = dynamic_pointer_cast<op::Reshape>(m.get_match_root());
if (r1->get_shape() != gop->get_shape())
{
......@@ -59,7 +61,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
return false;
}
auto do_r1 = ngraph::get_default_order(r1->get_shape());
auto do_r1 = get_default_order(r1->get_shape());
if (do_r1 != r1->get_input_order())
{
......@@ -67,22 +69,22 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
return false;
}
ngraph::replace_node(m.get_match_root(), gop);
replace_node(m.get_match_root(), gop);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape1, callback);
auto m = make_shared<pattern::Matcher>(reshape1, callback);
this->add_matcher(m);
}
void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
void pass::ReshapeElimination::construct_reshapex2_pattern()
{
Shape shape_op{3};
Shape shape_r1{1, 3};
auto op = std::make_shared<pattern::op::Label>(element::f32, shape_op);
auto reshape1 = std::make_shared<op::Reshape>(op, AxisVector{0}, shape_r1);
auto reshape2 = std::make_shared<op::Reshape>(reshape1, AxisVector{0, 1}, shape_op);
auto op = make_shared<pattern::op::Label>(element::f32, shape_op);
auto reshape1 = make_shared<op::Reshape>(op, AxisVector{0}, shape_r1);
auto reshape2 = make_shared<op::Reshape>(reshape1, AxisVector{0, 1}, shape_op);
auto callback = [op](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_reshapex2_pattern against node = "
......@@ -101,11 +103,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
return false;
}
auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root());
auto r1 = std::dynamic_pointer_cast<op::Reshape>(r2->get_argument(0));
auto r2 = dynamic_pointer_cast<op::Reshape>(m.get_match_root());
auto r1 = dynamic_pointer_cast<op::Reshape>(r2->get_argument(0));
auto do_r2 = ngraph::get_default_order(r1->get_shape());
auto do_r1 = ngraph::get_default_order(gop->get_shape());
auto do_r2 = get_default_order(r1->get_shape());
auto do_r1 = get_default_order(gop->get_shape());
NGRAPH_DEBUG << "r1's i/o = " << vector_to_string(r1->get_input_order())
<< "do_r1 = " << vector_to_string(do_r1);
......@@ -115,40 +117,40 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2)
{
NGRAPH_DEBUG << "Two reshapes were removed!";
ngraph::replace_node(m.get_match_root(), gop);
replace_node(m.get_match_root(), gop);
return true;
}
auto perm1 = ngraph::apply_permutation(do_r1, r1->get_input_order());
auto perm2 = ngraph::apply_permutation(perm1, r2->get_input_order());
auto perm1 = apply_permutation(do_r1, r1->get_input_order());
auto perm2 = apply_permutation(perm1, r2->get_input_order());
if (perm2 == do_r1)
{
NGRAPH_DEBUG << "Two transposes were removed!";
ngraph::replace_node(m.get_match_root(), gop);
replace_node(m.get_match_root(), gop);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape2, callback);
auto m = make_shared<pattern::Matcher>(reshape2, callback);
this->add_matcher(m);
}
void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
void pass::ReshapeElimination::construct_dot_transpose_pattern()
{
// dot(A,B).T = dot (B.T, A.T)
auto dot_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Dot>(n));
auto dot_pred = [](shared_ptr<Node> n) {
return static_cast<bool>(dynamic_pointer_cast<op::Dot>(n));
};
auto pdot = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = std::make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
auto pdot = make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.get_match_root()->get_name();
auto mtranspose = std::static_pointer_cast<op::Reshape>(m.get_match_root());
auto mtranspose = static_pointer_cast<op::Reshape>(m.get_match_root());
// this also checks the rank
if (mtranspose->get_input_order() != AxisVector{1, 0})
{
......@@ -171,7 +173,7 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
return false;
}
auto reshape0_shape = Shape{arg0->get_shape().at(1), arg0->get_shape().at(0)};
auto reshape0 = std::make_shared<op::Reshape>(arg0, AxisVector{1, 0}, reshape0_shape);
auto reshape0 = make_shared<op::Reshape>(arg0, AxisVector{1, 0}, reshape0_shape);
auto arg1 = mdot->get_argument(1);
if (arg1->get_shape().size() != 2)
......@@ -180,13 +182,13 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
return false;
}
auto reshape1_shape = Shape{arg1->get_shape().at(1), arg1->get_shape().at(0)};
auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
auto reshape1 = make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
auto tdot = std::shared_ptr<Node>(new op::Dot(reshape1, reshape0));
ngraph::replace_node(m.get_match_root(), tdot);
auto tdot = shared_ptr<Node>(new op::Dot(reshape1, reshape0));
replace_node(m.get_match_root(), tdot);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(preshape, callback);
auto m = make_shared<pattern::Matcher>(preshape, callback);
this->add_matcher(m);
}
......@@ -39,14 +39,15 @@
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
using ReshapeMap = std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>>;
using ReshapeMap = unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>;
static std::string describe_reshape(std::shared_ptr<Node> node)
static string describe_reshape(shared_ptr<Node> node)
{
std::stringstream ss;
auto reshape = std::dynamic_pointer_cast<op::Reshape>(node);
stringstream ss;
auto reshape = dynamic_pointer_cast<op::Reshape>(node);
ss << reshape->get_name()
<< " ( axis order = " << ngraph::vector_to_string(reshape->get_input_order())
<< " , shape = " << vector_to_string(reshape->get_shape()) << " ) "
......@@ -55,25 +56,24 @@ static std::string describe_reshape(std::shared_ptr<Node> node)
return ss.str();
}
static std::shared_ptr<op::Reshape> combine_reshapes(std::shared_ptr<op::Reshape> r1,
std::shared_ptr<op::Reshape> r2)
static shared_ptr<op::Reshape> combine_reshapes(shared_ptr<op::Reshape> r1,
shared_ptr<op::Reshape> r2)
{
auto default_order = ngraph::get_default_order(r1->get_shape());
auto perm_r1 = apply_permutation(default_order, r1->get_input_order());
auto perm_r2 = apply_permutation(perm_r1, r2->get_input_order());
auto rreshape = std::make_shared<op::Reshape>(r2->get_argument(0), perm_r2, r2->get_shape());
auto rreshape = make_shared<op::Reshape>(r2->get_argument(0), perm_r2, r2->get_shape());
return rreshape;
}
static void
insert_reshape(std::shared_ptr<Node> target, std::shared_ptr<Node> reshape, size_t input_index)
static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, size_t input_index)
{
auto arg = target->get_inputs().at(input_index).get_output().get_node();
auto new_reshape = reshape->copy_with_new_args({arg});
target->get_inputs().at(input_index).replace_output(new_reshape->get_outputs().at(0));
}
static void delete_reshape(std::shared_ptr<Node> reshape)
static void delete_reshape(shared_ptr<Node> reshape)
{
NGRAPH_DEBUG << "Removing reshape " << reshape->get_name();
if (!reshape->get_users().empty())
......@@ -82,22 +82,22 @@ static void delete_reshape(std::shared_ptr<Node> reshape)
}
}
static void mark_reshape_for_deletion(std::shared_ptr<Node> reshape,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
static void mark_reshape_for_deletion(shared_ptr<Node> reshape,
set<shared_ptr<Node>>& reshapes_to_delete)
{
NGRAPH_DEBUG << "Marking reshape " << reshape->get_name() << " for deletion";
reshapes_to_delete.insert(reshape);
}
static std::shared_ptr<op::Reshape> create_default_reshape(std::shared_ptr<Node> n)
static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n)
{
auto default_order = ngraph::get_default_order(n->get_shape());
auto default_reshape = std::make_shared<op::Reshape>(n, default_order, n->get_shape());
auto default_reshape = make_shared<op::Reshape>(n, default_order, n->get_shape());
return default_reshape;
}
//compute an axis order that converts the given axis order to default
static AxisSet get_quantization_axes_in_default_order(std::shared_ptr<op::Reshape> arg_reshape,
static AxisSet get_quantization_axes_in_default_order(shared_ptr<op::Reshape> arg_reshape,
const AxisSet& old_axis_set)
{
auto perm_to_def = ngraph::get_permutation_to_default_order(arg_reshape->get_input_order());
......@@ -112,7 +112,7 @@ static AxisSet get_quantization_axes_in_default_order(std::shared_ptr<op::Reshap
struct Swimmer
{
descriptor::Input* input;
std::shared_ptr<op::Reshape> reshape;
shared_ptr<op::Reshape> reshape;
};
//Swim is used to push/"swim" reshapes towards paramaters.
......@@ -121,10 +121,10 @@ struct Swimmer
//we prefer nchw since a lot of ngraph ops require this format,
//so keeping things in nchw allows us to eliminate as many reshapes
//as possible
void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
void swim(descriptor::Input* input, shared_ptr<op::Reshape> reshape)
{
Swimmer sw{input, reshape};
std::list<Swimmer> work_queue;
list<Swimmer> work_queue;
work_queue.push_back(sw);
//TODO: if we support more ops (especially, with >1 args)
......@@ -135,21 +135,21 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
work_queue.pop_front();
auto n = csw.input->get_output().get_node();
NGRAPH_DEBUG << "Processing (swimming) " << n->get_name();
if (auto unary = std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
if (auto unary = dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
{
Swimmer nsw{&unary->get_inputs().at(0), csw.reshape};
work_queue.push_back(nsw);
NGRAPH_DEBUG << "Propagating reshape " << describe_reshape(csw.reshape) << " for "
<< n->get_name() << " to " << unary->get_argument(0);
}
else if (std::dynamic_pointer_cast<op::Broadcast>(n))
else if (dynamic_pointer_cast<op::Broadcast>(n))
{
auto old_broadcast = std::static_pointer_cast<op::Broadcast>(n);
auto old_broadcast = static_pointer_cast<op::Broadcast>(n);
auto broadcast_axes = old_broadcast->get_broadcast_axes();
auto broadcast_reshape = csw.reshape;
bool in_order = true;
AxisSet new_broadcast_axes;
std::vector<size_t> new_source_axes;
vector<size_t> new_source_axes;
auto input_order = broadcast_reshape->get_input_order();
for (size_t i = 0; i < input_order.size(); i++)
{
......@@ -171,8 +171,8 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
if (!in_order)
{
AxisVector new_source_axes_sorted{new_source_axes};
std::sort(new_source_axes_sorted.begin(), new_source_axes_sorted.end());
std::map<size_t, size_t> old_new_source_axes;
sort(new_source_axes_sorted.begin(), new_source_axes_sorted.end());
map<size_t, size_t> old_new_source_axes;
for (size_t i = 0; new_source_axes_sorted.size(); i++)
{
old_new_source_axes.insert({new_source_axes.at(i), i});
......@@ -186,11 +186,11 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
auto new_arg_shape =
ngraph::apply_permutation(broadcast_input->get_shape(), new_source_axis_order);
broadcast_input = std::make_shared<op::Reshape>(
broadcast_input, new_source_axis_order, new_arg_shape);
broadcast_input =
make_shared<op::Reshape>(broadcast_input, new_source_axis_order, new_arg_shape);
}
auto new_broadcast = std::make_shared<op::Broadcast>(
auto new_broadcast = make_shared<op::Broadcast>(
broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes);
csw.input->replace_output(new_broadcast->get_outputs().at(0));
}
......@@ -210,11 +210,11 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
//We have to normalize this other argument to nchw by swimming nchw towards parameters
//as far as we can
static void convert_binary_to_default_order(
std::shared_ptr<Node> binary,
shared_ptr<Node> binary,
descriptor::Input& input,
std::shared_ptr<Node> right,
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>>& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
shared_ptr<Node> right,
unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>& reorders,
set<shared_ptr<Node>>& reshapes_to_delete)
{
auto left = input.get_output().get_node();
auto perm_to_def =
......@@ -222,7 +222,7 @@ static void convert_binary_to_default_order(
auto new_shape = apply_permutation(left->get_shape(), perm_to_def);
NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right->get_shape()) << ", "
<< right->get_name();
auto new_reshape = std::make_shared<op::Reshape>(left, perm_to_def, new_shape);
auto new_reshape = make_shared<op::Reshape>(left, perm_to_def, new_shape);
NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to "
<< left->get_name();
//this should now insert and swim reshape on right
......@@ -231,9 +231,9 @@ static void convert_binary_to_default_order(
reorders[binary] = reorders.at(right);
}
static void materialize_shapes(std::shared_ptr<Node> n,
static void materialize_shapes(shared_ptr<Node> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
set<shared_ptr<Node>>& reshapes_to_delete)
{
//skip multiple output nodes and deal with GOEs exclusively
if (n->get_outputs().size() > 1)
......@@ -257,9 +257,9 @@ static void materialize_shapes(std::shared_ptr<Node> n,
reorders[n] = create_default_reshape(n);
}
static void sink_reshape(std::shared_ptr<op::Reshape> reshape,
static void sink_reshape(shared_ptr<op::Reshape> reshape,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
set<shared_ptr<Node>>& reshapes_to_delete)
{
auto orig_reshape = reorders.at(reshape->get_argument(0));
if (!reshape->get_is_transpose())
......@@ -286,18 +286,18 @@ static void sink_reshape(std::shared_ptr<op::Reshape> reshape,
}
}
static void sink_unary(std::shared_ptr<op::util::UnaryElementwiseArithmetic> n,
static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
set<shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0));
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name();
reorders[n] = reorders[n->get_argument(0)];
}
static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
set<shared_ptr<Node>>& reshapes_to_delete)
{
auto left = binary->get_argument(0);
auto right = binary->get_argument(1);
......@@ -333,9 +333,9 @@ static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> b
}
}
static void sink_slice(std::shared_ptr<op::Slice> n,
static void sink_slice(shared_ptr<op::Slice> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
set<shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0));
auto order = arg_reshape->get_input_order();
......@@ -346,25 +346,23 @@ static void sink_slice(std::shared_ptr<op::Slice> n,
auto def_order = ngraph::get_permutation_to_default_order(order);
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
auto dummy_correct_shape =
std::make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
auto new_lower = ngraph::apply_permutation(n->get_lower_bounds(), def_order);
auto new_upper = ngraph::apply_permutation(n->get_upper_bounds(), def_order);
auto new_strides = ngraph::apply_permutation(n->get_strides(), def_order);
auto new_slice =
std::make_shared<op::Slice>(dummy_correct_shape, new_lower, new_upper, new_strides);
auto new_slice = make_shared<op::Slice>(dummy_correct_shape, new_lower, new_upper, new_strides);
ngraph::replace_node(dummy_correct_shape, n->get_argument(0));
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_slice->get_name();
ngraph::replace_node(n, new_slice);
auto new_reshape = std::make_shared<op::Reshape>(new_slice, order, n->get_shape());
auto new_reshape = make_shared<op::Reshape>(new_slice, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_slice] = new_reshape;
}
static void sink_pad(std::shared_ptr<op::Pad> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
static void
sink_pad(shared_ptr<op::Pad> n, ReshapeMap& reorders, set<shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0));
auto order = arg_reshape->get_input_order();
......@@ -374,41 +372,41 @@ static void sink_pad(std::shared_ptr<op::Pad> n,
auto def_order = ngraph::get_permutation_to_default_order(order);
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
auto dummy_correct_shape =
std::make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
auto new_lower = ngraph::apply_permutation(n->get_padding_below(), def_order);
auto new_upper = ngraph::apply_permutation(n->get_padding_above(), def_order);
auto new_interior = ngraph::apply_permutation(n->get_padding_interior(), def_order);
auto new_pad = std::make_shared<op::Pad>(
auto new_pad = make_shared<op::Pad>(
dummy_correct_shape, n->get_argument(1), new_lower, new_upper, new_interior);
ngraph::replace_node(dummy_correct_shape, n->get_argument(0));
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
ngraph::replace_node(n, new_pad);
auto new_reshape = std::make_shared<op::Reshape>(new_pad, order, n->get_shape());
auto new_reshape = make_shared<op::Reshape>(new_pad, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_pad] = new_reshape;
}
static void sink_quantize(std::shared_ptr<op::Quantize> quantize,
static void sink_quantize(shared_ptr<op::Quantize> quantize,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
set<shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(quantize->get_argument(0));
AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, quantize->get_axes());
auto new_quantize = std::make_shared<op::Quantize>(quantize->get_argument(0),
quantize->get_argument(1),
quantize->get_argument(2),
quantize->get_element_type(),
axes_in_def_order,
quantize->get_round_mode());
auto new_quantize = make_shared<op::Quantize>(quantize->get_argument(0),
quantize->get_argument(1),
quantize->get_argument(2),
quantize->get_element_type(),
axes_in_def_order,
quantize->get_round_mode());
ngraph::replace_node(quantize, new_quantize);
reorders[new_quantize] = arg_reshape;
}
static void sink_concat(std::shared_ptr<op::Concat> n,
static void sink_concat(shared_ptr<op::Concat> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
set<shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0));
auto order = arg_reshape->get_input_order();
......@@ -418,7 +416,7 @@ static void sink_concat(std::shared_ptr<op::Concat> n,
auto def_order = ngraph::get_permutation_to_default_order(order);
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
auto dummy_correct_shape =
std::make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
NodeVector new_args;
new_args.push_back(dummy_correct_shape);
......@@ -436,12 +434,12 @@ static void sink_concat(std::shared_ptr<op::Concat> n,
auto iinput_shape = ngraph::apply_permutation(iarg_reshape->get_shape(), def_order);
auto idummy_correct_shape =
std::make_shared<pattern::op::Label>(iarg_reshape->get_element_type(), iinput_shape);
make_shared<pattern::op::Label>(iarg_reshape->get_element_type(), iinput_shape);
new_args.push_back(idummy_correct_shape);
}
auto new_axis = order.at(n->get_concatenation_axis());
auto new_concat = std::make_shared<op::Concat>(new_args, new_axis);
auto new_concat = make_shared<op::Concat>(new_args, new_axis);
//put back the original arguments
for (size_t i = 0; i < new_concat->get_input_size(); i++)
{
......@@ -450,23 +448,23 @@ static void sink_concat(std::shared_ptr<op::Concat> n,
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
ngraph::replace_node(n, new_concat);
auto new_reshape = std::make_shared<op::Reshape>(new_concat, order, n->get_shape());
auto new_reshape = make_shared<op::Reshape>(new_concat, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_concat] = new_reshape;
}
static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
set<shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(dequantize->get_argument(0));
AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, dequantize->get_axes());
auto new_dequantize = std::make_shared<op::Dequantize>(dequantize->get_argument(0),
dequantize->get_argument(1),
dequantize->get_argument(2),
dequantize->get_element_type(),
axes_in_def_order);
auto new_dequantize = make_shared<op::Dequantize>(dequantize->get_argument(0),
dequantize->get_argument(1),
dequantize->get_argument(2),
dequantize->get_element_type(),
axes_in_def_order);
ngraph::replace_node(dequantize, new_dequantize);
reorders[new_dequantize] = arg_reshape;
......@@ -481,11 +479,11 @@ static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
//For each op type we support we can either combine
//two reshapes by replacing the existing Reshape,
//materialize pending reshapes if they can't be propagated through op
bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Function> f)
bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> f)
{
ReshapeMap reorders;
NodeVector results;
std::set<std::shared_ptr<Node>> reshapes_to_delete;
set<shared_ptr<Node>> reshapes_to_delete;
//STEP 1 : Sink or Swim reshapes away for op clusters
for (auto n : f->get_ordered_ops())
......@@ -497,31 +495,31 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
results.push_back(n);
}
if (auto reshape = std::dynamic_pointer_cast<op::Reshape>(n))
if (auto reshape = dynamic_pointer_cast<op::Reshape>(n))
{
sink_reshape(reshape, reorders, reshapes_to_delete);
}
else if (auto unary = std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
else if (auto unary = dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
{
sink_unary(unary, reorders, reshapes_to_delete);
}
else if (auto binary = std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n))
else if (auto binary = dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n))
{
sink_binary(binary, reorders, reshapes_to_delete);
}
else if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(n))
else if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(n))
{
reorders[goe] = create_default_reshape(goe);
}
else if (auto quantize = std::dynamic_pointer_cast<op::Quantize>(n))
else if (auto quantize = dynamic_pointer_cast<op::Quantize>(n))
{
sink_quantize(quantize, reorders, reshapes_to_delete);
}
else if (auto dequantize = std::dynamic_pointer_cast<op::Dequantize>(n))
else if (auto dequantize = dynamic_pointer_cast<op::Dequantize>(n))
{
sink_dequantize(dequantize, reorders, reshapes_to_delete);
}
else if (auto slice = std::dynamic_pointer_cast<op::Slice>(n))
else if (auto slice = dynamic_pointer_cast<op::Slice>(n))
{
// A heuristic. If Reshape has multiple slice users, if sunk
// it will be replicated by the number of its users
......@@ -542,11 +540,11 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
materialize_shapes(n, reorders, reshapes_to_delete);
}
}
else if (auto pad = std::dynamic_pointer_cast<op::Pad>(n))
else if (auto pad = dynamic_pointer_cast<op::Pad>(n))
{
sink_pad(pad, reorders, reshapes_to_delete);
}
else if (auto concat = std::dynamic_pointer_cast<op::Concat>(n))
else if (auto concat = dynamic_pointer_cast<op::Concat>(n))
{
sink_concat(concat, reorders, reshapes_to_delete);
}
......
......@@ -27,9 +27,9 @@
using namespace ngraph;
using namespace std;
#define TI(x) std::type_index(typeid(x))
#define TI(x) type_index(typeid(x))
bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& functions)
bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
{
for (shared_ptr<Function> f : functions)
{
......@@ -42,10 +42,10 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu
m_ss << add_attributes(node);
m_ss << " " << arg->get_name() << " -> " << node->get_name();
if (std::getenv("NGRAPH_VISUALIZE_EDGE_LABELS") != nullptr)
if (getenv("NGRAPH_VISUALIZE_EDGE_LABELS") != nullptr)
{
size_t output = 0;
if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(node))
if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(node))
{
output = goe->get_n();
}
......@@ -71,7 +71,7 @@ pass::VisualizeTree::VisualizeTree(const string& file_name, node_modifiers_t nm)
{
}
std::string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
{
string rc;
if (m_nodes_with_attributes.find(node) == m_nodes_with_attributes.end())
......@@ -82,7 +82,7 @@ std::string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
return rc;
}
std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
{
vector<string> attributes;
if (node->is_parameter() || node->is_output())
......@@ -110,22 +110,22 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
stringstream label;
label << "label=\"" << node->get_friendly_name();
static const char* nvtos = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES");
static const char* nvtos = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES");
if (nvtos != nullptr)
{
// The shapes of the Outputs of a multi-output op
// will be printed for its corresponding `GetOutputElement`s
label << " " << (node->get_outputs().size() != 1 ? std::string("[skipped]")
label << " " << (node->get_outputs().size() != 1 ? string("[skipped]")
: vector_to_string(node->get_shape()));
}
static const char* nvtot = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_TYPES");
static const char* nvtot = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_TYPES");
if (nvtot != nullptr)
{
// The types of the Outputs of a multi-output op
// will be printed for its corresponding `GetOutputElement`s
label << " "
<< ((node->get_outputs().size() != 1) ? std::string("[skipped]")
<< ((node->get_outputs().size() != 1) ? string("[skipped]")
: node->get_element_type().c_type_string());
}
......@@ -150,9 +150,9 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
return ss.str();
}
std::string pass::VisualizeTree::get_file_ext()
string pass::VisualizeTree::get_file_ext()
{
const char* format = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT");
const char* format = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT");
if (!format)
{
format = "png";
......@@ -163,7 +163,7 @@ std::string pass::VisualizeTree::get_file_ext()
format += 1;
}
return std::string(format);
return string(format);
}
void pass::VisualizeTree::render() const
......
......@@ -30,9 +30,10 @@
#include "ngraph/op/sum.hpp"
#include "zero_dim_tensor_elimination.hpp"
using namespace std;
using namespace ngraph;
static bool has_zero_dim(std::shared_ptr<Node> node)
static bool has_zero_dim(shared_ptr<Node> node)
{
if (node->get_output_size() != 1)
{
......@@ -40,12 +41,12 @@ static bool has_zero_dim(std::shared_ptr<Node> node)
}
const auto& shape = node->get_shape();
return std::find(shape.begin(), shape.end(), 0) != shape.end();
return find(shape.begin(), shape.end(), 0) != shape.end();
}
static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function> f)
static bool verify_no_internal_zero_length_ops(shared_ptr<Function> f)
{
std::set<std::shared_ptr<Node>> zero_length_nodes;
set<shared_ptr<Node>> zero_length_nodes;
for (auto n : f->get_ordered_ops())
{
if (n->is_output() || n->is_parameter() || n->get_outputs().size() > 1)
......@@ -76,10 +77,10 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function>
return zero_length_nodes.size() > 0;
}
bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngraph::Function> f)
bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
{
bool replaced = false;
auto cvals = std::vector<std::string>(0);
auto cvals = vector<string>(0);
// we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op
// as an internal node (i.e. a node that isn't an argument to `op::Result`)
for (auto n : f->get_ordered_ops())
......@@ -98,8 +99,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
{
// we don't have to create constants every time but this is the easiest
// and it's CSE's job to eliminate the same ones
auto constant =
std::make_shared<op::Constant>(n->get_element_type(), n->get_shape(), cvals);
auto constant = make_shared<op::Constant>(n->get_element_type(), n->get_shape(), cvals);
replace_node(n, constant);
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << constant->get_name();
replaced = true;
......@@ -111,7 +111,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
continue;
}
if (auto concat = std::dynamic_pointer_cast<op::Concat>(n))
if (auto concat = dynamic_pointer_cast<op::Concat>(n))
{
NodeVector non_zero_dim_args;
for (auto arg : concat->get_arguments())
......@@ -127,7 +127,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
auto new_concat = concat->copy_with_new_args(non_zero_dim_args);
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with "
<< new_concat->get_name();
ngraph::replace_node(concat, new_concat);
replace_node(concat, new_concat);
continue;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment