Commit 01a3d6b3 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Switch to primitive pointers

parent 165de8f2
......@@ -22,6 +22,12 @@
using namespace ngraph::runtime::cpu;
size_t MKLDNNEmitter::insert_primitive(mkldnn::primitive* primitive)
{
mkldnn_primitives.emplace_back(primitive);
return (mkldnn_primitives.size() - 1);
}
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt)
{
return mkldnn::memory::desc(mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()),
......@@ -41,34 +47,43 @@ mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tv
return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr);
}
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
{
return mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, nullptr);
return insert_primitive(
new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, nullptr)
);
}
size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above)
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& strides,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above)
{
mkldnn_primitives.push_back(mkldnn::convolution_forward(
{
size_t input_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t conv_index = insert_primitive(
new mkldnn::convolution_forward(
{
mkldnn::prop_kind::forward, mkldnn::algorithm::convolution_direct,
input_data_desc, weights_desc, result_desc,
mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero
{
mkldnn::prop_kind::forward, mkldnn::algorithm::convolution_direct,
input_data_desc, weights_desc, result_desc,
mkldnn::memory::dims(strides.begin(), strides.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero
},
mkldnn_utils::global_cpu_engine
},
mkldnn_utils::global_cpu_engine
},
build_memory_primitive(input_data_desc),
build_memory_primitive(weights_desc),
build_memory_primitive(result_desc)));
*mkldnn_primitives[input_index],
*mkldnn_primitives[weights_index],
*mkldnn_primitives[result_index])
);
return (mkldnn_primitives.size() - 1);
primitive_deps[conv_index] = {input_index, weights_index, result_index};
return conv_index;
}
......@@ -15,6 +15,7 @@
#pragma once
#include <memory>
#include <unordered_map>
#include <vector>
#include <mkldnn.hpp>
......@@ -39,10 +40,11 @@ namespace ngraph
}
// TODO(jmenon): Get rid of TensorViewWrappers at some point
size_t insert_primitive(mkldnn::primitive* primitive);
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw, mkldnn::memory::format fmt);
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw);
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw);
mkldnn::memory build_memory_primitive(const mkldnn::memory::desc& desc);
size_t build_memory_primitive(const mkldnn::memory::desc& desc);
size_t build_convolution_forward(const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
......@@ -52,8 +54,9 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_above);
private:
std::shared_ptr<CPU_ExternalFunction> external_function;
std::vector<mkldnn::primitive> mkldnn_primitives;
std::vector<mkldnn::primitive*> mkldnn_primitives;
std::vector<mkldnn::stream> mkldnn_streams;
std::unordered_map<size_t, std::vector<size_t>> primitive_deps;
};
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment