Commit 5a32dfe4 authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] Global Pool ops (#1823)

* [ONNX] Global Pool ops

* Remove is_global flag

* Style fix

* Review fix pt. 1

* Change comments style
parent aeb92efd
...@@ -59,6 +59,10 @@ add_library(onnx_import STATIC ...@@ -59,6 +59,10 @@ add_library(onnx_import STATIC
op/floor.hpp op/floor.hpp
op/gemm.cpp op/gemm.cpp
op/gemm.hpp op/gemm.hpp
op/global_average_pool.cpp
op/global_average_pool.hpp
op/global_max_pool.cpp
op/global_max_pool.hpp
op/greater.hpp op/greater.hpp
op/hard_sigmoid.cpp op/hard_sigmoid.cpp
op/hard_sigmoid.hpp op/hard_sigmoid.hpp
......
...@@ -180,7 +180,7 @@ namespace ngraph ...@@ -180,7 +180,7 @@ namespace ngraph
} }
template <> template <>
inline const std::string& get_value(const onnx::AttributeProto& attribute) inline std::string get_value(const onnx::AttributeProto& attribute)
{ {
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_STRING)) if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_STRING))
{ {
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "utils/convpool.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector global_average_pool(const Node& node)
{
return convpool::make_ng_pool<ngraph::op::AvgPool>(node);
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Convert ONNX GlobalAveragePool operation to an nGraph node.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing Ngraph nodes producing output of ONNX GlobalAveragePool
/// operation.
NodeVector global_average_pool(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/max_pool.hpp"
#include "utils/convpool.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector global_max_pool(const Node& node)
{
return convpool::make_ng_pool<ngraph::op::MaxPool>(node);
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Convert ONNX GlobalMaxPool operation to an nGraph node.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing Ngraph nodes producing output of ONNX GlobalMaxPool
/// operation.
NodeVector global_max_pool(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -37,6 +37,8 @@ ...@@ -37,6 +37,8 @@
#include "op/flatten.hpp" #include "op/flatten.hpp"
#include "op/floor.hpp" #include "op/floor.hpp"
#include "op/gemm.hpp" #include "op/gemm.hpp"
#include "op/global_average_pool.hpp"
#include "op/global_max_pool.hpp"
#include "op/greater.hpp" #include "op/greater.hpp"
#include "op/hard_sigmoid.hpp" #include "op/hard_sigmoid.hpp"
#include "op/identity.hpp" #include "op/identity.hpp"
...@@ -148,6 +150,8 @@ namespace ngraph ...@@ -148,6 +150,8 @@ namespace ngraph
REGISTER_OPERATOR("Flatten", 1, flatten); REGISTER_OPERATOR("Flatten", 1, flatten);
REGISTER_OPERATOR("Floor", 1, floor); REGISTER_OPERATOR("Floor", 1, floor);
REGISTER_OPERATOR("Gemm", 1, gemm); REGISTER_OPERATOR("Gemm", 1, gemm);
REGISTER_OPERATOR("GlobalAveragePool", 1, global_average_pool);
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
REGISTER_OPERATOR("Greater", 1, greater); REGISTER_OPERATOR("Greater", 1, greater);
REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid); REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid);
REGISTER_OPERATOR("Identity", 1, identity); REGISTER_OPERATOR("Identity", 1, identity);
......
...@@ -31,7 +31,9 @@ namespace ngraph ...@@ -31,7 +31,9 @@ namespace ngraph
{ {
Shape get_kernel_shape(const Node& node) Shape get_kernel_shape(const Node& node)
{ {
return node.get_attribute_value<std::vector<std::size_t>>("kernel_shape", {1, 1}); std::size_t input_spacial_dims = node.get_ng_inputs()[0]->get_shape().size() - 2;
return node.get_attribute_value<std::vector<std::size_t>>(
"kernel_shape", std::vector<std::size_t>(input_spacial_dims, 1UL));
} }
namespace detail namespace detail
...@@ -121,7 +123,7 @@ namespace ngraph ...@@ -121,7 +123,7 @@ namespace ngraph
pads = CoordinateDiff(static_cast<std::ptrdiff_t>(kernel_shape.size()), 0UL); pads = CoordinateDiff(static_cast<std::ptrdiff_t>(kernel_shape.size()), 0UL);
} }
if (pads.size() <= 3) if (pads.size() != kernel_shape.size() * 2)
{ {
// Paddings specified in (H, W, C) format. // Paddings specified in (H, W, C) format.
return {pads, pads}; return {pads, pads};
......
...@@ -16,7 +16,10 @@ ...@@ -16,7 +16,10 @@
#pragma once #pragma once
#include <string>
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "core/attribute.hpp" #include "core/attribute.hpp"
...@@ -84,13 +87,11 @@ namespace ngraph ...@@ -84,13 +87,11 @@ namespace ngraph
return get_pads(node, get_kernel_shape(node)); return get_pads(node, get_kernel_shape(node));
} }
/** /// \brief Create an nGraph pooling operation based on an ONNX pooling op.
* @brief Create an nGraph pooling operation based on an ONNX pooling op. ///
* /// \param T Class of an nGraph pooling operation (e.g. AveragePool, MaxPool)
* @tparam T Class of an nGraph pooling operation (e.g. AveragePool, MaxPool) /// \param node incoming ONNX opearation
* @param node incoming ONNX opearation /// \return nGraph node equivalent of the ONNX operation
* @return nGraph node equivalent of the ONNX operation
*/
template <class T> template <class T>
inline NodeVector make_ng_pool(const Node& node) inline NodeVector make_ng_pool(const Node& node)
{ {
...@@ -98,19 +99,44 @@ namespace ngraph ...@@ -98,19 +99,44 @@ namespace ngraph
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
// Parse ONNX op attributes // Parse ONNX op attributes
Shape kernel_shape = convpool::get_kernel_shape(node); Shape kernel_shape;
if (node.op_type().find("Global") != std::string::npos)
{
kernel_shape = node.get_ng_inputs()[0]->get_shape();
// Remove N and C dimensions and leave only spatial dims.
kernel_shape.erase(std::begin(kernel_shape),
std::next(std::begin(kernel_shape), 2));
}
else
{
kernel_shape = convpool::get_kernel_shape(node);
}
auto strides = convpool::get_strides(node); auto strides = convpool::get_strides(node);
auto dilations = convpool::get_dilations(node); auto dilations = convpool::get_dilations(node);
auto paddings = convpool::get_pads(node); auto paddings = convpool::get_pads(node);
bool count_include_pad = node.get_attribute_value<int64_t>("count_include_pad", 0);
// Convert padding from CoordinateDiff to Shape objects // Convert padding from CoordinateDiff to Shape objects
const CoordinateDiff& padding_above{paddings.first}; const CoordinateDiff& padding_above{paddings.first};
const CoordinateDiff& padding_below{paddings.second}; const CoordinateDiff& padding_below{paddings.second};
Shape padding_below_shape{std::begin(padding_below), std::end(padding_below)}; Shape padding_below_shape{std::begin(padding_below), std::end(padding_below)};
Shape padding_above_shape{std::begin(padding_above), std::end(padding_above)}; Shape padding_above_shape{std::begin(padding_above), std::end(padding_above)};
return {std::make_shared<T>( if (count_include_pad)
data, kernel_shape, strides, padding_below_shape, padding_above_shape)}; {
return {std::make_shared<ngraph::op::AvgPool>(data,
kernel_shape,
strides,
padding_below_shape,
padding_above_shape,
count_include_pad)};
}
else
{
return {std::make_shared<T>(
data, kernel_shape, strides, padding_below_shape, padding_above_shape)};
}
} }
} // namespace convpool } // namespace convpool
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment