Commit 165de8f2 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Conv builder

parent 1fcfbca7
...@@ -1814,7 +1814,17 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution) ...@@ -1814,7 +1814,17 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution)
writer << "}\n"; writer << "}\n";
#else #else
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc = mkldnn_emitter->build_memory_descriptor(args[0]); auto input_data_desc = mkldnn_emitter->build_memory_descriptor(args[0], mkldnn::memory::format::nchw);
auto weights_desc = mkldnn_emitter->build_memory_descriptor(args[1], mkldnn::memory::format::oihw);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], mkldnn::memory::format::nchw);
size_t conv_index = mkldnn_emitter->build_convolution_forward(input_data_desc, weights_desc, result_desc,
convolution->get_window_movement_strides(),
convolution->get_padding_below(),
convolution->get_padding_above());
writer << "mkldnn_utils::mkldnn_invoke(ctx, " << to_string(conv_index) << ");\n";
#endif #endif
} }
else if (filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 && else if (filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
......
...@@ -69,3 +69,8 @@ bool runtime::cpu::TensorViewWrapper::is_output() const ...@@ -69,3 +69,8 @@ bool runtime::cpu::TensorViewWrapper::is_output() const
{ {
return m_tensor_view->get_tensor().is_output(); return m_tensor_view->get_tensor().is_output();
} }
const std::shared_ptr<descriptor::TensorView> runtime::cpu::TensorViewWrapper::get_tensor_view() const
{
return m_tensor_view;
}
...@@ -45,6 +45,7 @@ public: ...@@ -45,6 +45,7 @@ public:
const std::string& get_name() const; const std::string& get_name() const;
const std::string& get_type() const; const std::string& get_type() const;
bool is_output() const; bool is_output() const;
const std::shared_ptr<descriptor::TensorView> get_tensor_view() const;
private: private:
std::shared_ptr<descriptor::TensorView> m_tensor_view; std::shared_ptr<descriptor::TensorView> m_tensor_view;
......
...@@ -12,16 +12,63 @@ ...@@ -12,16 +12,63 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <memory>
#include "mkldnn_emitter.hpp" #include "mkldnn_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace ngraph::runtime::cpu; using namespace ngraph::runtime::cpu;
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw) mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt)
{ {
return mkldnn::memory::desc(mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()), return mkldnn::memory::desc(mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()),
mkldnn_utils::GetDataType(tvw.get_element_type()), mkldnn_utils::GetDataType(tvw.get_element_type()),
mkldnn::memory::format::nchw); fmt);
}
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw)
{
auto layout = std::static_pointer_cast<LayoutDescriptor>(tvw.get_tensor_view()->get_tensor_view_layout());
return build_memory_descriptor(tvw, layout->get_mkldnn_format());
}
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tvw)
{
return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr);
}
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
{
return mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, nullptr);
} }
size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above)
{
mkldnn_primitives.push_back(mkldnn::convolution_forward(
{
{
mkldnn::prop_kind::forward, mkldnn::algorithm::convolution_direct,
input_data_desc, weights_desc, result_desc,
mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero
},
mkldnn_utils::global_cpu_engine
},
build_memory_primitive(input_data_desc),
build_memory_primitive(weights_desc),
build_memory_primitive(result_desc)));
return (mkldnn_primitives.size() - 1);
}
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include <mkldnn.hpp> #include <mkldnn.hpp>
#include "ngraph/common.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -37,8 +39,17 @@ namespace ngraph ...@@ -37,8 +39,17 @@ namespace ngraph
} }
// TODO(jmenon): Get rid of TensorViewWrappers at some point // TODO(jmenon): Get rid of TensorViewWrappers at some point
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt);
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw); mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw);
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw);
mkldnn::memory build_memory_primitive(const mkldnn::memory::desc& desc);
size_t build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
private: private:
std::shared_ptr<CPU_ExternalFunction> external_function; std::shared_ptr<CPU_ExternalFunction> external_function;
std::vector<mkldnn::primitive> mkldnn_primitives; std::vector<mkldnn::primitive> mkldnn_primitives;
......
...@@ -21,6 +21,6 @@ mkldnn::engine ngraph::runtime::cpu::mkldnn_utils::global_cpu_engine(mkldnn::eng ...@@ -21,6 +21,6 @@ mkldnn::engine ngraph::runtime::cpu::mkldnn_utils::global_cpu_engine(mkldnn::eng
extern "C" void extern "C" void
ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(CPURuntimeContext* ctx, ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(CPURuntimeContext* ctx,
unsigned int primitive_index) size_t primitive_index)
{ {
} }
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <cstddef>
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -25,7 +27,7 @@ namespace ngraph ...@@ -25,7 +27,7 @@ namespace ngraph
namespace mkldnn_utils namespace mkldnn_utils
{ {
extern "C" void mkldnn_invoke_primitive(CPURuntimeContext* ctx, extern "C" void mkldnn_invoke_primitive(CPURuntimeContext* ctx,
unsigned int primitive_index); size_t primitive_index);
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment