Commit 9c1c5b59 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Added MKLDNN concat in DEX (#1379)

parent 94889200
......@@ -17,6 +17,8 @@
#include "ngraph/op/concat.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/concat.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
......@@ -35,12 +37,6 @@ namespace ngraph
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
std::function<decltype(runtime::cpu::kernel::concat<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
out[0].get_element_type(),
out[0].get_shape().size(),
runtime::cpu::kernel::concat);
vector<reference_wrapper<void*>> arg_tensors;
vector<Shape> arg_shapes;
......@@ -56,12 +52,51 @@ namespace ngraph
auto& out_tensor = tensor_data[out[0].get_name()];
auto out_shape = out[0].get_shape();
auto functor =
[&, kernel, arg_tensors, arg_shapes, out_shape, axis](CPURuntimeContext* ctx) {
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
std::vector<mkldnn::memory::desc> inputs_data_desc;
for (size_t i = 0; i < args.size(); i++)
{
inputs_data_desc.push_back(mkldnn_utils::get_input_mkldnn_md(node, i));
}
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t concat_dim =
(dynamic_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis();
auto nargs = args.size();
auto concat_index =
mkldnn_emitter->build_concat(inputs_data_desc, result_desc, concat_dim);
auto& deps = mkldnn_emitter->get_primitive_deps(concat_index);
auto functor = [&, arg_tensors, nargs, concat_index](CPURuntimeContext* ctx) {
for (size_t i = 0; i < nargs; i++)
{
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[i], arg_tensors[i]);
}
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[nargs], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, concat_index);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::concat<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
out[0].get_element_type(),
out[0].get_shape().size(),
runtime::cpu::kernel::concat);
auto functor = [&, kernel, arg_tensors, arg_shapes, out_shape, axis](
CPURuntimeContext* ctx) {
kernel(arg_tensors, arg_shapes, out_tensor, out_shape, axis);
};
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(Concat);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment