Commit 7fdd8d70 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Formatting fixes

parent cd74b8f0
...@@ -1814,23 +1814,31 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution) ...@@ -1814,23 +1814,31 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution)
writer << "}\n"; writer << "}\n";
#else #else
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc = mkldnn_emitter->build_memory_descriptor(args[0], mkldnn::memory::format::nchw); auto input_data_desc =
auto weights_desc = mkldnn_emitter->build_memory_descriptor(args[1], mkldnn::memory::format::oihw); mkldnn_emitter->build_memory_descriptor(args[0], mkldnn::memory::format::nchw);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], mkldnn::memory::format::nchw); auto weights_desc =
mkldnn_emitter->build_memory_descriptor(args[1], mkldnn::memory::format::oihw);
size_t conv_index = mkldnn_emitter->build_convolution_forward(input_data_desc, auto result_desc =
weights_desc, mkldnn_emitter->build_memory_descriptor(out[0], mkldnn::memory::format::nchw);
result_desc,
convolution->get_window_movement_strides(), size_t conv_index =
convolution->get_padding_below(), mkldnn_emitter->build_convolution_forward(input_data_desc,
convolution->get_padding_above()); weights_desc,
result_desc,
convolution->get_window_movement_strides(),
convolution->get_padding_below(),
convolution->get_padding_above());
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index); auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", " << args[0].get_name() << ");\n"; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", " << args[1].get_name() << ");\n"; << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", " << out[0].get_name() << ");\n"; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", "
<< args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, " << to_string(conv_index) << ");\n"; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, " << to_string(conv_index)
<< ");\n";
#endif #endif
} }
else if (filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 && else if (filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
......
...@@ -70,7 +70,8 @@ bool runtime::cpu::TensorViewWrapper::is_output() const ...@@ -70,7 +70,8 @@ bool runtime::cpu::TensorViewWrapper::is_output() const
return m_tensor_view->get_tensor().is_output(); return m_tensor_view->get_tensor().is_output();
} }
const std::shared_ptr<descriptor::TensorView> runtime::cpu::TensorViewWrapper::get_tensor_view() const const std::shared_ptr<descriptor::TensorView>
runtime::cpu::TensorViewWrapper::get_tensor_view() const
{ {
return m_tensor_view; return m_tensor_view;
} }
...@@ -38,16 +38,19 @@ const std::vector<size_t>& MKLDNNEmitter::get_primitive_deps(size_t index) const ...@@ -38,16 +38,19 @@ const std::vector<size_t>& MKLDNNEmitter::get_primitive_deps(size_t index) const
return primitive_deps.at(index); return primitive_deps.at(index);
} }
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt) const mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw,
mkldnn::memory::format fmt) const
{ {
return mkldnn::memory::desc(mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()), return mkldnn::memory::desc(
mkldnn_utils::GetDataType(tvw.get_element_type()), mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()),
fmt); mkldnn_utils::GetDataType(tvw.get_element_type()),
fmt);
} }
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw) const mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw) const
{ {
auto layout = std::static_pointer_cast<LayoutDescriptor>(tvw.get_tensor_view()->get_tensor_view_layout()); auto layout =
std::static_pointer_cast<LayoutDescriptor>(tvw.get_tensor_view()->get_tensor_view_layout());
return build_memory_descriptor(tvw, layout->get_mkldnn_format()); return build_memory_descriptor(tvw, layout->get_mkldnn_format());
} }
...@@ -64,8 +67,7 @@ size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc) ...@@ -64,8 +67,7 @@ size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
// Primitives are initialized at runtime so we use a known-invalid address here // Primitives are initialized at runtime so we use a known-invalid address here
// to bypass this check // to bypass this check
return insert_primitive( return insert_primitive(
new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, reinterpret_cast<void*>(0x42)) new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, reinterpret_cast<void*>(0x42)));
);
} }
size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc, size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
...@@ -77,26 +79,23 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -77,26 +79,23 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
{ {
size_t input_data_index = build_memory_primitive(input_data_desc); size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc); size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc); size_t result_index = build_memory_primitive(result_desc);
size_t conv_index = insert_primitive( size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
new mkldnn::convolution_forward( {{mkldnn::prop_kind::forward,
{ mkldnn::algorithm::convolution_direct,
{ input_data_desc,
mkldnn::prop_kind::forward, mkldnn::algorithm::convolution_direct, weights_desc,
input_data_desc, weights_desc, result_desc, result_desc,
mkldnn::memory::dims(strides.begin(), strides.end()), mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()), mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()), mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero mkldnn::padding_kind::zero},
}, mkldnn_utils::global_cpu_engine},
mkldnn_utils::global_cpu_engine
},
*mkldnn_primitives[input_data_index], *mkldnn_primitives[input_data_index],
*mkldnn_primitives[weights_index], *mkldnn_primitives[weights_index],
*mkldnn_primitives[result_index]) *mkldnn_primitives[result_index]));
);
primitive_deps[conv_index] = {input_data_index, weights_index, result_index}; primitive_deps[conv_index] = {input_data_index, weights_index, result_index};
return conv_index; return conv_index;
......
...@@ -45,7 +45,8 @@ namespace ngraph ...@@ -45,7 +45,8 @@ namespace ngraph
const std::vector<size_t>& get_primitive_deps(size_t index) const; const std::vector<size_t>& get_primitive_deps(size_t index) const;
// TODO(jmenon): Get rid of TensorViewWrappers at some point // TODO(jmenon): Get rid of TensorViewWrappers at some point
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt) const; mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw,
mkldnn::memory::format fmt) const;
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw) const; mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw) const;
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw) const; mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw) const;
size_t build_memory_primitive(const mkldnn::memory::desc& desc); size_t build_memory_primitive(const mkldnn::memory::desc& desc);
...@@ -56,6 +57,7 @@ namespace ngraph ...@@ -56,6 +57,7 @@ namespace ngraph
const ngraph::Strides& strides, const ngraph::Strides& strides,
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above); const ngraph::CoordinateDiff& padding_above);
private: private:
std::shared_ptr<CPU_ExternalFunction> external_function; std::shared_ptr<CPU_ExternalFunction> external_function;
std::vector<mkldnn::primitive*> mkldnn_primitives; std::vector<mkldnn::primitive*> mkldnn_primitives;
......
...@@ -20,16 +20,16 @@ ...@@ -20,16 +20,16 @@
mkldnn::engine ngraph::runtime::cpu::mkldnn_utils::global_cpu_engine(mkldnn::engine::cpu, 0); mkldnn::engine ngraph::runtime::cpu::mkldnn_utils::global_cpu_engine(mkldnn::engine::cpu, 0);
extern "C" void ngraph::runtime::cpu::mkldnn_utils::set_memory_ptr(CPURuntimeContext* ctx, size_t primitive_index, extern "C" void ngraph::runtime::cpu::mkldnn_utils::set_memory_ptr(CPURuntimeContext* ctx,
size_t primitive_index,
void* ptr) void* ptr)
{ {
auto primitive = static_cast<mkldnn::memory*>(ctx->mkldnn_primitives[primitive_index]); auto primitive = static_cast<mkldnn::memory*>(ctx->mkldnn_primitives[primitive_index]);
primitive->set_data_handle(ptr); primitive->set_data_handle(ptr);
} }
extern "C" void extern "C" void ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(CPURuntimeContext* ctx,
ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(CPURuntimeContext* ctx, size_t primitive_index)
size_t primitive_index)
{ {
mkldnn::stream s(mkldnn::stream::kind::eager); mkldnn::stream s(mkldnn::stream::kind::eager);
s.submit({*ctx->mkldnn_primitives[primitive_index]}).wait(); s.submit({*ctx->mkldnn_primitives[primitive_index]}).wait();
......
...@@ -26,8 +26,8 @@ namespace ngraph ...@@ -26,8 +26,8 @@ namespace ngraph
namespace mkldnn_utils namespace mkldnn_utils
{ {
extern "C" void set_memory_ptr(CPURuntimeContext* ctx, size_t primitive_index, extern "C" void
void* ptr); set_memory_ptr(CPURuntimeContext* ctx, size_t primitive_index, void* ptr);
extern "C" void mkldnn_invoke_primitive(CPURuntimeContext* ctx, extern "C" void mkldnn_invoke_primitive(CPURuntimeContext* ctx,
size_t primitive_index); size_t primitive_index);
} }
......
...@@ -45,23 +45,24 @@ namespace ngraph ...@@ -45,23 +45,24 @@ namespace ngraph
TI(ngraph::op::ConvolutionBackpropFilters), TI(ngraph::op::ConvolutionBackpropFilters),
TI(ngraph::op::MaxPool)}; TI(ngraph::op::MaxPool)};
static const std::unordered_map<std::string, const mkldnn::memory::data_type> s_data_type_map{ static const std::unordered_map<std::string, const mkldnn::memory::data_type>
{"char", mkldnn::memory::data_type::s8}, s_data_type_map{{"char", mkldnn::memory::data_type::s8},
{"float", mkldnn::memory::data_type::f32}, {"float", mkldnn::memory::data_type::f32},
{"double", mkldnn::memory::data_type::data_undef}, {"double", mkldnn::memory::data_type::data_undef},
{"int8_t", mkldnn::memory::data_type::s8}, {"int8_t", mkldnn::memory::data_type::s8},
{"int16_t", mkldnn::memory::data_type::s16}, {"int16_t", mkldnn::memory::data_type::s16},
{"int32_t", mkldnn::memory::data_type::s32}, {"int32_t", mkldnn::memory::data_type::s32},
{"int64_t", mkldnn::memory::data_type::data_undef}, {"int64_t", mkldnn::memory::data_type::data_undef},
{"uint8_t", mkldnn::memory::data_type::u8}, {"uint8_t", mkldnn::memory::data_type::u8},
{"uint16_t", mkldnn::memory::data_type::data_undef}, {"uint16_t", mkldnn::memory::data_type::data_undef},
{"uint32_t", mkldnn::memory::data_type::data_undef}, {"uint32_t", mkldnn::memory::data_type::data_undef},
{"uint64_t", mkldnn::memory::data_type::data_undef}}; {"uint64_t", mkldnn::memory::data_type::data_undef}};
mkldnn::memory::data_type GetDataType(const ngraph::element::Type& et) mkldnn::memory::data_type GetDataType(const ngraph::element::Type& et)
{ {
auto it = s_data_type_map.find(et.c_type_string()); auto it = s_data_type_map.find(et.c_type_string());
if (it == s_data_type_map.end() || it->second == mkldnn::memory::data_type::data_undef) if (it == s_data_type_map.end() ||
it->second == mkldnn::memory::data_type::data_undef)
throw ngraph_error("No MKLDNN data type exists for the given element type"); throw ngraph_error("No MKLDNN data type exists for the given element type");
return it->second; return it->second;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment