Unverified Commit 4487d60e authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Call generic CPU kernels for Broadcast and Reshape (#4130)

* Call generic CPU kernels for Broadcast and Reshape

* Move generic kernels to opt_kernel

* Use strides in broadcast
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent ad6d1d43
......@@ -16,7 +16,7 @@
#include "constant_folding.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/opt_kernel/broadcast.hpp"
#include "ngraph/type/element_type.hpp"
using namespace std;
......@@ -45,7 +45,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
auto static_bcast_axes = broadcast_v1->get_broadcast_axes();
if (static_bcast_axes.first)
{
runtime::reference::broadcast<T>(constant->get_data_ptr<T>(),
runtime::opt_kernel::broadcast<T>(constant->get_data_ptr<T>(),
data_ptr,
constant->get_shape(),
out_shape,
......@@ -58,7 +58,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
}
else if (auto broadcast_v0 = as_type_ptr<op::v0::Broadcast>(broadcast))
{
runtime::reference::broadcast<T>(constant->get_data_ptr<T>(),
runtime::opt_kernel::broadcast<T>(constant->get_data_ptr<T>(),
data_ptr,
constant->get_shape(),
out_shape,
......
......@@ -16,7 +16,7 @@
#include "constant_folding.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/opt_kernel/reshape.hpp"
using namespace std;
using namespace ngraph;
......@@ -41,7 +41,7 @@ shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant
}
else
{
runtime::reference::reshape<T>(constant->get_data_ptr<T>(),
runtime::opt_kernel::reshape<T>(constant->get_data_ptr<T>(),
data_ptr,
constant->get_shape(),
reshape->get_input_order(),
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <Eigen/Dense>
#include <cmath>
#include <utility>
#ifdef PARALLEL
#include <omp.h>
#endif
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/shape_util.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
namespace kernel
{
#ifdef PARALLEL
static std::tuple<size_t, size_t> get_start_finish(size_t size)
{
const size_t nthreads = omp_get_num_threads();
const size_t ithread = omp_get_thread_num();
const size_t start = ithread * size / nthreads;
const size_t finish = (ithread + 1) * size / nthreads;
return std::make_tuple(start, finish);
}
#endif
template <typename T>
void broadcast_2d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[2];
size_t* out_index =
(broadcast_axes.find(0) == broadcast_axes.end() ? &index[0] : &index[1]);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
out[index[0] * out_shape[1] + index[1]] = in[*out_index];
}
}
}
// #define PARALLEL
template <typename T>
void broadcast_3d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
#ifdef PARALLEL
#pragma omp parallel
#endif
{
size_t start;
size_t finish;
#ifdef PARALLEL
std::tie(start, finish) = get_start_finish(out_shape[0]);
#else
start = 0;
finish = out_shape[0];
#endif
size_t index[3];
size_t* out_index = 0;
for (size_t i = 0; i < 3; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = start; index[0] < finish; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
out[index[0] * out_shape[1] * out_shape[2] +
index[1] * out_shape[2] + index[2]] = in[*out_index];
}
}
}
}
}
template <typename T>
void broadcast_4d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[4];
size_t* out_index = 0;
for (size_t i = 0; i < 4; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
out[index[0] * out_shape[1] * out_shape[2] * out_shape[3] +
index[1] * out_shape[2] * out_shape[3] +
index[2] * out_shape[3] + index[3]] = in[*out_index];
}
}
}
}
}
template <typename T>
void broadcast_5d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[5];
size_t* out_index = 0;
for (size_t i = 0; i < 5; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
out[index[0] * out_shape[1] * out_shape[2] * out_shape[3] *
out_shape[4] +
index[1] * out_shape[2] * out_shape[3] * out_shape[4] +
index[2] * out_shape[3] * out_shape[4] +
index[3] * out_shape[4] + index[4]] = in[*out_index];
}
}
}
}
}
}
template <typename T>
void broadcast_6d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[6];
size_t* out_index = 0;
for (size_t i = 0; i < 6; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
for (index[5] = 0; index[5] < out_shape[5]; ++index[5])
{
out[index[0] * out_shape[1] * out_shape[2] *
out_shape[3] * out_shape[4] * out_shape[5] +
index[1] * out_shape[2] * out_shape[3] *
out_shape[4] * out_shape[5] +
index[2] * out_shape[3] * out_shape[4] *
out_shape[5] +
index[3] * out_shape[4] * out_shape[5] +
index[4] * out_shape[5] + index[5]] =
in[*out_index];
}
}
}
}
}
}
}
template <typename T>
void broadcast(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
if (in_shape.size() == 0)
{
for (size_t i = 0; i < shape_size(out_shape); ++i)
{
out[i] = in[0];
}
}
else if (in_shape.size() == 1)
{
switch (out_shape.size())
{
case 2:
broadcast_2d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 3:
broadcast_3d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 4:
broadcast_4d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 5:
broadcast_5d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 6:
broadcast_6d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
default:
runtime::reference::broadcast<T>(
in, out, in_shape, out_shape, broadcast_axes);
break;
}
}
else
{
runtime::reference::broadcast<T>(
in, out, in_shape, out_shape, broadcast_axes);
}
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/axis_vector.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
namespace kernel
{
template <typename T>
void reshape_in0(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
*out = *in;
}
template <typename T>
void reshape_in1(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[1];
size_t in_index[1];
size_t* map_index[1];
for (size_t i = 0; i < 1; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
*out++ = in[*map_index[0]];
}
}
template <typename T>
void reshape_in2(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[2];
size_t in_index[2];
size_t* map_index[2];
for (size_t i = 0; i < 2; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
*out++ = in[*map_index[0] * in_shape[1] + *map_index[1]];
}
}
}
template <typename T>
void reshape_in3(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[3];
size_t in_index[3];
size_t* map_index[3];
for (size_t i = 0; i < 3; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
*out++ = in[*map_index[0] * in_shape[1] * in_shape[2] +
*map_index[1] * in_shape[2] + *map_index[2]];
}
}
}
}
template <typename T>
void reshape_in4(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[4];
size_t in_index[4];
size_t* map_index[4];
for (size_t i = 0; i < 4; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
{
*out++ =
in[*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] +
*map_index[1] * in_shape[2] * in_shape[3] +
*map_index[2] * in_shape[3] + *map_index[3]];
}
}
}
}
}
template <typename T>
void reshape_in5(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[5];
size_t in_index[5];
size_t* map_index[5];
for (size_t i = 0; i < 5; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
{
for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4])
{
*out++ = in[*map_index[0] * in_shape[1] * in_shape[2] *
in_shape[3] * in_shape[4] +
*map_index[1] * in_shape[2] * in_shape[3] *
in_shape[4] +
*map_index[2] * in_shape[3] * in_shape[4] +
*map_index[3] * in_shape[4] + *map_index[4]];
}
}
}
}
}
}
template <typename T>
void reshape_in6(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[6];
size_t in_index[6];
size_t* map_index[6];
for (size_t i = 0; i < 6; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
{
for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4])
{
for (in_index[5] = 0; in_index[5] < size[5]; ++in_index[5])
{
*out++ =
in[*map_index[0] * in_shape[1] * in_shape[2] *
in_shape[3] * in_shape[4] * in_shape[5] +
*map_index[1] * in_shape[2] * in_shape[3] *
in_shape[4] * in_shape[5] +
*map_index[2] * in_shape[3] * in_shape[4] *
in_shape[5] +
*map_index[3] * in_shape[4] * in_shape[5] +
*map_index[4] * in_shape[5] + *map_index[5]];
}
}
}
}
}
}
}
template <typename T>
void reshape(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
switch (in_shape.size())
{
case 0: reshape_in0<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 1: reshape_in1<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 2: reshape_in2<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 3: reshape_in3<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break;
default: reference::reshape(in, out, in_shape, in_axis_order, out_shape); break;
}
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <utility>
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/shape_util.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace runtime
{
namespace opt_kernel
{
template <typename T>
void broadcast_2d(
const T* in, T* out, const Shape& in_shape, const Shape& out_shape, size_t out_axis)
{
size_t index[2];
size_t& in_index = index[out_axis];
auto out_strides = row_major_strides(out_shape);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
out[index[0] * out_strides[0] + index[1]] = in[in_index];
}
}
}
// #define PARALLEL
template <typename T>
void broadcast_3d(
const T* in, T* out, const Shape& in_shape, const Shape& out_shape, size_t out_axis)
{
size_t index[3];
size_t& in_index = index[out_axis];
auto out_strides = row_major_strides(out_shape);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
out[index[0] * out_strides[0] + index[1] * out_strides[1] + index[2]] =
in[in_index];
}
}
}
}
template <typename T>
void broadcast_4d(
const T* in, T* out, const Shape& in_shape, const Shape& out_shape, size_t out_axis)
{
size_t index[4];
size_t& in_index = index[out_axis];
auto out_strides = row_major_strides(out_shape);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
out[index[0] * out_strides[0] + index[1] * out_strides[1] +
index[2] * out_strides[2] + index[3]] = in[in_index];
}
}
}
}
}
template <typename T>
void broadcast_5d(
const T* in, T* out, const Shape& in_shape, const Shape& out_shape, size_t out_axis)
{
size_t index[5];
size_t& in_index = index[out_axis];
auto out_strides = row_major_strides(out_shape);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
out[index[0] * out_strides[0] + index[1] * out_strides[1] +
index[2] * out_strides[2] + index[3] * out_strides[3] +
index[4]] = in[in_index];
}
}
}
}
}
}
template <typename T>
void broadcast_6d(
const T* in, T* out, const Shape& in_shape, const Shape& out_shape, size_t out_axis)
{
size_t index[6];
size_t& in_index = index[out_axis];
auto out_strides = row_major_strides(out_shape);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
for (index[5] = 0; index[5] < out_shape[5]; ++index[5])
{
out[index[0] * out_strides[0] + index[1] * out_strides[1] +
index[2] * out_strides[2] + index[3] * out_strides[3] +
index[4] * out_strides[4] + index[5]] = in[in_index];
}
}
}
}
}
}
}
template <typename T>
void broadcast(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
if (in_shape.size() == 0)
{
for (size_t i = 0; i < shape_size(out_shape); ++i)
{
out[i] = in[0];
}
}
else if (in_shape.size() == 1)
{
size_t output_axis = 0;
for (size_t i = 0; i < out_shape.size(); i++)
{
if (broadcast_axes.count(i) == 0)
{
output_axis = i;
break;
}
}
switch (out_shape.size())
{
case 2: broadcast_2d<T>(in, out, in_shape, out_shape, output_axis); break;
case 3: broadcast_3d<T>(in, out, in_shape, out_shape, output_axis); break;
case 4: broadcast_4d<T>(in, out, in_shape, out_shape, output_axis); break;
case 5: broadcast_5d<T>(in, out, in_shape, out_shape, output_axis); break;
case 6: broadcast_6d<T>(in, out, in_shape, out_shape, output_axis); break;
default:
runtime::reference::broadcast<T>(
in, out, in_shape, out_shape, broadcast_axes);
break;
}
}
else
{
runtime::reference::broadcast<T>(in, out, in_shape, out_shape, broadcast_axes);
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/axis_vector.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace opt_kernel
{
template <typename T>
void reshape_in0(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
*out = *in;
}
template <typename T>
void reshape_in1(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[1];
size_t in_index[1];
size_t* map_index[1];
for (size_t i = 0; i < 1; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
*out++ = in[*map_index[0]];
}
}
template <typename T>
void reshape_in2(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[2];
size_t in_index[2];
size_t* map_index[2];
for (size_t i = 0; i < 2; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
*out++ = in[*map_index[0] * in_shape[1] + *map_index[1]];
}
}
}
template <typename T>
void reshape_in3(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[3];
size_t in_index[3];
size_t* map_index[3];
for (size_t i = 0; i < 3; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
*out++ = in[*map_index[0] * in_shape[1] * in_shape[2] +
*map_index[1] * in_shape[2] + *map_index[2]];
}
}
}
}
template <typename T>
void reshape_in4(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[4];
size_t in_index[4];
size_t* map_index[4];
for (size_t i = 0; i < 4; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
{
*out++ =
in[*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] +
*map_index[1] * in_shape[2] * in_shape[3] +
*map_index[2] * in_shape[3] + *map_index[3]];
}
}
}
}
}
template <typename T>
void reshape_in5(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[5];
size_t in_index[5];
size_t* map_index[5];
for (size_t i = 0; i < 5; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
{
for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4])
{
*out++ =
in[*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] *
in_shape[4] +
*map_index[1] * in_shape[2] * in_shape[3] * in_shape[4] +
*map_index[2] * in_shape[3] * in_shape[4] +
*map_index[3] * in_shape[4] + *map_index[4]];
}
}
}
}
}
}
template <typename T>
void reshape_in6(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[6];
size_t in_index[6];
size_t* map_index[6];
for (size_t i = 0; i < 6; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
{
for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4])
{
for (in_index[5] = 0; in_index[5] < size[5]; ++in_index[5])
{
*out++ = in[*map_index[0] * in_shape[1] * in_shape[2] *
in_shape[3] * in_shape[4] * in_shape[5] +
*map_index[1] * in_shape[2] * in_shape[3] *
in_shape[4] * in_shape[5] +
*map_index[2] * in_shape[3] * in_shape[4] *
in_shape[5] +
*map_index[3] * in_shape[4] * in_shape[5] +
*map_index[4] * in_shape[5] + *map_index[5]];
}
}
}
}
}
}
}
template <typename T>
void reshape(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
switch (in_shape.size())
{
case 0: reshape_in0<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 1: reshape_in1<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 2: reshape_in2<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 3: reshape_in3<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break;
default: reference::reshape(in, out, in_shape, in_axis_order, out_shape); break;
}
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment