Commit 3b578db4 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

change function to take reference rather than shared_ptr (#1238)

parent f78133d2
......@@ -267,7 +267,7 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
//Please see `add_output` in `BatchNorm::BatchNorm` for more details
if (this->get_training_flag() && get_input_size() == 3)
{
auto goes = op::get_output_elements(this->shared_from_this());
auto goes = op::get_output_elements(*this);
mean = goes.at(1);
var = goes.at(2);
}
......
......@@ -54,11 +54,11 @@ void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta, get_n());
}
NodeVector op::get_output_elements(const std::shared_ptr<Node>& mon)
NodeVector op::get_output_elements(const Node& mon)
{
NodeVector goes(mon->get_outputs().size());
NodeVector goes(mon.get_outputs().size());
for (auto goe_input : mon->get_output_inputs(0))
for (auto goe_input : mon.get_output_inputs(0))
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.at(goe->get_n()) = goe_input->get_node();
......
......@@ -22,7 +22,7 @@ namespace ngraph
{
namespace op
{
NodeVector get_output_elements(const std::shared_ptr<Node>& mon);
NodeVector get_output_elements(const Node& mon);
/// \brief Operation to get an output from a node.
class GetOutputElement : public Node
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment