Commit 46c21a0d authored by Swapna's avatar Swapna Committed by Scott Cyphers

Added API to get input node@index for current node (#4186)

* Add simpler API to get input node

* style check

* Change shr_ to shared_
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 2f37d151
...@@ -132,7 +132,7 @@ void Function::map_unordered_ops(std::function<void(Node*)> f) const ...@@ -132,7 +132,7 @@ void Function::map_unordered_ops(std::function<void(Node*)> f) const
f(op); f(op);
for (size_t i = 0; i < op->get_input_size(); ++i) for (size_t i = 0; i < op->get_input_size(); ++i)
{ {
remaining_ops.push(op->input(i).get_source_output().get_node()); remaining_ops.push(op->get_input_node_ptr(i));
} }
for (auto& cdep : op->get_control_dependencies()) for (auto& cdep : op->get_control_dependencies())
{ {
......
...@@ -86,9 +86,9 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results, ...@@ -86,9 +86,9 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
if (instances_seen.insert(n).second) if (instances_seen.insert(n).second)
{ {
f(n->shared_from_this()); f(n->shared_from_this());
for (auto input : n->inputs()) for (size_t i = 0; i < n->inputs().size(); i++)
{ {
stack.push(input.get_source_output().get_node()); stack.push(n->get_input_node_ptr(i));
} }
if (include_control_deps) if (include_control_deps)
...@@ -706,9 +706,9 @@ static bool check_for_cycles_bkwd(std::shared_ptr<ngraph::Node> node, ...@@ -706,9 +706,9 @@ static bool check_for_cycles_bkwd(std::shared_ptr<ngraph::Node> node,
{ {
path.push_back(node); path.push_back(node);
path_set.insert(node); path_set.insert(node);
for (auto& input : node->inputs()) for (size_t i = 0; i < node->inputs().size(); i++)
{ {
auto arg = input.get_source_output().get_node_shared_ptr(); auto arg = node->get_input_node_shared_ptr(i);
if (path_set.find(arg) != path_set.end()) if (path_set.find(arg) != path_set.end())
{ {
for (auto it : path) for (auto it : path)
......
...@@ -272,7 +272,7 @@ namespace ngraph ...@@ -272,7 +272,7 @@ namespace ngraph
size_t arg_count = node->get_input_size(); size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i) for (size_t i = 0; i < arg_count; ++i)
{ {
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node(); Node* dep = node->get_input_node_ptr(arg_count - i - 1);
if (nodes_done.count(dep) == 0) if (nodes_done.count(dep) == 0)
{ {
can_add = false; can_add = false;
...@@ -332,7 +332,7 @@ namespace ngraph ...@@ -332,7 +332,7 @@ namespace ngraph
size_t arg_count = node->get_input_size(); size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i) for (size_t i = 0; i < arg_count; ++i)
{ {
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node(); Node* dep = node->get_input_node_ptr(arg_count - i - 1);
if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0) if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0)
{ {
can_add = false; can_add = false;
......
...@@ -491,6 +491,20 @@ std::shared_ptr<Node> Node::get_argument(size_t index) const ...@@ -491,6 +491,20 @@ std::shared_ptr<Node> Node::get_argument(size_t index) const
return m_inputs[index].get_output().get_node(); return m_inputs[index].get_output().get_node();
} }
Node* Node::get_input_node_ptr(size_t index) const
{
NGRAPH_CHECK(
index < m_inputs.size(), "index '", index, "' out of range in get_argument(size_t index)");
return m_inputs[index].get_output().get_node().get();
}
std::shared_ptr<Node> Node::get_input_node_shared_ptr(size_t index) const
{
NGRAPH_CHECK(
index < m_inputs.size(), "index '", index, "' out of range in get_argument(size_t index)");
return m_inputs[index].get_output().get_node();
}
NodeVector Node::get_arguments() const NodeVector Node::get_arguments() const
{ {
NodeVector result; NodeVector result;
......
...@@ -393,6 +393,9 @@ namespace ngraph ...@@ -393,6 +393,9 @@ namespace ngraph
// Will be deprecated // Will be deprecated
std::shared_ptr<Node> get_argument(size_t index) const; std::shared_ptr<Node> get_argument(size_t index) const;
Node* get_input_node_ptr(size_t index) const;
std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const;
protected: protected:
// Will be replaced with an OutputVector version // Will be replaced with an OutputVector version
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const = 0; virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const = 0;
......
...@@ -154,8 +154,8 @@ void op::BatchMatMulTranspose::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -154,8 +154,8 @@ void op::BatchMatMulTranspose::generate_adjoints(autodiff::Adjoints& adjoints,
{ {
auto delta = deltas.at(0); // NxIxK auto delta = deltas.at(0); // NxIxK
auto arg0 = input(0).get_source_output().get_node_shared_ptr(); // NxIxJ (maybe transposed) auto arg0 = get_input_node_shared_ptr(0); // NxIxJ (maybe transposed)
auto arg1 = input(1).get_source_output().get_node_shared_ptr(); // NxJxK (maybe transposed) auto arg1 = get_input_node_shared_ptr(1); // NxJxK (maybe transposed)
// If arg1 is already transposed, it does not need to be transposed again // If arg1 is already transposed, it does not need to be transposed again
auto delta_dot_arg1 = auto delta_dot_arg1 =
......
...@@ -67,10 +67,8 @@ void pass::ConstantFolding::construct_constant_quantize() ...@@ -67,10 +67,8 @@ void pass::ConstantFolding::construct_constant_quantize()
NGRAPH_CHECK(revalidate_and_ensure_static(quantize_op)); NGRAPH_CHECK(revalidate_and_ensure_static(quantize_op));
auto args = quant_match->get_arguments(); auto args = quant_match->get_arguments();
auto scale = static_pointer_cast<op::Constant>( auto scale = static_pointer_cast<op::Constant>(quant_match->get_input_node_shared_ptr(1));
quant_match->input(1).get_source_output().get_node_shared_ptr()); auto offset = static_pointer_cast<op::Constant>(quant_match->get_input_node_shared_ptr(2));
auto offset = static_pointer_cast<op::Constant>(
quant_match->input(2).get_source_output().get_node_shared_ptr());
auto type = quant_match->get_element_type(); auto type = quant_match->get_element_type();
......
...@@ -60,7 +60,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function) ...@@ -60,7 +60,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
{ {
auto output = &node->output(oi_pair.output).get_tensor(); auto output = &node->output(oi_pair.output).get_tensor();
auto input = &node->input(oi_pair.input).get_tensor(); auto input = &node->input(oi_pair.input).get_tensor();
auto input_node = node->input(oi_pair.input).get_source_output().get_node(); auto input_node = node->get_input_node_ptr(oi_pair.input);
// For destructive kernel, this should be the last use // For destructive kernel, this should be the last use
// Non-destructive kernels can pass through if memory sharing is disabled // Non-destructive kernels can pass through if memory sharing is disabled
......
...@@ -107,7 +107,7 @@ bool pass::ShapeRelevance::run_on_function(std::shared_ptr<Function> f) ...@@ -107,7 +107,7 @@ bool pass::ShapeRelevance::run_on_function(std::shared_ptr<Function> f)
{ {
continue; continue;
} }
auto source_node = node->input(i).get_source_output().get_node(); auto source_node = node->get_input_node_ptr(i);
if (already_visited.count(source_node) == 0) if (already_visited.count(source_node) == 0)
{ {
to_visit.push_front(source_node); to_visit.push_front(source_node);
......
...@@ -53,8 +53,7 @@ shared_ptr<Node> op::Dropout::copy_with_new_args(const NodeVector& new_args) con ...@@ -53,8 +53,7 @@ shared_ptr<Node> op::Dropout::copy_with_new_args(const NodeVector& new_args) con
bool op::Dropout::get_use_seed() const bool op::Dropout::get_use_seed() const
{ {
bool use_seed = false; bool use_seed = false;
if (auto const_op = if (auto const_op = as_type_ptr<op::Constant>(get_input_node_shared_ptr(2)))
as_type_ptr<op::Constant>(input(2).get_source_output().get_node_shared_ptr()))
{ {
auto use_seed_ptr = static_cast<const int32_t*>(const_op->get_data_ptr()); auto use_seed_ptr = static_cast<const int32_t*>(const_op->get_data_ptr());
use_seed = static_cast<const bool>(*use_seed_ptr); use_seed = static_cast<const bool>(*use_seed_ptr);
...@@ -65,8 +64,7 @@ bool op::Dropout::get_use_seed() const ...@@ -65,8 +64,7 @@ bool op::Dropout::get_use_seed() const
uint64_t op::Dropout::get_seed() const uint64_t op::Dropout::get_seed() const
{ {
uint64_t seed = 0; uint64_t seed = 0;
if (auto const_op = if (auto const_op = as_type_ptr<op::Constant>(get_input_node_shared_ptr(3)))
as_type_ptr<op::Constant>(input(3).get_source_output().get_node_shared_ptr()))
{ {
auto seed_ptr = static_cast<const uint64_t*>(const_op->get_data_ptr()); auto seed_ptr = static_cast<const uint64_t*>(const_op->get_data_ptr());
seed = *seed_ptr; seed = *seed_ptr;
......
...@@ -605,12 +605,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu() ...@@ -605,12 +605,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto m_bn = auto m_bn = std::static_pointer_cast<ngraph::op::BatchNormTraining>(
std::static_pointer_cast<ngraph::op::BatchNormTraining>(m.get_match_root() m.get_match_root()->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(0));
->get_argument(0)
->input(0)
.get_source_output()
.get_node_shared_ptr());
if (!mkldnn_utils::can_use_mkldnn_batchnorm_fprop(m_bn.get())) if (!mkldnn_utils::can_use_mkldnn_batchnorm_fprop(m_bn.get()))
{ {
...@@ -671,7 +667,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta ...@@ -671,7 +667,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto bn_match = m.get_match_root()->input(0).get_source_output().get_node_shared_ptr(); auto bn_match = m.get_match_root()->get_input_node_shared_ptr(0);
if (bn_match->get_users().size() > 1) if (bn_match->get_users().size() > 1)
{ {
NGRAPH_DEBUG << "Relu isn't the only user of BatchNorm's output"; NGRAPH_DEBUG << "Relu isn't the only user of BatchNorm's output";
......
...@@ -419,7 +419,7 @@ bool ngraph::runtime::cpu::pass::CPUConvertLayoutConstantFolding::run_on_functio ...@@ -419,7 +419,7 @@ bool ngraph::runtime::cpu::pass::CPUConvertLayoutConstantFolding::run_on_functio
continue; continue;
} }
auto arg = m_convertlayout->input(0).get_source_output().get_node_shared_ptr(); auto arg = m_convertlayout->get_input_node_shared_ptr(0);
if (is_type<ngraph::op::Constant>(arg)) if (is_type<ngraph::op::Constant>(arg))
{ {
auto m_input = static_pointer_cast<ngraph::op::Constant>(arg); auto m_input = static_pointer_cast<ngraph::op::Constant>(arg);
......
...@@ -245,7 +245,7 @@ static void replace_collapse_node_user(std::shared_ptr<Node> collapsed_node, ...@@ -245,7 +245,7 @@ static void replace_collapse_node_user(std::shared_ptr<Node> collapsed_node,
NGRAPH_DEBUG << "node_name: " << node->get_name(); NGRAPH_DEBUG << "node_name: " << node->get_name();
for (size_t i = 0; i < node->get_input_size(); i++) for (size_t i = 0; i < node->get_input_size(); i++)
{ {
if (node->input(i).get_source_output().get_node_shared_ptr() == collapsed_node) if (node->get_input_node_shared_ptr(i) == collapsed_node)
{ {
node->set_argument(i, new_output); node->set_argument(i, new_output);
} }
...@@ -399,8 +399,7 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop() ...@@ -399,8 +399,7 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop()
{ {
// swap the inputs if the cell_state and hidden state does not // swap the inputs if the cell_state and hidden state does not
// belong to the same Lstm // belong to the same Lstm
if (hidden_state->input(0).get_source_output().get_node() != if (hidden_state->get_input_node_ptr(0) != cell_state->get_input_node_ptr(0))
cell_state->input(0).get_source_output().get_node())
{ {
swap_lstm_inputs(); swap_lstm_inputs();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment