Commit bb6de284 authored by gaurides's avatar gaurides Committed by Robert Kimball

Add in_place support for ReplaceSlice (#1559)

* Add in_place suport for ReplaceSlice

* Add emit_replace_slice_inplace kernel

* changed file permissions to original

* Formatted code using maint/apply-code-format.sh

* Removed data type check and removed dead code

* Removed setting mkldnn_op(true). ReplaceSlice is not mkldnn op
parent ba59b80b
......@@ -70,7 +70,6 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
size_t offset = in_place_outputs.count(tensor)
? in_place_outputs.at(tensor)->get_pool_offset()
: mm.allocate(tensor->size());
tensor->set_pool_offset(offset);
}
......
......@@ -2451,16 +2451,31 @@ namespace ngraph
writer << " {" << join(out[0].get_shape()) << "});\n";
}
#else
kernel::emit_replace_slice(writer,
args[0].get_element_type().c_type_string(),
args[0].get_name(),
args[1].get_name(),
out[0].get_name(),
args[1].get_shape(),
out[0].get_shape(),
replace_slice->get_lower_bounds(),
replace_slice->get_upper_bounds(),
replace_slice->get_strides());
if (args[0].get_name() != out[0].get_name())
{
kernel::emit_replace_slice(writer,
args[0].get_element_type().c_type_string(),
args[0].get_name(),
args[1].get_name(),
out[0].get_name(),
args[1].get_shape(),
out[0].get_shape(),
replace_slice->get_lower_bounds(),
replace_slice->get_upper_bounds(),
replace_slice->get_strides());
}
else
{
kernel::emit_replace_slice_inplace(writer,
args[0].get_element_type().c_type_string(),
args[0].get_name(),
args[1].get_name(),
args[1].get_shape(),
args[0].get_shape(),
replace_slice->get_lower_bounds(),
replace_slice->get_upper_bounds(),
replace_slice->get_strides());
}
#endif
writer.block_end();
}
......
......@@ -196,6 +196,24 @@ void ngraph::runtime::cpu::kernel::emit_replace_slice(codegen::CodeWriter& write
emit_pointwise_copy(writer, element_type, arg1, out, input_transform, output_transform);
}
void ngraph::runtime::cpu::kernel::emit_replace_slice_inplace(
codegen::CodeWriter& writer,
const string& element_type,
const string& arg0, // replacement context
const string& arg1, // replacement value
const Shape& arg1_shape,
const Shape& arg0_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& strides)
{
// Step 1: Overwrite the slice for replacement.
CoordinateTransform input_transform(arg1_shape);
CoordinateTransform output_transform(arg0_shape, lower_bounds, upper_bounds, strides);
emit_pointwise_copy(writer, element_type, arg1, arg0, input_transform, output_transform);
}
void ngraph::runtime::cpu::kernel::emit_slice(codegen::CodeWriter& writer,
const string& element_type,
const string& arg0, // replacement context
......
......@@ -54,6 +54,15 @@ namespace ngraph
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& strides);
void emit_replace_slice_inplace(codegen::CodeWriter& writer,
const std::string& element_type,
const std::string& arg0, // replacement context
const std::string& arg1, // replacement value
const Shape& arg1_shape,
const Shape& out_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& strides);
void emit_slice(codegen::CodeWriter& writer,
const std::string& element_type,
const std::string& arg0, // replacement context
......
......@@ -33,6 +33,7 @@
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
......@@ -514,6 +515,22 @@ namespace ngraph
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ReplaceSlice)
{
auto replace_slice = static_cast<op::ReplaceSlice*>(node);
// ReplaceSlice is independent of data type. Hence not checking type
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
if (get_user_count(node->get_argument(0).get()) == 1)
{
// Safe to overwrite input
op_annotations->add_in_place_oi_pair({0, 0, true});
}
replace_slice->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::LRN)
{
......@@ -788,6 +805,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::QuantizedAvgPool),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedAvgPool>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Softmax>},
{TI(ngraph::op::ReplaceSlice),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReplaceSlice>},
{TI(ngraph::op::ConvolutionAdd),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionAdd>},
{TI(ngraph::op::Dequantize),
......
......@@ -3991,6 +3991,33 @@ NGRAPH_TEST(${BACKEND_NAME}, replace_slice_scalar)
EXPECT_EQ((vector<float>{808}), read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, replace_slice_matrix_inplace)
{
Shape shape_a{4, 4};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto abs_A = make_shared<op::Abs>(A);
Shape shape_b{3, 2};
auto B = make_shared<op::Parameter>(element::f32, shape_b);
Shape shape_r{4, 4};
auto r = make_shared<op::ReplaceSlice>(abs_A, B, Coordinate{0, 1}, Coordinate{3, 3});
auto abs_r = make_shared<op::Abs>(r);
auto f = make_shared<Function>(abs_r, op::ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
auto b = backend->create_tensor(element::f32, shape_b);
copy_data(b, vector<float>{102, 103, 106, 107, 110, 111});
auto result = backend->create_tensor(element::f32, shape_r);
backend->call_with_validate(f, {result}, {a, b});
EXPECT_EQ((vector<float>{1, 102, 103, 4, 5, 106, 107, 8, 9, 110, 111, 12, 13, 14, 15, 16}),
read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, replace_slice_matrix)
{
Shape shape_a{4, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment