Commit d4982dd1 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Sang Ik Lee

[ONNX] Lp/InstanceNormalization operations. (#3087)

* Add LpNormalization operator along with unit tests.

* Add validation macro based on NGRAPH_CHECK.

* Add InstanceNormalization operation along with unit tests.

* Update supported ops table.

* Fix merge error.
parent e8e3db24
...@@ -37,6 +37,7 @@ add_library(onnx_import STATIC ...@@ -37,6 +37,7 @@ add_library(onnx_import STATIC
core/operator_set.hpp core/operator_set.hpp
core/tensor.hpp core/tensor.hpp
core/value_info.hpp core/value_info.hpp
exceptions.cpp
exceptions.hpp exceptions.hpp
op/acos.hpp op/acos.hpp
op/acosh.cpp op/acosh.cpp
...@@ -103,11 +104,15 @@ add_library(onnx_import STATIC ...@@ -103,11 +104,15 @@ add_library(onnx_import STATIC
op/hardmax.cpp op/hardmax.cpp
op/hardmax.hpp op/hardmax.hpp
op/identity.hpp op/identity.hpp
op/instance_norm.cpp
op/instance_norm.hpp
op/leaky_relu.cpp op/leaky_relu.cpp
op/leaky_relu.hpp op/leaky_relu.hpp
op/less.hpp op/less.hpp
op/log.hpp op/log.hpp
op/log_softmax.hpp op/log_softmax.hpp
op/lp_norm.cpp
op/lp_norm.hpp
op/lp_pool.cpp op/lp_pool.cpp
op/lp_pool.hpp op/lp_pool.hpp
op/lrn.cpp op/lrn.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <sstream>
#include "exceptions.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace error
{
namespace detail
{
std::string get_error_msg_prefix(const Node& node)
{
std::stringstream ss;
ss << "While validating ONNX node '" << node << "'";
return ss.str();
}
}
}
}
}
...@@ -16,7 +16,11 @@ ...@@ -16,7 +16,11 @@
#pragma once #pragma once
#include <string>
#include "core/node.hpp"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/check.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
namespace ngraph namespace ngraph
...@@ -25,6 +29,11 @@ namespace ngraph ...@@ -25,6 +29,11 @@ namespace ngraph
{ {
namespace error namespace error
{ {
namespace detail
{
std::string get_error_msg_prefix(const Node& node);
}
struct NotSupported : AssertionFailure struct NotSupported : AssertionFailure
{ {
explicit NotSupported(const std::string& what_arg) explicit NotSupported(const std::string& what_arg)
...@@ -41,6 +50,17 @@ namespace ngraph ...@@ -41,6 +50,17 @@ namespace ngraph
} }
}; };
class NodeValidationFailure : public CheckFailure
{
public:
NodeValidationFailure(const CheckLocInfo& check_loc_info,
const Node& node,
const std::string& explanation)
: CheckFailure(check_loc_info, detail::get_error_msg_prefix(node), explanation)
{
}
};
} // namespace error } // namespace error
} // namespace onnx_import } // namespace onnx_import
...@@ -54,3 +74,7 @@ namespace ngraph ...@@ -54,3 +74,7 @@ namespace ngraph
NGRAPH_ASSERT_STREAM_DO_NOT_USE_IN_NEW_CODE(ngraph::onnx_import::error::InvalidArgument, \ NGRAPH_ASSERT_STREAM_DO_NOT_USE_IN_NEW_CODE(ngraph::onnx_import::error::InvalidArgument, \
cond_) \ cond_) \
<< (node_) << " " << (node_) << " "
#define CHECK_VALID_NODE(node_, cond_, ...) \
NGRAPH_CHECK_HELPER( \
::ngraph::onnx_import::error::NodeValidationFailure, (node_), (cond_), ##__VA_ARGS__)
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <cstddef>
#include "exceptions.hpp"
#include "instance_norm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "utils/common.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector instance_norm(const Node& node)
{
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::shared_ptr<ngraph::Node> scale{node.get_ng_inputs().at(1)};
std::shared_ptr<ngraph::Node> bias{node.get_ng_inputs().at(2)};
const float epsilon{node.get_attribute_value<float>("epsilon", 1e-5f)};
CHECK_VALID_NODE(node,
(scale->get_shape().size() == 1 &&
scale->get_shape()[0] == data->get_shape().at(1)),
"Scale input must be one dimensional vector of number of "
"input data channels size.");
CHECK_VALID_NODE(node,
(bias->get_shape().size() == 1 &&
bias->get_shape()[0] == data->get_shape().at(1)),
"Bias input must be one dimensional vector of number of "
"input data channels size.");
// all dimensions except spatial/feature
const AxisSet reduction_axes{
common::get_monotonic_range<std::size_t>(data->get_shape().size(), 2)};
const std::shared_ptr<ngraph::Node> eps_node =
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_shape(),
std::vector<float>{epsilon});
scale = ngraph::op::legacy_style_broadcast_for_binary_operation(data, scale, 1)
.at(1);
bias = ngraph::op::legacy_style_broadcast_for_binary_operation(data, bias, 1)
.at(1);
std::shared_ptr<ngraph::Node> mean = builder::mean(data, reduction_axes);
mean = std::make_shared<ngraph::op::Broadcast>(
mean, data->get_shape(), reduction_axes);
std::shared_ptr<ngraph::Node> variance =
builder::variance(data, reduction_axes);
variance = std::make_shared<ngraph::op::Broadcast>(
variance, data->get_shape(), reduction_axes);
const auto sqrt = std::make_shared<ngraph::op::Sqrt>(variance + eps_node);
return {scale * (data - mean) / sqrt + bias};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Creates nGraph node representing ONNX InstanceNormalization operator.
///
/// \note The resulting node represents following equation:
/// y = scale * (x - mean) / sqrt(variance + epsilon) + B
/// where mean and variance are computed per instance per channel.
///
/// \param[in] node The input ONNX node representing this operation.
///
/// \return Vector of nodes containting resulting nGraph nodes.
///
NodeVector instance_norm(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <cstddef>
#include <cstdint>
#include "exceptions.hpp"
#include "lp_norm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/divide.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector lp_norm(const Node& node)
{
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
if (axis < 0)
{
axis += data->get_shape().size();
}
ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
<< "Invalid `p` attribute value: " << p_norm
<< "Only normalization of 1st or 2nd order is supported.";
const AxisSet reduction_axes{static_cast<std::size_t>(axis)};
std::shared_ptr<ngraph::Node> norm = ngraph::builder::lp_norm(
data, reduction_axes, static_cast<std::size_t>(p_norm));
norm = std::make_shared<ngraph::op::Broadcast>(
norm, data->get_shape(), reduction_axes);
return {data / norm};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Creates nGraph node representing ONNX LpNormalization operator.
///
/// Suppose A contains spatial dimensions of input tensor, then
/// for matrix A we have p-norm defined as following double sum over
/// all elements:
/// ||A||_p = ||vec(A)||_p = [sum_{i=1}^m sum_{j=1}^n abs(a_{i,j})^p]^{1/p}
///
/// \param[in] node The input ONNX node representing this operation.
///
/// \return Vector of nodes containting resulting nGraph nodes.
///
NodeVector lp_norm(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <memory>
#include "exceptions.hpp" #include "exceptions.hpp"
#include "lp_pool.hpp" #include "lp_pool.hpp"
......
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
#pragma once #pragma once
#include <memory>
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
......
...@@ -52,12 +52,14 @@ opset versions starting from `1` to `6` and to the latest opset version. ...@@ -52,12 +52,14 @@ opset versions starting from `1` to `6` and to the latest opset version.
| GlobalMaxPool | 1- | | GlobalMaxPool | 1- |
| Greater | 1-7-9 | | Greater | 1-7-9 |
| HardSigmoid | 1-6- | | HardSigmoid | 1-6- |
| Identity | 1- | | Identity | 1- |
| InstanceNormalization | 1- |
| LRN | 1- | | LRN | 1- |
| LeakyRelu | 1-6- | | LeakyRelu | 1-6- |
| Less | 1-7-9 | | Less | 1-7-9 |
| Log | 1-6- | | Log | 1-6- |
| LogSoftmax | 1- | | LogSoftmax | 1- |
| LpNormalization | 1- |
| MatMul | 1-9 | | MatMul | 1-9 |
| Max | 1-6-8- | | Max | 1-6-8- |
| MaxPool | 1-8- | | MaxPool | 1-8- |
...@@ -153,8 +155,4 @@ opset versions starting from `1` to `6` and to the latest opset version. ...@@ -153,8 +155,4 @@ opset versions starting from `1` to `6` and to the latest opset version.
|------|-----------------|--------|--------|---------| |------|-----------------|--------|--------|---------|
| Add, Sub, Mul, Div | 1-6 | | | We currently don't support legacy broadcasting rules for binary ops. | | Add, Sub, Mul, Div | 1-6 | | | We currently don't support legacy broadcasting rules for binary ops. |
| Cast | 1-6- | | 427 | Errors while casting to bool | | Cast | 1-6- | | 427 | Errors while casting to bool |
| EyeLike | (9) | | 439 | Make constant node. |
| Hardmax | - | | 431 | Use make constant and Argmax. See `test_ops_unary.py::test_hardmax()` | | Hardmax | - | | 431 | Use make constant and Argmax. See `test_ops_unary.py::test_hardmax()` |
| LpNormalization | - | | 436 | Just an equation. Only Lp{1,2} need to be supported. |
| InstanceNormalization | - | | 436 | Just an equation. For per channel computation may _slice/op/concat_ pattern need to be used. |
| Shrink | (9) | | 449 | Just an easy equation. |
...@@ -64,10 +64,12 @@ ...@@ -64,10 +64,12 @@
#include "op/hard_sigmoid.hpp" #include "op/hard_sigmoid.hpp"
#include "op/hardmax.hpp" #include "op/hardmax.hpp"
#include "op/identity.hpp" #include "op/identity.hpp"
#include "op/instance_norm.hpp"
#include "op/leaky_relu.hpp" #include "op/leaky_relu.hpp"
#include "op/less.hpp" #include "op/less.hpp"
#include "op/log.hpp" #include "op/log.hpp"
#include "op/log_softmax.hpp" #include "op/log_softmax.hpp"
#include "op/lp_norm.hpp"
#include "op/lp_pool.hpp" #include "op/lp_pool.hpp"
#include "op/lrn.hpp" #include "op/lrn.hpp"
#include "op/lstm.hpp" #include "op/lstm.hpp"
...@@ -273,10 +275,12 @@ namespace ngraph ...@@ -273,10 +275,12 @@ namespace ngraph
REGISTER_OPERATOR("Hardmax", 1, hardmax); REGISTER_OPERATOR("Hardmax", 1, hardmax);
REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid); REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid);
REGISTER_OPERATOR("Identity", 1, identity); REGISTER_OPERATOR("Identity", 1, identity);
REGISTER_OPERATOR("InstanceNormalization", 1, instance_norm);
REGISTER_OPERATOR("LeakyRelu", 1, leaky_relu); REGISTER_OPERATOR("LeakyRelu", 1, leaky_relu);
REGISTER_OPERATOR("Less", 1, less); REGISTER_OPERATOR("Less", 1, less);
REGISTER_OPERATOR("Log", 1, log); REGISTER_OPERATOR("Log", 1, log);
REGISTER_OPERATOR("LogSoftmax", 1, log_softmax); REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
REGISTER_OPERATOR("LpNormalization", 1, lp_norm);
REGISTER_OPERATOR("LRN", 1, lrn); REGISTER_OPERATOR("LRN", 1, lrn);
REGISTER_OPERATOR("LSTM", 1, lstm); REGISTER_OPERATOR("LSTM", 1, lstm);
REGISTER_OPERATOR("MatMul", 1, matmul); REGISTER_OPERATOR("MatMul", 1, matmul);
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "scale"
input: "bias"
output: "y"
op_type: "InstanceNormalization"
attribute {
name: "epsilon"
f: 0.01
type: FLOAT
}
}
name: "instance_norm_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "scale"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "bias"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 1
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "LpNormalization"
}
name: "lp_norm_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 1
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "LpNormalization"
attribute {
name: "axis"
i: 0
type: INT
}
attribute {
name: "p"
i: 1
type: INT
}
}
name: "lp_norm_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 1
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "LpNormalization"
attribute {
name: "axis"
i: 0
type: INT
}
attribute {
name: "p"
i: 2
type: INT
}
}
name: "lp_norm_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 1
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <fstream> #include <fstream>
#include <iterator> #include <iterator>
#include <limits> #include <limits>
#include <numeric>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
...@@ -1483,45 +1484,96 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int) ...@@ -1483,45 +1484,96 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lp_norm_p1)
{
const auto lp_norm_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lp_norm_p1.prototxt"));
Shape data_shape{2, 3, 4};
std::vector<float> data(shape_size(data_shape));
std::iota(std::begin(data), std::end(data), 1);
auto test_case = ngraph::test::NgraphTestCase(lp_norm_fn, "${BACKEND_NAME}");
test_case.add_input<float>(data);
test_case.add_expected_output<float>(
data_shape, {0.07142857f, 0.125f, 0.16666667f, 0.2f, 0.22727273f, 0.25f,
0.26923078f, 0.2857143f, 0.3f, 0.3125f, 0.32352942f, 0.33333334f,
0.9285714f, 0.875f, 0.8333333f, 0.8f, 0.77272725f, 0.75f,
0.7307692f, 0.71428573f, 0.7f, 0.6875f, 0.6764706f, 0.6666667f});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lp_norm_p2)
{
const auto lp_norm_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lp_norm_p2.prototxt"));
Shape data_shape{2, 3, 4};
std::vector<float> data(shape_size(data_shape));
std::iota(std::begin(data), std::end(data), 1);
auto test_case = ngraph::test::NgraphTestCase(lp_norm_fn, "${BACKEND_NAME}");
test_case.add_input<float>(data);
test_case.add_expected_output<float>(
data_shape, {0.0766965f, 0.14142136f, 0.19611613f, 0.24253564f, 0.28216633f, 0.31622776f,
0.34570536f, 0.37139067f, 0.39391932f, 0.41380295f, 0.4314555f, 0.4472136f,
0.9970545f, 0.98994946f, 0.9805807f, 0.97014254f, 0.9593655f, 0.9486833f,
0.9383431f, 0.9284767f, 0.91914505f, 0.9103665f, 0.9021342f, 0.8944272f});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lp_norm_default)
{
const auto lp_norm_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lp_norm_default.prototxt"));
Shape data_shape{2, 3, 4};
std::vector<float> data(shape_size(data_shape));
std::iota(std::begin(data), std::end(data), 1);
auto test_case = ngraph::test::NgraphTestCase(lp_norm_fn, "${BACKEND_NAME}");
test_case.add_input<float>(data);
test_case.add_expected_output<float>(
data_shape, {0.18257418f, 0.36514837f, 0.5477225f, 0.73029673f, 0.37904903f, 0.45485884f,
0.5306686f, 0.60647845f, 0.42616236f, 0.47351375f, 0.5208651f, 0.5682165f,
0.4469492f, 0.48132992f, 0.51571065f, 0.5500913f, 0.45862272f, 0.48560053f,
0.5125783f, 0.53955615f, 0.46609157f, 0.4882864f, 0.51048124f, 0.5326761f});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_instance_normalization)
{
const auto instance_norm_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/instance_norm.prototxt"));
Shape data_shape{1, 2, 3, 4};
std::vector<float> data(shape_size(data_shape));
std::iota(std::begin(data), std::end(data), 1);
auto test_case = ngraph::test::NgraphTestCase(instance_norm_fn, "${BACKEND_NAME}");
test_case.add_input<float>(data);
test_case.add_input<float>(std::vector<float>{2.134f, 3.256f});
test_case.add_input<float>(std::vector<float>{0.765f, 1.055f});
test_case.add_expected_output<float>(
data_shape, {-2.6335807f, -2.015657f, -1.3977331f, -0.77980936f, -0.16188562f, 0.45603812f,
1.0739619f, 1.6918856f, 2.3098092f, 2.927733f, 3.5456567f, 4.1635804f,
-4.130463f, -3.1876516f, -2.2448401f, -1.3020288f, -0.35921717f, 0.5835942f,
1.5264057f, 2.469217f, 3.4120288f, 4.35484f, 5.2976513f, 6.240463f});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_eye_like) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_eye_like)
{ {
const auto eye_like_fn = onnx_import::import_onnx_model( const auto eye_like_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/eye_like.prototxt")); file_util::path_join(SERIALIZED_ZOO, "onnx/eye_like.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(eye_like_fn, "${BACKEND_NAME}"); auto test_case = ngraph::test::NgraphTestCase(eye_like_fn, "${BACKEND_NAME}");
test_case.add_input<float>({ test_case.add_input<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
0.f, test_case.add_expected_output<float>(
0.f, Shape{3, 4}, {0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f});
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
});
test_case.add_expected_output<float>(Shape{3, 4},
{
0.f,
0.f,
0.f,
0.f,
1.f,
0.f,
0.f,
0.f,
0.f,
1.f,
0.f,
0.f,
});
test_case.run(); test_case.run();
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment