Commit 835304fc authored by tsocha's avatar tsocha Committed by Michał Karzyński

[ONNX] Add optional outputs to batch_norm op (#2527)

parent d52d676b
......@@ -17,6 +17,7 @@
#include <cstdint>
#include <memory>
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/batch_norm.hpp"
#include "ngraph/node_vector.hpp"
......@@ -46,16 +47,30 @@ namespace ngraph
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
// optional outputs
auto after_bn_mean = std::make_shared<NullNode>();
auto after_bn_var = std::make_shared<NullNode>();
auto saved_mean = std::make_shared<NullNode>();
auto saved_var = std::make_shared<NullNode>();
if (inputs.size() >= 5)
{
mean = inputs.at(3);
var = inputs.at(4);
return {std::make_shared<ngraph::op::BatchNormInference>(
x, scale, bias, mean, var, epsilon)};
x, scale, bias, mean, var, epsilon),
after_bn_mean,
after_bn_var,
saved_mean,
saved_var};
}
return {
std::make_shared<ngraph::op::BatchNormTraining>(x, scale, bias, epsilon)};
std::make_shared<ngraph::op::BatchNormTraining>(x, scale, bias, epsilon),
after_bn_mean,
after_bn_var,
saved_mean,
saved_var};
}
} // namespace set_1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment