Commit a16c4961 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Robert Kimball

Update slice kernels (#2180)

* initial commit for update slice op

* Finished up update_slice fusion and added codegen support

* style fixes

* Added unit test for in-place update-slice strided

* change pattern name
parent e0933553
......@@ -42,7 +42,6 @@ set(SRC
builder/concat.cpp
builder/convert.cpp
builder/convert_layout.cpp
builder/quantized_conv.cpp
builder/convolution.cpp
builder/dot.cpp
builder/function_call.cpp
......@@ -60,8 +59,10 @@ set(SRC
builder/reduce_function.cpp
builder/reduce_function_window.cpp
builder/replace_slice.cpp
builder/quantized_max_pool.cpp
builder/quantization.cpp
builder/quantized_avg_pool.cpp
builder/quantized_conv.cpp
builder/quantized_max_pool.cpp
builder/reshape.cpp
builder/reverse.cpp
builder/reverse_sequence.cpp
......@@ -70,11 +71,11 @@ set(SRC
builder/select_and_scatter.cpp
builder/sigmoid.cpp
builder/slice.cpp
builder/state.cpp
builder/softmax.cpp
builder/sum.cpp
builder/topk.cpp
builder/state.cpp
builder/quantization.cpp
builder/update_slice.cpp
kernel/pad.cpp
kernel/reduce_max.cpp
kernel/reduce_sum.cpp
......@@ -85,12 +86,13 @@ set(SRC
op/batch_dot.cpp
op/batch_norm_relu.cpp
op/bounded_relu.cpp
op/group_conv.cpp
op/group_conv_bias.cpp
op/halide_op.cpp
op/conv_add.cpp
op/conv_bias.cpp
op/conv_relu.cpp
op/convert_layout.cpp
op/group_conv.cpp
op/group_conv_bias.cpp
op/halide_op.cpp
op/leaky_relu.cpp
op/loop_kernel.cpp
op/lstm.cpp
......@@ -98,7 +100,7 @@ set(SRC
op/max_pool_with_indices.cpp
op/rnn.cpp
op/sigmoid_mul.cpp
op/conv_add.cpp
op/update_slice.cpp
pass/cpu_assignment.cpp
pass/cpu_collapse_dims.cpp
pass/cpu_fusion.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstring>
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/update_slice.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::UpdateSlice)
{
auto& functors = external_function->get_functors();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto update_slice = static_cast<const ngraph::op::UpdateSlice*>(node);
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto strides = update_slice->get_strides();
auto lower_bounds = update_slice->get_lower_bounds();
auto upper_bounds = update_slice->get_upper_bounds();
if (!arg0_shape.size())
{
size_t size = args[0].get_element_type().size();
auto functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(out_tensor, arg1_tensor, size);
};
functors.emplace_back(functor);
return;
}
if (ngraph::is_strided(strides))
{
std::function<decltype(runtime::cpu::kernel::strided_update_slice<float, 2>)>
kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::strided_update_slice);
auto functor =
[&, kernel, arg0_shape, arg1_shape, lower_bounds, upper_bounds, strides](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
arg0_shape,
arg1_shape,
lower_bounds,
upper_bounds,
strides,
ectx->arena);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::update_slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::update_slice);
auto functor = [&, kernel, arg0_shape, arg1_shape, lower_bounds](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
arg0_shape,
arg1_shape,
lower_bounds,
ectx->arena);
};
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(UpdateSlice);
}
}
}
......@@ -120,6 +120,7 @@
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
......@@ -2499,6 +2500,49 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::UpdateSlice)
{
auto update_slice = static_cast<const ngraph::op::UpdateSlice*>(node);
const Shape& arg0_shape = args[0].get_shape();
const Shape& arg1_shape = args[1].get_shape();
auto strides = update_slice->get_strides();
writer.block_begin();
if (!ngraph::is_strided(strides))
{
writer << "cpu::kernel::update_slice<"
<< args[0].get_element_type().c_type_string() << ", "
<< arg0_shape.size() << ">(\n"
<< " " << args[0].get_name() << ",\n"
<< " " << args[1].get_name() << ",\n"
<< " " << out[0].get_name() << ",\n"
<< " {" << join(arg0_shape) << "},\n"
<< " {" << join(arg1_shape) << "},\n"
<< " {"
<< join(update_slice->get_lower_bounds()) << "},\n"
<< "0);\n";
}
else
{
writer << "cpu::kernel::strided_update_slice<"
<< args[0].get_element_type().c_type_string() << ", "
<< arg0_shape.size() << ">(\n"
<< " " << args[0].get_name() << ",\n"
<< " " << args[1].get_name() << ",\n"
<< " " << out[0].get_name() << ",\n"
<< " {" << join(arg0_shape) << "},\n"
<< " {" << join(arg1_shape) << "},\n"
<< " {"
<< join(update_slice->get_lower_bounds()) << "},\n"
<< " {"
<< join(update_slice->get_upper_bounds()) << "},\n"
<< " {"
<< join(update_slice->get_strides()) << "},\n"
<< "0);\n";
}
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ReplaceSlice)
{
......
......@@ -161,6 +161,7 @@
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/runtime/cpu/pass/cpu_assignment.hpp"
#include "ngraph/runtime/cpu/pass/cpu_collapse_dims.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
......@@ -331,6 +332,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Acos), &runtime::cpu::CPU_Emitter::emit<op::Acos>},
{TI(ngraph::op::Atan), &runtime::cpu::CPU_Emitter::emit<op::Atan>},
{TI(ngraph::op::ReplaceSlice), &runtime::cpu::CPU_Emitter::emit<op::ReplaceSlice>},
{TI(ngraph::op::UpdateSlice), &runtime::cpu::CPU_Emitter::emit<op::UpdateSlice>},
{TI(ngraph::op::OneHot), &runtime::cpu::CPU_Emitter::emit<op::OneHot>},
{TI(ngraph::op::Floor), &runtime::cpu::CPU_Emitter::emit<op::Floor>},
{TI(ngraph::op::Ceiling), &runtime::cpu::CPU_Emitter::emit<op::Ceiling>},
......
......@@ -124,9 +124,11 @@ namespace mkl
namespace ngraph
{
class Shape;
class AxisSet;
class AxisVector;
class Coordinate;
class Shape;
class Strides;
namespace runtime
{
......@@ -195,6 +197,26 @@ namespace ngraph
const AxisVector& input_axis_order,
const Shape& output_shape,
int arena);
template <typename ElementType, unsigned int Rank>
void update_slice(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Coordinate& lower_bounds,
int arena);
template <typename ElementType, unsigned int Rank>
void strided_update_slice(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& slice_strides,
int arena);
}
}
}
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
template <typename ElementType, unsigned int Rank>
void update_slice(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Coordinate& lower_bounds,
int arena)
{
Eigen::array<Eigen::Index, Rank> in0_dims, in1_dims;
Eigen::array<Eigen::Index, Rank> indices;
for (int i = 0; i < Rank; i++)
{
in0_dims[i] = input0_shape[i];
in1_dims[i] = input1_shape[i];
indices[i] = lower_bounds[i];
}
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), in0_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in0(
static_cast<ElementType*>(input0), in0_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in1(
static_cast<ElementType*>(input1), in1_dims);
if (input0 != output)
{
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = in0;
}
out.slice(indices, in1_dims)
.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = in0.slice(indices, in1_dims) + in1;
}
template <typename ElementType, unsigned int Rank>
void strided_update_slice(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& slice_strides,
int arena)
{
Eigen::array<Eigen::Index, Rank> in0_dims, in1_dims;
Eigen::array<Eigen::Index, Rank> start_indices, stop_indices, strides;
for (int i = 0; i < Rank; i++)
{
in0_dims[i] = input0_shape[i];
in1_dims[i] = input1_shape[i];
start_indices[i] = lower_bounds[i];
stop_indices[i] = upper_bounds[i];
strides[i] = slice_strides[i];
}
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), in0_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in0(
static_cast<ElementType*>(input0), in0_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in1(
static_cast<ElementType*>(input1), in1_dims);
if (input0 != output)
{
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = in0;
}
out.stridedSlice(start_indices, stop_indices, strides)
.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = in0.stridedSlice(start_indices, stop_indices, strides) + in1;
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/op/update_slice.hpp"
using namespace std;
using namespace ngraph;
op::UpdateSlice::UpdateSlice(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& strides)
: Op("UpdateSlice", check_single_output_args({arg0, arg1}))
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_strides(strides)
{
constructor_validate_and_infer_types();
}
op::UpdateSlice::UpdateSlice(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds)
: Op("UpdateSlice", check_single_output_args({arg0, arg1}))
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_strides(Strides(lower_bounds.size(), 1))
{
constructor_validate_and_infer_types();
}
void op::UpdateSlice::validate_and_infer_types()
{
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
// construct the default value.
if (m_strides.size() == 0)
{
m_strides = Strides(m_lower_bounds.size(), 1);
}
const PartialShape& arg0_shape = get_input_partial_shape(0);
const PartialShape& arg1_shape = get_input_partial_shape(1);
Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this,
Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()))
<< "Argument ranks do not match (arg0 shape: " << arg0_shape
<< ", arg1 shape: " << arg1_shape << ").";
element::Type arg0_et = get_input_element_type(0);
element::Type arg1_et = get_input_element_type(1);
element::Type merged_args_et;
NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et))
<< "Argument element types do not match (arg0 element type: " << arg0_et
<< ", arg1 element type: " << arg1_et << ").";
NODE_VALIDATION_ASSERT(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size())
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ").";
}
NODE_VALIDATION_ASSERT(this,
merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank)
<< "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ").";
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
sliced_dims[i] = sliced_dim;
}
PartialShape slice_shape{sliced_dims};
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape))
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape "
<< "(" << slice_shape << ").";
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
PartialShape result_shape =
(arg0_shape.rank().is_static())
? arg0_shape
: PartialShape(std::vector<Dimension>(output_rank, Dimension::dynamic()));
set_output_type(0, merged_args_et, result_shape);
}
shared_ptr<Node> op::UpdateSlice::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<UpdateSlice>(
new_args.at(0), new_args.at(1), m_lower_bounds, m_upper_bounds, m_strides);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/coordinate.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/strides.hpp"
namespace ngraph
{
namespace op
{
/// \brief Takes two input tensors of identical rank, with the second tensor no larger than the first in any dimension, and returns a copy of
/// the first input tensor with the specified slice incremented by the second input tensor.
///
/// ## Parameters
///
/// | | Description |
/// | -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | `lower_bounds` | The (inclusive) lower-bound coordinates \f$l_i\f$ for the tensor slice to be overwritten. For example, a lower-bound of \f$(1,2)\f$ means to start the slice at row 1 and column 2. |
/// | `upper_bounds` | The (non-inclusive) upper-bound coordinates \f$u_i\f$ for the tensor slice to be overwritten. For example, an upper-bound of \f$(5,4)\f$ means to end the slice before row 4 and column 3. |
/// | `strides` | The strides \f$s_i\f$ for the tensor slice to be overwritten. For example, in the matrix case, strides of \f$(1,3)\f$ means to take every row, and every third column (starting at the lower bound). |
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | ------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------- |
/// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
/// | `arg1` | \f$E[d'_1,\dots,d'_n]\f$ where \f$(d'_i = \lceil(u_i - l_i)\, /\, s_i\rceil\f$ | A tensor of the same element type and rank as `arg0`, whose shape is determined by the lower and upper slice bounds and slice strides. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$ where \f$T[i_1,\dots,i_n] += \texttt{arg1}[j_1,\dots,j_n]\f$ if \f$j_1,\dots,j_n\f$ is in bounds for `arg1` and for all \f$m\f$, \f$i_m = l_m + j_m s_m\f$, otherwise \f$\texttt{arg0}[i_1,\dots,i_n]\f$. |
class UpdateSlice : public Op
{
public:
/// \brief Constructs a tensor slice update operation.
///
/// \param arg0 The tensor view to overwrite into.
/// \param arg1 The tensor view to increment into `arg0`.
/// \param lower_bounds The axiswise lower bounds of the slice (inclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param strides The slicing strides; for example, strides of `{n,m}` means to take
/// every nth row and every mth column of `arg0` as part of the
/// slice to be replaced.
UpdateSlice(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& strides);
/// \brief Constructs a tensor slice replacement operation with unit strides; i.e., every element inside the bounding box will be overwritten.
///
/// \param arg0 The tensor view to overwrite into.
/// \param arg1 The tensor view to increment into `arg0`.
/// \param lower_bounds The axiswise lower bounds of the slice (inclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
UpdateSlice(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
/// \return The exclusive upper-bound coordinates.
const Coordinate& get_upper_bounds() const { return m_upper_bounds; }
/// \return The slicing strides.
const Strides& get_strides() const { return m_strides; }
protected:
void validate_and_infer_types() override;
Coordinate m_lower_bounds;
Coordinate m_upper_bounds;
Strides m_strides;
};
}
}
......@@ -58,6 +58,7 @@
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
using namespace std;
using namespace ngraph;
......@@ -427,6 +428,21 @@ namespace ngraph
replace_slice->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::UpdateSlice)
{
auto update_slice = static_cast<op::UpdateSlice*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
if (get_user_count(node->get_argument(0).get()) == 1)
{
// Safe to overwrite input
op_annotations->add_in_place_oi_pair({0, 0, true});
}
update_slice->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::LRN)
{
......@@ -838,6 +854,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Slice), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Slice>},
{TI(ngraph::op::ReplaceSlice),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReplaceSlice>},
{TI(ngraph::op::UpdateSlice),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::UpdateSlice>},
{TI(ngraph::op::ConvolutionAdd),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionAdd>},
{TI(ngraph::op::QuantizedConvolutionRelu),
......
......@@ -42,6 +42,7 @@
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/slice.hpp"
......@@ -64,6 +65,7 @@
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/util.hpp"
extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input,
......@@ -1806,3 +1808,51 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fuse_lstm_recurrent_state(
auto m = std::make_shared<ngraph::pattern::Matcher>(concat_label, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_update_slice()
{
Shape shape_a{2, 32, 2};
Shape shape_b{1, 32, 2};
auto input = std::make_shared<pattern::op::Label>(element::f32, shape_a);
auto slice = std::make_shared<op::Slice>(input, Coordinate{1, 0, 0}, Coordinate{2, 32, 2});
auto slice_label = std::make_shared<pattern::op::Label>(slice, nullptr, NodeVector{slice});
auto update_input = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto update = std::make_shared<op::Add>(update_input, slice_label);
auto replace_slice = std::make_shared<op::ReplaceSlice>(
input, update, Coordinate{1, 0, 0}, Coordinate{2, 32, 2});
ngraph::pattern::graph_rewrite_callback callback = [input, update_input, slice_label](
pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for update_slice = " << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto slice_m = std::static_pointer_cast<op::Slice>(pattern_map[slice_label]);
auto replace_m = std::static_pointer_cast<op::ReplaceSlice>(m.get_match_root());
if (replace_m->get_lower_bounds() != slice_m->get_lower_bounds() ||
replace_m->get_upper_bounds() != slice_m->get_upper_bounds() ||
replace_m->get_strides() != slice_m->get_strides())
{
NGRAPH_DEBUG
<< "Update slice cannot be created, slice and replace_slice are not compatible";
return false;
}
if (slice_m->get_users().size() > 1 || replace_m->get_argument(1)->get_users().size() > 1)
{
NGRAPH_DEBUG << "Update slice cannot be created, intermediate values required";
return false;
}
auto update_slice = std::make_shared<op::UpdateSlice>(pattern_map[input],
pattern_map[update_input],
replace_m->get_lower_bounds(),
replace_m->get_upper_bounds(),
replace_m->get_strides());
ngraph::replace_node(m.get_match_root(), update_slice);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(
replace_slice, callback, "CPUFusion.UpdateSlice");
this->add_matcher(m);
}
......@@ -79,6 +79,7 @@ public:
// construct_conv_add() should always be after construct_conv_bias()
construct_conv_add();
construct_conv_add_relu();
construct_update_slice();
construct_fuse_lstm_recurrent_state();
}
}
......@@ -107,5 +108,6 @@ private:
void construct_conv_bias_affine_folding();
void construct_groupconv_batchnorm_global_stats_folding();
void construct_groupconv_batchnorm_global_stats_folding_relu();
void construct_update_slice();
void construct_fuse_lstm_recurrent_state();
};
......@@ -62,6 +62,7 @@
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_loop_kernel_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_mat_fusion.hpp"
......@@ -3122,6 +3123,201 @@ TEST(cpu_fusion, fuse_leaky_relu)
EXPECT_TRUE(test::all_close(cpu2_results.at(0), expected_result));
}
TEST(cpu_fusion, fuse_update_slice)
{
auto make_function = [](bool fuse = true) {
auto input = std::make_shared<op::Parameter>(element::f32, Shape{4, 32, 16});
Shape lower_bounds{1, 0, 0};
Shape upper_bounds{2, 32, 16};
auto slice = std::make_shared<op::Slice>(
input, fuse ? lower_bounds : Shape{3, 0, 0}, fuse ? upper_bounds : Shape{4, 32, 16});
auto update = std::make_shared<op::Parameter>(element::f32, Shape{1, 32, 16});
auto add = std::make_shared<op::Add>(slice, update);
auto out = std::make_shared<op::ReplaceSlice>(input, add, lower_bounds, upper_bounds);
auto f = make_shared<Function>(NodeVector{out}, ParameterVector{input, update});
return f;
};
auto fuse = make_function(true);
auto no_fuse = make_function(false);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(fuse);
pass_manager.run_passes(no_fuse);
EXPECT_EQ(1, count_ops_of_type<op::UpdateSlice>(fuse));
EXPECT_EQ(0, count_ops_of_type<op::UpdateSlice>(no_fuse));
auto int_f = make_function();
auto cpu_f = make_function();
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
}
TEST(cpu_fusion, fuse_update_slice_inplace)
{
auto make_function = [](bool fuse = true) {
auto input = std::make_shared<op::Parameter>(element::f32, Shape{4, 32, 16});
auto abs = std::make_shared<op::Abs>(input);
Shape lower_bounds{1, 0, 0};
Shape upper_bounds{2, 32, 16};
auto slice = std::make_shared<op::Slice>(abs, lower_bounds, upper_bounds);
auto update = std::make_shared<op::Parameter>(element::f32, Shape{1, 32, 16});
auto add = std::make_shared<op::Add>(slice, update);
auto rs = std::make_shared<op::ReplaceSlice>(abs, add, lower_bounds, upper_bounds);
auto out = std::make_shared<op::Abs>(rs);
if (fuse)
{
return make_shared<Function>(NodeVector{out}, ParameterVector{input, update});
}
else
{
return make_shared<Function>(NodeVector{out, add}, ParameterVector{input, update});
}
};
auto fuse = make_function(true);
auto no_fuse = make_function(false);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(fuse);
pass_manager.run_passes(no_fuse);
EXPECT_EQ(1, count_ops_of_type<op::UpdateSlice>(fuse));
EXPECT_EQ(0, count_ops_of_type<op::UpdateSlice>(no_fuse));
auto int_f = make_function();
auto cpu_f = make_function();
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
}
TEST(cpu_fusion, fuse_update_slice_strided)
{
auto make_function = [](bool fuse = true) {
auto input = std::make_shared<op::Parameter>(element::f32, Shape{4, 32, 16});
Shape lower_bounds{1, 0, 0};
Shape upper_bounds{2, 32, 16};
Strides strides{1, 2, 2};
auto slice = std::make_shared<op::Slice>(input,
fuse ? lower_bounds : Shape{3, 0, 0},
fuse ? upper_bounds : Shape{4, 32, 16},
strides);
auto update = std::make_shared<op::Parameter>(element::f32, Shape{1, 16, 8});
auto add = std::make_shared<op::Add>(slice, update);
auto out =
std::make_shared<op::ReplaceSlice>(input, add, lower_bounds, upper_bounds, strides);
auto f = make_shared<Function>(NodeVector{out}, ParameterVector{input, update});
return f;
};
auto fuse = make_function(true);
auto no_fuse = make_function(false);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(fuse);
pass_manager.run_passes(no_fuse);
EXPECT_EQ(1, count_ops_of_type<op::UpdateSlice>(fuse));
EXPECT_EQ(0, count_ops_of_type<op::UpdateSlice>(no_fuse));
auto int_f = make_function();
auto cpu_f = make_function();
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
}
TEST(cpu_fusion, fuse_update_slice_strided_inplace)
{
auto make_function = [](bool fuse = true) {
auto input = std::make_shared<op::Parameter>(element::f32, Shape{4, 32, 16});
auto abs = std::make_shared<op::Abs>(input);
Shape lower_bounds{1, 0, 0};
Shape upper_bounds{2, 32, 16};
Strides strides{1, 4, 2};
auto slice = std::make_shared<op::Slice>(abs, lower_bounds, upper_bounds, strides);
auto update = std::make_shared<op::Parameter>(element::f32, Shape{1, 8, 8});
auto add = std::make_shared<op::Add>(slice, update);
auto rs = std::make_shared<op::ReplaceSlice>(abs, add, lower_bounds, upper_bounds, strides);
auto out = std::make_shared<op::Abs>(rs);
if (fuse)
{
return make_shared<Function>(NodeVector{out}, ParameterVector{input, update});
}
else
{
return make_shared<Function>(NodeVector{out, add}, ParameterVector{input, update});
}
};
auto fuse = make_function(true);
auto no_fuse = make_function(false);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(fuse);
pass_manager.run_passes(no_fuse);
EXPECT_EQ(1, count_ops_of_type<op::UpdateSlice>(fuse));
EXPECT_EQ(0, count_ops_of_type<op::UpdateSlice>(no_fuse));
auto int_f = make_function();
auto cpu_f = make_function();
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
}
TEST(cpu_fusion, dot_batch_forward)
{
const Shape shape_a{2, 3, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment