Commit 51de579e authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Memory primitive initialization

parent 01a3d6b3
...@@ -1818,11 +1818,17 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution) ...@@ -1818,11 +1818,17 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution)
auto weights_desc = mkldnn_emitter->build_memory_descriptor(args[1], mkldnn::memory::format::oihw); auto weights_desc = mkldnn_emitter->build_memory_descriptor(args[1], mkldnn::memory::format::oihw);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], mkldnn::memory::format::nchw); auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], mkldnn::memory::format::nchw);
size_t conv_index = mkldnn_emitter->build_convolution_forward(input_data_desc, weights_desc, result_desc, size_t conv_index = mkldnn_emitter->build_convolution_forward(input_data_desc,
weights_desc,
result_desc,
convolution->get_window_movement_strides(), convolution->get_window_movement_strides(),
convolution->get_padding_below(), convolution->get_padding_below(),
convolution->get_padding_above()); convolution->get_padding_above());
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", " << args[0].get_name() << ");\n";
writer << "mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", " << args[1].get_name() << ");\n";
writer << "mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", " << out[0].get_name() << ");\n";
writer << "mkldnn_utils::mkldnn_invoke(ctx, " << to_string(conv_index) << ");\n"; writer << "mkldnn_utils::mkldnn_invoke(ctx, " << to_string(conv_index) << ");\n";
#endif #endif
......
...@@ -28,21 +28,26 @@ size_t MKLDNNEmitter::insert_primitive(mkldnn::primitive* primitive) ...@@ -28,21 +28,26 @@ size_t MKLDNNEmitter::insert_primitive(mkldnn::primitive* primitive)
return (mkldnn_primitives.size() - 1); return (mkldnn_primitives.size() - 1);
} }
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt) const std::vector<size_t>& MKLDNNEmitter::get_primitive_deps(size_t index) const
{
return primitive_deps.at(index);
}
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt) const
{ {
return mkldnn::memory::desc(mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()), return mkldnn::memory::desc(mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()),
mkldnn_utils::GetDataType(tvw.get_element_type()), mkldnn_utils::GetDataType(tvw.get_element_type()),
fmt); fmt);
} }
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw) mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw) const
{ {
auto layout = std::static_pointer_cast<LayoutDescriptor>(tvw.get_tensor_view()->get_tensor_view_layout()); auto layout = std::static_pointer_cast<LayoutDescriptor>(tvw.get_tensor_view()->get_tensor_view_layout());
return build_memory_descriptor(tvw, layout->get_mkldnn_format()); return build_memory_descriptor(tvw, layout->get_mkldnn_format());
} }
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tvw) mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tvw) const
{ {
return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr); return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr);
} }
...@@ -62,9 +67,9 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -62,9 +67,9 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
const ngraph::CoordinateDiff& padding_above) const ngraph::CoordinateDiff& padding_above)
{ {
size_t input_index = build_memory_primitive(input_data_desc); size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc); size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc); size_t result_index = build_memory_primitive(result_desc);
size_t conv_index = insert_primitive( size_t conv_index = insert_primitive(
new mkldnn::convolution_forward( new mkldnn::convolution_forward(
...@@ -79,11 +84,11 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -79,11 +84,11 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
}, },
mkldnn_utils::global_cpu_engine mkldnn_utils::global_cpu_engine
}, },
*mkldnn_primitives[input_index], *mkldnn_primitives[input_data_index],
*mkldnn_primitives[weights_index], *mkldnn_primitives[weights_index],
*mkldnn_primitives[result_index]) *mkldnn_primitives[result_index])
); );
primitive_deps[conv_index] = {input_index, weights_index, result_index}; primitive_deps[conv_index] = {input_data_index, weights_index, result_index};
return conv_index; return conv_index;
} }
...@@ -39,19 +39,21 @@ namespace ngraph ...@@ -39,19 +39,21 @@ namespace ngraph
{ {
} }
// TODO(jmenon): Get rid of TensorViewWrappers at some point
size_t insert_primitive(mkldnn::primitive* primitive); size_t insert_primitive(mkldnn::primitive* primitive);
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt); const std::vector<size_t>& get_primitive_deps(size_t index) const;
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw);
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw); // TODO(jmenon): Get rid of TensorViewWrappers at some point
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt) const;
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw) const;
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw) const;
size_t build_memory_primitive(const mkldnn::memory::desc& desc); size_t build_memory_primitive(const mkldnn::memory::desc& desc);
size_t build_convolution_forward(const mkldnn::memory::desc& input_data_desc, size_t build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc, const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides, const ngraph::Strides& strides,
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above); const ngraph::CoordinateDiff& padding_above);
private: private:
std::shared_ptr<CPU_ExternalFunction> external_function; std::shared_ptr<CPU_ExternalFunction> external_function;
std::vector<mkldnn::primitive*> mkldnn_primitives; std::vector<mkldnn::primitive*> mkldnn_primitives;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment