Unverified Commit 4db318a3 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

get_output_elements (#1154)

* get_get_output_elements

* fix comp error

* address scott's feedback
parent f7a34a02
...@@ -47,6 +47,11 @@ namespace ngraph ...@@ -47,6 +47,11 @@ namespace ngraph
{ {
} }
NodeVector(size_t size)
: std::vector<std::shared_ptr<Node>>(size)
{
}
NodeVector& operator=(const NodeVector& other) = default; NodeVector& operator=(const NodeVector& other) = default;
NodeVector() {} NodeVector() {}
......
...@@ -265,15 +265,9 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -265,15 +265,9 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
//and get_n() is used to sort the inputs in the same order as Batchnorm's outputs //and get_n() is used to sort the inputs in the same order as Batchnorm's outputs
//Next, Mean and Variance (`at(1)` and `at(2)`) are extracted //Next, Mean and Variance (`at(1)` and `at(2)`) are extracted
//Please see `add_output` in `BatchNorm::BatchNorm` for more details //Please see `add_output` in `BatchNorm::BatchNorm` for more details
std::vector<std::shared_ptr<Node>> goes(get_outputs().size());
if (this->get_training_flag() && get_input_size() == 3) if (this->get_training_flag() && get_input_size() == 3)
{ {
for (auto goe_input : get_output_inputs(0)) auto goes = op::get_output_elements(this->shared_from_this());
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.at(goe->get_n()) = goe_input->get_node();
}
mean = goes.at(1); mean = goes.at(1);
var = goes.at(2); var = goes.at(2);
} }
......
...@@ -53,3 +53,15 @@ void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const ...@@ -53,3 +53,15 @@ void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta, get_n()); adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta, get_n());
} }
NodeVector op::get_output_elements(const std::shared_ptr<Node>& mon)
{
NodeVector goes(mon->get_outputs().size());
for (auto goe_input : mon->get_output_inputs(0))
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.at(goe->get_n()) = goe_input->get_node();
}
return goes;
}
...@@ -22,6 +22,8 @@ namespace ngraph ...@@ -22,6 +22,8 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector get_output_elements(const std::shared_ptr<Node>& mon);
/// \brief Operation to get an output from a node. /// \brief Operation to get an output from a node.
class GetOutputElement : public Node class GetOutputElement : public Node
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment