Commit cd74b8f0 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Propagate staged primitives to the runtime context

parent c9c01be1
......@@ -142,6 +142,8 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
{
ctx->op_durations = new int64_t[m_external_function->get_op_attrs().size()];
}
const auto& mkldnn_emitter = m_external_function->get_mkldnn_emitter();
ctx->mkldnn_primitives = mkldnn_emitter->get_mkldnn_primitives().data();
}
void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
......
......@@ -1826,11 +1826,11 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution)
convolution->get_padding_above());
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
writer << "mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", " << args[0].get_name() << ");\n";
writer << "mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", " << args[1].get_name() << ");\n";
writer << "mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", " << out[0].get_name() << ");\n";
writer << "mkldnn_utils::mkldnn_invoke(ctx, " << to_string(conv_index) << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, " << to_string(conv_index) << ");\n";
#endif
}
else if (filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
......
......@@ -268,6 +268,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/cpu/cpu_eigen_utils.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/cpu_runtime_context.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/kernel/avg_pool.hpp"
#include "ngraph/runtime/kernel/broadcast.hpp"
#include "ngraph/runtime/kernel/concat.hpp"
......
......@@ -36,7 +36,7 @@ namespace ngraph
struct CPURuntimeContext
{
int64_t* op_durations;
mkldnn::primitive** mkldnn_primitives;
mkldnn::primitive* const* mkldnn_primitives;
};
}
}
......
......@@ -59,8 +59,12 @@ mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tv
size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
{
// The MKL-DNN C++ API forces proper initialization of a memory primitive
// with a non-null pointer (unlike the C API)
// Primitives are initialized at runtime so we use a known-invalid address here
// to bypass this check
return insert_primitive(
new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, nullptr)
new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, reinterpret_cast<void*>(0x42))
);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment