Commit 1f1ab184 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Dex non-mkldnn version of clipped relu (#1376)

* Dex non-mkldnn version of clipped relu

* Change to static_cast
parent a8fb4fe0
......@@ -32,17 +32,12 @@ namespace ngraph
template <>
void Builder::BUILDER_DECL(ngraph::op::BoundedRelu)
{
if (!runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error(
"BoundedRelu is supported only through MKLDNN and doesnt have reference "
"INTERPRETER implementation");
}
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& input_tensor = tensor_data[args[0].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
size_t count = out[0].get_size();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
......@@ -56,6 +51,19 @@ namespace ngraph
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::bounded_relu<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::bounded_relu);
auto alpha = static_cast<const op::BoundedRelu*>(node)->get_alpha();
auto functor = [&, kernel, alpha, count](CPURuntimeContext* ctx) {
kernel(input_tensor, out_tensor, alpha, count);
};
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(BoundedRelu);
}
......
......@@ -45,6 +45,22 @@ namespace ngraph
out.device(eigen::global_thread_pool_device) = in0.cwiseMax(ElementType(0));
}
template <typename ElementType>
void bounded_relu(void* input0, void* output, ElementType alpha, size_t count)
{
Eigen::array<Eigen::Index, 1> out_dims, in_dims;
out_dims[0] = in_dims[0] = count;
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> in0(
static_cast<ElementType*>(input0), in_dims);
out.device(eigen::global_thread_pool_device) =
in0.cwiseMax(ElementType(0)).cwiseMin(alpha);
}
template <typename ElementType>
void relu_backprop(void* arg, void* delta_arg, void* out, size_t count)
{
......
......@@ -2603,7 +2603,7 @@ static void check_bounded_relu(Shape param_shape, float constant_val)
auto cpu_f = make_function(param_shape, constant_val);
auto int_f = make_function(param_shape, constant_val);
test::Uniform<float> rng(0.0f, 1.0f);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment