Unverified Commit cf770aa5 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by GitHub

Merge branch 'master' into jbobba/batchnorm-inference

parents 7ce15121 8520e846
...@@ -146,6 +146,7 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context() ...@@ -146,6 +146,7 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
} }
const auto& mkldnn_emitter = m_external_function->get_mkldnn_emitter(); const auto& mkldnn_emitter = m_external_function->get_mkldnn_emitter();
ctx->mkldnn_primitives = mkldnn_emitter->get_mkldnn_primitives().data(); ctx->mkldnn_primitives = mkldnn_emitter->get_mkldnn_primitives().data();
ctx->mkldnn_workspaces = mkldnn_emitter->get_mkldnn_workspaces().data();
} }
void runtime::cpu::CPU_CallFrame::cleanup_runtime_context() void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
......
This diff is collapsed.
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
{ {
int64_t* op_durations; int64_t* op_durations;
mkldnn::primitive* const* mkldnn_primitives; mkldnn::primitive* const* mkldnn_primitives;
char* const* mkldnn_workspaces;
}; };
} }
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
...@@ -37,12 +38,24 @@ const std::vector<mkldnn::primitive*>& MKLDNNEmitter::get_mkldnn_primitives() co ...@@ -37,12 +38,24 @@ const std::vector<mkldnn::primitive*>& MKLDNNEmitter::get_mkldnn_primitives() co
return m_mkldnn_primitives; return m_mkldnn_primitives;
} }
const std::vector<char*>& MKLDNNEmitter::get_mkldnn_workspaces()
{
return m_workspace_bufs;
}
size_t MKLDNNEmitter::insert_primitive(mkldnn::primitive* primitive) size_t MKLDNNEmitter::insert_primitive(mkldnn::primitive* primitive)
{ {
m_mkldnn_primitives.emplace_back(primitive); m_mkldnn_primitives.emplace_back(primitive);
return (m_mkldnn_primitives.size() - 1); return (m_mkldnn_primitives.size() - 1);
} }
size_t MKLDNNEmitter::insert_workspace(std::unique_ptr<MKLDNNWorkspace>& workspace)
{
m_workspace_bufs.push_back(workspace.get()->buf);
m_workspaces.push_back(std::move(workspace));
return (m_workspaces.size() - 1);
}
const std::vector<size_t>& MKLDNNEmitter::get_primitive_deps(size_t index) const const std::vector<size_t>& MKLDNNEmitter::get_primitive_deps(size_t index) const
{ {
return m_primitive_deps.at(index); return m_primitive_deps.at(index);
...@@ -331,6 +344,105 @@ size_t MKLDNNEmitter::build_pooling_forward(mkldnn::algorithm pooling_algorithm, ...@@ -331,6 +344,105 @@ size_t MKLDNNEmitter::build_pooling_forward(mkldnn::algorithm pooling_algorithm,
return primitive_index; return primitive_index;
} }
size_t MKLDNNEmitter::build_pooling_backward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{
size_t input_index = build_memory_primitive(diff_dst_desc);
size_t result_index = build_memory_primitive(diff_src_desc);
size_t primitive_index = insert_primitive(new mkldnn::pooling_backward(
{{pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine,
{{mkldnn::prop_kind::forward_training,
pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine}},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_max_pooling_backward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& fprop_src_desc,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{
size_t fprop_src_index = build_memory_primitive(fprop_src_desc);
size_t diff_dst_index = build_memory_primitive(diff_dst_desc);
size_t diff_src_index = build_memory_primitive(diff_src_desc);
mkldnn::pooling_forward::primitive_desc fwd_pd{
{mkldnn::prop_kind::forward_training,
pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine};
auto ws_index = build_memory_primitive(fwd_pd.workspace_primitive_desc().desc());
// Allocate workspace
// TODO (jbobba): Might need to align memory
auto ws = std::unique_ptr<MKLDNNWorkspace>(
new MKLDNNWorkspace(fwd_pd.workspace_primitive_desc().get_size()));
auto ws_buf_index = insert_workspace(ws);
size_t fwd_primitive_index = insert_primitive(new mkldnn::pooling_forward(
fwd_pd,
*m_mkldnn_primitives[fprop_src_index],
*m_mkldnn_primitives
[diff_src_index], // HACK - Uses diff_src buffer. Safe since diff_src > fprop_dst
*m_mkldnn_primitives[ws_index]));
size_t bwd_primitive_index = insert_primitive(new mkldnn::pooling_backward(
{{pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine,
fwd_pd},
*m_mkldnn_primitives[diff_dst_index],
*m_mkldnn_primitives[ws_index],
*m_mkldnn_primitives[diff_src_index]));
m_primitive_deps[fwd_primitive_index] = {
fprop_src_index, diff_src_index, ws_index, ws_buf_index};
m_primitive_deps[bwd_primitive_index] = {
diff_dst_index, ws_index, diff_src_index, ws_buf_index};
return bwd_primitive_index;
}
size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc, size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc) const mkldnn::memory::desc& result_desc)
{ {
......
...@@ -36,6 +36,14 @@ namespace ngraph ...@@ -36,6 +36,14 @@ namespace ngraph
class CPU_ExternalFunction; class CPU_ExternalFunction;
class TensorViewWrapper; class TensorViewWrapper;
class MKLDNNWorkspace
{
public:
MKLDNNWorkspace(size_t size) { buf = reinterpret_cast<char*>(malloc(size)); }
~MKLDNNWorkspace() { free(buf); }
char* buf;
};
class MKLDNNEmitter class MKLDNNEmitter
{ {
public: public:
...@@ -43,8 +51,10 @@ namespace ngraph ...@@ -43,8 +51,10 @@ namespace ngraph
~MKLDNNEmitter(); ~MKLDNNEmitter();
const std::vector<mkldnn::primitive*>& get_mkldnn_primitives() const; const std::vector<mkldnn::primitive*>& get_mkldnn_primitives() const;
const std::vector<char*>& get_mkldnn_workspaces();
size_t insert_primitive(mkldnn::primitive* primitive); size_t insert_primitive(mkldnn::primitive* primitive);
size_t insert_workspace(std::unique_ptr<MKLDNNWorkspace>& workspace);
const std::vector<size_t>& get_primitive_deps(size_t index) const; const std::vector<size_t>& get_primitive_deps(size_t index) const;
// TODO(jmenon): Get rid of TensorViewWrappers at some point // TODO(jmenon): Get rid of TensorViewWrappers at some point
...@@ -113,6 +123,23 @@ namespace ngraph ...@@ -113,6 +123,23 @@ namespace ngraph
const ngraph::Shape& padding_below, const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above); const ngraph::Shape& padding_above);
size_t build_pooling_backward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_max_pooling_backward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& fprop_src_desc,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_reorder(const mkldnn::memory::desc& input_desc, size_t build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc); const mkldnn::memory::desc& result_desc);
...@@ -153,6 +180,8 @@ namespace ngraph ...@@ -153,6 +180,8 @@ namespace ngraph
std::vector<mkldnn::primitive*> m_mkldnn_primitives; std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams; std::vector<mkldnn::stream> m_mkldnn_streams;
std::unordered_map<size_t, std::vector<size_t>> m_primitive_deps; std::unordered_map<size_t, std::vector<size_t>> m_primitive_deps;
std::vector<std::unique_ptr<MKLDNNWorkspace>> m_workspaces;
std::vector<char*> m_workspace_bufs;
}; };
} }
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp" #include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/relu.hpp" #include "ngraph/ops/relu.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
...@@ -245,10 +246,48 @@ namespace ngraph ...@@ -245,10 +246,48 @@ namespace ngraph
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPool)
{
auto max_pool = static_cast<op::MaxPool*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && max_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolBackprop)
{
auto max_pool = static_cast<op::MaxPoolBackprop*>(node);
auto arg1_shape = node->get_input_shape(1);
auto arg1_rank = arg1_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg1_rank == 4 && max_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(1) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
}
}
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu) void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu)
{ {
auto avg_pool = static_cast<op::Relu*>(node); auto relu = static_cast<op::Relu*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -260,7 +299,7 @@ namespace ngraph ...@@ -260,7 +299,7 @@ namespace ngraph
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
avg_pool->set_op_annotations(op_annotations); relu->set_op_annotations(op_annotations);
} }
} }
...@@ -280,18 +319,19 @@ namespace ngraph ...@@ -280,18 +319,19 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ReluBackprop) void CPUAssignment::ASSIGN_DECL(ngraph::op::ReluBackprop)
{ {
auto avg_pool = static_cast<op::ReluBackprop*>(node); auto relu_bprop = static_cast<op::ReluBackprop*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0); auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && node->get_input_element_type(0) == element::f32) if ((arg0_rank == 4 || arg0_rank == 2) &&
node->get_input_element_type(0) == element::f32)
{ {
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
avg_pool->set_op_annotations(op_annotations); relu_bprop->set_op_annotations(op_annotations);
} }
} }
...@@ -323,6 +363,9 @@ namespace ngraph ...@@ -323,6 +363,9 @@ namespace ngraph
static const runtime::cpu::pass::AssignOpMap s_dispatcher{ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>}, {TI(ngraph::op::Add), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Add>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>}, {TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop), {TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNormBackprop>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNormBackprop>},
...@@ -332,13 +375,13 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{ ...@@ -332,13 +375,13 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropData>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters), {TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropFilters>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::MaxPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPool>},
{TI(ngraph::op::MaxPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolBackprop>},
{TI(ngraph::op::ConvolutionBias), {TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias), {TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBiasBackpropFiltersBias>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Relu>}, {TI(ngraph::op::Relu), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Relu>},
{TI(ngraph::op::ReluBackprop), {TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReluBackprop>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReluBackprop>},
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment