Commit 01a3d6b3 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Switch to primitive pointers

parent 165de8f2
...@@ -22,6 +22,12 @@ ...@@ -22,6 +22,12 @@
using namespace ngraph::runtime::cpu; using namespace ngraph::runtime::cpu;
size_t MKLDNNEmitter::insert_primitive(mkldnn::primitive* primitive)
{
mkldnn_primitives.emplace_back(primitive);
return (mkldnn_primitives.size() - 1);
}
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt) mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt)
{ {
return mkldnn::memory::desc(mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()), return mkldnn::memory::desc(mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()),
...@@ -41,9 +47,11 @@ mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tv ...@@ -41,9 +47,11 @@ mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tv
return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr); return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr);
} }
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc) size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
{ {
return mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, nullptr); return insert_primitive(
new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, nullptr)
);
} }
size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc, size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
...@@ -54,7 +62,12 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -54,7 +62,12 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
const ngraph::CoordinateDiff& padding_above) const ngraph::CoordinateDiff& padding_above)
{ {
mkldnn_primitives.push_back(mkldnn::convolution_forward( size_t input_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t conv_index = insert_primitive(
new mkldnn::convolution_forward(
{ {
{ {
mkldnn::prop_kind::forward, mkldnn::algorithm::convolution_direct, mkldnn::prop_kind::forward, mkldnn::algorithm::convolution_direct,
...@@ -66,9 +79,11 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -66,9 +79,11 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
}, },
mkldnn_utils::global_cpu_engine mkldnn_utils::global_cpu_engine
}, },
build_memory_primitive(input_data_desc), *mkldnn_primitives[input_index],
build_memory_primitive(weights_desc), *mkldnn_primitives[weights_index],
build_memory_primitive(result_desc))); *mkldnn_primitives[result_index])
);
return (mkldnn_primitives.size() - 1); primitive_deps[conv_index] = {input_index, weights_index, result_index};
return conv_index;
} }
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <unordered_map>
#include <vector> #include <vector>
#include <mkldnn.hpp> #include <mkldnn.hpp>
...@@ -39,10 +40,11 @@ namespace ngraph ...@@ -39,10 +40,11 @@ namespace ngraph
} }
// TODO(jmenon): Get rid of TensorViewWrappers at some point // TODO(jmenon): Get rid of TensorViewWrappers at some point
size_t insert_primitive(mkldnn::primitive* primitive);
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt); mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt);
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw); mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw);
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw); mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw);
mkldnn::memory build_memory_primitive(const mkldnn::memory::desc& desc); size_t build_memory_primitive(const mkldnn::memory::desc& desc);
size_t build_convolution_forward(const mkldnn::memory::desc& input_data_desc, size_t build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc, const mkldnn::memory::desc& weights_desc,
...@@ -52,8 +54,9 @@ namespace ngraph ...@@ -52,8 +54,9 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_above); const ngraph::CoordinateDiff& padding_above);
private: private:
std::shared_ptr<CPU_ExternalFunction> external_function; std::shared_ptr<CPU_ExternalFunction> external_function;
std::vector<mkldnn::primitive> mkldnn_primitives; std::vector<mkldnn::primitive*> mkldnn_primitives;
std::vector<mkldnn::stream> mkldnn_streams; std::vector<mkldnn::stream> mkldnn_streams;
std::unordered_map<size_t, std::vector<size_t>> primitive_deps;
}; };
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment