Commit ee74e576 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

DEX Sigmoid Multiply (#1389)

* dex sigmoid multiply

* sigmoid multiply

* refactor compute logic into standalone kernels

* address jayaram's feedback
parent 3d664abd
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
//#include "ngraph/runtime/cpu/kernel/avg_pool.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/sigmoid_multiply.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -107,8 +108,69 @@ namespace ngraph ...@@ -107,8 +108,69 @@ namespace ngraph
functors.emplace_back(functor); functors.emplace_back(functor);
} }
template <>
void Builder::BUILDER_DECL(ngraph::op::SigmoidMultiply)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto tensor_size = shape_size(args[0].get_shape());
auto sigmoid_mul = static_cast<const ngraph::op::SigmoidMultiply*>(node);
const size_t index =
static_cast<size_t>(sigmoid_mul->get_input_func_type(0)) *
static_cast<size_t>(ngraph::op::SigmoidMultiply::FunctionType::NumTypes) +
static_cast<size_t>(sigmoid_mul->get_input_func_type(1));
auto functor = [&, index, tensor_size](CPURuntimeContext* ctx) {
ngraph::runtime::cpu::kernel::sigmoid_multiply(
arg0_tensor, arg1_tensor, out_tensor, tensor_size, index);
};
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::SigmoidMultiplyBackprop)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& arg2_tensor = tensor_data[args[2].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto& out1_tensor = tensor_data[out[1].get_name()];
auto tensor_size = shape_size(args[0].get_shape());
auto sigmoid_mul = static_cast<const ngraph::op::SigmoidMultiplyBackprop*>(node);
const size_t index =
static_cast<size_t>(sigmoid_mul->get_input_func_type(0)) *
static_cast<size_t>(ngraph::op::SigmoidMultiply::FunctionType::NumTypes) +
static_cast<size_t>(sigmoid_mul->get_input_func_type(1));
auto functor = [&, index, tensor_size](CPURuntimeContext* ctx) {
ngraph::runtime::cpu::kernel::sigmoid_multiply_backprop(arg0_tensor,
arg1_tensor,
arg2_tensor,
out0_tensor,
out1_tensor,
tensor_size,
index);
};
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(Sigmoid); REGISTER_OP_BUILDER(Sigmoid);
REGISTER_OP_BUILDER(SigmoidBackprop); REGISTER_OP_BUILDER(SigmoidBackprop);
REGISTER_OP_BUILDER(SigmoidMultiply);
REGISTER_OP_BUILDER(SigmoidMultiplyBackprop);
} }
} }
} }
This diff is collapsed.
...@@ -34,7 +34,8 @@ namespace ngraph ...@@ -34,7 +34,8 @@ namespace ngraph
{ {
Logistic, Logistic,
Tanh, Tanh,
Identity Identity,
NumTypes
}; };
/// Input nodes are expected to be actual inputs where the corresponding input /// Input nodes are expected to be actual inputs where the corresponding input
/// FunctionType will be applied to those inputs in the fused operation. /// FunctionType will be applied to those inputs in the fused operation.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment