Commit e765956a authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

IAT: Collapse dimensions around arithmetic reduction operations (#1763)

* Collapse dimensions for arithmetic reduction ops to support faster kernels

* Propagate in-place constants and allow in-place reshapes for more cases

* style fix

* Additional checks for parameter and constant to help backends that dont propagate in-place parameter and constant inputs

* Allow non-destructive pass through onlyu if memory sharing is disabled

* Address PR feedback

* Bug fix for collapse dimensions in case of null reduction
parent 1beec46b
......@@ -51,8 +51,10 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
auto input = &node->get_inputs().at(oi_pair.input).get_tensor();
auto input_node = node->get_inputs().at(oi_pair.input).get_output().get_node();
// an input tensor can be reused if this is the last use
if (node->liveness_free_list.count(input) != 0 &&
// For destructive kernel, this should be the last use
// Non-destructive kernels can pass through if memory sharing is disabled
if ((node->liveness_free_list.count(input) != 0 ||
(m_disable_memory_sharing && !oi_pair.destructive)) &&
node->liveness_new_list.count(output) != 0)
{
in_place_outputs.insert({output, input});
......
......@@ -551,7 +551,7 @@ using namespace ngraph::runtime;
{
for (shared_ptr<Node> node : function_ordered_ops.at(current_function))
{
const ngraph::op::Constant* c = dynamic_cast<ngraph::op::Constant*>(node.get());
ngraph::op::Constant* c = dynamic_cast<ngraph::op::Constant*>(node.get());
if (c)
{
m_active_constants.push_back(node);
......@@ -677,6 +677,15 @@ using namespace ngraph::runtime;
"(*(ctx->G), [&](const tbb::flow::continue_msg &msg)\n{});\n";
}
for (shared_ptr<Node> node : ordered_ops)
{
if (dynamic_cast<ngraph::op::Constant*>(node.get()))
{
shared_ptr<descriptor::Tensor> tv = node->get_outputs()[0].get_tensor_ptr();
propagate_in_place_constant(&node->get_outputs().at(0), tv->get_name(), false);
}
}
// Add inputs to the variable name map
size_t arg_index = 0;
for (shared_ptr<ngraph::op::Parameter> param : current_function->get_parameters())
......@@ -1102,6 +1111,53 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input(
}
}
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_constant(
ngraph::descriptor::Output* output, std::string input_name, bool dex)
{
std::deque<ngraph::descriptor::Output*> stack;
stack.push_front(output);
while (stack.size() > 0)
{
ngraph::descriptor::Output* it = stack.front();
stack.pop_front();
for (auto input : it->get_inputs())
{
auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node());
if (!c_op || c_op->is_output())
{
continue;
}
if (auto op_annotations = c_op->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{
if (oi_pair.input == input->get_index() && !oi_pair.destructive)
{
size_t output_index = oi_pair.output;
auto& output_tensor = c_op->get_outputs().at(output_index).get_tensor();
if (dex)
{
tensor_alias[output_tensor.get_name()] = input_name;
}
else
{
m_variable_name_map[output_tensor.get_name()] = input_name;
}
m_tensor_roles[output_tensor.get_name()] = CPUTensorRole::CONSTANT;
NGRAPH_DEBUG << " CPU: Forwarding " << input_name << " through "
<< output_tensor.get_name();
stack.push_back(&c_op->get_outputs().at(output_index));
}
}
}
}
}
}
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
ngraph::descriptor::Output* res_src_output, std::string output_name, bool dex)
{
......@@ -1239,6 +1295,7 @@ void runtime::cpu::CPU_ExternalFunction::build()
tensor_data[tv->get_name()] =
const_cast<void*>(static_pointer_cast<ngraph::op::Constant>(node)->get_data_ptr());
m_tensor_roles[tv->get_name()] = CPUTensorRole::CONSTANT;
propagate_in_place_constant(&node->get_outputs().at(0), tv->get_name(), true);
}
}
......
......@@ -174,6 +174,11 @@ namespace ngraph
// Register passes that are common to codegen and DEX
void register_common_passes(ngraph::pass::Manager& pass_manager);
// For non-destructive passthrough kernels, propagate function
// constant buffers to internal ops
void propagate_in_place_constant(ngraph::descriptor::Output* output,
std::string input_name,
bool dex);
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
void propagate_in_place_input(ngraph::descriptor::Output* output,
......
......@@ -22,32 +22,39 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
struct CollapsedDims
struct CollapsedShape
{
std::vector<size_t> output_shape;
std::vector<bool> is_operated_axis;
std::vector<size_t> axis_set;
std::vector<size_t> input_shape;
Shape fshape; // Collapsed shape with operated axes
Shape rshape; // Collapsed shape without operated axes
AxisVector axis_set; // operated axis in fshape
};
// Fold and collapse axes of output_shape.
// Fold and collapse axes of shape.
// Contiguous axes that are not being operated on can be collapsed.
// Contiguous axes that are being operated on are collapsed optionally.
// Skip size 1 dimensions.
static void collapse_dims(std::vector<size_t>& output_shape,
// E.g.,
// Shape{3, 3, 2}, AxisSet{0, 1} -> Shape{9, 2}, AxisSet{0}
// Shape{2, 4, 6, 6}, AxisSet{2, 3} -> Shape{8, 36}, AxisSet{1}
static void collapse_dims(std::vector<size_t>& shape,
std::set<size_t> operated_axes,
struct CollapsedDims& cdims,
struct CollapsedShape& cshape,
bool collapse_operated_axes)
{
size_t collapse_size = 1;
bool operated_axes_run = false;
std::vector<bool> fshape_operated_axis;
bool collapsing = false;
for (int output_idx = static_cast<int>(output_shape.size()) - 1; output_idx >= 0; output_idx--)
for (int output_idx = static_cast<int>(shape.size()) - 1; output_idx >= 0; output_idx--)
{
auto is_operated_axis = operated_axes.count(output_idx) == 1;
auto end_run = (operated_axes_run != is_operated_axis) ||
......@@ -56,88 +63,158 @@ static void collapse_dims(std::vector<size_t>& output_shape,
{
if (collapse_size != 1)
{
cdims.output_shape.push_back(collapse_size);
cdims.is_operated_axis.push_back(operated_axes_run);
cshape.fshape.push_back(collapse_size);
fshape_operated_axis.push_back(operated_axes_run);
collapse_size = 1;
}
}
collapse_size *= output_shape[output_idx];
collapse_size *= shape[output_idx];
operated_axes_run = is_operated_axis;
collapsing = true;
}
// Last run
if (collapse_size != 1)
{
cdims.output_shape.push_back(collapse_size);
cdims.is_operated_axis.push_back(operated_axes_run);
cshape.fshape.push_back(collapse_size);
fshape_operated_axis.push_back(operated_axes_run);
}
std::reverse(cdims.output_shape.begin(), cdims.output_shape.end());
std::reverse(cdims.is_operated_axis.begin(), cdims.is_operated_axis.end());
std::reverse(cshape.fshape.begin(), cshape.fshape.end());
std::reverse(fshape_operated_axis.begin(), fshape_operated_axis.end());
for (size_t i = 0; i < cdims.is_operated_axis.size(); i++)
for (size_t i = 0; i < fshape_operated_axis.size(); i++)
{
if (cdims.is_operated_axis[i])
if (fshape_operated_axis[i])
{
cdims.axis_set.push_back(i);
cshape.axis_set.push_back(i);
}
else
{
cdims.input_shape.push_back(cdims.output_shape[i]);
cshape.rshape.push_back(cshape.fshape[i]);
}
}
}
bool runtime::cpu::pass::CPUCollapseDims::run_on_function(std::shared_ptr<ngraph::Function> f)
static bool collapse_broadcast(std::shared_ptr<Node> n)
{
bool replaced = false;
for (auto n : f->get_ordered_ops())
auto node = std::dynamic_pointer_cast<op::Broadcast>(n).get();
auto input_shape = node->get_argument(0)->get_shape();
auto output_shape = node->get_shape();
auto operated_axes = node->get_broadcast_axes();
struct CollapsedShape cshape;
collapse_dims(output_shape, operated_axes, cshape, true);
if (cshape.axis_set.size() == 0)
{
if (std::dynamic_pointer_cast<op::Broadcast>(n))
{
auto node = std::dynamic_pointer_cast<op::Broadcast>(n).get();
auto input_shape = node->get_argument(0)->get_shape();
auto output_shape = node->get_shape();
auto operated_axes = node->get_broadcast_axes();
// Null broadcast operation, replace with reshape
AxisVector axis_order = ngraph::get_default_order(input_shape);
auto reshape =
std::make_shared<op::Reshape>(node->get_argument(0), axis_order, n->get_shape());
ngraph::replace_node(n, reshape);
replaced = true;
}
else if (output_shape.size() != cshape.fshape.size())
{
// Reshape arg to collapsed input_shape
AxisVector input_axis_order = ngraph::get_default_order(input_shape);
auto reshape_input = std::make_shared<op::Reshape>(
node->get_argument(0), input_axis_order, Shape(cshape.rshape));
struct CollapsedDims cdims;
auto broadcast = std::make_shared<op::Broadcast>(
reshape_input, Shape(cshape.fshape), AxisSet(cshape.axis_set));
collapse_dims(output_shape, operated_axes, cdims, true);
// Reshape collapsed output to original output_shape
AxisVector output_axis_order = ngraph::get_default_order(cshape.fshape);
auto reshape_output =
std::make_shared<op::Reshape>(broadcast, output_axis_order, output_shape);
ngraph::replace_node(n, reshape_output);
replaced = true;
}
if (cdims.axis_set.size() == 0)
{
// Null broadcast operation, replace with reshape
AxisVector axis_order = ngraph::get_default_order(input_shape);
auto reshape = std::make_shared<op::Reshape>(
node->get_argument(0), axis_order, n->get_shape());
ngraph::replace_node(n, reshape);
replaced = true;
}
else if (output_shape.size() != cdims.output_shape.size())
{
// Reshape arg to collapsed input_shape
AxisVector input_axis_order = ngraph::get_default_order(input_shape);
auto reshape_input = std::make_shared<op::Reshape>(
node->get_argument(0), input_axis_order, Shape(cdims.input_shape));
auto broadcast = std::make_shared<op::Broadcast>(
reshape_input, Shape(cdims.output_shape), AxisSet(cdims.axis_set));
// Reshape collapsed output to original output_shape
AxisVector output_axis_order = ngraph::get_default_order(cdims.output_shape);
auto reshape_output =
std::make_shared<op::Reshape>(broadcast, output_axis_order, output_shape);
ngraph::replace_node(n, reshape_output);
replaced = true;
}
if (replaced)
{
NGRAPH_DEBUG << "CollapseDims: Replaced broadcast " << input_shape << " " << operated_axes
<< " " << output_shape << " with " << Shape(cshape.rshape) << " "
<< AxisSet(cshape.axis_set) << " " << Shape(cshape.fshape);
}
return replaced;
}
if (replaced)
{
NGRAPH_DEBUG << "CollapseDims: Replaced broadcast " << input_shape << " "
<< operated_axes << " " << output_shape << " with "
<< Shape(cdims.input_shape) << " " << AxisSet(cdims.axis_set) << " "
<< Shape(cdims.output_shape);
}
template <typename T>
static bool collapse_reduction(std::shared_ptr<Node> n)
{
bool replaced = false;
auto node = std::dynamic_pointer_cast<T>(n).get();
auto input_shape = node->get_argument(0)->get_shape();
auto output_shape = node->get_shape();
auto operated_axes = node->get_reduction_axes();
struct CollapsedShape cshape;
collapse_dims(input_shape, operated_axes, cshape, true);
if (cshape.axis_set.size() == 0)
{
// Null reduction operation
AxisVector axis_order = ngraph::get_default_order(input_shape);
auto reshape =
std::make_shared<op::Reshape>(node->get_argument(0), axis_order, n->get_shape());
ngraph::replace_node(n, reshape);
replaced = true;
}
else if (input_shape.size() != cshape.fshape.size())
{
// Reshape arg to collapsed input_shape
AxisVector input_axis_order = ngraph::get_default_order(input_shape);
auto reshape_input = std::make_shared<op::Reshape>(
node->get_argument(0), input_axis_order, Shape(cshape.fshape));
auto reduction = std::make_shared<T>(reshape_input, AxisSet(cshape.axis_set));
// Reshape collapsed output to original output_shape
AxisVector output_axis_order = ngraph::get_default_order(cshape.rshape);
auto reshape_output =
std::make_shared<op::Reshape>(reduction, output_axis_order, output_shape);
ngraph::replace_node(n, reshape_output);
replaced = true;
}
if (replaced)
{
NGRAPH_DEBUG << "CollapseDims: Replaced arithmetic reduction " << input_shape << " "
<< operated_axes << " " << output_shape << " with " << Shape(cshape.fshape)
<< " " << AxisSet(cshape.axis_set) << " " << Shape(cshape.rshape);
}
return replaced;
}
bool runtime::cpu::pass::CPUCollapseDims::run_on_function(std::shared_ptr<ngraph::Function> f)
{
bool replaced = false;
for (auto n : f->get_ordered_ops())
{
if (std::dynamic_pointer_cast<op::Broadcast>(n))
{
replaced |= collapse_broadcast(n);
}
else if (std::dynamic_pointer_cast<op::Max>(n))
{
replaced |= collapse_reduction<op::Max>(n);
}
else if (std::dynamic_pointer_cast<op::Min>(n))
{
replaced |= collapse_reduction<op::Min>(n);
}
else if (std::dynamic_pointer_cast<op::Product>(n))
{
replaced |= collapse_reduction<op::Product>(n);
}
else if (std::dynamic_pointer_cast<op::Sum>(n))
{
replaced |= collapse_reduction<op::Sum>(n);
}
}
......
......@@ -478,3 +478,37 @@ TEST(cpu_test, reshape_layout_optimizations7)
}
EXPECT_EQ(count_ops_of_type<runtime::cpu::op::ConvertLayout>(cpu_f), 0);
}
TEST(cpu_test, collapse_dims1)
{
// Expand multiple dimensions. Ensure no extra conversions downstream
auto make_function = []() -> std::shared_ptr<Function> {
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 4, 10, 6, 10});
auto sum1 = make_shared<op::Sum>(A, AxisVector{1}); // Shape{1, 10, 6, 10}
auto sum2 = make_shared<op::Sum>(sum1, AxisVector{0}); // Shape{10, 6, 10}
return make_shared<Function>(NodeVector{sum2}, op::ParameterVector{A});
};
auto backend = runtime::Backend::create("CPU");
auto cpu_f = make_function();
auto int_f = make_function();
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
// sum1 will have two reshapes added around it. sum2 will be replaced
// with a reshape
EXPECT_EQ(count_ops_of_type<op::Reshape>(cpu_f), 3);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment