Commit 12a0dca5 authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[ONNX] Variadic ops (#1509)

* [ONNX] Sum op

* [ONNX] Generic variadic op template

* Add support for Min op

* clang-format

* Add support for Max op

* Add support for Mean op

* Docs, code cleanup

* Docs, code cleanup
parent 6f1664df
......@@ -46,6 +46,10 @@ add_library(onnx_import STATIC
op/matmul.hpp
op/max_pool.cpp
op/max_pool.hpp
op/max.hpp
op/mean.cpp
op/mean.hpp
op/min.hpp
op/mul.hpp
op/relu.hpp
op/softmax.cpp
......@@ -53,13 +57,15 @@ add_library(onnx_import STATIC
op/split.cpp
op/split.hpp
op/sub.hpp
op/sum.hpp
ops_bridge.cpp
utils/broadcasting.cpp
utils/broadcasting.hpp
utils/convpool.cpp
utils/convpool.hpp
utils/reshape.cpp
utils/reshape.hpp)
utils/reshape.hpp
utils/variadic.hpp)
add_dependencies(onnx_import onnx_import_interface)
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/maximum.hpp"
#include "core/node.hpp"
#include "utils/variadic.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector max(const Node& node)
{
return variadic::make_ng_variadic_op<ngraph::op::Maximum>(node);
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "mean.hpp"
#include "utils/variadic.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector mean(const Node& node)
{
auto sum = variadic::make_ng_variadic_op<ngraph::op::Add>(node).front();
auto shape = sum->get_shape();
// Create a Constant representing the number of inputs with the same shape as sum
auto count = ngraph::op::Constant::create(
sum->get_element_type(),
shape,
std::vector<int>(shape_size(shape), node.get_ng_inputs().size()));
return {sum / count};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector mean(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/minimum.hpp"
#include "core/node.hpp"
#include "utils/variadic.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector min(const Node& node)
{
return variadic::make_ng_variadic_op<ngraph::op::Minimum>(node);
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "core/node.hpp"
#include "utils/variadic.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector sum(const Node& node)
{
return variadic::make_ng_variadic_op<ngraph::op::Add>(node);
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -26,12 +26,16 @@
#include "op/div.hpp"
#include "op/gemm.hpp"
#include "op/matmul.hpp"
#include "op/max.hpp"
#include "op/max_pool.hpp"
#include "op/mean.hpp"
#include "op/min.hpp"
#include "op/mul.hpp"
#include "op/relu.hpp"
#include "op/softmax.hpp"
#include "op/split.hpp"
#include "op/sub.hpp"
#include "op/sum.hpp"
#include "ops_bridge.hpp"
namespace ngraph
......@@ -87,11 +91,15 @@ namespace ngraph
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1));
m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1));
m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1));
m_map.emplace("Max", std::bind(op::max, std::placeholders::_1));
m_map.emplace("Mean", std::bind(op::mean, std::placeholders::_1));
m_map.emplace("Min", std::bind(op::min, std::placeholders::_1));
m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1));
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1));
m_map.emplace("Sum", std::bind(op::sum, std::placeholders::_1));
}
NodeVector operator()(const Node& node) const
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <numeric>
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace variadic
{
/// \brief Create an nGraph version of an ONNX variadic operation.
/// This creates a subgraph with a series of binary operations.
///
/// \tparam T Class of an nGraph binary operation (e.g. Add, Minimum, Maximum)
/// \param node incoming ONNX opearation
/// \return nGraph node equivalent of the ONNX operation
template <class T>
inline NodeVector make_ng_variadic_op(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
// Templated binary operation - Creates Add, Minimum, Maximum, etc.
auto binary_operation = [](const std::shared_ptr<ngraph::Node>& arg0,
const std::shared_ptr<ngraph::Node>& arg1) {
return std::make_shared<T>(arg0, arg1);
};
// Create a result node as a series of binary operations
auto result = std::accumulate(
std::next(std::begin(ng_inputs)), // First operand value - the second input
std::end(ng_inputs), // Last value - final input
ng_inputs.front(), // Initial value - first input
binary_operation);
return {result};
}
} // namespace variadic
} // namespace onnx_import
} // namespace ngraph
 backend-test:‘
%
data_0
data_1
data_2result"Maxtest_max_exampleZ
data_0

Z
data_1

Z
data_2

b
result

B
\ No newline at end of file
 backend-test:
&
data_0
data_1
data_2result"Meantest_mean_exampleZ
data_0

Z
data_1

Z
data_2

b
result

B
\ No newline at end of file
 backend-test:v

data_0
data_1result"Mintest_min_two_inputsZ
data_0

Z
data_1

b
result

B
\ No newline at end of file
 backend-test:
%
data_0
data_1
data_2result"Sumtest_sum_exampleZ
data_0

Z
data_1

Z
data_2

b
result

B
\ No newline at end of file
 backend-test:W

data_0result"Sumtest_sum_one_inputZ
data_0

b
result

B
\ No newline at end of file
......@@ -307,6 +307,82 @@ TEST(onnx, model_relu)
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_sum)
{
// Simple Sum test
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/sum.onnx"));
// input data shape (3, )
Inputs inputs;
inputs.emplace_back(std::vector<float>{3.f, 0.f, 2.f});
inputs.emplace_back(std::vector<float>{1.f, 3.f, 4.f});
inputs.emplace_back(std::vector<float>{2.f, 6.f, 6.f});
Outputs expected_outputs{{6.f, 9.f, 12.f}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_sum_one_input)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/sum_one_input.onnx"));
// input data shape (3, )
Inputs inputs{{3.f, 0.f, 2.f}};
Outputs expected_outputs{{3.f, 0.f, 2.f}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_min_two_inputs)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/min_two_inputs.onnx"));
// input data shape (3, )
Inputs inputs;
inputs.emplace_back(std::vector<float>{1.f, 2.f, 1.f});
inputs.emplace_back(std::vector<float>{1.f, 4.f, 4.f});
Outputs expected_outputs{{1.f, 2.f, 1.f}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_max)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/max.onnx"));
// input data shape (3, )
Inputs inputs;
inputs.emplace_back(std::vector<float>{3.f, 2.f, 1.f});
inputs.emplace_back(std::vector<float>{1.f, 4.f, 4.f});
inputs.emplace_back(std::vector<float>{2.f, 5.f, 3.f});
Outputs expected_outputs{{3.f, 5.f, 4.f}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_mean)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/mean.onnx"));
// input data shape (3, )
Inputs inputs;
inputs.emplace_back(std::vector<float>{3.f, 0.f, 2.f});
inputs.emplace_back(std::vector<float>{1.f, 3.f, 4.f});
inputs.emplace_back(std::vector<float>{2.f, 6.f, 6.f});
Outputs expected_outputs{{2.f, 3.f, 4.f}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_gemm_abc)
{
auto function = ngraph::onnx_import::import_onnx_function(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment