Commit c625abfd authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

CPU Direct Execution: Add MKL-DNN EW Add (#1392)

parent 82e5fee3
......@@ -27,6 +27,7 @@ set(SRC
cpu_tensor_view_wrapper.cpp
cpu_tensor_view.cpp
cpu_tracing.cpp
builder/add.cpp
builder/avg_pool.cpp
builder/batch_norm.cpp
builder/broadcast.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <vector>
#include "ngraph/op/add.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/add.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Add)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
vector<float> scale_vector(2, 1);
vector<mkldnn::memory::primitive_desc> inputs_pd;
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input0_data_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto input1_data_desc = mkldnn_utils::get_input_mkldnn_md(node, 1);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
inputs_pd.push_back(mkldnn::memory::primitive_desc(
input0_data_desc, runtime::cpu::mkldnn_utils::global_cpu_engine));
inputs_pd.push_back(mkldnn::memory::primitive_desc(
input1_data_desc, runtime::cpu::mkldnn_utils::global_cpu_engine));
size_t add_index = mkldnn_emitter->build_elementwise_add(
input0_data_desc, input1_data_desc, result_desc, scale_vector, inputs_pd);
auto& deps = mkldnn_emitter->get_primitive_deps(add_index);
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto functor = [&, add_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], arg1_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, add_index);
};
functors.emplace_back(functor);
}
else
{
BUILD_BINARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::add);
}
}
REGISTER_OP_BUILDER(Add);
}
}
}
......@@ -26,7 +26,6 @@
#include "ngraph/node.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
......@@ -67,7 +66,6 @@
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/kernel/abs.hpp"
#include "ngraph/runtime/cpu/kernel/acos.hpp"
#include "ngraph/runtime/cpu/kernel/add.hpp"
#include "ngraph/runtime/cpu/kernel/and.hpp"
#include "ngraph/runtime/cpu/kernel/asin.hpp"
#include "ngraph/runtime/cpu/kernel/atan.hpp"
......@@ -119,12 +117,6 @@ namespace ngraph
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Add)
{
BUILD_BINARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::add);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Subtract)
{
......@@ -386,7 +378,6 @@ namespace ngraph
REGISTER_OP_BUILDER(Constant);
REGISTER_OP_BUILDER(Result);
REGISTER_OP_BUILDER(Add);
REGISTER_OP_BUILDER(Subtract);
REGISTER_OP_BUILDER(Multiply);
REGISTER_OP_BUILDER(Divide);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment