Unverified Commit 0ec3b01e authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into cyphers/patterndesc

parents a489a2ba 9476e0f4
...@@ -118,13 +118,14 @@ autodiff::Adjoints::Adjoints(const OutputVector& ys, const OutputVector& cs) ...@@ -118,13 +118,14 @@ autodiff::Adjoints::Adjoints(const OutputVector& ys, const OutputVector& cs)
auto node = nodes_to_check.front(); auto node = nodes_to_check.front();
nodes_to_check.pop_front(); nodes_to_check.pop_front();
// Look for nodes that will be available when this node is done // Look for nodes that will be available when this node is done
for (auto arg : node->get_arguments()) for (auto input : node->inputs())
{ {
auto count_it = parent_counts.find(arg); auto input_source_node = input.get_source_output().get_node_shared_ptr();
auto count_it = parent_counts.find(input_source_node);
count_it->second--; count_it->second--;
if (0 == count_it->second) if (0 == count_it->second)
{ {
nodes_to_check.push_front(arg); nodes_to_check.push_front(input_source_node);
} }
} }
OutputVector deltas = get(node); OutputVector deltas = get(node);
......
...@@ -86,9 +86,9 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results, ...@@ -86,9 +86,9 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
if (instances_seen.insert(n).second) if (instances_seen.insert(n).second)
{ {
f(n->shared_from_this()); f(n->shared_from_this());
for (auto& arg : n->get_arguments()) for (auto input : n->inputs())
{ {
stack.push(arg.get()); stack.push(input.get_source_output().get_node());
} }
if (include_control_deps) if (include_control_deps)
......
...@@ -111,9 +111,9 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -111,9 +111,9 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
size_t pos = 0; size_t pos = 0;
for (auto arg : get_arguments()) for (auto input : inputs())
{ {
auto arg_shape = arg->get_shape(); auto arg_shape = input.get_shape();
auto slice_width = arg_shape[m_concatenation_axis]; auto slice_width = arg_shape[m_concatenation_axis];
...@@ -123,7 +123,7 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -123,7 +123,7 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
arg_delta_slice_upper[m_concatenation_axis] = next_pos; arg_delta_slice_upper[m_concatenation_axis] = next_pos;
adjoints.add_delta( adjoints.add_delta(
arg, input.get_source_output(),
make_shared<op::Slice>( make_shared<op::Slice>(
delta, arg_delta_slice_lower, arg_delta_slice_upper, arg_delta_slice_strides)); delta, arg_delta_slice_lower, arg_delta_slice_upper, arg_delta_slice_strides));
......
...@@ -231,7 +231,7 @@ void op::MaxPoolBackprop::validate_and_infer_types() ...@@ -231,7 +231,7 @@ void op::MaxPoolBackprop::validate_and_infer_types()
shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
if (this->get_arguments().size() == 3) if (this->get_input_size() == 3)
{ {
return make_shared<op::MaxPoolBackprop>(new_args.at(0), return make_shared<op::MaxPoolBackprop>(new_args.at(0),
new_args.at(1), new_args.at(1),
...@@ -259,7 +259,7 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -259,7 +259,7 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto operand = get_argument(0); auto operand = input(0).get_source_output();
auto backprop = auto backprop =
make_shared<op::MaxPoolBackprop>(operand, make_shared<op::MaxPoolBackprop>(operand,
delta, delta,
......
...@@ -235,7 +235,8 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant, ...@@ -235,7 +235,8 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
{ {
auto out_shape = pad->get_shape(); auto out_shape = pad->get_shape();
vector<T> out_vec(shape_size(out_shape)); vector<T> out_vec(shape_size(out_shape));
auto pad_value = std::static_pointer_cast<op::Constant>(pad->get_argument(1)); auto pad_value = std::static_pointer_cast<op::Constant>(
pad->input(1).get_source_output().get_node_shared_ptr());
if (func != nullptr) if (func != nullptr)
{ {
...@@ -1375,9 +1376,10 @@ void pass::ConstantFolding::construct_constant_dequantize() ...@@ -1375,9 +1376,10 @@ void pass::ConstantFolding::construct_constant_dequantize()
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto dequant_match = pattern_map[dequant]; auto dequant_match = pattern_map[dequant];
auto dequantize_op = dynamic_pointer_cast<op::Dequantize>(dequant_match); auto dequantize_op = dynamic_pointer_cast<op::Dequantize>(dequant_match);
auto args = dequant_match->get_arguments(); auto scale = dynamic_pointer_cast<op::Constant>(
auto scale = dynamic_pointer_cast<op::Constant>(args[1]); dequant_match->input(1).get_source_output().get_node_shared_ptr());
auto offset = dynamic_pointer_cast<op::Constant>(args[2]); auto offset = dynamic_pointer_cast<op::Constant>(
dequant_match->input(2).get_source_output().get_node_shared_ptr());
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
...@@ -1452,8 +1454,10 @@ void pass::ConstantFolding::construct_constant_quantize() ...@@ -1452,8 +1454,10 @@ void pass::ConstantFolding::construct_constant_quantize()
auto quant_match = pattern_map[quant]; auto quant_match = pattern_map[quant];
auto quantize_op = dynamic_pointer_cast<op::Quantize>(quant_match); auto quantize_op = dynamic_pointer_cast<op::Quantize>(quant_match);
auto args = quant_match->get_arguments(); auto args = quant_match->get_arguments();
auto scale = static_pointer_cast<op::Constant>(args[1]); auto scale = static_pointer_cast<op::Constant>(
auto offset = static_pointer_cast<op::Constant>(args[2]); quant_match->input(1).get_source_output().get_node_shared_ptr());
auto offset = static_pointer_cast<op::Constant>(
quant_match->input(2).get_source_output().get_node_shared_ptr());
auto type = quant_match->get_element_type(); auto type = quant_match->get_element_type();
......
...@@ -89,7 +89,7 @@ static bool cse_reshape(shared_ptr<Node> a, shared_ptr<Node> b) ...@@ -89,7 +89,7 @@ static bool cse_reshape(shared_ptr<Node> a, shared_ptr<Node> b)
auto reshape_a = static_pointer_cast<ngraph::op::Reshape>(a); auto reshape_a = static_pointer_cast<ngraph::op::Reshape>(a);
auto reshape_b = static_pointer_cast<ngraph::op::Reshape>(b); auto reshape_b = static_pointer_cast<ngraph::op::Reshape>(b);
return (a->get_argument(0) == b->get_argument(0)) && return (a->input(0).get_source_output() == b->input(0).get_source_output()) &&
(reshape_a->get_input_order() == reshape_b->get_input_order()) && (reshape_a->get_input_order() == reshape_b->get_input_order()) &&
(reshape_a->get_output_shape() == reshape_b->get_output_shape()); (reshape_a->get_output_shape() == reshape_b->get_output_shape());
} }
...@@ -100,7 +100,7 @@ static bool cse_broadcast(shared_ptr<Node> a, shared_ptr<Node> b) ...@@ -100,7 +100,7 @@ static bool cse_broadcast(shared_ptr<Node> a, shared_ptr<Node> b)
auto broadcast_a = static_pointer_cast<ngraph::op::Broadcast>(a); auto broadcast_a = static_pointer_cast<ngraph::op::Broadcast>(a);
auto broadcast_b = static_pointer_cast<ngraph::op::Broadcast>(b); auto broadcast_b = static_pointer_cast<ngraph::op::Broadcast>(b);
return (a->get_argument(0) == b->get_argument(0)) && return (a->input(0).get_source_output() == b->input(0).get_source_output()) &&
(broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) && (broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) &&
(broadcast_a->get_broadcast_shape() == broadcast_b->get_broadcast_shape()); (broadcast_a->get_broadcast_shape() == broadcast_b->get_broadcast_shape());
} }
...@@ -108,15 +108,17 @@ static bool cse_unarywise(shared_ptr<Node> a, shared_ptr<Node> b) ...@@ -108,15 +108,17 @@ static bool cse_unarywise(shared_ptr<Node> a, shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_unarywise for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_unarywise for " << a->get_name() << " and " << b->get_name();
return a->get_argument(0) == b->get_argument(0); return a->input(0).get_source_output() == b->input(0).get_source_output();
} }
static bool cse_binarywise(shared_ptr<Node> a, shared_ptr<Node> b) static bool cse_binarywise(shared_ptr<Node> a, shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_binary for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_binary for " << a->get_name() << " and " << b->get_name();
return (a->get_argument(0) == b->get_argument(0) && a->get_argument(1) == b->get_argument(1)) || return (a->input(0).get_source_output() == b->input(0).get_source_output() &&
(a->get_argument(1) == b->get_argument(0) && a->get_argument(0) == b->get_argument(1)); a->input(1).get_source_output() == b->input(1).get_source_output()) ||
(a->input(1).get_source_output() == b->input(0).get_source_output() &&
a->input(0).get_source_output() == b->input(1).get_source_output());
} }
static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b) static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b)
...@@ -126,7 +128,7 @@ static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b) ...@@ -126,7 +128,7 @@ static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b)
auto ar_a = static_pointer_cast<op::util::ArithmeticReduction>(a); auto ar_a = static_pointer_cast<op::util::ArithmeticReduction>(a);
auto ar_b = static_pointer_cast<op::util::ArithmeticReduction>(b); auto ar_b = static_pointer_cast<op::util::ArithmeticReduction>(b);
return ar_a->get_argument(0) == ar_b->get_argument(0) && return ar_a->input(0).get_source_output() == ar_b->input(0).get_source_output() &&
ar_a->get_reduction_axes() == ar_b->get_reduction_axes(); ar_a->get_reduction_axes() == ar_b->get_reduction_axes();
} }
...@@ -137,7 +139,7 @@ static bool cse_one_hot(shared_ptr<Node> a, shared_ptr<Node> b) ...@@ -137,7 +139,7 @@ static bool cse_one_hot(shared_ptr<Node> a, shared_ptr<Node> b)
auto one_hot_a = static_pointer_cast<ngraph::op::OneHot>(a); auto one_hot_a = static_pointer_cast<ngraph::op::OneHot>(a);
auto one_hot_b = static_pointer_cast<ngraph::op::OneHot>(b); auto one_hot_b = static_pointer_cast<ngraph::op::OneHot>(b);
return (a->get_argument(0) == b->get_argument(0)) && return (a->input(0).get_source_output() == b->input(0).get_source_output()) &&
(one_hot_a->get_one_hot_axis() == one_hot_b->get_one_hot_axis()) && (one_hot_a->get_one_hot_axis() == one_hot_b->get_one_hot_axis()) &&
(a->get_shape() == b->get_shape()); (a->get_shape() == b->get_shape());
} }
...@@ -247,7 +249,11 @@ namespace std ...@@ -247,7 +249,11 @@ namespace std
arg_ids.push_back(type_hash); arg_ids.push_back(type_hash);
auto cargs = k.get_node()->get_arguments(); std::vector<Output<Node>> cargs;
for (auto input : k.get_node()->inputs())
{
cargs.push_back(input.get_source_output());
}
// TODO: Do we need another map, so we could // TODO: Do we need another map, so we could
// specify how to compute hash for each op? // specify how to compute hash for each op?
...@@ -258,7 +264,8 @@ namespace std ...@@ -258,7 +264,8 @@ namespace std
for (auto arg : cargs) for (auto arg : cargs)
{ {
arg_ids.push_back(arg->get_instance_id()); arg_ids.push_back(arg.get_node_shared_ptr()->get_instance_id());
arg_ids.push_back(arg.get_index());
} }
auto hashc = ngraph::hash_combine(arg_ids); auto hashc = ngraph::hash_combine(arg_ids);
......
...@@ -49,12 +49,13 @@ bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function) ...@@ -49,12 +49,13 @@ bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
else else
{ {
bool cacheable = true; bool cacheable = true;
for (auto arg : node->get_arguments()) for (auto input : node->inputs())
{ {
NGRAPH_DEBUG << "propagate cacheability: arg is " << arg->get_name(); auto input_value_node = input.get_source_output().get_node_shared_ptr();
if (arg->is_op()) NGRAPH_DEBUG << "propagate cacheability: arg is " << *input_value_node;
if (input_value_node->is_op())
{ {
auto arg_op = static_pointer_cast<op::Op>(arg); auto arg_op = static_pointer_cast<op::Op>(input_value_node);
auto arg_op_annotations = arg_op->get_op_annotations(); auto arg_op_annotations = arg_op->get_op_annotations();
NGRAPH_CHECK(arg_op_annotations); NGRAPH_CHECK(arg_op_annotations);
if (!arg_op_annotations->is_cacheable()) if (!arg_op_annotations->is_cacheable())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment