Commit 1b71fdca authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] Enable Pad operator and batch normalization non-spatial mode (#2078)

* ONNX Pad operator

* Disable Batch Norm is spatial checking

* Remove unused variable
parent ee96e8d1
...@@ -102,6 +102,8 @@ add_library(onnx_import STATIC ...@@ -102,6 +102,8 @@ add_library(onnx_import STATIC
op/neg.hpp op/neg.hpp
op/not.hpp op/not.hpp
op/or.hpp op/or.hpp
op/pad.cpp
op/pad.hpp
op/pow.hpp op/pow.hpp
op/prelu.cpp op/prelu.cpp
op/prelu.hpp op/prelu.hpp
......
...@@ -40,13 +40,11 @@ namespace ngraph ...@@ -40,13 +40,11 @@ namespace ngraph
std::shared_ptr<ngraph::Node> var{nullptr}; std::shared_ptr<ngraph::Node> var{nullptr};
std::int64_t is_test{node.get_attribute_value<std::int64_t>("is_test", 1)}; std::int64_t is_test{node.get_attribute_value<std::int64_t>("is_test", 1)};
std::int64_t spatial{node.get_attribute_value<std::int64_t>("spatial", 1)};
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)}; double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
// TODO: Implement learning mode support // TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)}; // float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported."; ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
ASSERT_IS_SUPPORTED(node, spatial) << "only 'spatial' mode is supported.";
if (inputs.size() >= 5) if (inputs.size() >= 5)
{ {
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/frontend/onnx_import/op/pad.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector pad(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
const Shape& data_shape = data->get_shape();
double value = node.get_attribute_value<double>("value", 0);
auto paddings = convpool::get_pads(node, data_shape);
ngraph::CoordinateDiff padding_below = paddings.first;
ngraph::CoordinateDiff padding_above = paddings.second;
return {std::make_shared<ngraph::op::Pad>(
data,
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{value}),
Shape(padding_below.begin(), padding_below.end()),
Shape(padding_above.begin(), padding_above.end()),
Shape(data_shape.size(), 0))};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/frontend/onnx_import/core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector pad(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -66,6 +66,8 @@ ...@@ -66,6 +66,8 @@
#include "op/neg.hpp" #include "op/neg.hpp"
#include "op/not.hpp" #include "op/not.hpp"
#include "op/or.hpp" #include "op/or.hpp"
#include "op/pad.cpp"
#include "op/pad.hpp"
#include "op/pow.hpp" #include "op/pow.hpp"
#include "op/prelu.hpp" #include "op/prelu.hpp"
#include "op/reciprocal.hpp" #include "op/reciprocal.hpp"
...@@ -195,6 +197,7 @@ namespace ngraph ...@@ -195,6 +197,7 @@ namespace ngraph
REGISTER_OPERATOR("Neg", 1, neg); REGISTER_OPERATOR("Neg", 1, neg);
REGISTER_OPERATOR("Not", 1, logical_not); REGISTER_OPERATOR("Not", 1, logical_not);
REGISTER_OPERATOR("Or", 1, logical_or); REGISTER_OPERATOR("Or", 1, logical_or);
REGISTER_OPERATOR("Pad", 1, pad);
REGISTER_OPERATOR("Pow", 1, pow); REGISTER_OPERATOR("Pow", 1, pow);
REGISTER_OPERATOR("PRelu", 1, prelu); REGISTER_OPERATOR("PRelu", 1, prelu);
REGISTER_OPERATOR("Reciprocal", 1, reciprocal); REGISTER_OPERATOR("Reciprocal", 1, reciprocal);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment