Commit d05b5e39 authored by Sandeep's avatar Sandeep Committed by Scott Cyphers

move sigmoid to core fusion (#1132)

* declare sigmoid for core fusion

* add simple test for sigmoid

* info fusion status

* cp op as main op

* builds as expected

* move sigmoid fusion code

* add reference kernel

* sigmoid bprop reference kernel and clang-format

* add delta to bprop

* fprop called

* compiles bprop

* move tests

* serializer support

* address comments in code

* add doc

* naming similar to core ops

* fix failing test

* fix failing test

* address clang issue

* more changes

* change test macro
parent 18e58ea9
.. sigmoid.rst:
####
Sigmoid
####
.. code-block:: cpp
Sigmoid // Elementwise sigmoid operation
Inputs
------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``arg`` | Any | Any |
+-----------------+-------------------------+--------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``output`` | Same as ``arg`` | Same as ``arg`` |
+-----------------+-------------------------+--------------------------------+
Mathematical Definition
=======================
.. math::
\mathtt{output}_{i_0, \ldots, i_{n-1}} =
\begin{cases}
0&\text{if }\mathtt{arg}_{i_0, \ldots, i_{n-1}} \le 0 \\
\mathtt{arg}_{i_0, \ldots, i_{n-1}}&\text{otherwise}
\end{cases}
C++ Interface
=============
.. doxygenclass:: ngraph::op::Sigmoid
:project: ngraph
:members:
......@@ -87,6 +87,7 @@ set (SRC
op/reverse_sequence.cpp
op/select_and_scatter.cpp
op/select.cpp
op/sigmoid.cpp
op/sign.cpp
op/sin.cpp
op/sinh.cpp
......
......@@ -113,6 +113,7 @@
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Sigmoid>(new_args.at(0));
}
op::Sigmoid::Sigmoid(shared_ptr<Node> arg)
: UnaryElementwiseArithmetic("Sigmoid", {arg})
{
set_value_type_checked(arg->get_element_type(), arg->get_shape());
}
op::SigmoidBackprop::SigmoidBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: RequiresTensorViewArgs("SigmoidBackprop", {arg, delta})
{
if (arg->get_element_type() != delta->get_element_type())
{
throw ngraph_error("Argument and delta element types for Sigmoid backprop do not match");
}
if (arg->get_shape() != delta->get_shape())
{
throw ngraph_error("Argument and delta shape for Sigmoid backprop do not match");
}
set_value_type_checked(delta->get_element_type(), delta->get_shape());
}
shared_ptr<Node> op::SigmoidBackprop::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<SigmoidBackprop>(new_args.at(0), new_args.at(1));
}
void op::Sigmoid::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
auto backprop = make_shared<op::SigmoidBackprop>(get_argument(0), delta);
adjoints.add_delta(get_argument(0), backprop);
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/util/requires_tensor_view_args.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
class Sigmoid : public util::UnaryElementwiseArithmetic
{
public:
Sigmoid(std::shared_ptr<Node> arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
/// \brief Elementwise SigmoidBackprop operation.
///
class SigmoidBackprop : public util::RequiresTensorViewArgs
{
public:
/// \brief Constructs a SigmoidBackprop operation.
///
/// \param arg Node that produces the Sigmoid forward input tensor.
SigmoidBackprop(std::shared_ptr<ngraph::Node> arg, std::shared_ptr<ngraph::Node> delta);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
......@@ -27,12 +27,14 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
......@@ -81,6 +83,101 @@ void pass::CoreFusion::construct_relu()
this->add_matcher(m);
}
void pass::CoreFusion::construct_sigmoid()
{
//construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
// broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant);
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
if (m.get_match_root()->get_outputs().size() != pattern_map[input]->get_outputs().size())
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!";
return false;
}
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
ngraph::replace_node(m.get_match_root(), sigmoid_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, callback);
this->add_matcher(m);
}
void pass::CoreFusion::construct_sigmoid_bprop()
{
//construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
// broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant);
// //auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
auto sigmoid_fwd = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto delta = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_delta = std::make_shared<op::Negative>(delta);
auto multiply_sigmoid_delta = std::make_shared<op::Multiply>(sigmoid_fwd, neg_delta);
auto divide_2 = std::make_shared<op::Divide>(multiply_sigmoid_delta, add_exp);
auto multiply_2 = std::make_shared<op::Multiply>(divide_2, exp_neg_input);
auto negtive_2 = std::make_shared<op::Negative>(multiply_2);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
if (m.get_match_root()->get_shape().size() != pattern_map[input]->get_shape().size())
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!";
return false;
}
auto dsigmoid =
std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]);
ngraph::replace_node(m.get_match_root(), dsigmoid);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(negtive_2, callback);
this->add_matcher(m);
}
void pass::CoreFusion::construct_folded_batch_norm()
{
Shape shape{2, 2, 1, 1};
......
......@@ -34,9 +34,13 @@ public:
{
construct_relu();
construct_folded_batch_norm();
construct_sigmoid();
construct_sigmoid_bprop();
construct_optimized_strided_conv();
}
void construct_relu();
void construct_folded_batch_norm();
void construct_sigmoid();
void construct_sigmoid_bprop();
void construct_optimized_strided_conv();
};
......@@ -47,6 +47,7 @@
#include "ngraph/op/product.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/remainder.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
......@@ -106,6 +107,7 @@ static std::unordered_map<std::type_index,
{TI(op::Log), cse_unarywise},
{TI(op::Negative), cse_unarywise},
{TI(op::Relu), cse_unarywise},
{TI(op::Sigmoid), cse_unarywise},
{TI(op::Sign), cse_unarywise},
{TI(op::Sin), cse_unarywise},
{TI(op::Sinh), cse_unarywise},
......
......@@ -48,7 +48,6 @@ set(SRC
op/max_pool_with_indices.cpp
op/rnn.cpp
op/sigmoid_mul.cpp
op/sigmoid.cpp
pass/cpu_assignment.cpp
pass/cpu_concat_inputs.cpp
pass/cpu_fusion.cpp
......
......@@ -39,6 +39,7 @@
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
......@@ -51,7 +52,6 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
......@@ -603,101 +603,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv_backprop_
this->add_matcher(std::make_shared<ngraph::pattern::Matcher>(conv_label, callback));
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
{
//construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
// broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant);
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
if (m.get_match_root()->get_outputs().size() != pattern_map[input]->get_outputs().size())
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!";
return false;
}
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
ngraph::replace_node(m.get_match_root(), sigmoid_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop()
{
//construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
// broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant);
// //auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
auto sigmoid_fwd = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto delta = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_delta = std::make_shared<op::Negative>(delta);
auto multiply_sigmoid_delta = std::make_shared<op::Multiply>(sigmoid_fwd, neg_delta);
auto divide_2 = std::make_shared<op::Divide>(multiply_sigmoid_delta, add_exp);
auto multiply_2 = std::make_shared<op::Multiply>(divide_2, exp_neg_input);
auto negtive_2 = std::make_shared<op::Negative>(multiply_2);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
if (m.get_match_root()->get_shape().size() != pattern_map[input]->get_shape().size())
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!";
return false;
}
auto dsigmoid =
std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]);
ngraph::replace_node(m.get_match_root(), dsigmoid);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(negtive_2, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
{
Shape shape{2, 2, 1, 1};
......
......@@ -57,8 +57,6 @@ public:
construct_zero_padded_reshaped_conv();
construct_zero_padded_conv();
construct_zero_padded_conv_backprop_filters();
construct_sigmoid();
construct_sigmoid_bprop();
construct_conv_bias_bprop();
construct_batch_norm_relu();
construct_batch_norm_relu_global_stats();
......@@ -82,8 +80,6 @@ private:
void construct_conv_bias();
void construct_conv_bias_bprop();
void construct_fprop_bn();
void construct_sigmoid();
void construct_sigmoid_bprop();
void construct_sigmoid_multiply();
void construct_zero_padded_reshaped_conv();
void construct_zero_padded_conv();
......
......@@ -36,6 +36,7 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......@@ -49,7 +50,6 @@
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace std;
using namespace mkldnn;
......
......@@ -23,6 +23,11 @@ select_and_scatter_with_overlap
select_and_scatter_without_overlap
#custom_mem is not implemented on GPU
tensorview_custom_mem
#sigmoid not implemented
sigmoid_n1c1h2w2
sigmoid_n1c1h4
sigmoid_bprop_n1c1h4
backwards_sigmoid
#integer is not supported by cuDNN on backward pooling
backwards_maxpool_n4_c1_hw4_2x2_max
backwards_maxpool_n2_c1_hw5_3x3_str2_max
......
......@@ -102,6 +102,7 @@
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/select_and_scatter.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sin.hpp"
#include "ngraph/runtime/reference/sinh.hpp"
......@@ -867,6 +868,18 @@ private:
select_and_scatter->get_window_shape(),
select_and_scatter->get_window_movement_strides());
}
else if (node_op == "Sigmoid")
{
reference::sigmoid<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
}
else if (node_op == "SigmoidBackprop")
{
reference::sigmoid_backprop<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
}
else if (node_op == "Sign")
{
reference::sign<T>(
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <cmath>
#include <cstddef>
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void sigmoid(const T* arg, T* out, size_t count)
{
T exp_value;
for (size_t i = 0; i < count; i++)
{
exp_value = std::exp(-arg[i]);
out[i] = 1 / (1 + exp_value);
}
}
template <typename T>
void sigmoid_backprop(const T* arg, T* delta_arg, T* out, size_t count)
{
T exp_value;
T func_x;
for (size_t i = 0; i < count; i++)
{
exp_value = std::exp(-arg[i]);
func_x = 1 / (1 + exp_value);
out[i] = delta_arg[i] * func_x * (1 - func_x);
}
}
}
}
}
......@@ -75,6 +75,7 @@
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
......@@ -837,6 +838,14 @@ static shared_ptr<ngraph::Function>
window_shape,
window_movement_strides);
}
else if (node_op == "Sigmoid")
{
node = make_shared<op::Sigmoid>(args[0]);
}
else if (node_op == "SigmoidBackprop")
{
node = make_shared<op::SigmoidBackprop>(args[0], args[1]);
}
else if (node_op == "Sign")
{
node = make_shared<op::Sign>(args[0]);
......@@ -1263,6 +1272,12 @@ static json write(const Node& n, bool binary_constant_data)
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
}
else if (node_op == "Sigmoid")
{
}
else if (node_op == "SigmoidBackprop")
{
}
else if (node_op == "Sign")
{
}
......
......@@ -1081,6 +1081,34 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_select_nested)
}
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_sigmoid)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<float> rng_neg(-1.0f, -0.01f);
test::Uniform<float> rng_pos(0.01f, 1.0f);
Shape shape{2, 3};
auto x0 = rng_neg.initialize(backend->create_tensor<float>(shape));
auto x1 = rng_pos.initialize(backend->create_tensor<float>(shape));
auto make_graph = [shape]() {
auto X = make_shared<op::Parameter>(element::f32, shape);
return make_shared<Function>(make_shared<op::Sigmoid>(X),
std::vector<std::shared_ptr<op::Parameter>>{X});
};
for (auto i = 0; i < ${TEST_LOOPS}; i++)
{
auto x_neg = rng_neg.initialize(backend->create_tensor<float>(shape));
EXPECT_TRUE(autodiff_numeric_compare<float>(backend, make_graph, {x_neg}, .01f, .01f));
auto x_pos = rng_pos.initialize(backend->create_tensor<float>(shape));
EXPECT_TRUE(autodiff_numeric_compare<float>(backend, make_graph, {x_pos}, .01f, .01f));
}
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_sign)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......
......@@ -7675,6 +7675,70 @@ NGRAPH_TEST(${BACKEND_NAME}, min_3d_eliminate_zero_dim)
EXPECT_EQ((vector<float>{inf, inf, inf, inf, inf, inf}), read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, sigmoid_n1c1h2w2)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1, 2, 2});
auto sigmoid_node = make_shared<op::Sigmoid>(input);
auto func = make_shared<Function>(sigmoid_node, op::ParameterVector{input});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
shared_ptr<runtime::TensorView> a = backend->create_tensor(element::f32, input->get_shape());
shared_ptr<runtime::TensorView> result =
backend->create_tensor(element::f32, input->get_shape());
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f};
copy_data(a, dataA);
backend->call(func, {result}, {a});
vector<float> expected{0.73105858f, 0.98201379f, 0.73105858f, 0.98201379f};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
NGRAPH_TEST(${BACKEND_NAME}, sigmoid_n1c1h4)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto sigmoid_node = make_shared<op::Sigmoid>(input);
auto func = make_shared<Function>(sigmoid_node, op::ParameterVector{input});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
shared_ptr<runtime::TensorView> a = backend->create_tensor(element::f32, input->get_shape());
shared_ptr<runtime::TensorView> result =
backend->create_tensor(element::f32, input->get_shape());
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f};
copy_data(a, dataA);
backend->call(func, {result}, {a});
vector<float> expected{0.73105858f, 0.98201379f, 0.73105858f, 0.98201379f};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
NGRAPH_TEST(${BACKEND_NAME}, sigmoid_bprop_n1c1h4)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto delta = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto sigmoid_node = make_shared<op::SigmoidBackprop>(input, delta);
auto func = make_shared<Function>(sigmoid_node, op::ParameterVector{input, delta});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
shared_ptr<runtime::TensorView> a = backend->create_tensor(element::f32, input->get_shape());
shared_ptr<runtime::TensorView> b = backend->create_tensor(element::f32, delta->get_shape());
shared_ptr<runtime::TensorView> result =
backend->create_tensor(element::f32, input->get_shape());
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{1.0f, 1.0f, 1.0f, 1.0f};
copy_data(a, dataA);
copy_data(b, dataB);
backend->call(func, {result}, {a, b});
vector<float> expected{0.196612f, 0.0176627f, 0.196612f, 0.0176627f};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, relu_2Dfprop)
{
auto shape_a = Shape{2, 5};
......
......@@ -36,6 +36,7 @@
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
......@@ -56,6 +57,32 @@ TEST(core_fusion, core_fusion_pass_basic)
ASSERT_NE(std::dynamic_pointer_cast<op::Relu>(graph->get_argument(0)), nullptr);
}
TEST(core_fusion, sigmoid_fprop_fusion)
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/Graph_fprop_sigmoid.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::Sigmoid>(func);
ASSERT_EQ(ccg, 1);
}
TEST(core_fusion, sigmoid_bprop_fusion)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/Graph_fprop_sigmoid.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
auto df = autodiff::backprop_function(func);
auto backend = runtime::Backend::create("CPU");
backend->compile(df);
size_t ccg = count_ops_of_type<op::SigmoidBackprop>(df);
ASSERT_EQ(ccg, 1);
}
TEST(core_fusion, sparsity_opt_56x56)
{
Shape win_size_3{1, 1, 3, 3};
......
......@@ -33,9 +33,11 @@
#include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
......@@ -55,7 +57,6 @@
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/pass/cpu_concat_inputs.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
......@@ -703,97 +704,6 @@ TEST(cpu_fusion, conv_bias_bprop)
ASSERT_EQ(ccg, 1);
}
TEST(cpu_fusion, sigmoid_fprop_fusion)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/Graph_fprop_sigmoid.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::Sigmoid>(func);
ASSERT_EQ(ccg, 1);
}
TEST(cpu_fusion, sigmoid_n1c1h2w2)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1, 2, 2});
auto sigmoid_node = make_shared<op::Sigmoid>(input);
auto func = make_shared<Function>(sigmoid_node, op::ParameterVector{input});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::TensorView> a = backend->create_tensor(element::f32, input->get_shape());
shared_ptr<runtime::TensorView> result =
backend->create_tensor(element::f32, input->get_shape());
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f};
copy_data(a, dataA);
backend->call(func, {result}, {a});
vector<float> expected{0.73105858f, 0.98201379f, 0.73105858f, 0.98201379f};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, sigmoid_n1c1h4)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto sigmoid_node = make_shared<op::Sigmoid>(input);
auto func = make_shared<Function>(sigmoid_node, op::ParameterVector{input});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::TensorView> a = backend->create_tensor(element::f32, input->get_shape());
shared_ptr<runtime::TensorView> result =
backend->create_tensor(element::f32, input->get_shape());
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f};
copy_data(a, dataA);
backend->call(func, {result}, {a});
vector<float> expected{0.73105858f, 0.98201379f, 0.73105858f, 0.98201379f};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, sigmoid_bprop_fusion)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/Graph_fprop_sigmoid.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
auto df = autodiff::backprop_function(func);
auto backend = runtime::Backend::create("CPU");
backend->compile(df);
size_t ccg = count_ops_of_type<op::SigmoidBackprop>(df);
ASSERT_EQ(ccg, 1);
}
TEST(cpu_fusion, sigmoid_bprop_n1c1h4)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto delta = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto sigmoid_node = make_shared<op::SigmoidBackprop>(input, delta);
auto func = make_shared<Function>(sigmoid_node, op::ParameterVector{input, delta});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::TensorView> a = backend->create_tensor(element::f32, input->get_shape());
shared_ptr<runtime::TensorView> b = backend->create_tensor(element::f32, delta->get_shape());
shared_ptr<runtime::TensorView> result =
backend->create_tensor(element::f32, input->get_shape());
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{1.0f, 1.0f, 1.0f, 1.0f};
copy_data(a, dataA);
copy_data(b, dataB);
backend->call(func, {result}, {a, b});
vector<float> expected{0.196612f, 0.0176627f, 0.196612f, 0.0176627f};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
}
TEST(cpu_fusion, batchnorm_fprop_relu_b1c2h2w2)
{
auto input_shape = Shape{1, 2, 2, 2};
......@@ -2185,6 +2095,7 @@ TEST(cpu_fusion, graph_partition_one_group)
TEST(cpu_fusion, sigmoid_multiply_fusion)
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/3_lstm_cell_forward.json");
const string json_string = file_util::read_file_to_string(json_path);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment