Commit f60dd831 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU Direct Execution: Implement Dot

parent 045c1898
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <cstring>
#include "ngraph/op/dot.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/dot.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Dot)
{
auto dot = static_cast<const ngraph::op::Dot*>(node);
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto reduction_axes_count = dot->get_reduction_axes_count();
if (!shape_size(result_shape))
{
auto functor = [](CPURuntimeContext* ctx) {};
functors.emplace_back(functor);
return;
}
if (!shape_size(arg0_shape) || !shape_size(arg1_shape))
{
auto size = shape_size(result_shape) * out[0].get_element_type().size();
auto functor = [&, size](CPURuntimeContext* ctx) {
memset(out_tensor, 0, size);
};
functors.emplace_back(functor);
return;
}
if (arg0_shape.empty() || arg1_shape.empty())
{
auto first = (arg0_shape.empty() ? args[0] : args[1]);
auto second = (arg0_shape.empty() ? args[1] : args[0]);
auto& first_tensor = tensor_data[first.get_name()];
auto& second_tensor = tensor_data[second.get_name()];
std::function<decltype(runtime::cpu::kernel::dot_scalar<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_scalar);
auto element_count = shape_size(second.get_shape());
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(first_tensor, second_tensor, out_tensor, element_count);
};
functors.emplace_back(functor);
return;
}
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1) &&
reduction_axes_count == 1)
{
std::function<decltype(runtime::cpu::kernel::dot_1d_1d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_1d_1rd);
auto functor =
[&, kernel, arg0_shape, arg1_shape, result_shape](CPURuntimeContext* ctx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
arg0_shape,
arg1_shape,
result_shape);
};
functors.emplace_back(functor);
return;
}
if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
reduction_axes_count == 1)
{
std::function<decltype(runtime::cpu::kernel::dot_2d_1d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_2d_1d_1rd);
auto functor =
[&, kernel, arg0_shape, arg1_shape, result_shape](CPURuntimeContext* ctx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
arg0_shape,
arg1_shape,
result_shape);
};
functors.emplace_back(functor);
return;
}
if ((arg0_shape.size() == 3) && (arg1_shape.size() == 3) &&
reduction_axes_count == 1)
{
std::function<decltype(runtime::cpu::kernel::dot_3d_3d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_3d_3d_1rd);
auto functor =
[&, kernel, arg0_shape, arg1_shape, result_shape](CPURuntimeContext* ctx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
arg0_shape,
arg1_shape,
result_shape);
};
functors.emplace_back(functor);
return;
}
if ((arg0_shape.size() == 3) && (arg1_shape.size() == 2) &&
reduction_axes_count == 1)
{
std::function<decltype(runtime::cpu::kernel::dot_3d_2d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_3d_2d_1rd);
auto functor =
[&, kernel, arg0_shape, arg1_shape, result_shape](CPURuntimeContext* ctx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
arg0_shape,
arg1_shape,
result_shape);
};
functors.emplace_back(functor);
return;
}
std::function<decltype(runtime::cpu::kernel::dot<float>)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::dot);
auto functor =
[&, kernel, arg0_shape, arg1_shape, result_shape, reduction_axes_count](
CPURuntimeContext* ctx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
arg0_shape,
arg1_shape,
result_shape,
reduction_axes_count);
};
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(Dot);
}
}
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/runtime/cpu/kernel/eigen_thread_pool.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
template <typename ElementType,
unsigned int Input0Rank,
unsigned int Input1Rank,
unsigned int DotDims>
void dot(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Shape& output_shape)
{
constexpr unsigned int OutRank = Input0Rank + Input1Rank - 2 * DotDims;
Eigen::array<Eigen::Index, OutRank> out_dims;
Eigen::array<Eigen::Index, Input0Rank> in0_dims;
Eigen::array<Eigen::Index, Input1Rank> in1_dims;
Eigen::array<Eigen::IndexPair<Eigen::Index>, DotDims> dot_dims;
for (int i = 0; i < OutRank; i++)
{
out_dims[i] = output_shape[i];
}
for (int i = 0; i < Input0Rank; i++)
{
in0_dims[i] = input0_shape[i];
}
for (int i = 0; i < Input1Rank; i++)
{
in1_dims[i] = input1_shape[i];
}
for (int i = 0; i < DotDims; i++)
{
dot_dims[i].first = Input0Rank - DotDims + i;
dot_dims[i].second = i;
}
Eigen::TensorMap<Eigen::Tensor<ElementType, OutRank, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Input0Rank, Eigen::RowMajor>> in0(
static_cast<ElementType*>(input0), in0_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Input1Rank, Eigen::RowMajor>> in1(
static_cast<ElementType*>(input1), in1_dims);
out.device(eigen::global_thread_pool_device) = in0.contract(in1, dot_dims);
}
template <typename ElementType>
void dot_scalar(void* input0, void* input1, void* output, size_t element_count)
{
Eigen::array<Eigen::Index, 1> out_dims;
Eigen::array<Eigen::Index, 1> in1_dims;
out_dims[0] = element_count;
in1_dims[0] = element_count;
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), out_dims);
auto in0 = static_cast<ElementType*>(input0);
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> in1(
static_cast<ElementType*>(input1), in1_dims);
out.device(eigen::global_thread_pool_device) = in0[0] * in1;
}
template <typename ElementType>
void dot_1d_1d_1rd(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Shape& output_shape)
{
dot<ElementType, 1, 1, 1>(
input0, input1, output, input0_shape, input1_shape, output_shape);
}
template <typename ElementType>
void dot_2d_1d_1rd(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Shape& output_shape)
{
dot<ElementType, 2, 1, 1>(
input0, input1, output, input0_shape, input1_shape, output_shape);
}
template <typename ElementType>
void dot_3d_3d_1rd(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Shape& output_shape)
{
dot<ElementType, 3, 3, 1>(
input0, input1, output, input0_shape, input1_shape, output_shape);
}
template <typename ElementType>
void dot_3d_2d_1rd(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Shape& output_shape)
{
dot<ElementType, 3, 2, 1>(
input0, input1, output, input0_shape, input1_shape, output_shape);
}
template <typename ElementType>
void dot(void* arg0,
void* arg1,
void* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& out_shape,
size_t reduction_axes_count)
{
reference::dot(static_cast<const ElementType*>(arg0),
static_cast<const ElementType*>(arg1),
static_cast<ElementType*>(out),
arg0_shape,
arg1_shape,
out_shape,
reduction_axes_count);
}
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment