Unverified Commit 5885c09a authored by Pruthvi's avatar Pruthvi Committed by GitHub

Pruthvi/sigmoid (#614)

* - Added sigmoid fusion pass
- added mkldnn emitter code for sigmoid

* - corrected sigmoid expected values
- add layout assignment for sigmoid op

* - added assert's in cpu fusion for sigmoid
- style fix

* remove debug prints

* NGMX-371 #comment addressed PR comments - Added sigmoid unit test case with 3D input ii) support in cpu_emmiter for sigmoid to handle all input shapes

* NGMX-371 #comment use shape_size() to calculate the 1d input size
parent e46184a1
......@@ -186,6 +186,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/mkldnn_invoke.cpp
runtime/cpu/mkldnn_utils.cpp
runtime/cpu/ops/convert_layout.cpp
runtime/cpu/ops/sigmoid.cpp
runtime/cpu/ops/matmul_bias.cpp
runtime/cpu/pass/cpu_assignment.cpp
runtime/cpu/pass/cpu_fusion.cpp
......
......@@ -92,6 +92,7 @@
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/ops/convert_layout.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/cpu/ops/sigmoid.hpp"
#include "ngraph/types/element_type.hpp"
#include "ngraph/util.hpp"
......@@ -3318,6 +3319,37 @@ namespace ngraph
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sigmoid)
{
auto input_shape = args[0].get_shape();
auto result_shape = out[0].get_shape();
int input_1d_size = static_cast<int>(shape_size(input_shape));
int result_1d_size = static_cast<int>(shape_size(result_shape));
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn::memory::desc(
{input_1d_size},
mkldnn_utils::get_mkldnn_data_type(args[0].get_element_type()),
mkldnn::memory::format::x);
auto result_desc = mkldnn::memory::desc(
{result_1d_size},
mkldnn_utils::get_mkldnn_data_type(out[0].get_element_type()),
mkldnn::memory::format::x);
size_t sigmoid_index =
mkldnn_emitter->build_sigmoid_forward(input_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(sigmoid_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
<< args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(sigmoid_index) << ");\n";
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Softmax)
{
......
......@@ -110,6 +110,7 @@
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/ops/convert_layout.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/cpu/ops/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_assignment.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_layout.hpp"
......@@ -243,6 +244,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Min), &runtime::cpu::CPU_Emitter::emit<op::Min>},
{TI(ngraph::op::Relu), &runtime::cpu::CPU_Emitter::emit<op::Relu>},
{TI(ngraph::op::ReluBackprop), &runtime::cpu::CPU_Emitter::emit<op::ReluBackprop>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::CPU_Emitter::emit<op::Sigmoid>},
{TI(ngraph::op::Softmax), &runtime::cpu::CPU_Emitter::emit<op::Softmax>},
};
......
......@@ -280,6 +280,26 @@ size_t MKLDNNEmitter::build_relu_forward(const mkldnn::memory::desc& input_desc,
return primitive_index;
}
size_t MKLDNNEmitter::build_sigmoid_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index =
insert_primitive(new mkldnn::eltwise_forward({{mkldnn::prop_kind::forward_training,
mkldnn::algorithm::eltwise_logistic,
input_desc,
0,
0},
mkldnn_utils::global_cpu_engine},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc,
......
......@@ -97,6 +97,9 @@ namespace ngraph
size_t build_relu_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
size_t build_sigmoid_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
size_t build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc,
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/ops/sigmoid.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
std::shared_ptr<ngraph::Node>
ngraph::op::Sigmoid::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Sigmoid>(new_args.at(0));
}
ngraph::op::Sigmoid::Sigmoid(std::shared_ptr<ngraph::Node> input)
: RequiresTensorViewArgs("Sigmoid", {input})
, m_shape_input(input->get_shape())
{
add_output(input->get_element_type(), m_shape_input);
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
class Sigmoid : public util::RequiresTensorViewArgs
{
public:
Sigmoid(std::shared_ptr<Node> input);
Shape get_input_shape() const { return m_shape_input; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
Shape m_shape_input;
};
}
}
......@@ -32,6 +32,7 @@
#include "ngraph/ops/relu.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/ops/sigmoid.hpp"
using namespace std;
using namespace ngraph;
......@@ -209,6 +210,19 @@ namespace ngraph
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Sigmoid)
{
auto sigmoid = static_cast<op::Sigmoid*>(node);
if (node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
sigmoid->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ReluBackprop)
{
......@@ -258,6 +272,7 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Relu>},
{TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReluBackprop>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Sigmoid>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
......@@ -30,8 +30,10 @@
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/exp.hpp"
#include "ngraph/ops/get_output_element.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/pad.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/reshape.hpp"
......@@ -42,6 +44,7 @@
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/cpu/ops/sigmoid.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
std::shared_ptr<ngraph::Node> arg,
......@@ -521,3 +524,45 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
this->add_matcher(std::make_shared<ngraph::pattern::Matcher>(conv_label, callback));
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
{
//construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
// broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant);
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback =
[input](pattern::Matcher& m) -> std::shared_ptr<Node> {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
if (m.match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() << " type is not float!";
return nullptr;
}
if (m.match_root()->get_outputs().size() != pattern_map[input]->get_outputs().size())
{
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!";
return nullptr;
}
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
return sigmoid_node;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, callback);
this->add_matcher(m);
}
......@@ -43,12 +43,14 @@ public:
construct_fprop_bn();
construct_zero_padded_reshaped_conv();
construct_zero_padded_conv();
construct_sigmoid();
}
private:
void construct_matmul_pattern();
void construct_matmulbias_pattern();
void construct_fprop_bn();
void construct_sigmoid();
void construct_zero_padded_reshaped_conv();
void construct_zero_padded_conv();
};
......@@ -38,6 +38,7 @@
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/ops/convert_layout.hpp"
#include "ngraph/runtime/cpu/ops/sigmoid.hpp"
using namespace std;
using namespace mkldnn;
......@@ -670,6 +671,23 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Sigmoid)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ReluBackprop)
{
......@@ -765,6 +783,7 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>},
{TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ReluBackprop>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Sigmoid>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
......@@ -37,6 +37,7 @@
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/cpu/ops/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
......@@ -671,3 +672,64 @@ TEST(cpu_fusion, non_zero_padded_conv)
ASSERT_EQ(count_ops_of_type<op::Pad>(func), 1);
}
TEST(cpu_fusion, sigmoid_fprop_fusion)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/Graph_fprop_sigmoid.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::Sigmoid>(func);
ASSERT_EQ(ccg, 1);
}
TEST(cpu_fusion, sigmoid_n1c1h2w2)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1, 2, 2});
auto sigmoid_node = make_shared<op::Sigmoid>(input);
auto func = make_shared<Function>(sigmoid_node, op::ParameterVector{input});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(func);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a =
backend->make_primary_tensor_view(element::f32, input->get_shape());
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, input->get_shape());
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f};
copy_data(a, dataA);
cf->call({a}, {result});
vector<float> expected{0.73105858f, 0.98201379f, 0.73105858f, 0.98201379f};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, sigmoid_n1c1h4)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto sigmoid_node = make_shared<op::Sigmoid>(input);
auto func = make_shared<Function>(sigmoid_node, op::ParameterVector{input});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(func);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a =
backend->make_primary_tensor_view(element::f32, input->get_shape());
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, input->get_shape());
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f};
copy_data(a, dataA);
cf->call({a}, {result});
vector<float> expected{0.73105858f, 0.98201379f, 0.73105858f, 0.98201379f};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
[{"name":"Function_0","ops":[{"element_type":"float","inputs":[],"name":"Parameter_0","op":"Parameter","outputs":["Parameter_0_0"],"shape":[3,4]},{"element_type":"float","inputs":[],"name":"Constant_1","op":"Constant","outputs":["Constant_1_0"],"shape":[],"value":["1"]},{"inputs":["Parameter_0"],"name":"Negative_3","op":"Negative","outputs":["Negative_3_0"]},{"axes":[0,1],"inputs":["Constant_1"],"name":"Broadcast_2","op":"Broadcast","outputs":["Broadcast_2_0"],"shape":[3,4]},{"inputs":["Negative_3"],"name":"Exp_4","op":"Exp","outputs":["Exp_4_0"]},{"inputs":["Broadcast_2","Exp_4"],"name":"Add_5","op":"Add","outputs":["Add_5_0"]},{"inputs":["Broadcast_2","Add_5"],"name":"Divide_6","op":"Divide","outputs":["Divide_6_0"]}],"parameters":["Parameter_0"],"result":["Divide_6"]}]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment