Unverified Commit 0ec3b01e authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into cyphers/patterndesc

parents a489a2ba 9476e0f4
......@@ -118,13 +118,14 @@ autodiff::Adjoints::Adjoints(const OutputVector& ys, const OutputVector& cs)
auto node = nodes_to_check.front();
nodes_to_check.pop_front();
// Look for nodes that will be available when this node is done
for (auto arg : node->get_arguments())
for (auto input : node->inputs())
{
auto count_it = parent_counts.find(arg);
auto input_source_node = input.get_source_output().get_node_shared_ptr();
auto count_it = parent_counts.find(input_source_node);
count_it->second--;
if (0 == count_it->second)
{
nodes_to_check.push_front(arg);
nodes_to_check.push_front(input_source_node);
}
}
OutputVector deltas = get(node);
......
......@@ -86,9 +86,9 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
if (instances_seen.insert(n).second)
{
f(n->shared_from_this());
for (auto& arg : n->get_arguments())
for (auto input : n->inputs())
{
stack.push(arg.get());
stack.push(input.get_source_output().get_node());
}
if (include_control_deps)
......
......@@ -111,9 +111,9 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
size_t pos = 0;
for (auto arg : get_arguments())
for (auto input : inputs())
{
auto arg_shape = arg->get_shape();
auto arg_shape = input.get_shape();
auto slice_width = arg_shape[m_concatenation_axis];
......@@ -123,7 +123,7 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
arg_delta_slice_upper[m_concatenation_axis] = next_pos;
adjoints.add_delta(
arg,
input.get_source_output(),
make_shared<op::Slice>(
delta, arg_delta_slice_lower, arg_delta_slice_upper, arg_delta_slice_strides));
......
......@@ -231,7 +231,7 @@ void op::MaxPoolBackprop::validate_and_infer_types()
shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
if (this->get_arguments().size() == 3)
if (this->get_input_size() == 3)
{
return make_shared<op::MaxPoolBackprop>(new_args.at(0),
new_args.at(1),
......@@ -259,7 +259,7 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto delta = deltas.at(0);
auto operand = get_argument(0);
auto operand = input(0).get_source_output();
auto backprop =
make_shared<op::MaxPoolBackprop>(operand,
delta,
......
......@@ -235,7 +235,8 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
{
auto out_shape = pad->get_shape();
vector<T> out_vec(shape_size(out_shape));
auto pad_value = std::static_pointer_cast<op::Constant>(pad->get_argument(1));
auto pad_value = std::static_pointer_cast<op::Constant>(
pad->input(1).get_source_output().get_node_shared_ptr());
if (func != nullptr)
{
......@@ -1375,9 +1376,10 @@ void pass::ConstantFolding::construct_constant_dequantize()
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto dequant_match = pattern_map[dequant];
auto dequantize_op = dynamic_pointer_cast<op::Dequantize>(dequant_match);
auto args = dequant_match->get_arguments();
auto scale = dynamic_pointer_cast<op::Constant>(args[1]);
auto offset = dynamic_pointer_cast<op::Constant>(args[2]);
auto scale = dynamic_pointer_cast<op::Constant>(
dequant_match->input(1).get_source_output().get_node_shared_ptr());
auto offset = dynamic_pointer_cast<op::Constant>(
dequant_match->input(2).get_source_output().get_node_shared_ptr());
auto type = constant_match->get_element_type();
......@@ -1452,8 +1454,10 @@ void pass::ConstantFolding::construct_constant_quantize()
auto quant_match = pattern_map[quant];
auto quantize_op = dynamic_pointer_cast<op::Quantize>(quant_match);
auto args = quant_match->get_arguments();
auto scale = static_pointer_cast<op::Constant>(args[1]);
auto offset = static_pointer_cast<op::Constant>(args[2]);
auto scale = static_pointer_cast<op::Constant>(
quant_match->input(1).get_source_output().get_node_shared_ptr());
auto offset = static_pointer_cast<op::Constant>(
quant_match->input(2).get_source_output().get_node_shared_ptr());
auto type = quant_match->get_element_type();
......
......@@ -89,7 +89,7 @@ static bool cse_reshape(shared_ptr<Node> a, shared_ptr<Node> b)
auto reshape_a = static_pointer_cast<ngraph::op::Reshape>(a);
auto reshape_b = static_pointer_cast<ngraph::op::Reshape>(b);
return (a->get_argument(0) == b->get_argument(0)) &&
return (a->input(0).get_source_output() == b->input(0).get_source_output()) &&
(reshape_a->get_input_order() == reshape_b->get_input_order()) &&
(reshape_a->get_output_shape() == reshape_b->get_output_shape());
}
......@@ -100,7 +100,7 @@ static bool cse_broadcast(shared_ptr<Node> a, shared_ptr<Node> b)
auto broadcast_a = static_pointer_cast<ngraph::op::Broadcast>(a);
auto broadcast_b = static_pointer_cast<ngraph::op::Broadcast>(b);
return (a->get_argument(0) == b->get_argument(0)) &&
return (a->input(0).get_source_output() == b->input(0).get_source_output()) &&
(broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) &&
(broadcast_a->get_broadcast_shape() == broadcast_b->get_broadcast_shape());
}
......@@ -108,15 +108,17 @@ static bool cse_unarywise(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_unarywise for " << a->get_name() << " and " << b->get_name();
return a->get_argument(0) == b->get_argument(0);
return a->input(0).get_source_output() == b->input(0).get_source_output();
}
static bool cse_binarywise(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_binary for " << a->get_name() << " and " << b->get_name();
return (a->get_argument(0) == b->get_argument(0) && a->get_argument(1) == b->get_argument(1)) ||
(a->get_argument(1) == b->get_argument(0) && a->get_argument(0) == b->get_argument(1));
return (a->input(0).get_source_output() == b->input(0).get_source_output() &&
a->input(1).get_source_output() == b->input(1).get_source_output()) ||
(a->input(1).get_source_output() == b->input(0).get_source_output() &&
a->input(0).get_source_output() == b->input(1).get_source_output());
}
static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b)
......@@ -126,7 +128,7 @@ static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b)
auto ar_a = static_pointer_cast<op::util::ArithmeticReduction>(a);
auto ar_b = static_pointer_cast<op::util::ArithmeticReduction>(b);
return ar_a->get_argument(0) == ar_b->get_argument(0) &&
return ar_a->input(0).get_source_output() == ar_b->input(0).get_source_output() &&
ar_a->get_reduction_axes() == ar_b->get_reduction_axes();
}
......@@ -137,7 +139,7 @@ static bool cse_one_hot(shared_ptr<Node> a, shared_ptr<Node> b)
auto one_hot_a = static_pointer_cast<ngraph::op::OneHot>(a);
auto one_hot_b = static_pointer_cast<ngraph::op::OneHot>(b);
return (a->get_argument(0) == b->get_argument(0)) &&
return (a->input(0).get_source_output() == b->input(0).get_source_output()) &&
(one_hot_a->get_one_hot_axis() == one_hot_b->get_one_hot_axis()) &&
(a->get_shape() == b->get_shape());
}
......@@ -247,7 +249,11 @@ namespace std
arg_ids.push_back(type_hash);
auto cargs = k.get_node()->get_arguments();
std::vector<Output<Node>> cargs;
for (auto input : k.get_node()->inputs())
{
cargs.push_back(input.get_source_output());
}
// TODO: Do we need another map, so we could
// specify how to compute hash for each op?
......@@ -258,7 +264,8 @@ namespace std
for (auto arg : cargs)
{
arg_ids.push_back(arg->get_instance_id());
arg_ids.push_back(arg.get_node_shared_ptr()->get_instance_id());
arg_ids.push_back(arg.get_index());
}
auto hashc = ngraph::hash_combine(arg_ids);
......
......@@ -49,12 +49,13 @@ bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
else
{
bool cacheable = true;
for (auto arg : node->get_arguments())
for (auto input : node->inputs())
{
NGRAPH_DEBUG << "propagate cacheability: arg is " << arg->get_name();
if (arg->is_op())
auto input_value_node = input.get_source_output().get_node_shared_ptr();
NGRAPH_DEBUG << "propagate cacheability: arg is " << *input_value_node;
if (input_value_node->is_op())
{
auto arg_op = static_pointer_cast<op::Op>(arg);
auto arg_op = static_pointer_cast<op::Op>(input_value_node);
auto arg_op_annotations = arg_op->get_op_annotations();
NGRAPH_CHECK(arg_op_annotations);
if (!arg_op_annotations->is_cacheable())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment