Commit 0be9266f authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[ONNX] Conv operation (#1472)

* [ONNX] Refactor exceptions

* [ONNX] Attribute helper functions

* [ONNX] Convolution operation
parent e19cc4a7
...@@ -35,18 +35,26 @@ add_library(onnx_import_interface OBJECT ...@@ -35,18 +35,26 @@ add_library(onnx_import_interface OBJECT
add_library(onnx_import STATIC add_library(onnx_import STATIC
onnx.pb.cc onnx.pb.cc
attribute.cpp attribute.cpp
attribute.hpp
exceptions.hpp exceptions.hpp
graph.cpp graph.cpp
graph.hpp
model.hpp model.hpp
node.cpp node.cpp
node.hpp
op/add.hpp op/add.hpp
op/batch_norm.hpp op/batch_norm.hpp
op/constant.cpp op/constant.cpp
op/constant.hpp op/constant.hpp
op/conv.cpp
op/relu.hpp op/relu.hpp
op/split.hpp op/split.hpp
ops_bridge.cpp ops_bridge.cpp
tensor.hpp tensor.hpp
utils/broadcasting.cpp
utils/broadcasting.hpp
utils/convpool.cpp
utils/convpool.hpp
value_info.hpp) value_info.hpp)
add_dependencies(onnx_import onnx_import_interface) add_dependencies(onnx_import onnx_import_interface)
......
/******************************************************************************* /*******************************************************************************
* Copyright 2018 Intel Corporation * Copyright 2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#pragma once #pragma once
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace error namespace error
{ {
struct NotSupported : ngraph_error struct NotSupported : ngraph_error
{ {
explicit NotSupported(const std::string& op_name, explicit NotSupported(const std::string& op_name,
const std::string& name, const std::string& name,
const std::string& message) const std::string& message)
: ngraph_error{op_name + " node (" + name + "): " + message} : ngraph_error{op_name + " node (" + name + "): " + message}
{ {
} }
}; };
namespace parameter namespace parameter
{ {
struct Value : ngraph_error struct Value : ngraph_error
{ {
Value(const std::string& op_name, Value(const std::string& op_name,
const std::string& name, const std::string& name,
const std::string& message) const std::string& message)
: ngraph_error{op_name + " node (" + name + "): " + message} : ngraph_error{op_name + " node (" + name + "): " + message}
{ {
} }
}; };
} // namespace paramter } // namespace paramter
} // namespace error } // namespace error
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
\ No newline at end of file
/******************************************************************************* /*******************************************************************************
* Copyright 2018 Intel Corporation * Copyright 2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#pragma once #pragma once
#include <memory> #include <memory>
#include "ngraph/frontend/onnx_import/exceptions.hpp" #include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/node.hpp" #include "ngraph/frontend/onnx_import/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{ {
inline NodeVector batch_norm(const Node& node) inline NodeVector batch_norm(const Node& node)
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
auto x = inputs.at(0); auto x = inputs.at(0);
auto scale = inputs.at(1); auto scale = inputs.at(1);
auto bias = inputs.at(2); auto bias = inputs.at(2);
std::shared_ptr<ngraph::Node> mean{nullptr}; std::shared_ptr<ngraph::Node> mean{nullptr};
std::shared_ptr<ngraph::Node> var{nullptr}; std::shared_ptr<ngraph::Node> var{nullptr};
int is_test{node.get_attribute_value<int>("is_test", 1)}; int is_test{node.get_attribute_value<int>("is_test", 1)};
int spatial{node.get_attribute_value<int>("spatial", 1)}; int spatial{node.get_attribute_value<int>("spatial", 1)};
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)}; double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
// TODO: Implement learning mode support // TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)}; // float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
bool training = false; bool training = false;
if (!is_test) if (!is_test)
{ {
throw error::NotSupported("BatchNormalization", throw error::NotSupported("BatchNormalization",
node.get_name(), node.get_name(),
"only 'is_test' mode is currently supported."); "only 'is_test' mode is currently supported.");
} }
if (!spatial) if (!spatial)
{ {
throw error::NotSupported("BatchNormalization", throw error::NotSupported("BatchNormalization",
node.get_name(), node.get_name(),
"only 'spatial' mode is currently supported."); "only 'spatial' mode is currently supported.");
} }
if (inputs.size() >= 5) if (inputs.size() >= 5)
{ {
mean = inputs.at(3); mean = inputs.at(3);
var = inputs.at(4); var = inputs.at(4);
return {std::make_shared<ngraph::op::BatchNorm>( return {std::make_shared<ngraph::op::BatchNorm>(
epsilon, scale, bias, x, mean, var, training)}; epsilon, scale, bias, x, mean, var, training)};
} }
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)}; return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)};
} }
} // namespace op } // namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
\ No newline at end of file
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <cstddef>
#include <memory>
#include <vector>
#include "ngraph/frontend/onnx_import/op/conv.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/slice.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace
{
std::shared_ptr<ngraph::op::Op>
make_ng_convolution(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& filters,
const ngraph::Strides& strides,
const ngraph::Strides& dilations,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above,
int groups)
{
if (groups > 1)
{
// Split one convolution op to N ops where N is the number of groups
// and concat results after computation.
// reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856
std::size_t n_data_channels{data->get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)};
// TODO: ensure n_data_channels % groups = 0
std::size_t data_group_size{n_data_channels / groups};
std::size_t filters_group_size{n_filters_channels / groups};
NodeVector convolution_nodes;
// initial bounds for splice
std::vector<std::size_t> data_lower_bounds(data->get_shape().size());
std::vector<std::size_t> data_upper_bounds{data->get_shape()};
std::vector<std::size_t> filters_lower_bounds(filters->get_shape().size());
std::vector<std::size_t> filters_upper_bounds{filters->get_shape()};
for (std::size_t group{0}; group < groups; ++group)
{
// slice data
data_lower_bounds[1] = group * data_group_size;
data_upper_bounds[1] = (group + 1) * data_group_size;
auto sliced_data = std::make_shared<ngraph::op::Slice>(
data, data_lower_bounds, data_upper_bounds);
// slice filters
filters_lower_bounds[0] = group * filters_group_size;
filters_upper_bounds[0] = (group + 1) * filters_group_size;
auto sliced_filters = std::make_shared<ngraph::op::Slice>(
filters, filters_lower_bounds, filters_upper_bounds);
convolution_nodes.push_back(
std::make_shared<ngraph::op::Convolution>(sliced_data,
sliced_filters,
strides,
dilations,
padding_below,
padding_above));
}
std::size_t concatenation_axis = 1;
return std::make_shared<ngraph::op::Concat>(convolution_nodes,
concatenation_axis);
}
else
{
return std::make_shared<ngraph::op::Convolution>(
data, filters, strides, dilations, padding_below, padding_above);
}
}
} // namespace
NodeVector conv(const Node& node)
{
const NodeVector& inputs = node.get_ng_inputs();
auto data = inputs.at(0);
auto filters = inputs.at(1);
int groups{node.get_attribute_value<int>("group", 1)};
// TODO: update to ASSERTION CHECK
if (groups < 0 || groups > data->get_shape().at(1) ||
groups > filters->get_shape().at(0))
{
throw error::parameter::Value{"Conv",
node.get_name(),
"incorrect value of 'group' attribute: " +
std::to_string(groups)};
}
auto strides{attribute::get_strides(node)};
auto dilations{attribute::get_dilations(node)};
auto paddings{attribute::get_pads(node)};
const auto& padding_below{paddings.first};
const auto& padding_above{paddings.second};
auto conv_node{make_ng_convolution(
data, filters, strides, dilations, padding_below, padding_above, groups)};
// no bias param
if (inputs.size() < 3)
{
return {conv_node};
}
auto bias{inputs.at(2)};
const Shape& new_shape = conv_node->get_shape();
auto broadcasted_bias{std::make_shared<ngraph::op::Broadcast>(
bias, new_shape, calculate_broadcast_axes(new_shape, bias->get_shape(), 1))};
return {std::make_shared<ngraph::op::Add>(conv_node, broadcasted_bias)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "frontend/onnx_import/utils/broadcasting.hpp"
#include "frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/node.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/strides.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace detail
{
std::shared_ptr<ngraph::op::Op>
make_ng_convolution(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& filters,
const ngraph::Strides& strides,
const ngraph::Strides& dilations,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above,
int groups);
}
/**
* @brief Performs ONNX Conv operation.
*
* @param node The ONNX node object representing this operation.
*
* @return The vector containing Ngraph nodes producing output of ONNX convolution
* operation.
*/
NodeVector conv(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/frontend/onnx_import/op/add.hpp" #include "ngraph/frontend/onnx_import/op/add.hpp"
#include "ngraph/frontend/onnx_import/op/batch_norm.hpp" #include "ngraph/frontend/onnx_import/op/batch_norm.hpp"
#include "ngraph/frontend/onnx_import/op/constant.hpp" #include "ngraph/frontend/onnx_import/op/constant.hpp"
#include "ngraph/frontend/onnx_import/op/conv.hpp"
#include "ngraph/frontend/onnx_import/op/relu.hpp" #include "ngraph/frontend/onnx_import/op/relu.hpp"
#include "ngraph/frontend/onnx_import/op/split.hpp" #include "ngraph/frontend/onnx_import/op/split.hpp"
#include "ops_bridge.hpp" #include "ops_bridge.hpp"
...@@ -71,6 +72,7 @@ namespace ngraph ...@@ -71,6 +72,7 @@ namespace ngraph
m_map.emplace("BatchNormalization", m_map.emplace("BatchNormalization",
std::bind(op::batch_norm, std::placeholders::_1)); std::bind(op::batch_norm, std::placeholders::_1));
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1)); m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1)); m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1)); m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
} }
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <numeric>
#include <vector>
#include "broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis)
{
std::vector<size_t> result(output_shape.size() - input_shape.size());
// Populate the result vector with monotonic increasing series from 0 until
// output_shape_size, excluding values in range [start_match_axis, start_match_axis + input_shape.size()
std::iota(std::begin(result), std::begin(result) + start_match_axis, 0);
std::iota(std::begin(result) + start_match_axis,
std::end(result),
start_match_axis + input_shape.size());
return result;
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/axis_set.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace onnx_import
{
/**
* @brief Generate a list of broadcast axes.
*
* @details Informally, a broadcast "adds" axes to the input tensor, replicating
* elements from the input tensor as needed to fill the new dimensions.
* Function calculate which of the output axes are added in this way.
*
* @param output_shape The new shape for the output tensor.
* @param input_shape The shape of input tensor.
* @param start_match_axis The axis along which we want to replicate elements.
* The starting axis position (0-based) int the output
* shape from which the current shape of the tensor
* matches the desired new shape.
*
* @return The indices of added axes.
*/
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis);
/**
* @brief Generate a list of broadcast along axes.
*
* @details Broadcast "adds" elements along axes to the input tensor, replicating
* elements from the input tensor as needed to fill the new dimensions.
* Function calculate which of the output axes are added in this way.
*
* This function will attempt to match shapes, assuming the current shape
* matches the rightmost positions of the desired new shape. This behaviour
* is similar to NumPy's broadcasting.
*
* @param output_shape The new shape for the output tensor.
* @param input_shape The shape of input tensor.
*
* @return The indices of added axes.
*/
inline AxisSet calculate_broadcast_axes(const Shape& output_shape, const Shape& input_shape)
{
return calculate_broadcast_axes(
output_shape, input_shape, output_shape.size() - input_shape.size());
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "convpool.hpp"
#include <cmath>
namespace ngraph
{
namespace onnx_import
{
namespace attribute
{
namespace
{
CoordinateDiff get_auto_pads(const Shape& kernel_shape, const std::string& auto_pad)
{
CoordinateDiff pads;
// Add padding to the input to match the size of output size.
auto pad_value = [](size_t dim) {
return (static_cast<float>(dim) - 1.f) / 2.f;
};
if (auto_pad == "SAME_UPPER")
{
for (size_t dim : kernel_shape)
{
pads.emplace_back(std::floor(pad_value(dim)));
}
for (size_t dim : kernel_shape)
{
pads.emplace_back(std::ceil(pad_value(dim)));
}
}
else if (auto_pad == "SAME_LOWER")
{
for (size_t dim : kernel_shape)
{
pads.emplace_back(std::ceil(pad_value(dim)));
}
for (size_t dim : kernel_shape)
{
pads.emplace_back(std::floor(pad_value(dim)));
}
}
return pads;
}
} // namespace
std::pair<CoordinateDiff, CoordinateDiff> get_pads(const Node& node,
const Shape& kernel_shape)
{
CoordinateDiff pads;
try
{
auto pads_int64 = node.get_attribute_value<std::vector<int64_t>>("pads");
pads = CoordinateDiff{std::begin(pads_int64), std::end(pads_int64)};
}
catch (const error::node::UnknownAttribute&)
{
std::string auto_pad{node.get_attribute_value<std::string>("auto_pad", "")};
if (!auto_pad.empty())
{
pads = get_auto_pads(kernel_shape, auto_pad);
}
}
if (pads.empty())
{
pads = {static_cast<std::ptrdiff_t>(kernel_shape.size()), 0UL};
}
if (pads.size() <= 3)
{
// Paddings specified in (H, W, C) format.
return {pads, pads};
}
else
{
return {{std::begin(pads) + pads.size() / 2, std::end(pads)},
{std::begin(pads), std::begin(pads) + pads.size() / 2}};
}
}
} // namespace attribute
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/frontend/onnx_import/attribute.hpp"
#include "ngraph/frontend/onnx_import/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace attribute
{
/**
* @brief Get padding values for the operation described by an ONNX node.
* @details If `auto_pad` attribute is specified as SAME_UPPER or SAME_LOWER, or VALID
* values are calculated. Otherwise values are taken from the `pads` attribute.
*
* `pads` value should follow [x1_begin, x2_begin..., x1_end, x2_end,...].
*
* @param node The Node ptr representing ONNX operation.
* @param kernel_shape The shape of the kernel which we retrieve pads for.
*
* @return A pair of (padding_above, padding_below), which elements contains number of
* pixels to pad in respective dimensions (height, width, depth).
*/
std::pair<CoordinateDiff, CoordinateDiff> get_pads(const Node& node,
const Shape& kernel_shape);
/**
* @brief Get padding values for the operation described by an ONNX node.
* @details If `auto_pad` attribute is specified as SAME_UPPER or SAME_LOWER, or VALID
* values are calculated. Otherwise values are taken from the `pads` attribute.
*
* `pads` value should follow [x1_begin, x2_begin..., x1_end, x2_end,...].
*
* @param node The Node ptr representing ONNX operation.
*
* @return A pair of (padding_above, padding_below), which elements contains number of
* pixels to pad in respective dimensions (height, width, depth).
*/
inline std::pair<CoordinateDiff, CoordinateDiff> get_pads(const Node& node)
{
return get_pads(node, get_kernel_shape(node));
}
} // namespace attribute
} // namespace onnx_import
} // namespace ngraph
 backend-test:
K
x
Wy"Conv*
kernel_shape@@*
pads@@@@*
strides@@test_conv_with_strides_paddingZ
x




Z
W




b
y




B
\ No newline at end of file
ONNXNgraphImporter:
N
A
B
CD"Conv*
kernel_shape@@*
pads@@@@*
strides@@ compute_graphZ
A




Z
B




Z
C

b
D




B
\ No newline at end of file
...@@ -104,6 +104,79 @@ TEST(onnx, model_split_variable_parts_2d) ...@@ -104,6 +104,79 @@ TEST(onnx, model_split_variable_parts_2d)
} }
} }
namespace
{
std::vector<std::vector<float>>
conv2d_execute(const std::shared_ptr<ngraph::Function>& function)
{
std::vector<std::vector<float>> args;
// data (1, 1, 7, 5) input tensor
args.emplace_back(ngraph::test::NDArray<float, 4>{{{{{0.f, 1.f, 2.f, 3.f, 4.f},
{5.f, 6.f, 7.f, 8.f, 9.f},
{10.f, 11.f, 12.f, 13.f, 14.f},
{15.f, 16.f, 17.f, 18.f, 19.f},
{20.f, 21.f, 22.f, 23.f, 24.f},
{25.f, 26.f, 27.f, 28.f, 29.f},
{30.f, 31.f, 32.f, 33.f, 34.f}}}}}
.get_vector());
// filters (1, 1, 3, 3) aka convolution weights
args.emplace_back(
ngraph::test::NDArray<float, 4>{{{{{1.f, 1.f, 1.f}, {1.f, 1.f, 1.f}, {1.f, 1.f, 1.f}}}}}
.get_vector());
return execute(function, args, "INTERPRETER");
}
} // namespace
TEST(onnx, mode_conv2d_strides_padding)
{
// Convolution with strides=2 and padding=1
auto function{ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_padding.onnx"))};
// (1, 1, 4, 3)
auto expected_output = ngraph::test::NDArray<float, 4>({{{{12.f, 27.f, 24.f},
{63.f, 108.f, 81.f},
{123.f, 198.f, 141.f},
{112.f, 177.f, 124.f}}}})
.get_vector();
auto result{conv2d_execute(function)};
EXPECT_EQ(expected_output, result.front());
}
TEST(onnx, model_conv2d_strides_no_padding)
{
// Convolution with strides=2 and padding=1
auto function{ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_no_padding.onnx"))};
// (1, 1, 3, 2)
auto expected_output =
ngraph::test::NDArray<float, 4>({{{{54.f, 72.f}, {144.f, 162.f}, {234.f, 252.f}}}})
.get_vector();
auto result{conv2d_execute(function)};
EXPECT_EQ(expected_output, result.front());
}
TEST(onnx, model_conv2d_strides_assymetric_padding)
{
// Convolution with strides=2 and padding=1
auto function{ngraph::onnx_import::import_onnx_function(ngraph::file_util::path_join(
SERIALIZED_ZOO, "onnx/conv_with_strides_and_assymmetric_padding.onnx"))};
// (1, 1, 4, 2)
auto expected_output = ngraph::test::NDArray<float, 4>(
{{{{21.f, 33.f}, {99.f, 117.f}, {189.f, 207.f}, {171.f, 183.f}}}})
.get_vector();
auto result{conv2d_execute(function)};
EXPECT_EQ(expected_output, result.front());
}
TEST(onnx, model_batchnorm_default) TEST(onnx, model_batchnorm_default)
{ {
// Batch Normalization with default parameters // Batch Normalization with default parameters
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment