Commit f1c29c9c authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

DEX: Softmax (#1341)

* Add helper macros to select from a partial set of ranks and element types

* CPU Direct Execution: Implement Softmax

* Add softmax builder to the build script

* Update
parent 74a7ef7f
......@@ -49,6 +49,7 @@ set(SRC
builder/select.cpp
builder/select_and_scatter.cpp
builder/sigmoid.cpp
builder/softmax.cpp
builder/sum.cpp
kernel/eigen_thread_pool.cpp
kernel/pad.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/softmax.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Softmax)
{
auto softmax = static_cast<const ngraph::op::Softmax*>(node);
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg_shape = args[0].get_shape();
auto& arg_tensor = tensor_data[args[0].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto axes = softmax->get_axes();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
if (axes.size() != 1)
{
throw ngraph_error("MKLDNN supports softmax only across single axis");
}
int softmax_axis = static_cast<int>(*(axes.begin()));
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t softmax_index = mkldnn_emitter->build_softmax_forward(
input_desc, result_desc, softmax_axis);
auto& deps = mkldnn_emitter->get_primitive_deps(softmax_index);
auto functor = [&, softmax_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, softmax_index);
};
functors.emplace_back(functor);
}
else
{
if (axes.size() == arg_shape.size())
{
std::function<decltype(runtime::cpu::kernel::softmax_all<float, 1>)> kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_all);
auto functor = [&, kernel, arg_shape](CPURuntimeContext* ctx) {
kernel(arg_tensor, out_tensor, arg_shape);
};
functors.emplace_back(functor);
}
else if (axes.size() == 1)
{
std::function<decltype(runtime::cpu::kernel::softmax_1rd<float, 1>)> kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_1rd);
auto functor = [&, kernel, arg_shape, axes](CPURuntimeContext* ctx) {
kernel(arg_tensor, out_tensor, arg_shape, axes);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("Unsupported Softmax");
}
}
}
REGISTER_OP_BUILDER(Softmax);
}
}
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/axis_set.hpp"
#include "ngraph/runtime/cpu/kernel/eigen_thread_pool.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
template <typename ElementType, unsigned int Rank>
void softmax_all(void* input, void* output, const Shape& input_shape)
{
Eigen::array<Eigen::Index, Rank> in_dims, rdims;
rdims.fill(1);
for (int i = 0; i < Rank; i++)
{
in_dims[i] = input_shape[i];
}
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out(
static_cast<ElementType *>(output), in_dims),
in(static_cast<ElementType *>(input), in_dims);
out.device(eigen::global_thread_pool_device) =
(in - in.maximum().eval().reshape(rdims).broadcast(in_dims)).exp();
out.device(eigen::global_thread_pool_device) =
out * out.sum().inverse().eval().reshape(rdims).broadcast(in_dims);
}
template <typename ElementType, unsigned int Rank, unsigned int AxisCount>
void softmax(void* input,
void* output,
const Shape& input_shape,
const AxisSet& softmax_axes)
{
Eigen::array<Eigen::Index, Rank> in_dims, rdims, bcast;
Eigen::array<Eigen::Index, AxisCount> axes;
rdims.fill(1);
for (int i = 0; i < Rank; i++)
{
in_dims[i] = input_shape[i];
}
for (int i = 0; i < Rank; i++)
{
if (softmax_axes.count(i))
{
rdims[i] = 1;
}
else
{
rdims[i] = in_dims[i];
}
}
for (int i = 0; i < Rank; i++)
{
bcast[i] = in_dims[i] / rdims[i];
}
int i = 0;
for (auto axis : softmax_axes)
{
axes[i++] = axis;
}
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out(
static_cast<ElementType *>(output), in_dims),
in(static_cast<ElementType *>(input), in_dims);
out.device(eigen::global_thread_pool_device) =
(in - in.maximum(axes).eval().reshape(rdims).broadcast(bcast)).exp();
out.device(eigen::global_thread_pool_device) =
out * out.sum(axes).inverse().eval().reshape(rdims).broadcast(bcast);
}
template <typename ElementType, unsigned int Rank>
void softmax_1rd(void* input,
void* output,
const Shape& input_shape,
const AxisSet& softmax_axes)
{
softmax<ElementType, Rank, 1>(input, output, input_shape, softmax_axes);
}
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment