Commit cd74b8f0 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Propagate staged primitives to the runtime context

parent c9c01be1
...@@ -142,6 +142,8 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context() ...@@ -142,6 +142,8 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
{ {
ctx->op_durations = new int64_t[m_external_function->get_op_attrs().size()]; ctx->op_durations = new int64_t[m_external_function->get_op_attrs().size()];
} }
const auto& mkldnn_emitter = m_external_function->get_mkldnn_emitter();
ctx->mkldnn_primitives = mkldnn_emitter->get_mkldnn_primitives().data();
} }
void runtime::cpu::CPU_CallFrame::cleanup_runtime_context() void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
......
...@@ -1826,11 +1826,11 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution) ...@@ -1826,11 +1826,11 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution)
convolution->get_padding_above()); convolution->get_padding_above());
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index); auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", " << args[0].get_name() << ");\n"; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", " << args[0].get_name() << ");\n";
writer << "mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", " << args[1].get_name() << ");\n"; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", " << args[1].get_name() << ");\n";
writer << "mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", " << out[0].get_name() << ");\n"; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", " << out[0].get_name() << ");\n";
writer << "mkldnn_utils::mkldnn_invoke(ctx, " << to_string(conv_index) << ");\n"; writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, " << to_string(conv_index) << ");\n";
#endif #endif
} }
else if (filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 && else if (filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
......
...@@ -268,6 +268,7 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -268,6 +268,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/cpu/cpu_eigen_utils.hpp" #include "ngraph/runtime/cpu/cpu_eigen_utils.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp" #include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/cpu_runtime_context.hpp" #include "ngraph/runtime/cpu/cpu_runtime_context.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/kernel/avg_pool.hpp" #include "ngraph/runtime/kernel/avg_pool.hpp"
#include "ngraph/runtime/kernel/broadcast.hpp" #include "ngraph/runtime/kernel/broadcast.hpp"
#include "ngraph/runtime/kernel/concat.hpp" #include "ngraph/runtime/kernel/concat.hpp"
......
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
struct CPURuntimeContext struct CPURuntimeContext
{ {
int64_t* op_durations; int64_t* op_durations;
mkldnn::primitive** mkldnn_primitives; mkldnn::primitive* const* mkldnn_primitives;
}; };
} }
} }
......
...@@ -59,8 +59,12 @@ mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tv ...@@ -59,8 +59,12 @@ mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tv
size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc) size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
{ {
// The MKL-DNN C++ API forces proper initialization of a memory primitive
// with a non-null pointer (unlike the C API)
// Primitives are initialized at runtime so we use a known-invalid address here
// to bypass this check
return insert_primitive( return insert_primitive(
new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, nullptr) new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, reinterpret_cast<void*>(0x42))
); );
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment