Unverified Commit 0804f5e2 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

[ONNX] Code cleanup (#1476)

* Move batch_norm implementation to a .cpp file
* Move split implementation to a .cpp file
parent 04230095
...@@ -43,11 +43,13 @@ add_library(onnx_import STATIC ...@@ -43,11 +43,13 @@ add_library(onnx_import STATIC
node.cpp node.cpp
node.hpp node.hpp
op/add.hpp op/add.hpp
op/batch_norm.cpp
op/batch_norm.hpp op/batch_norm.hpp
op/constant.cpp op/constant.cpp
op/constant.hpp op/constant.hpp
op/conv.cpp op/conv.cpp
op/relu.hpp op/relu.hpp
op/split.cpp
op/split.hpp op/split.hpp
ops_bridge.cpp ops_bridge.cpp
tensor.hpp tensor.hpp
......
...@@ -51,4 +51,4 @@ namespace ngraph ...@@ -51,4 +51,4 @@ namespace ngraph
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
\ No newline at end of file
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/node.hpp"
#include "ngraph/frontend/onnx_import/op/batch_norm.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector batch_norm(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto x = inputs.at(0);
auto scale = inputs.at(1);
auto bias = inputs.at(2);
std::shared_ptr<ngraph::Node> mean{nullptr};
std::shared_ptr<ngraph::Node> var{nullptr};
int is_test{node.get_attribute_value<int>("is_test", 1)};
int spatial{node.get_attribute_value<int>("spatial", 1)};
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
// TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
bool training = false;
if (!is_test)
{
throw error::NotSupported("BatchNormalization",
node.get_name(),
"only 'is_test' mode is currently supported.");
}
if (!spatial)
{
throw error::NotSupported("BatchNormalization",
node.get_name(),
"only 'spatial' mode is currently supported.");
}
if (inputs.size() >= 5)
{
mean = inputs.at(3);
var = inputs.at(4);
return {std::make_shared<ngraph::op::BatchNorm>(
epsilon, scale, bias, x, mean, var, training)};
}
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -16,13 +16,8 @@ ...@@ -16,13 +16,8 @@
#pragma once #pragma once
#include <memory>
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/node.hpp" #include "ngraph/frontend/onnx_import/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/batch_norm.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -30,48 +25,9 @@ namespace ngraph ...@@ -30,48 +25,9 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector batch_norm(const Node& node) NodeVector batch_norm(const Node& node);
{
NodeVector inputs{node.get_ng_inputs()};
auto x = inputs.at(0);
auto scale = inputs.at(1);
auto bias = inputs.at(2);
std::shared_ptr<ngraph::Node> mean{nullptr};
std::shared_ptr<ngraph::Node> var{nullptr};
int is_test{node.get_attribute_value<int>("is_test", 1)};
int spatial{node.get_attribute_value<int>("spatial", 1)};
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
// TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
bool training = false;
if (!is_test)
{
throw error::NotSupported("BatchNormalization",
node.get_name(),
"only 'is_test' mode is currently supported.");
}
if (!spatial)
{
throw error::NotSupported("BatchNormalization",
node.get_name(),
"only 'spatial' mode is currently supported.");
}
if (inputs.size() >= 5)
{
mean = inputs.at(3);
var = inputs.at(4);
return {std::make_shared<ngraph::op::BatchNorm>(
epsilon, scale, bias, x, mean, var, training)};
}
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)};
}
} // namespace op } // namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
\ No newline at end of file
...@@ -14,16 +14,17 @@ ...@@ -14,16 +14,17 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include <cstddef> #include "ngraph/op/add.hpp"
#include <memory> #include "ngraph/op/broadcast.hpp"
#include <vector>
#include "ngraph/frontend/onnx_import/op/conv.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/conv.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
......
...@@ -16,18 +16,8 @@ ...@@ -16,18 +16,8 @@
#pragma once #pragma once
#include "frontend/onnx_import/utils/broadcasting.hpp"
#include "frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/node.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/strides.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,18 +25,6 @@ namespace ngraph ...@@ -35,18 +25,6 @@ namespace ngraph
{ {
namespace op namespace op
{ {
namespace detail
{
std::shared_ptr<ngraph::op::Op>
make_ng_convolution(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& filters,
const ngraph::Strides& strides,
const ngraph::Strides& dilations,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above,
int groups);
}
/** /**
* @brief Performs ONNX Conv operation. * @brief Performs ONNX Conv operation.
* *
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/slice.hpp"
#include "ngraph/frontend/onnx_import/op/split.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace error
{
namespace op
{
namespace split
{
namespace detail
{
struct Error : ngraph_error
{
explicit Error(const std::string& name, const std::string& message)
: ngraph_error{"Split node (" + name + "): " + message}
{
}
};
}
struct OutOfRange : detail::Error
{
explicit OutOfRange(const std::string& name)
: Error{name,
"provided split axis is out of input tensor dimensions range."}
{
}
};
struct Parts : detail::Error
{
explicit Parts(const std::string& name,
std::size_t parts,
std::size_t axis_length)
: Error{name,
"tensor cannot be split into " + std::to_string(parts) +
" equal parts, along axis of length " +
std::to_string(axis_length)}
{
}
};
struct Sum : detail::Error
{
explicit Sum(const std::string& name, std::size_t parts, std::size_t axis)
: Error{name,
"provided lengths of split parts does not sum up to "
"length of axis we split on: " +
std::to_string(parts) + " != " + std::to_string(axis)}
{
}
};
} // namespace split
} // namespace op
} // namespace error
namespace op
{
namespace detail
{
template <typename T>
inline T get_valid_array_index(T left, T right)
{
return (left >= 0) ? std::min(left, right)
: std::max(static_cast<T>(0), right + left);
}
inline std::shared_ptr<ngraph::op::Slice>
make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes,
std::vector<std::size_t> starts,
std::vector<std::size_t> ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds);
}
} // namespace detail
NodeVector split(const Node& node)
{
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{
axis_to_split = input->get_shape().size() + axis;
}
else if (axis_to_split >= input->get_shape().size())
{
throw error::op::split::OutOfRange{node.get_name()};
}
std::size_t length_axis_to_split{input->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts;
try
{
length_parts = node.get_attribute_value<std::vector<std::size_t>>("split");
}
catch (const std::exception&)
{
if (length_axis_to_split % count_outputs)
{
throw error::op::split::Parts{
node.get_name(), count_outputs, length_axis_to_split};
}
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
}
std::size_t start_index{0};
NodeVector outputs;
for (const auto& length_part : length_parts)
{
std::size_t end_index{start_index + length_part};
outputs.push_back(
detail::make_ng_slice(input, {axis_to_split}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -16,144 +16,17 @@ ...@@ -16,144 +16,17 @@
#pragma once #pragma once
#include "ngraph/frontend/onnx_import/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/frontend/onnx_import/node.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace error
{
namespace op
{
namespace split
{
namespace detail
{
struct Error : ngraph_error
{
explicit Error(const std::string& name, const std::string& message)
: ngraph_error{"Split node (" + name + "): " + message}
{
}
};
}
struct OutOfRange : detail::Error
{
explicit OutOfRange(const std::string& name)
: Error{name,
"provided split axis is out of input tensor dimensions range."}
{
}
};
struct Parts : detail::Error
{
explicit Parts(const std::string& name,
std::size_t parts,
std::size_t axis_length)
: Error{name,
"tensor cannot be split into " + std::to_string(parts) +
" equal parts, along axis of length " +
std::to_string(axis_length)}
{
}
};
struct Sum : detail::Error
{
explicit Sum(const std::string& name, std::size_t parts, std::size_t axis)
: Error{name,
"provided lengths of split parts does not sum up to "
"length of axis we split on: " +
std::to_string(parts) + " != " + std::to_string(axis)}
{
}
};
} // namespace split
} // namespace op
} // namespace error
namespace op namespace op
{ {
namespace detail NodeVector split(const Node& node);
{
template <typename T>
inline T get_valid_array_index(T left, T right)
{
return (left >= 0) ? std::min(left, right)
: std::max(static_cast<T>(0), right + left);
}
inline std::shared_ptr<ngraph::op::Slice>
make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes,
std::vector<std::size_t> starts,
std::vector<std::size_t> ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds);
}
} // namespace detail
inline NodeVector split(const Node& node)
{
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{
axis_to_split = input->get_shape().size() + axis;
}
else if (axis_to_split >= input->get_shape().size())
{
throw error::op::split::OutOfRange{node.get_name()};
}
std::size_t length_axis_to_split{input->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts;
try
{
length_parts = node.get_attribute_value<std::vector<std::size_t>>("split");
}
catch (const std::exception&)
{
if (length_axis_to_split % count_outputs)
{
throw error::op::split::Parts{
node.get_name(), count_outputs, length_axis_to_split};
}
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
}
std::size_t start_index{0};
NodeVector outputs;
for (const auto& length_part : length_parts)
{
std::size_t end_index{start_index + length_part};
outputs.push_back(
detail::make_ng_slice(input, {axis_to_split}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
}
} // namespace op } // namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -166,7 +166,7 @@ TEST(onnx, model_conv2d_strides_assymetric_padding) ...@@ -166,7 +166,7 @@ TEST(onnx, model_conv2d_strides_assymetric_padding)
{ {
// Convolution with strides=2 and padding=1 // Convolution with strides=2 and padding=1
auto function{ngraph::onnx_import::import_onnx_function(ngraph::file_util::path_join( auto function{ngraph::onnx_import::import_onnx_function(ngraph::file_util::path_join(
SERIALIZED_ZOO, "onnx/conv_with_strides_and_assymmetric_padding.onnx"))}; SERIALIZED_ZOO, "onnx/conv_with_strides_and_asymmetric_padding.onnx"))};
// (1, 1, 4, 2) // (1, 1, 4, 2)
auto expected_output = ngraph::test::NDArray<float, 4>( auto expected_output = ngraph::test::NDArray<float, 4>(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment