Commit a4a3031b authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

clang-format comments: /src/ngraph/pass (#3473)

* Plant an updated .clang-format in src/ngraph, with overrides in each subdir thereof

* Whoops, forgot to commit the actual .clang-formats

* Remove src/ngraph/type/.clang-format (it was having no effect)

* Remove src/ngraph/distributed/.clang-format (it was having no effect)

* Remove src/ngraph/codegen/.clang-format (only one file was affected, so it's a wash)

* Remove src/ngraph/autodiff/.clang-format (only one file was affected, so it's a wash)

* Un-relax comment wrapping in src/ngraph/state

* Revert "Un-relax comment wrapping in src/ngraph/state"

This reverts commit 41fc50fb92bffb7f5aca4126eb1267f00dcca727.

* Un-relax comment wrapping in src/ngraph/pass

* Remove .clang-format
parent 69c02f2c
#
# OVERRIDE TO STYLE: Comments do *not* wrap.
#
BasedOnStyle: LLVM
IndentWidth: 4
UseTab: Never
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -4
AlignConsecutiveDeclarations: false
AlignConsecutiveAssignments: false
AlignTrailingComments: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: Inline
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BreakBeforeBraces: Allman
BreakConstructorInitializersBeforeComma: true
ColumnLimit: 100
CommentPragmas: '.*'
IndentCaseLabels: false
IndentWrappedFunctionNames: true
KeepEmptyLinesAtTheStartOfBlocks: false
NamespaceIndentation: All
PointerAlignment: Left
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesInAngles: false
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
SortIncludes: false
ReflowComments: true
IncludeCategories:
- Regex: '^".*'
Priority: 3
- Regex: '^<.*'
Priority: 2
SortIncludes: true
...@@ -61,7 +61,7 @@ static shared_ptr<pattern::op::Label> get_broadcast_label(shared_ptr<pattern::Ma ...@@ -61,7 +61,7 @@ static shared_ptr<pattern::op::Label> get_broadcast_label(shared_ptr<pattern::Ma
//`simplify_concat` identifies slices-concat sequences //`simplify_concat` identifies slices-concat sequences
// that cancel each other. Namely it replaces subgraphs // that cancel each other. Namely it replaces subgraphs
//similar to the one below with `arg` // similar to the one below with `arg`
// //
// +----------+ // +----------+
// +----+slice(n/2..n)---+ // +----+slice(n/2..n)---+
...@@ -107,7 +107,8 @@ static bool simplify_concat(shared_ptr<Node> n) ...@@ -107,7 +107,8 @@ static bool simplify_concat(shared_ptr<Node> n)
return false; return false;
} }
//slice chunks should be slice in the same order as slice nodes in concat's argument list // slice chunks should be slice in the same order as slice nodes in concat's argument
// list
auto cur_lower_bounds = slice->get_lower_bounds(); auto cur_lower_bounds = slice->get_lower_bounds();
if (cur_lower_bounds < prev_lower_bounds) if (cur_lower_bounds < prev_lower_bounds)
{ {
...@@ -116,7 +117,7 @@ static bool simplify_concat(shared_ptr<Node> n) ...@@ -116,7 +117,7 @@ static bool simplify_concat(shared_ptr<Node> n)
} }
prev_lower_bounds.assign(cur_lower_bounds.begin(), cur_lower_bounds.end()); prev_lower_bounds.assign(cur_lower_bounds.begin(), cur_lower_bounds.end());
//slice shapes need to match // slice shapes need to match
if (slice->get_shape() != prev_slice_shape) if (slice->get_shape() != prev_slice_shape)
{ {
NGRAPH_DEBUG << slice->get_name() NGRAPH_DEBUG << slice->get_name()
...@@ -145,7 +146,7 @@ static bool simplify_concat(shared_ptr<Node> n) ...@@ -145,7 +146,7 @@ static bool simplify_concat(shared_ptr<Node> n)
return false; return false;
} }
//check that no other node uses slices and reshapes // check that no other node uses slices and reshapes
if (auto rcarg = dynamic_pointer_cast<op::Reshape>(carg)) if (auto rcarg = dynamic_pointer_cast<op::Reshape>(carg))
{ {
auto default_shape = get_default_order(rcarg->get_argument(0)->get_shape()); auto default_shape = get_default_order(rcarg->get_argument(0)->get_shape());
...@@ -171,7 +172,7 @@ static bool simplify_concat(shared_ptr<Node> n) ...@@ -171,7 +172,7 @@ static bool simplify_concat(shared_ptr<Node> n)
auto btip_shape = branch_tip->get_shape(); auto btip_shape = branch_tip->get_shape();
//slices should cover all elements // slices should cover all elements
if (shape_size(btip_shape) != shape_size(n->get_shape())) if (shape_size(btip_shape) != shape_size(n->get_shape()))
{ {
NGRAPH_DEBUG << "The number of elements in Concat (" << shape_size(n->get_shape()) NGRAPH_DEBUG << "The number of elements in Concat (" << shape_size(n->get_shape())
...@@ -242,10 +243,10 @@ static bool simplify_concat(shared_ptr<Node> n) ...@@ -242,10 +243,10 @@ static bool simplify_concat(shared_ptr<Node> n)
//`simplify_multiply` optimizes the following 4 *base* cases //`simplify_multiply` optimizes the following 4 *base* cases
//(8 cases in total including variants due to commutativity) //(8 cases in total including variants due to commutativity)
// //
//a * 0 -> 0 // a * 0 -> 0
//a * broadcast(0) -> broadcast(0) // a * broadcast(0) -> broadcast(0)
//a * 1 -> a // a * 1 -> a
//a * broadcast(1) -> a // a * broadcast(1) -> a
static bool simplify_multiply(shared_ptr<Node> n) static bool simplify_multiply(shared_ptr<Node> n)
{ {
NGRAPH_DEBUG << "In simplify_multiply for " << n->get_name(); NGRAPH_DEBUG << "In simplify_multiply for " << n->get_name();
...@@ -280,8 +281,8 @@ static bool simplify_multiply(shared_ptr<Node> n) ...@@ -280,8 +281,8 @@ static bool simplify_multiply(shared_ptr<Node> n)
//`simplify_add` optimizes the following 2 *base* cases //`simplify_add` optimizes the following 2 *base* cases
//(4 cases in total including variants due to commutativity) //(4 cases in total including variants due to commutativity)
// //
//a + 0 -> a // a + 0 -> a
//a + broadcast(0) -> a // a + broadcast(0) -> a
static bool simplify_add(shared_ptr<Node> n) static bool simplify_add(shared_ptr<Node> n)
{ {
NGRAPH_DEBUG << "In simplify_add for " << n->get_name(); NGRAPH_DEBUG << "In simplify_add for " << n->get_name();
...@@ -406,10 +407,10 @@ static shared_ptr<Node> get_prod_constant(shared_ptr<op::Constant> cnst, size_t ...@@ -406,10 +407,10 @@ static shared_ptr<Node> get_prod_constant(shared_ptr<op::Constant> cnst, size_t
} }
//`simplify_reduction` optimizes the following case: //`simplify_reduction` optimizes the following case:
//sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant) // sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant * shape_size(reduction_axes) // where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
//product(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant) // product(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant ^ shape_size(reduction_axes) // where constant2's values are equal to scalar_constant ^ shape_size(reduction_axes)
template <typename T, shared_ptr<Node> (*F)(shared_ptr<op::Constant> cnst, size_t multiplier)> template <typename T, shared_ptr<Node> (*F)(shared_ptr<op::Constant> cnst, size_t multiplier)>
static bool simplify_reduction(shared_ptr<Node> n) static bool simplify_reduction(shared_ptr<Node> n)
{ {
......
...@@ -83,8 +83,8 @@ public: ...@@ -83,8 +83,8 @@ public:
construct_constant_select(); construct_constant_select();
} }
//this allows to specify the order in which matchers will be run // this allows to specify the order in which matchers will be run
//and also allows to register the same matcher more than once // and also allows to register the same matcher more than once
ConstantFolding(const std::vector<CFTransformations>& transformations, ConstantFolding(const std::vector<CFTransformations>& transformations,
const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap()) const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
: GraphRewrite() : GraphRewrite()
......
...@@ -51,7 +51,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant, ...@@ -51,7 +51,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
shared_ptr<Node> unary, shared_ptr<Node> unary,
NodeExecutorTy func) NodeExecutorTy func)
{ {
//check sqrt arg // check sqrt arg
if (std::dynamic_pointer_cast<op::Sqrt>(unary)) if (std::dynamic_pointer_cast<op::Sqrt>(unary))
{ {
std::vector<T> values{constant->get_vector<T>()}; std::vector<T> values{constant->get_vector<T>()};
......
...@@ -449,7 +449,7 @@ void pass::CoreFusion::construct_reshape_broadcast() ...@@ -449,7 +449,7 @@ void pass::CoreFusion::construct_reshape_broadcast()
auto reshape1_m = static_pointer_cast<op::Reshape>(broadcast_m->get_argument(0)); auto reshape1_m = static_pointer_cast<op::Reshape>(broadcast_m->get_argument(0));
auto input_m = m.get_pattern_map()[input]; auto input_m = m.get_pattern_map()[input];
//it doesn't seem to make sense to support shapes : [0] or [1] // it doesn't seem to make sense to support shapes : [0] or [1]
if (input_m->get_shape().size() != 1 || input_m->get_shape().at(0) < 2) if (input_m->get_shape().size() != 1 || input_m->get_shape().at(0) < 2)
{ {
NGRAPH_DEBUG << "input_m isn't a scalar or contains zero dimension"; NGRAPH_DEBUG << "input_m isn't a scalar or contains zero dimension";
...@@ -458,8 +458,8 @@ void pass::CoreFusion::construct_reshape_broadcast() ...@@ -458,8 +458,8 @@ void pass::CoreFusion::construct_reshape_broadcast()
size_t dim = input_m->get_shape().at(0); size_t dim = input_m->get_shape().at(0);
//We are going to support the most common case where broadcast doesn't add 1-dimensions // We are going to support the most common case where broadcast doesn't add 1-dimensions
//since it's also very simple to implement // since it's also very simple to implement
size_t dim_one_count = 0; size_t dim_one_count = 0;
for (auto d : reshape1_m->get_shape()) for (auto d : reshape1_m->get_shape())
{ {
...@@ -503,13 +503,13 @@ void pass::CoreFusion::construct_reshape_broadcast() ...@@ -503,13 +503,13 @@ void pass::CoreFusion::construct_reshape_broadcast()
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
// conv(56w3s1) conv(28w3s2) // conv(56w3s1) conv(28w3s2)
// | | // | |
// conv(56w1s1) ==> conv(28w1s1) // conv(56w1s1) ==> conv(28w1s1)
// | | // | |
//elt------------56 elt------------pool(28s2) // elt------------56 elt------------pool(28s2)
// | | | | // | | | |
//conv(28w1s2) conv(28w1s2) conv(28w1s1) conv(28w1s1) // conv(28w1s2) conv(28w1s2) conv(28w1s1) conv(28w1s1)
void pass::CoreFusion::construct_optimized_strided_conv() void pass::CoreFusion::construct_optimized_strided_conv()
{ {
Shape win_size_1{1, 1, 1, 1}; Shape win_size_1{1, 1, 1, 1};
......
...@@ -44,7 +44,8 @@ using namespace ngraph; ...@@ -44,7 +44,8 @@ using namespace ngraph;
// replace nodes `Abs2` and `Constant1` if needed // replace nodes `Abs2` and `Constant1` if needed
// This gives Matchers a nice cascading property. For example, if m1 folds `Abs2(Constant1)` // This gives Matchers a nice cascading property. For example, if m1 folds `Abs2(Constant1)`
// and `m2` folds `Neg3(Constant1)` when `m3` is called on `Add4` it will discover that // and `m2` folds `Neg3(Constant1)` when `m3` is called on `Add4` it will discover that
// both `Abs2` and `Neg3` were already replaced by constants, so `Add4` will also be folded into one. // both `Abs2` and `Neg3` were already replaced by constants, so `Add4` will also be folded into
// one.
// If any Matcher succeeds the rest of the matchers will **not** be called. // If any Matcher succeeds the rest of the matchers will **not** be called.
// E.g. if `m1` succeeds and replaces `Abs2` with a new constant, nor `m2` or `m3` will be called // E.g. if `m1` succeeds and replaces `Abs2` with a new constant, nor `m2` or `m3` will be called
// However, sometimes, you will need more than one fusion occur on the same node. // However, sometimes, you will need more than one fusion occur on the same node.
...@@ -56,7 +57,8 @@ using namespace ngraph; ...@@ -56,7 +57,8 @@ using namespace ngraph;
// a) need more than one fusion occur on the same node // a) need more than one fusion occur on the same node
// b) you are modifying nodes after the current node in the topological order // b) you are modifying nodes after the current node in the topological order
// c) there's no linear order of fusions which will give // c) there's no linear order of fusions which will give
// the correct final fusion. i.e. the same fusion needs to occur before and after some other fusion // the correct final fusion. i.e. the same fusion needs to occur before and after some other
// fusion
bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f) bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
{ {
...@@ -113,7 +115,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -113,7 +115,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
} while (rewritten && m_matchers.size() > 0 && tries--); } while (rewritten && m_matchers.size() > 0 && tries--);
m_matchers.assign(original_matchers.begin(), original_matchers.end()); m_matchers.assign(original_matchers.begin(), original_matchers.end());
return (NUM_TRIES - tries) > 1; //this means a graph was transformed return (NUM_TRIES - tries) > 1; // this means a graph was transformed
} }
static vector<regex> initialize_fusion_regexes() static vector<regex> initialize_fusion_regexes()
...@@ -135,7 +137,7 @@ static vector<regex> initialize_fusion_regexes() ...@@ -135,7 +137,7 @@ static vector<regex> initialize_fusion_regexes()
bool pass::GraphRewrite::is_enabled(const shared_ptr<pattern::Matcher>& m) const bool pass::GraphRewrite::is_enabled(const shared_ptr<pattern::Matcher>& m) const
{ {
//note, regexes are static to avoid re-initialization // note, regexes are static to avoid re-initialization
static const auto regexes = initialize_fusion_regexes(); static const auto regexes = initialize_fusion_regexes();
for (const auto& regex : regexes) for (const auto& regex : regexes)
......
...@@ -37,7 +37,8 @@ using namespace ngraph; ...@@ -37,7 +37,8 @@ using namespace ngraph;
static bool replace_broadcast_like(const std::shared_ptr<ngraph::Node>& node) static bool replace_broadcast_like(const std::shared_ptr<ngraph::Node>& node)
{ {
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument // Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like"
// argument
auto broadcast_like = static_pointer_cast<op::BroadcastLike>(node); auto broadcast_like = static_pointer_cast<op::BroadcastLike>(node);
replace_node(node, replace_node(node,
make_shared<op::Broadcast>(broadcast_like->get_argument(0), make_shared<op::Broadcast>(broadcast_like->get_argument(0),
......
...@@ -86,8 +86,8 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) ...@@ -86,8 +86,8 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
for (auto f_pair : fs) for (auto f_pair : fs)
{ {
shared_ptr<Function> f = f_pair.first; shared_ptr<Function> f = f_pair.first;
// This checks is to skip the graph optimization when the graph pass relies on static shape // This checks is to skip the graph optimization when the graph pass relies on
// but the function state is dynamic. // static shape but the function state is dynamic.
// we update the function dynamic state only if we run the graph pass successfully. // we update the function dynamic state only if we run the graph pass successfully.
if (function_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && if (function_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
f_pair.second) f_pair.second)
......
...@@ -81,7 +81,8 @@ static bool eliminate_slice(const std::shared_ptr<Node>& node) ...@@ -81,7 +81,8 @@ static bool eliminate_slice(const std::shared_ptr<Node>& node)
static bool replace_broadcast_like(const std::shared_ptr<Node>& node) static bool replace_broadcast_like(const std::shared_ptr<Node>& node)
{ {
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument // Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like"
// argument
auto broadcast_like = std::static_pointer_cast<op::BroadcastLike>(node); auto broadcast_like = std::static_pointer_cast<op::BroadcastLike>(node);
replace_node(node, replace_node(node,
std::make_shared<op::Broadcast>(broadcast_like->get_argument(0), std::make_shared<op::Broadcast>(broadcast_like->get_argument(0),
......
...@@ -53,8 +53,10 @@ pass::PassConfig::PassConfig() ...@@ -53,8 +53,10 @@ pass::PassConfig::PassConfig()
// //
// Parses the semi-colon separated environment string passed through NGRAPH_PASS_ATTRIBUTES // Parses the semi-colon separated environment string passed through NGRAPH_PASS_ATTRIBUTES
// and returns the pass attributes and whether they should be enabled or disabled in the // and returns the pass attributes and whether they should be enabled or disabled in the
// provided unordered_map. Naming of pass attributes is up to the backends // provided unordered_map. Naming of pass attributes is up to the backends.
// E.g., NGRAPH_PASS_ATTRIBUTES="OptimizeForMemory=0;MemoryAssignment::ReuseMemory=1;UseDefaultLayouts" //
// For example:
// NGRAPH_PASS_ATTRIBUTES="OptimizeForMemory=0;MemoryAssignment::ReuseMemory=1;UseDefaultLayouts"
// would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on // would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on
// "UseDefaultLayouts" // "UseDefaultLayouts"
// //
......
...@@ -229,8 +229,8 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape() ...@@ -229,8 +229,8 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
auto driver_op = first_bound_reshape_op->get_argument(0); auto driver_op = first_bound_reshape_op->get_argument(0);
auto last_bound_reshape_op = reshape_node_vector.back(); auto last_bound_reshape_op = reshape_node_vector.back();
// Need to check if the user of the last bound op is a reshape since the last reshape is allowed // Need to check if the user of the last bound op is a reshape since the last reshape is
// to have fan-out but the matcher will discard any reshape if it has fan-out // allowed to have fan-out but the matcher will discard any reshape if it has fan-out
auto user_of_last_bound_reshape_op = last_bound_reshape_op->get_users(true)[0]; auto user_of_last_bound_reshape_op = last_bound_reshape_op->get_users(true)[0];
if (std::dynamic_pointer_cast<op::Reshape>(user_of_last_bound_reshape_op)) if (std::dynamic_pointer_cast<op::Reshape>(user_of_last_bound_reshape_op))
{ {
......
...@@ -128,7 +128,7 @@ static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n) ...@@ -128,7 +128,7 @@ static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n)
return default_reshape; return default_reshape;
} }
//compute an axis order that converts the given axis order to default // compute an axis order that converts the given axis order to default
static AxisSet get_quantization_axes_in_default_order(shared_ptr<op::Reshape> arg_reshape, static AxisSet get_quantization_axes_in_default_order(shared_ptr<op::Reshape> arg_reshape,
const AxisSet& old_axis_set) const AxisSet& old_axis_set)
{ {
...@@ -147,20 +147,20 @@ struct Swimmer ...@@ -147,20 +147,20 @@ struct Swimmer
shared_ptr<op::Reshape> reshape; shared_ptr<op::Reshape> reshape;
}; };
//Swim is used to push/"swim" reshapes towards paramaters. // Swim is used to push/"swim" reshapes towards paramaters.
//This is typically done for binary ops when // This is typically done for binary ops when
//one operand is in nchw, while the other one is nhwc // one operand is in nchw, while the other one is nhwc
//we prefer nchw since a lot of ngraph ops require this format, // we prefer nchw since a lot of ngraph ops require this format,
//so keeping things in nchw allows us to eliminate as many reshapes // so keeping things in nchw allows us to eliminate as many reshapes
//as possible // as possible
void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
{ {
Swimmer sw{input, reshape}; Swimmer sw{input, reshape};
list<Swimmer> work_queue; list<Swimmer> work_queue;
work_queue.push_back(sw); work_queue.push_back(sw);
//TODO: if we support more ops (especially, with >1 args) // TODO: if we support more ops (especially, with >1 args)
//we will need to keep track of nodes we visited and their reshapes // we will need to keep track of nodes we visited and their reshapes
while (work_queue.size() > 0) while (work_queue.size() > 0)
{ {
auto csw = work_queue.front(); auto csw = work_queue.front();
...@@ -226,10 +226,10 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -226,10 +226,10 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes); broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes);
csw.input.replace_source_output(new_broadcast->output(0)); csw.input.replace_source_output(new_broadcast->output(0));
} }
//TODO: Add cases to push through Reshape and BinaryElementwiseArithmetic // TODO: Add cases to push through Reshape and BinaryElementwiseArithmetic
else else
{ {
//materialize // materialize
auto new_reshape = csw.reshape->copy_with_new_args({n}); auto new_reshape = csw.reshape->copy_with_new_args({n});
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape); NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
csw.input.replace_source_output(new_reshape->output(0)); csw.input.replace_source_output(new_reshape->output(0));
...@@ -237,10 +237,10 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -237,10 +237,10 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
} }
} }
//convert_binary_to_default_order is used when one of the arguments // convert_binary_to_default_order is used when one of the arguments
//of a binary op isn't in the default format (i.e. nhwc instead of nchw) // of a binary op isn't in the default format (i.e. nhwc instead of nchw)
//We have to normalize this other argument to nchw by swimming nchw towards parameters // We have to normalize this other argument to nchw by swimming nchw towards parameters
//as far as we can // as far as we can
static void convert_binary_to_default_order(shared_ptr<Node> binary, static void convert_binary_to_default_order(shared_ptr<Node> binary,
const Input<Node>& input, const Input<Node>& input,
shared_ptr<Node> right, shared_ptr<Node> right,
...@@ -256,7 +256,7 @@ static void convert_binary_to_default_order(shared_ptr<Node> binary, ...@@ -256,7 +256,7 @@ static void convert_binary_to_default_order(shared_ptr<Node> binary,
auto new_reshape = make_reshape(left, perm_to_def, new_shape); auto new_reshape = make_reshape(left, perm_to_def, new_shape);
NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to " NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to "
<< left->get_name(); << left->get_name();
//this should now insert and swim reshape on right // this should now insert and swim reshape on right
swim(input, new_reshape); swim(input, new_reshape);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
write_reshapemap(reorders, binary, read_reshapemap(reorders, right)); write_reshapemap(reorders, binary, read_reshapemap(reorders, right));
...@@ -266,7 +266,7 @@ static void materialize_shapes(shared_ptr<Node> n, ...@@ -266,7 +266,7 @@ static void materialize_shapes(shared_ptr<Node> n,
ReshapeMap& reorders, ReshapeMap& reorders,
set<shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
//skip multiple output nodes and deal with GOEs exclusively // skip multiple output nodes and deal with GOEs exclusively
if (n->get_output_size() > 1) if (n->get_output_size() > 1)
{ {
return; return;
...@@ -274,7 +274,7 @@ static void materialize_shapes(shared_ptr<Node> n, ...@@ -274,7 +274,7 @@ static void materialize_shapes(shared_ptr<Node> n,
for (size_t i = 0; i < n->get_arguments().size(); i++) for (size_t i = 0; i < n->get_arguments().size(); i++)
{ {
//materialize all pending reshapes, flush pending reshapes // materialize all pending reshapes, flush pending reshapes
auto arg = n->get_argument(i); auto arg = n->get_argument(i);
if (reorders.count(arg) != 0) if (reorders.count(arg) != 0)
{ {
...@@ -288,7 +288,7 @@ static void materialize_shapes(shared_ptr<Node> n, ...@@ -288,7 +288,7 @@ static void materialize_shapes(shared_ptr<Node> n,
// Insert if arg needs to be transposed. // Insert if arg needs to be transposed.
insert_reshape(n, arg_reshape, i); insert_reshape(n, arg_reshape, i);
} }
//no swimming up // no swimming up
} }
} }
write_reshapemap(reorders, n, create_default_reshape(n)); write_reshapemap(reorders, n, create_default_reshape(n));
...@@ -312,12 +312,12 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape, ...@@ -312,12 +312,12 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape,
} }
else else
{ {
//combine both reshapes // combine both reshapes
auto new_reshape = combine_reshapes(orig_reshape, reshape); auto new_reshape = combine_reshapes(orig_reshape, reshape);
//remove original reshape now it's combined with a new one // remove original reshape now it's combined with a new one
//should be safe to remove an already detached node // should be safe to remove an already detached node
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete); mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
//replace reshape with combined one // replace reshape with combined one
ngraph::replace_node(reshape, new_reshape); ngraph::replace_node(reshape, new_reshape);
mark_reshape_for_deletion(new_reshape, reshapes_to_delete); mark_reshape_for_deletion(new_reshape, reshapes_to_delete);
write_reshapemap(reorders, new_reshape, new_reshape); write_reshapemap(reorders, new_reshape, new_reshape);
...@@ -345,7 +345,7 @@ static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary ...@@ -345,7 +345,7 @@ static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary
NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for " NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for "
<< binary->get_name(); << binary->get_name();
write_reshapemap(reorders, binary, read_reshapemap(reorders, left)); write_reshapemap(reorders, binary, read_reshapemap(reorders, left));
//at this point, both reshapes will be eventually removed // at this point, both reshapes will be eventually removed
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
} }
...@@ -477,7 +477,7 @@ static void sink_concat(shared_ptr<op::Concat> n, ...@@ -477,7 +477,7 @@ static void sink_concat(shared_ptr<op::Concat> n,
auto new_axis = order.at(n->get_concatenation_axis()); auto new_axis = order.at(n->get_concatenation_axis());
auto new_concat = make_shared<op::Concat>(new_args, new_axis); auto new_concat = make_shared<op::Concat>(new_args, new_axis);
//put back the original arguments // put back the original arguments
for (size_t i = 0; i < new_concat->get_input_size(); i++) for (size_t i = 0; i < new_concat->get_input_size(); i++)
{ {
ngraph::replace_node(new_args.at(i), n->get_argument(i)); ngraph::replace_node(new_args.at(i), n->get_argument(i));
...@@ -507,26 +507,26 @@ static void sink_dequantize(shared_ptr<op::Dequantize> dequantize, ...@@ -507,26 +507,26 @@ static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
write_reshapemap(reorders, new_dequantize, arg_reshape); write_reshapemap(reorders, new_dequantize, arg_reshape);
} }
//The goal of ReshapeSinking is to remove // The goal of ReshapeSinking is to remove
//round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc) // round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc)
//around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool) // around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool)
//This is achieved by both **sinking**, propagating reshapes // This is achieved by both **sinking**, propagating reshapes
//through ops towards op::Results, // through ops towards op::Results,
//or **swimming** Reshapes up towards op::Parameter // or **swimming** Reshapes up towards op::Parameter
//For each op type we support we can either combine // For each op type we support we can either combine
//two reshapes by replacing the existing Reshape, // two reshapes by replacing the existing Reshape,
//materialize pending reshapes if they can't be propagated through op // materialize pending reshapes if they can't be propagated through op
bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> f) bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> f)
{ {
ReshapeMap reorders; ReshapeMap reorders;
NodeVector results; NodeVector results;
set<shared_ptr<Node>> reshapes_to_delete; set<shared_ptr<Node>> reshapes_to_delete;
//STEP 1 : Sink or Swim reshapes away for op clusters // STEP 1 : Sink or Swim reshapes away for op clusters
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
NGRAPH_DEBUG << "Start: Processing node " << n->get_name(); NGRAPH_DEBUG << "Start: Processing node " << n->get_name();
//collect all Result nodes for a sanity check // collect all Result nodes for a sanity check
if (n->is_output()) if (n->is_output())
{ {
results.push_back(n); results.push_back(n);
...@@ -592,13 +592,13 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> ...@@ -592,13 +592,13 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
NGRAPH_DEBUG << "End: Processing node " << n->get_name(); NGRAPH_DEBUG << "End: Processing node " << n->get_name();
} }
//STEP 2: purge all the reshapes we either sunk or swam. // STEP 2: purge all the reshapes we either sunk or swam.
for (auto r : reshapes_to_delete) for (auto r : reshapes_to_delete)
{ {
delete_reshape(r); delete_reshape(r);
} }
//make sure shapes are always materialized before results // make sure shapes are always materialized before results
for (auto r : results) for (auto r : results)
{ {
NGRAPH_CHECK(r->get_shape() == r->get_argument(0)->get_shape() && NGRAPH_CHECK(r->get_shape() == r->get_argument(0)->get_shape() &&
...@@ -609,7 +609,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> ...@@ -609,7 +609,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
*r->get_argument(0)); *r->get_argument(0));
} }
//STEP 3: fix wrong shape info wholesale // STEP 3: fix wrong shape info wholesale
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
n->revalidate_and_infer_types(); n->revalidate_and_infer_types();
......
...@@ -81,8 +81,8 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f) ...@@ -81,8 +81,8 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
{ {
bool replaced = false; bool replaced = false;
auto cvals = vector<string>(0); auto cvals = vector<string>(0);
// we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op // we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar
// as an internal node (i.e. a node that isn't an argument to `op::Result`) // op as an internal node (i.e. a node that isn't an argument to `op::Result`)
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
// don't try to replace `op::Result` // don't try to replace `op::Result`
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment