Unverified Commit 1f662004 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

[ONNX] Average and Max Pooling (#1489)

parent 14f5fd6f
...@@ -32,6 +32,8 @@ add_library(onnx_import STATIC ...@@ -32,6 +32,8 @@ add_library(onnx_import STATIC
core/value_info.hpp core/value_info.hpp
exceptions.hpp exceptions.hpp
op/add.hpp op/add.hpp
op/average_pool.cpp
op/average_pool.hpp
op/batch_norm.cpp op/batch_norm.cpp
op/batch_norm.hpp op/batch_norm.hpp
op/constant.cpp op/constant.cpp
...@@ -40,6 +42,8 @@ add_library(onnx_import STATIC ...@@ -40,6 +42,8 @@ add_library(onnx_import STATIC
op/gemm.cpp op/gemm.cpp
op/gemm.hpp op/gemm.hpp
op/matmul.hpp op/matmul.hpp
op/max_pool.cpp
op/max_pool.hpp
op/mul.hpp op/mul.hpp
op/relu.hpp op/relu.hpp
op/split.cpp op/split.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "utils/convpool.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector average_pool(const Node& node)
{
return convpool::make_ng_pool<ngraph::op::AvgPool>(node);
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
/**
* @brief Convert ONNX AveragePool operation to an nGraph node.
*
* @param node The ONNX node object representing this operation.
*
* @return The vector containing Ngraph nodes producing output of ONNX AveragePool
* operation.
*/
NodeVector average_pool(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -118,9 +118,9 @@ namespace ngraph ...@@ -118,9 +118,9 @@ namespace ngraph
std::to_string(groups)}; std::to_string(groups)};
} }
auto strides = attribute::get_strides(node); auto strides = convpool::get_strides(node);
auto dilations = attribute::get_dilations(node); auto dilations = convpool::get_dilations(node);
auto paddings = attribute::get_pads(node); auto paddings = convpool::get_pads(node);
const auto& padding_below = paddings.first; const auto& padding_below = paddings.first;
const auto& padding_above = paddings.second; const auto& padding_above = paddings.second;
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/max_pool.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "utils/convpool.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector max_pool(const Node& node)
{
return convpool::make_ng_pool<ngraph::op::MaxPool>(node);
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
/**
* @brief Convert ONNX MaxPool operation to an nGraph node.
*
* @param node The ONNX node object representing this operation.
*
* @return The vector containing Ngraph nodes producing output of ONNX MaxPool
* operation.
*/
NodeVector max_pool(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -19,11 +19,13 @@ ...@@ -19,11 +19,13 @@
#include "core/attribute.hpp" #include "core/attribute.hpp"
#include "op/add.hpp" #include "op/add.hpp"
#include "op/average_pool.hpp"
#include "op/batch_norm.hpp" #include "op/batch_norm.hpp"
#include "op/constant.hpp" #include "op/constant.hpp"
#include "op/conv.hpp" #include "op/conv.hpp"
#include "op/gemm.hpp" #include "op/gemm.hpp"
#include "op/matmul.hpp" #include "op/matmul.hpp"
#include "op/max_pool.hpp"
#include "op/mul.hpp" #include "op/mul.hpp"
#include "op/relu.hpp" #include "op/relu.hpp"
#include "op/split.hpp" #include "op/split.hpp"
...@@ -72,12 +74,15 @@ namespace ngraph ...@@ -72,12 +74,15 @@ namespace ngraph
ops_bridge() ops_bridge()
{ {
m_map.emplace("Add", std::bind(op::add, std::placeholders::_1)); m_map.emplace("Add", std::bind(op::add, std::placeholders::_1));
m_map.emplace("AveragePool",
std::bind(op::average_pool, std::placeholders::_1));
m_map.emplace("BatchNormalization", m_map.emplace("BatchNormalization",
std::bind(op::batch_norm, std::placeholders::_1)); std::bind(op::batch_norm, std::placeholders::_1));
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1)); m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1)); m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1)); m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1));
m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1)); m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1));
m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1));
m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1)); m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1));
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1)); m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1)); m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "convpool.hpp"
#include <cmath> #include <cmath>
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "convpool.hpp"
#include "core/attribute.hpp" #include "core/attribute.hpp"
#include "core/node.hpp" #include "core/node.hpp"
...@@ -27,7 +27,7 @@ namespace ngraph ...@@ -27,7 +27,7 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace attribute namespace convpool
{ {
Shape get_kernel_shape(const Node& node) Shape get_kernel_shape(const Node& node)
{ {
...@@ -118,7 +118,7 @@ namespace ngraph ...@@ -118,7 +118,7 @@ namespace ngraph
} }
if (pads.empty()) if (pads.empty())
{ {
pads = {static_cast<std::ptrdiff_t>(kernel_shape.size()), 0UL}; pads = CoordinateDiff(static_cast<std::ptrdiff_t>(kernel_shape.size()), 0UL);
} }
if (pads.size() <= 3) if (pads.size() <= 3)
...@@ -133,6 +133,6 @@ namespace ngraph ...@@ -133,6 +133,6 @@ namespace ngraph
} }
} }
} // namespace attribute } // namespace convpool
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace attribute namespace convpool
{ {
/** /**
* @brief Get shape of kernel (filter) in pixels. * @brief Get shape of kernel (filter) in pixels.
...@@ -94,7 +94,37 @@ namespace ngraph ...@@ -94,7 +94,37 @@ namespace ngraph
{ {
return get_pads(node, get_kernel_shape(node)); return get_pads(node, get_kernel_shape(node));
} }
} // namespace attribute
/**
* @brief Create an nGraph pooling operation based on an ONNX pooling op.
*
* @tparam T Class of an nGraph pooling operation (e.g. AveragePool, MaxPool)
* @param node incoming ONNX opearation
* @return nGraph node equivalent of the ONNX operation
*/
template <class T>
inline NodeVector make_ng_pool(const Node& node)
{
// Fetch input node for the pooling operation
auto data = node.get_ng_inputs().at(0);
// Parse ONNX op attributes
Shape kernel_shape = convpool::get_kernel_shape(node);
auto strides = convpool::get_strides(node);
auto dilations = convpool::get_dilations(node);
auto paddings = convpool::get_pads(node);
// Convert padding from CoordinateDiff to Shape objects
const CoordinateDiff& padding_below{paddings.first};
const CoordinateDiff& padding_above{paddings.second};
Shape padding_below_shape{std::begin(padding_below), std::end(padding_below)};
Shape padding_above_shape{std::begin(padding_above), std::end(padding_above)};
return {std::make_shared<T>(
data, kernel_shape, strides, padding_below_shape, padding_above_shape)};
}
} // namespace convpool
} // namespace onnx_import } // namespace onnx_import
......
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
} }
else else
{ {
for (int i = 0; i < axes_order.size(); ++i) for (auto i = 0; i < axes_order.size(); ++i)
{ {
out_shape[i] = node->get_shape().at(axes_order.at(i)); out_shape[i] = node->get_shape().at(axes_order.at(i));
} }
......
...@@ -48,7 +48,7 @@ namespace ngraph ...@@ -48,7 +48,7 @@ namespace ngraph
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation); bool include_padding_in_avg_computation = false);
/// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0). /// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0).
/// ///
......
...@@ -56,17 +56,17 @@ TEST(onnx, model_add_abc_initializers) ...@@ -56,17 +56,17 @@ TEST(onnx, model_add_abc_initializers)
TEST(onnx, model_addmul_abc) TEST(onnx, model_addmul_abc)
{ {
auto function{ngraph::onnx_import::import_onnx_function( auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/addmul_abc.onnx"))}; ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/addmul_abc.onnx"));
std::vector<std::vector<float>> inputs; std::vector<std::vector<float>> inputs;
ngraph::Shape shape{1, 2, 2}; ngraph::Shape shape{1, 2, 2};
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{9, 10}}, {{11, 12}}}).get_vector()); inputs.emplace_back(test::NDArray<float, 3>({{{9, 10}}, {{11, 12}}}).get_vector());
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{5, 6}}, {{7, 8}}}).get_vector()); inputs.emplace_back(test::NDArray<float, 3>({{{5, 6}}, {{7, 8}}}).get_vector());
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{1, 2}}, {{3, 4}}}).get_vector()); inputs.emplace_back(test::NDArray<float, 3>({{{1, 2}}, {{3, 4}}}).get_vector());
auto expected_output = ngraph::test::NDArray<float, 3>({{{46, 62}}, {{80, 100}}}).get_vector(); auto expected_output = test::NDArray<float, 3>({{{46, 62}}, {{80, 100}}}).get_vector();
auto result_vectors = execute(function, inputs, "INTERPRETER"); auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front())); EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
...@@ -130,7 +130,7 @@ namespace ...@@ -130,7 +130,7 @@ namespace
std::vector<std::vector<float>> args; std::vector<std::vector<float>> args;
// data (1, 1, 7, 5) input tensor // data (1, 1, 7, 5) input tensor
args.emplace_back(ngraph::test::NDArray<float, 4>{{{{{0.f, 1.f, 2.f, 3.f, 4.f}, args.emplace_back(test::NDArray<float, 4>{{{{{0.f, 1.f, 2.f, 3.f, 4.f},
{5.f, 6.f, 7.f, 8.f, 9.f}, {5.f, 6.f, 7.f, 8.f, 9.f},
{10.f, 11.f, 12.f, 13.f, 14.f}, {10.f, 11.f, 12.f, 13.f, 14.f},
{15.f, 16.f, 17.f, 18.f, 19.f}, {15.f, 16.f, 17.f, 18.f, 19.f},
...@@ -141,7 +141,7 @@ namespace ...@@ -141,7 +141,7 @@ namespace
// filters (1, 1, 3, 3) aka convolution weights // filters (1, 1, 3, 3) aka convolution weights
args.emplace_back( args.emplace_back(
ngraph::test::NDArray<float, 4>{{{{{1.f, 1.f, 1.f}, {1.f, 1.f, 1.f}, {1.f, 1.f, 1.f}}}}} test::NDArray<float, 4>{{{{{1.f, 1.f, 1.f}, {1.f, 1.f, 1.f}, {1.f, 1.f, 1.f}}}}}
.get_vector()); .get_vector());
return execute(function, args, "INTERPRETER"); return execute(function, args, "INTERPRETER");
...@@ -155,7 +155,7 @@ TEST(onnx, model_conv2d_strides_padding) ...@@ -155,7 +155,7 @@ TEST(onnx, model_conv2d_strides_padding)
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_padding.onnx")); ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_padding.onnx"));
// (1, 1, 4, 3) // (1, 1, 4, 3)
auto expected_output = ngraph::test::NDArray<float, 4>({{{{12.f, 27.f, 24.f}, auto expected_output = test::NDArray<float, 4>({{{{12.f, 27.f, 24.f},
{63.f, 108.f, 81.f}, {63.f, 108.f, 81.f},
{123.f, 198.f, 141.f}, {123.f, 198.f, 141.f},
{112.f, 177.f, 124.f}}}}) {112.f, 177.f, 124.f}}}})
...@@ -173,8 +173,7 @@ TEST(onnx, model_conv2d_strides_no_padding) ...@@ -173,8 +173,7 @@ TEST(onnx, model_conv2d_strides_no_padding)
// (1, 1, 3, 2) // (1, 1, 3, 2)
auto expected_output = auto expected_output =
ngraph::test::NDArray<float, 4>({{{{54.f, 72.f}, {144.f, 162.f}, {234.f, 252.f}}}}) test::NDArray<float, 4>({{{{54.f, 72.f}, {144.f, 162.f}, {234.f, 252.f}}}}).get_vector();
.get_vector();
auto result = conv2d_execute(function); auto result = conv2d_execute(function);
EXPECT_EQ(expected_output, result.front()); EXPECT_EQ(expected_output, result.front());
...@@ -187,14 +186,84 @@ TEST(onnx, model_conv2d_strides_assymetric_padding) ...@@ -187,14 +186,84 @@ TEST(onnx, model_conv2d_strides_assymetric_padding)
SERIALIZED_ZOO, "onnx/conv_with_strides_and_asymmetric_padding.onnx")); SERIALIZED_ZOO, "onnx/conv_with_strides_and_asymmetric_padding.onnx"));
// (1, 1, 4, 2) // (1, 1, 4, 2)
auto expected_output = ngraph::test::NDArray<float, 4>( auto expected_output =
{{{{21.f, 33.f}, {99.f, 117.f}, {189.f, 207.f}, {171.f, 183.f}}}}) test::NDArray<float, 4>({{{{21.f, 33.f}, {99.f, 117.f}, {189.f, 207.f}, {171.f, 183.f}}}})
.get_vector(); .get_vector();
auto result = conv2d_execute(function); auto result = conv2d_execute(function);
EXPECT_EQ(expected_output, result.front()); EXPECT_EQ(expected_output, result.front());
} }
TEST(onnx, model_average_pool_2d)
{
// Pooling with strides=2 and no padding
auto model = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d.onnx"));
// input data shape (1, 1, 4, 4)
Inputs inputs;
inputs.push_back(test::NDArray<float, 4>({{{{0.f, 1.f, 2.f, 3.f},
{4.f, 5.f, 6.f, 7.f},
{8.f, 9.f, 10.f, 11.f},
{12.f, 13.f, 14.f, 15.f}}}})
.get_vector());
// (1, 1, 2, 2)
auto expected_output = test::NDArray<float, 4>({{{{2.5f, 4.5f}, {10.5f, 12.5f}}}}).get_vector();
Outputs outputs{execute(model, inputs, "INTERPRETER")};
EXPECT_EQ(expected_output, outputs.front());
}
TEST(onnx, model_average_pool_2d_pads)
{
// Pooling with strides=2 and padding=1
auto model = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d_pads.onnx"));
// input data shape (1, 1, 4, 4)
Inputs inputs;
inputs.push_back(test::NDArray<float, 4>({{{{0.f, 1.f, 2.f, 3.f},
{4.f, 5.f, 6.f, 7.f},
{8.f, 9.f, 10.f, 11.f},
{12.f, 13.f, 14.f, 15.f}}}})
.get_vector());
// (1, 1, 3, 3)
auto expected_output =
test::NDArray<float, 4>({{{{0.f, 1.5f, 3.f}, {6.f, 7.5f, 9.f}, {12.f, 13.5f, 15.f}}}})
.get_vector();
Outputs outputs = execute(model, inputs, "INTERPRETER");
EXPECT_EQ(expected_output, outputs.front());
}
TEST(onnx, model_max_pool_2d_pads)
{
// Pooling with strides=2 and padding=1
auto model = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_2d_pads.onnx"));
// input data shape (1, 1, 4, 4)
Inputs inputs;
inputs.push_back(test::NDArray<float, 4>({{{{0.f, 1.f, 2.f, 3.f},
{4.f, 5.f, 6.f, 7.f},
{8.f, 9.f, 10.f, 11.f},
{12.f, 13.f, 14.f, 15.f}}}})
.get_vector());
// (1, 1, 3, 3)
auto expected_output =
test::NDArray<float, 4>({{{{0.f, 2.f, 3.f}, {8.f, 10.f, 11.f}, {12.f, 14.f, 15.f}}}})
.get_vector();
Outputs outputs{execute(model, inputs, "INTERPRETER")};
EXPECT_EQ(expected_output, outputs.front());
}
TEST(onnx, model_batchnorm_default) TEST(onnx, model_batchnorm_default)
{ {
// Batch Normalization with default parameters // Batch Normalization with default parameters
...@@ -240,16 +309,16 @@ TEST(onnx, model_relu) ...@@ -240,16 +309,16 @@ TEST(onnx, model_relu)
TEST(onnx, model_gemm_abc) TEST(onnx, model_gemm_abc)
{ {
auto function{ngraph::onnx_import::import_onnx_function( auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/gemm_abc.onnx"))}; ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/gemm_abc.onnx"));
std::vector<std::vector<float>> inputs; std::vector<std::vector<float>> inputs;
inputs.emplace_back(ngraph::test::NDArray<float, 2>( inputs.emplace_back(test::NDArray<float, 2>(
{{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}, {13, 14, 15, 16, 17, 18}}) {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}, {13, 14, 15, 16, 17, 18}})
.get_vector()); .get_vector());
inputs.emplace_back(ngraph::test::NDArray<float, 2>({{19, 20, 21, 22}, inputs.emplace_back(test::NDArray<float, 2>({{19, 20, 21, 22},
{23, 24, 25, 26}, {23, 24, 25, 26},
{27, 28, 29, 30}, {27, 28, 29, 30},
{31, 32, 33, 34}, {31, 32, 33, 34},
...@@ -258,10 +327,10 @@ TEST(onnx, model_gemm_abc) ...@@ -258,10 +327,10 @@ TEST(onnx, model_gemm_abc)
.get_vector()); .get_vector());
inputs.emplace_back( inputs.emplace_back(
ngraph::test::NDArray<float, 2>({{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}).get_vector()); test::NDArray<float, 2>({{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}).get_vector());
auto expected_output = auto expected_output =
ngraph::test::NDArray<float, 2>( test::NDArray<float, 2>(
{{340, 350.5, 361, 371.5}, {862, 890.5, 919, 947.5}, {1384, 1430.5, 1477, 1523.5}}) {{340, 350.5, 361, 371.5}, {862, 890.5, 919, 947.5}, {1384, 1430.5, 1477, 1523.5}})
.get_vector(); .get_vector();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment