Unverified Commit 4db318a3 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

get_output_elements (#1154)

* get_get_output_elements

* fix comp error

* address scott's feedback
parent f7a34a02
......@@ -47,6 +47,11 @@ namespace ngraph
{
}
NodeVector(size_t size)
: std::vector<std::shared_ptr<Node>>(size)
{
}
NodeVector& operator=(const NodeVector& other) = default;
NodeVector() {}
......
......@@ -265,15 +265,9 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
//and get_n() is used to sort the inputs in the same order as Batchnorm's outputs
//Next, Mean and Variance (`at(1)` and `at(2)`) are extracted
//Please see `add_output` in `BatchNorm::BatchNorm` for more details
std::vector<std::shared_ptr<Node>> goes(get_outputs().size());
if (this->get_training_flag() && get_input_size() == 3)
{
for (auto goe_input : get_output_inputs(0))
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.at(goe->get_n()) = goe_input->get_node();
}
auto goes = op::get_output_elements(this->shared_from_this());
mean = goes.at(1);
var = goes.at(2);
}
......
......@@ -53,3 +53,15 @@ void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta, get_n());
}
NodeVector op::get_output_elements(const std::shared_ptr<Node>& mon)
{
NodeVector goes(mon->get_outputs().size());
for (auto goe_input : mon->get_output_inputs(0))
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.at(goe->get_n()) = goe_input->get_node();
}
return goes;
}
......@@ -22,6 +22,8 @@ namespace ngraph
{
namespace op
{
NodeVector get_output_elements(const std::shared_ptr<Node>& mon);
/// \brief Operation to get an output from a node.
class GetOutputElement : public Node
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment