Commit 835304fc authored by tsocha's avatar tsocha Committed by Michał Karzyński

[ONNX] Add optional outputs to batch_norm op (#2527)

parent d52d676b
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp" #include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/batch_norm.hpp" #include "ngraph/frontend/onnx_import/op/batch_norm.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
...@@ -46,16 +47,30 @@ namespace ngraph ...@@ -46,16 +47,30 @@ namespace ngraph
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)}; // float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported."; ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
// optional outputs
auto after_bn_mean = std::make_shared<NullNode>();
auto after_bn_var = std::make_shared<NullNode>();
auto saved_mean = std::make_shared<NullNode>();
auto saved_var = std::make_shared<NullNode>();
if (inputs.size() >= 5) if (inputs.size() >= 5)
{ {
mean = inputs.at(3); mean = inputs.at(3);
var = inputs.at(4); var = inputs.at(4);
return {std::make_shared<ngraph::op::BatchNormInference>( return {std::make_shared<ngraph::op::BatchNormInference>(
x, scale, bias, mean, var, epsilon)}; x, scale, bias, mean, var, epsilon),
after_bn_mean,
after_bn_var,
saved_mean,
saved_var};
} }
return { return {
std::make_shared<ngraph::op::BatchNormTraining>(x, scale, bias, epsilon)}; std::make_shared<ngraph::op::BatchNormTraining>(x, scale, bias, epsilon),
after_bn_mean,
after_bn_var,
saved_mean,
saved_var};
} }
} // namespace set_1 } // namespace set_1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment