Commit 1fe02337 authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] Non-linear operators (#1580)

* [ONNX] Non-linear operators

* Review fix pt. 1

* Review fix pt. 2

* Non-linear tests

* style check

* Exception fix

* Test fix
parent 17af4266
......@@ -44,12 +44,16 @@ add_library(onnx_import STATIC
op/conv.cpp
op/conv.hpp
op/div.hpp
op/elu.cpp
op/elu.hpp
op/equal.hpp
op/flatten.cpp
op/flatten.hpp
op/gemm.cpp
op/gemm.hpp
op/greater.hpp
op/leaky_relu.cpp
op/leaky_relu.hpp
op/less.hpp
op/matmul.hpp
op/max_pool.cpp
......@@ -62,6 +66,8 @@ add_library(onnx_import STATIC
op/not.hpp
op/or.hpp
op/pow.hpp
op/prelu.cpp
op/prelu.hpp
op/reduce.cpp
op/reduce.hpp
op/relu.hpp
......@@ -69,12 +75,18 @@ add_library(onnx_import STATIC
op/reshape.hpp
op/shape.cpp
op/shape.hpp
op/selu.cpp
op/selu.hpp
op/sigmoid.hpp
op/softmax.cpp
op/softmax.hpp
op/split.cpp
op/split.hpp
op/sub.hpp
op/sum.hpp
op/tanh.hpp
op/thresholded_relu.cpp
op/thresholded_relu.hpp
op/unsqueeze.cpp
op/unsqueeze.hpp
op/xor.hpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "utils/broadcasting.hpp"
#include "elu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector elu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector elu(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/multiply.hpp"
#include "exceptions.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "leaky_relu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector leaky_relu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.01);
ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1)))
<< " alpha value should be in range (0,1)";
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector leaky_relu(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iterator>
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "prelu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector prelu(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
std::shared_ptr<ngraph::Node> slope = ng_inputs.at(1);
auto slope_shape = slope->get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it =
std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
else
{
auto params = numpy_style_broadcast_for_binary_operation(slope, data);
slope = params.at(0);
}
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector prelu(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "selu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector selu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.67326319217681884765625);
double gamma = node.get_attribute_value<double>("gamma", 1.05070102214813232421875);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> gamma_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma});
gamma_node = make_broadcast_node(gamma_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {gamma_node *
(std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector selu(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector sigmoid(const Node& node)
{
return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/tanh.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector tanh(const Node& node)
{
return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/multiply.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "thresholded_relu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector thresholded_relu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.0);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
auto data_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, alpha_node),
data->get_element_type());
return {data * data_map};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector thresholded_relu(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -26,10 +26,12 @@
#include "op/constant.hpp"
#include "op/conv.hpp"
#include "op/div.hpp"
#include "op/elu.hpp"
#include "op/equal.hpp"
#include "op/flatten.hpp"
#include "op/gemm.hpp"
#include "op/greater.hpp"
#include "op/leaky_relu.hpp"
#include "op/less.hpp"
#include "op/matmul.hpp"
#include "op/max.hpp"
......@@ -40,14 +42,19 @@
#include "op/not.hpp"
#include "op/or.hpp"
#include "op/pow.hpp"
#include "op/prelu.hpp"
#include "op/reduce.hpp"
#include "op/relu.hpp"
#include "op/reshape.hpp"
#include "op/selu.hpp"
#include "op/shape.hpp"
#include "op/sigmoid.hpp"
#include "op/softmax.hpp"
#include "op/split.hpp"
#include "op/sub.hpp"
#include "op/sum.hpp"
#include "op/tanh.hpp"
#include "op/thresholded_relu.hpp"
#include "op/unsqueeze.hpp"
#include "op/xor.hpp"
#include "ops_bridge.hpp"
......@@ -104,10 +111,12 @@ namespace ngraph
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Div", std::bind(op::div, std::placeholders::_1));
m_map.emplace("Elu", std::bind(op::elu, std::placeholders::_1));
m_map.emplace("Equal", std::bind(op::equal, std::placeholders::_1));
m_map.emplace("Flatten", std::bind(op::flatten, std::placeholders::_1));
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1));
m_map.emplace("Greater", std::bind(op::greater, std::placeholders::_1));
m_map.emplace("LeakyRelu", std::bind(op::leaky_relu, std::placeholders::_1));
m_map.emplace("Less", std::bind(op::less, std::placeholders::_1));
m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1));
m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1));
......@@ -118,6 +127,7 @@ namespace ngraph
m_map.emplace("Not", std::bind(op::logical_not, std::placeholders::_1));
m_map.emplace("Or", std::bind(op::logical_or, std::placeholders::_1));
m_map.emplace("Pow", std::bind(op::pow, std::placeholders::_1));
m_map.emplace("PRelu", std::bind(op::prelu, std::placeholders::_1));
m_map.emplace("ReduceLogSum",
std::bind(op::reduce_log_sum, std::placeholders::_1));
m_map.emplace("ReduceLogSumExp",
......@@ -134,10 +144,15 @@ namespace ngraph
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1));
m_map.emplace("Shape", std::bind(op::shape, std::placeholders::_1));
m_map.emplace("Selu", std::bind(op::selu, std::placeholders::_1));
m_map.emplace("Sigmoid", std::bind(op::sigmoid, std::placeholders::_1));
m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1));
m_map.emplace("Sum", std::bind(op::sum, std::placeholders::_1));
m_map.emplace("Tanh", std::bind(op::tanh, std::placeholders::_1));
m_map.emplace("ThresholdedRelu",
std::bind(op::thresholded_relu, std::placeholders::_1));
m_map.emplace("Unsqueeze", std::bind(op::unsqueeze, std::placeholders::_1));
m_map.emplace("Xor", std::bind(op::logical_xor, std::placeholders::_1));
}
......
......@@ -16,6 +16,8 @@
#pragma once
#include <memory>
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
......@@ -89,6 +91,17 @@ namespace ngraph
return std::make_shared<ngraph::op::Broadcast>(
node, new_shape, calculate_broadcast_axes(new_shape, node->get_shape()));
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node,
ngraph::Shape new_shape,
std::size_t start_match_axis)
{
return std::make_shared<ngraph::op::Broadcast>(
node,
new_shape,
calculate_broadcast_axes(new_shape, node->get_shape(), start_match_axis));
}
} // namespace onnx_import
} // namespace ngraph
 backend-test:f
"
xy" LeakyRelu*
alpha=test_leakyreluZ
x



b
y



B
\ No newline at end of file
 backend-test:y

x
slopey"PRelutest_prelu_exampleZ
x



Z
slope



b
y



B
\ No newline at end of file
 backend-test:Q

xy"Sigmoid test_sigmoidZ
x



b
y



B
\ No newline at end of file
 backend-test:K
xy"Tanh test_tanhZ
x



b
y



B
\ No newline at end of file
......@@ -934,3 +934,326 @@ TEST(onnx, model_shape)
execute<float, int64_t>(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close(expected_output.front(), outputs.front()));
}
TEST(onnx, model_elu)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/elu.onnx"));
Inputs inputs;
inputs.emplace_back(
test::NDArray<float, 3>(
{{{-9, -8, -7, -6, -5}, {-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}},
{{-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}},
{{1, 1, 1, 1, 1}, {-1, -1, -1, -1, -1}, {0, 0, 0, 0, 0}, {2, 2, 2, 2, 2}}})
.get_vector());
Outputs expected_output{test::NDArray<float, 3>({{{-1.999753180391830f,
-1.999329074744190f,
-1.998176236068890f,
-1.995042495646670f,
-1.986524106001830f},
{-1.963368722222530f,
-1.900425863264270f,
-1.729329433526770f,
-1.264241117657120f,
0},
{1, 2, 3, 4, 5},
{6, 7, 8, 9, 10}},
{{-1.963368722222530f,
-1.900425863264270f,
-1.729329433526770f,
-1.264241117657120f,
0},
{1, 2, 3, 4, 5},
{6, 7, 8, 9, 10},
{11, 12, 13, 14, 15}},
{{1, 1, 1, 1, 1},
{-1.264241117657120f,
-1.264241117657120f,
-1.264241117657120f,
-1.264241117657120f,
-1.264241117657120f},
{0, 0, 0, 0, 0},
{2, 2, 2, 2, 2}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx, model_leaky_relu)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/leaky_relu.onnx"));
Inputs inputs;
inputs.emplace_back(
test::NDArray<float, 3>(
{{{-9, -8, -7, -6, -5}, {-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}},
{{-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}},
{{1, 1, 1, 1, 1}, {-1, -1, -1, -1, -1}, {0, 0, 0, 0, 0}, {2, 2, 2, 2, 2}}})
.get_vector());
Outputs expected_output{test::NDArray<float, 3>({{{-0.9f, -0.8f, -0.7f, -0.6f, -0.5f},
{-0.4f, -0.3f, -0.2f, -0.1f, 0},
{1, 2, 3, 4, 5},
{6, 7, 8, 9, 10}},
{{-0.4f, -0.3f, -0.2f, -0.1f, 0},
{1, 2, 3, 4, 5},
{6, 7, 8, 9, 10},
{11, 12, 13, 14, 15}},
{{1, 1, 1, 1, 1},
{-0.1f, -0.1f, -0.1f, -0.1f, -0.1f},
{0, 0, 0, 0, 0},
{2, 2, 2, 2, 2}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx, prelu)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/prelu.onnx"));
Inputs inputs;
inputs.emplace_back(
test::NDArray<float, 3>(
{{{-9, -8, -7, -6, -5}, {-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}},
{{-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}},
{{1, 1, 1, 1, 1}, {-1, -1, -1, -1, -1}, {0, 0, 0, 0, 0}, {2, 2, 2, 2, 2}}})
.get_vector());
inputs.emplace_back(test::NDArray<float, 3>(
{{{1, 0, 1, 0, 1}, {0, 1, 0, 1, 0}, {1, 0, 1, 0, 1}, {0, 1, 0, 1, 0}},
{{0, 1, 0, 1, 0}, {1, 0, 1, 0, 1}, {0, 1, 0, 1, 0}, {1, 0, 1, 0, 1}},
{{1, 0, 1, 0, 1}, {0, 1, 0, 1, 0}, {1, 0, 1, 0, 1}, {0, 1, 0, 1, 0}}})
.get_vector());
Outputs expected_output{
test::NDArray<float, 3>(
{{{-9, 0, -7, 0, -5}, {0, -3, 0, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}},
{{0, -3, 0, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}},
{{1, 1, 1, 1, 1}, {0, -1, 0, -1, 0}, {0, 0, 0, 0, 0}, {2, 2, 2, 2, 2}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx, model_selu)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/selu.onnx"));
Inputs inputs;
inputs.emplace_back(
test::NDArray<float, 3>(
{{{-9, -8, -7, -6, -5}, {-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}},
{{-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}},
{{1, 1, 1, 1, 1}, {-1, -1, -1, -1, -1}, {0, 0, 0, 0, 0}, {2, 2, 2, 2, 2}}})
.get_vector());
Outputs expected_output{
test::NDArray<float, 3>(
{{{-5.99925954117548f,
-5.99798722423258f,
-5.99452870820667f,
-5.98512748694000f,
-5.95957231800549f},
{-5.89010616666759f, -5.70127758979282f, -5.18798830058032f, -3.79272335297135f, 0},
{3, 6, 9, 12, 15},
{18, 21, 24, 27, 30}},
{{-5.89010616666759f, -5.70127758979282f, -5.18798830058032f, -3.79272335297135f, 0},
{3, 6, 9, 12, 15},
{18, 21, 24, 27, 30},
{33, 36, 39, 42, 45}},
{{3, 3, 3, 3, 3},
{-3.79272335297135f,
-3.79272335297135f,
-3.79272335297135f,
-3.79272335297135f,
-3.79272335297135f},
{0, 0, 0, 0, 0},
{6, 6, 6, 6, 6}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx, model_sigmoid)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/sigmoid.onnx"));
Inputs inputs;
inputs.emplace_back(
test::NDArray<float, 3>(
{{{-9, -8, -7, -6, -5}, {-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}},
{{-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}},
{{1, 1, 1, 1, 1}, {-1, -1, -1, -1, -1}, {0, 0, 0, 0, 0}, {2, 2, 2, 2, 2}}})
.get_vector());
Outputs expected_output{test::NDArray<float, 3>({{{0.00012339457598623f,
0.00033535013046648f,
0.00091105119440065f,
0.00247262315663477f,
0.00669285092428486f},
{0.01798620996209160f,
0.04742587317756680f,
0.119202922022118f,
0.268941421369995f,
0.5f},
{0.731058578630005f,
0.880797077977882f,
0.952574126822433f,
0.982013790037908f,
0.993307149075715f},
{0.997527376843365f,
0.999088948805599f,
0.999664649869534f,
0.999876605424014f,
0.999954602131298f}},
{{0.01798620996209160f,
0.04742587317756680f,
0.119202922022118f,
0.268941421369995f,
0.5f},
{0.731058578630005f,
0.880797077977882f,
0.952574126822433f,
0.982013790037908f,
0.993307149075715f},
{0.997527376843365f,
0.999088948805599f,
0.999664649869534f,
0.999876605424014f,
0.999954602131298f},
{0.999983298578152f,
0.999993855825398f,
0.999997739675702f,
0.999999168471972f,
0.999999694097773f}},
{{0.731058578630005f,
0.731058578630005f,
0.731058578630005f,
0.731058578630005f,
0.731058578630005f},
{0.268941421369995f,
0.268941421369995f,
0.268941421369995f,
0.268941421369995f,
0.268941421369995f},
{0.5f, 0.5f, 0.5f, 0.5f, 0.5f},
{0.880797077977882f,
0.880797077977882f,
0.880797077977882f,
0.880797077977882f,
0.880797077977882f}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx, model_tanh)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/tanh.onnx"));
Inputs inputs;
inputs.emplace_back(
test::NDArray<float, 3>(
{{{-9, -8, -7, -6, -5}, {-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}},
{{-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}},
{{1, 1, 1, 1, 1}, {-1, -1, -1, -1, -1}, {0, 0, 0, 0, 0}, {2, 2, 2, 2, 2}}})
.get_vector());
Outputs expected_output{test::NDArray<float, 3>({{{-0.999999969540041f,
-0.999999774929676f,
-0.999998336943945f,
-0.999987711650796f,
-0.999909204262595f},
{-0.999329299739067f,
-0.995054753686731f,
-0.964027580075817f,
-0.761594155955765f,
0},
{0.761594155955765f,
0.964027580075817f,
0.995054753686731f,
0.999329299739067f,
0.999909204262595f},
{0.999987711650796f,
0.999998336943945f,
0.999999774929676f,
0.999999969540041f,
0.999999995877693f}},
{{-0.999329299739067f,
-0.995054753686731f,
-0.964027580075817f,
-0.761594155955765f,
0},
{0.761594155955765f,
0.964027580075817f,
0.995054753686731f,
0.999329299739067f,
0.999909204262595f},
{0.999987711650796f,
0.999998336943945f,
0.999999774929676f,
0.999999969540041f,
0.999999995877693f},
{0.999999999442106f,
0.999999999924497f,
0.999999999989782f,
0.999999999998617f,
0.999999999999813f}},
{{0.761594155955765f,
0.761594155955765f,
0.761594155955765f,
0.761594155955765f,
0.761594155955765f},
{-0.761594155955765f,
-0.761594155955765f,
-0.761594155955765f,
-0.761594155955765f,
-0.761594155955765f},
{0, 0, 0, 0, 0},
{0.964027580075817f,
0.964027580075817f,
0.964027580075817f,
0.964027580075817f,
0.964027580075817f}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx, model_thresholded_relu)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/thresholded_relu.onnx"));
Inputs inputs;
inputs.emplace_back(
test::NDArray<float, 3>(
{{{-9, -8, -7, -6, -5}, {-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}},
{{-4, -3, -2, -1, 0}, {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}},
{{1, 1, 1, 1, 1}, {-1, -1, -1, -1, -1}, {0, 0, 0, 0, 0}, {2, 2, 2, 2, 2}}})
.get_vector());
Outputs expected_output{
test::NDArray<float, 3>(
{{{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 3, 4, 5}, {6, 7, 8, 9, 10}},
{{0, 0, 0, 0, 0}, {0, 0, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}},
{{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment