Commit 46c21a0d authored by Swapna's avatar Swapna Committed by Scott Cyphers

Added API to get input node@index for current node (#4186)

* Add simpler API to get input node

* style check

* Change shr_ to shared_
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 2f37d151
......@@ -132,7 +132,7 @@ void Function::map_unordered_ops(std::function<void(Node*)> f) const
f(op);
for (size_t i = 0; i < op->get_input_size(); ++i)
{
remaining_ops.push(op->input(i).get_source_output().get_node());
remaining_ops.push(op->get_input_node_ptr(i));
}
for (auto& cdep : op->get_control_dependencies())
{
......
......@@ -86,9 +86,9 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
if (instances_seen.insert(n).second)
{
f(n->shared_from_this());
for (auto input : n->inputs())
for (size_t i = 0; i < n->inputs().size(); i++)
{
stack.push(input.get_source_output().get_node());
stack.push(n->get_input_node_ptr(i));
}
if (include_control_deps)
......@@ -706,9 +706,9 @@ static bool check_for_cycles_bkwd(std::shared_ptr<ngraph::Node> node,
{
path.push_back(node);
path_set.insert(node);
for (auto& input : node->inputs())
for (size_t i = 0; i < node->inputs().size(); i++)
{
auto arg = input.get_source_output().get_node_shared_ptr();
auto arg = node->get_input_node_shared_ptr(i);
if (path_set.find(arg) != path_set.end())
{
for (auto it : path)
......
......@@ -272,7 +272,7 @@ namespace ngraph
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
Node* dep = node->get_input_node_ptr(arg_count - i - 1);
if (nodes_done.count(dep) == 0)
{
can_add = false;
......@@ -332,7 +332,7 @@ namespace ngraph
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
Node* dep = node->get_input_node_ptr(arg_count - i - 1);
if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0)
{
can_add = false;
......
......@@ -491,6 +491,20 @@ std::shared_ptr<Node> Node::get_argument(size_t index) const
return m_inputs[index].get_output().get_node();
}
Node* Node::get_input_node_ptr(size_t index) const
{
NGRAPH_CHECK(
index < m_inputs.size(), "index '", index, "' out of range in get_argument(size_t index)");
return m_inputs[index].get_output().get_node().get();
}
std::shared_ptr<Node> Node::get_input_node_shared_ptr(size_t index) const
{
NGRAPH_CHECK(
index < m_inputs.size(), "index '", index, "' out of range in get_argument(size_t index)");
return m_inputs[index].get_output().get_node();
}
NodeVector Node::get_arguments() const
{
NodeVector result;
......
......@@ -393,6 +393,9 @@ namespace ngraph
// Will be deprecated
std::shared_ptr<Node> get_argument(size_t index) const;
Node* get_input_node_ptr(size_t index) const;
std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const;
protected:
// Will be replaced with an OutputVector version
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const = 0;
......
......@@ -154,8 +154,8 @@ void op::BatchMatMulTranspose::generate_adjoints(autodiff::Adjoints& adjoints,
{
auto delta = deltas.at(0); // NxIxK
auto arg0 = input(0).get_source_output().get_node_shared_ptr(); // NxIxJ (maybe transposed)
auto arg1 = input(1).get_source_output().get_node_shared_ptr(); // NxJxK (maybe transposed)
auto arg0 = get_input_node_shared_ptr(0); // NxIxJ (maybe transposed)
auto arg1 = get_input_node_shared_ptr(1); // NxJxK (maybe transposed)
// If arg1 is already transposed, it does not need to be transposed again
auto delta_dot_arg1 =
......
......@@ -67,10 +67,8 @@ void pass::ConstantFolding::construct_constant_quantize()
NGRAPH_CHECK(revalidate_and_ensure_static(quantize_op));
auto args = quant_match->get_arguments();
auto scale = static_pointer_cast<op::Constant>(
quant_match->input(1).get_source_output().get_node_shared_ptr());
auto offset = static_pointer_cast<op::Constant>(
quant_match->input(2).get_source_output().get_node_shared_ptr());
auto scale = static_pointer_cast<op::Constant>(quant_match->get_input_node_shared_ptr(1));
auto offset = static_pointer_cast<op::Constant>(quant_match->get_input_node_shared_ptr(2));
auto type = quant_match->get_element_type();
......
......@@ -60,7 +60,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
{
auto output = &node->output(oi_pair.output).get_tensor();
auto input = &node->input(oi_pair.input).get_tensor();
auto input_node = node->input(oi_pair.input).get_source_output().get_node();
auto input_node = node->get_input_node_ptr(oi_pair.input);
// For destructive kernel, this should be the last use
// Non-destructive kernels can pass through if memory sharing is disabled
......
......@@ -107,7 +107,7 @@ bool pass::ShapeRelevance::run_on_function(std::shared_ptr<Function> f)
{
continue;
}
auto source_node = node->input(i).get_source_output().get_node();
auto source_node = node->get_input_node_ptr(i);
if (already_visited.count(source_node) == 0)
{
to_visit.push_front(source_node);
......
......@@ -53,8 +53,7 @@ shared_ptr<Node> op::Dropout::copy_with_new_args(const NodeVector& new_args) con
bool op::Dropout::get_use_seed() const
{
bool use_seed = false;
if (auto const_op =
as_type_ptr<op::Constant>(input(2).get_source_output().get_node_shared_ptr()))
if (auto const_op = as_type_ptr<op::Constant>(get_input_node_shared_ptr(2)))
{
auto use_seed_ptr = static_cast<const int32_t*>(const_op->get_data_ptr());
use_seed = static_cast<const bool>(*use_seed_ptr);
......@@ -65,8 +64,7 @@ bool op::Dropout::get_use_seed() const
uint64_t op::Dropout::get_seed() const
{
uint64_t seed = 0;
if (auto const_op =
as_type_ptr<op::Constant>(input(3).get_source_output().get_node_shared_ptr()))
if (auto const_op = as_type_ptr<op::Constant>(get_input_node_shared_ptr(3)))
{
auto seed_ptr = static_cast<const uint64_t*>(const_op->get_data_ptr());
seed = *seed_ptr;
......
......@@ -605,12 +605,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto m_bn =
std::static_pointer_cast<ngraph::op::BatchNormTraining>(m.get_match_root()
->get_argument(0)
->input(0)
.get_source_output()
.get_node_shared_ptr());
auto m_bn = std::static_pointer_cast<ngraph::op::BatchNormTraining>(
m.get_match_root()->get_input_node_shared_ptr(0)->get_input_node_shared_ptr(0));
if (!mkldnn_utils::can_use_mkldnn_batchnorm_fprop(m_bn.get()))
{
......@@ -671,7 +667,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
auto pattern_map = m.get_pattern_map();
auto bn_match = m.get_match_root()->input(0).get_source_output().get_node_shared_ptr();
auto bn_match = m.get_match_root()->get_input_node_shared_ptr(0);
if (bn_match->get_users().size() > 1)
{
NGRAPH_DEBUG << "Relu isn't the only user of BatchNorm's output";
......
......@@ -419,7 +419,7 @@ bool ngraph::runtime::cpu::pass::CPUConvertLayoutConstantFolding::run_on_functio
continue;
}
auto arg = m_convertlayout->input(0).get_source_output().get_node_shared_ptr();
auto arg = m_convertlayout->get_input_node_shared_ptr(0);
if (is_type<ngraph::op::Constant>(arg))
{
auto m_input = static_pointer_cast<ngraph::op::Constant>(arg);
......
......@@ -245,7 +245,7 @@ static void replace_collapse_node_user(std::shared_ptr<Node> collapsed_node,
NGRAPH_DEBUG << "node_name: " << node->get_name();
for (size_t i = 0; i < node->get_input_size(); i++)
{
if (node->input(i).get_source_output().get_node_shared_ptr() == collapsed_node)
if (node->get_input_node_shared_ptr(i) == collapsed_node)
{
node->set_argument(i, new_output);
}
......@@ -399,8 +399,7 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop()
{
// swap the inputs if the cell_state and hidden state does not
// belong to the same Lstm
if (hidden_state->input(0).get_source_output().get_node() !=
cell_state->input(0).get_source_output().get_node())
if (hidden_state->get_input_node_ptr(0) != cell_state->get_input_node_ptr(0))
{
swap_lstm_inputs();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment