Commit 3b578db4 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

change function to take reference rather than shared_ptr (#1238)

parent f78133d2
...@@ -267,7 +267,7 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -267,7 +267,7 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
//Please see `add_output` in `BatchNorm::BatchNorm` for more details //Please see `add_output` in `BatchNorm::BatchNorm` for more details
if (this->get_training_flag() && get_input_size() == 3) if (this->get_training_flag() && get_input_size() == 3)
{ {
auto goes = op::get_output_elements(this->shared_from_this()); auto goes = op::get_output_elements(*this);
mean = goes.at(1); mean = goes.at(1);
var = goes.at(2); var = goes.at(2);
} }
......
...@@ -54,11 +54,11 @@ void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const ...@@ -54,11 +54,11 @@ void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta, get_n()); adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta, get_n());
} }
NodeVector op::get_output_elements(const std::shared_ptr<Node>& mon) NodeVector op::get_output_elements(const Node& mon)
{ {
NodeVector goes(mon->get_outputs().size()); NodeVector goes(mon.get_outputs().size());
for (auto goe_input : mon->get_output_inputs(0)) for (auto goe_input : mon.get_output_inputs(0))
{ {
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node()); auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.at(goe->get_n()) = goe_input->get_node(); goes.at(goe->get_n()) = goe_input->get_node();
......
...@@ -22,7 +22,7 @@ namespace ngraph ...@@ -22,7 +22,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector get_output_elements(const std::shared_ptr<Node>& mon); NodeVector get_output_elements(const Node& mon);
/// \brief Operation to get an output from a node. /// \brief Operation to get an output from a node.
class GetOutputElement : public Node class GetOutputElement : public Node
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment