Unverified Commit 6667bc0e authored by Fenglei's avatar Fenglei Committed by GitHub

Merge branch 'master' into tfl/gpu_dot_back

parents eec717c0 34a8b27d
This diff is collapsed.
......@@ -58,8 +58,6 @@ namespace ngraph
{
}
static void emit_mkldnn_preamble(codegen::CodeWriter& writer);
private:
static std::string emit_vector(const TensorViewWrapper&,
const std::string& name = "");
......
......@@ -293,18 +293,6 @@ void runtime::cpu::CPU_ExternalFunction::compile()
codegen::CodeWriter writer;
bool include_mkldnn_headers = false;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
for (shared_ptr<Node> node : current_function->get_ordered_ops())
{
if (ngraph::runtime::cpu::mkldnn_utils::IsMKLDNNOp(*node))
{
include_mkldnn_headers = true;
}
}
}
writer +=
R"(// Generated by the nGraph CPU backend
#include <cmath>
......@@ -354,11 +342,6 @@ using namespace ngraph::runtime;
writer << "#include <tbb/flow_graph.h>\n";
}
if (include_mkldnn_headers)
{
runtime::cpu::CPU_Emitter::emit_mkldnn_preamble(writer);
}
string pch_header_source = writer.get_code();
// The "dso_handle" symbol is required by __cxa_atexit()
......
......@@ -23,6 +23,7 @@
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::runtime::cpu;
......@@ -77,6 +78,15 @@ mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrap
return build_memory_descriptor(tvw, layout->get_mkldnn_format());
}
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et,
mkldnn::memory::format fmt) const
{
return mkldnn::memory::desc(mkldnn::memory::dims(shape.begin(), shape.end()),
mkldnn_utils::get_mkldnn_data_type(et),
fmt);
}
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tvw) const
{
return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr);
......@@ -462,6 +472,27 @@ size_t MKLDNNEmitter::build_relu_forward(const mkldnn::memory::desc& input_desc,
return primitive_index;
}
size_t MKLDNNEmitter::build_relu_backward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& result_desc)
{
size_t input_index = build_memory_primitive(input_desc);
size_t delta_index = build_memory_primitive(delta_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::relu_backward(
{{mkldnn::algorithm::eltwise_relu, delta_desc, input_desc, 0, 0},
mkldnn_utils::global_cpu_engine,
{{mkldnn::prop_kind::forward, mkldnn::algorithm::eltwise_relu, input_desc, 0, 0},
mkldnn_utils::global_cpu_engine}},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, delta_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_sigmoid_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc)
{
......@@ -509,3 +540,80 @@ size_t MKLDNNEmitter::build_elementwise_add(
m_primitive_deps[add_index] = {input0_data_index, input1_data_index, result_index};
return add_index;
}
size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps)
{
size_t input_index = build_memory_primitive(input_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc);
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine},
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[result_index]),
*m_mkldnn_primitives[mean_index],
*m_mkldnn_primitives[variance_index]));
m_primitive_deps[batchnorm_index] = {
input_index, weights_index, result_index, mean_index, variance_index};
return batchnorm_index;
}
size_t MKLDNNEmitter::build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& dinput_desc,
const mkldnn::memory::desc& dweights_desc,
const double eps)
{
size_t weights_index = build_memory_primitive(weights_desc);
size_t input_index = build_memory_primitive(input_desc);
size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc);
size_t delta_index = build_memory_primitive(delta_desc);
size_t dinput_index = build_memory_primitive(dinput_desc);
size_t dweights_index = build_memory_primitive(dweights_desc);
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_backward(
{{mkldnn::prop_kind::backward,
delta_desc,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine,
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine}},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[mean_index],
*m_mkldnn_primitives[variance_index],
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[dinput_index],
*m_mkldnn_primitives[dweights_index]));
m_primitive_deps[batchnorm_index] = {weights_index,
input_index,
mean_index,
variance_index,
delta_index,
dinput_index,
dweights_index};
return batchnorm_index;
}
......@@ -25,6 +25,7 @@
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph
{
......@@ -60,6 +61,9 @@ namespace ngraph
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw,
mkldnn::memory::format fmt) const;
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw) const;
mkldnn::memory::desc build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et,
mkldnn::memory::format fmt) const;
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw) const;
size_t build_memory_primitive(const mkldnn::memory::desc& desc);
......@@ -142,6 +146,10 @@ namespace ngraph
size_t build_relu_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
size_t build_relu_backward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& result_desc);
size_t build_sigmoid_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
......@@ -152,6 +160,22 @@ namespace ngraph
const std::vector<float>& scale_vector,
const std::vector<mkldnn::memory::primitive_desc>& input_pd);
size_t build_batchnorm_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps);
size_t build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& dinput_desc,
const mkldnn::memory::desc& dweights_desc,
const double eps);
private:
std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams;
......
......@@ -344,6 +344,16 @@ namespace ngraph
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormBackprop)
{
auto batchnorm = static_cast<op::BatchNormBackprop*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
}
}
}
......@@ -357,6 +367,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData),
......
......@@ -1009,6 +1009,36 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNormBackprop)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto delta_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 5);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x); // gamma
prim_input_formats.push_back(memory::format::x); // beta
prim_input_formats.push_back(delta_layout); // input
prim_input_formats.push_back(memory::format::x); // mean
prim_input_formats.push_back(memory::format::x); // variance
prim_input_formats.push_back(delta_layout); // delta
prim_output_formats.push_back(delta_layout); // dinput
prim_output_formats.push_back(memory::format::x); // dgamma
prim_output_formats.push_back(memory::format::x); // dbeta
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
throw ngraph_error("Batchnorm Backprop only supported in MKLDNN for now");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{
......@@ -1056,6 +1086,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment