Commit b29f7220 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

Optimize 4D Reshape (#836)

* CPU: Optimize 4D "nGraph" Reshapes (shuffle+reshape)

* CPU: Add kernel sources

* CPU: Replace 2D with 3D reshape

* CPU: Fixes

* CPU: Simplify
parent 877ac969
......@@ -209,6 +209,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/kernel/pad.cpp
runtime/cpu/kernel/reduce_max.cpp
runtime/cpu/kernel/reduce_sum.cpp
runtime/cpu/kernel/reshape.cpp
runtime/cpu/op/conv_bias.cpp
runtime/cpu/op/conv_relu.cpp
runtime/cpu/op/convert_layout.cpp
......
......@@ -1318,13 +1318,36 @@ namespace ngraph
writer << " );\n";
}
#else
kernel::emit_reshape(writer,
args[0].get_element_type().c_type_string(),
args[0].get_name(),
out[0].get_name(),
args[0].get_shape(),
out[0].get_shape(),
reshape->get_input_order());
if (args[0].get_element_type() == element::f32 && args[0].get_shape().size() == 3 &&
out[0].get_shape().size() == 3)
{
writer << "cpu::kernel::reshape_3d_3d_float32(" << args[0].get_name() << ", "
<< out[0].get_name() << ", "
<< "{" << join(args[0].get_shape()) << "}, "
<< "{" << join(reshape->get_input_order()) << "}, "
<< "{" << join(out[0].get_shape()) << "}"
<< ");\n";
}
else if (args[0].get_element_type() == element::f32 &&
args[0].get_shape().size() == 4 && out[0].get_shape().size() == 4)
{
writer << "cpu::kernel::reshape_4d_4d_float32(" << args[0].get_name() << ", "
<< out[0].get_name() << ", "
<< "{" << join(args[0].get_shape()) << "}, "
<< "{" << join(reshape->get_input_order()) << "}, "
<< "{" << join(out[0].get_shape()) << "}"
<< ");\n";
}
else
{
kernel::emit_reshape(writer,
args[0].get_element_type().c_type_string(),
args[0].get_name(),
out[0].get_name(),
args[0].get_shape(),
out[0].get_shape(),
reshape->get_input_order());
}
#endif
writer.block_end();
}
......
......@@ -109,6 +109,7 @@ namespace ngraph
{
class Shape;
class AxisSet;
class AxisVector;
namespace runtime
{
......@@ -150,6 +151,18 @@ namespace ngraph
const Shape& input_shape,
const Shape& output_shape,
const AxisSet& reduction_axes);
void reshape_3d_3d_float32(float* input,
float* output,
const Shape& input_shape,
const AxisVector& input_axis_order,
const Shape& output_shape);
void reshape_4d_4d_float32(float* input,
float* output,
const Shape& input_shape,
const AxisVector& input_axis_order,
const Shape& output_shape);
}
}
}
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "reshape.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
void reshape_3d_3d_float32(float* input,
float* output,
const Shape& input_shape,
const AxisVector& input_axis_order,
const Shape& output_shape)
{
reshape<float, 3, 3>(
input, output, input_shape, input_axis_order, output_shape);
}
void reshape_4d_4d_float32(float* input,
float* output,
const Shape& input_shape,
const AxisVector& input_axis_order,
const Shape& output_shape)
{
reshape<float, 4, 4>(
input, output, input_shape, input_axis_order, output_shape);
}
}
}
}
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/axis_vector.hpp"
#include "ngraph/runtime/cpu/kernel/eigen_thread_pool.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
template <typename ElementType, unsigned int InRank, unsigned int OutRank>
void reshape(ElementType* input,
ElementType* output,
const Shape& input_shape,
const AxisVector& input_axis_order,
const Shape& output_shape)
{
Eigen::array<Eigen::Index, OutRank> out_dims;
Eigen::array<Eigen::Index, InRank> in_dims;
Eigen::array<Eigen::Index, InRank> axis_order;
for (int i = 0; i < OutRank; i++)
{
out_dims[i] = output_shape[i];
}
for (int i = 0; i < InRank; i++)
{
in_dims[i] = input_shape[i];
axis_order[i] = input_axis_order[i];
}
Eigen::TensorMap<Eigen::Tensor<ElementType, OutRank, Eigen::RowMajor>> out(
output, out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, InRank, Eigen::RowMajor>> in(
input, in_dims);
out.device(eigen::global_thread_pool_device) =
in.shuffle(axis_order).reshape(out_dims);
}
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment