Commit 1daac094 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Robert Kimball

Don't collapse unit size dimensions for dot ops (#2031)

parent e54156cf
...@@ -42,14 +42,13 @@ struct CollapsedShape ...@@ -42,14 +42,13 @@ struct CollapsedShape
// Fold and collapse axes of shape. // Fold and collapse axes of shape.
// Contiguous axes that are not being operated on can be collapsed. // Contiguous axes that are not being operated on can be collapsed.
// Contiguous axes that are being operated on are collapsed optionally. // Contiguous axes that are being operated on are collapsed optionally.
// Skip size 1 dimensions.
// E.g., // E.g.,
// Shape{3, 3, 2}, AxisSet{0, 1} -> Shape{9, 2}, AxisSet{0} // Shape{3, 3, 2}, AxisSet{0, 1} -> Shape{9, 2}, AxisSet{0}
// Shape{2, 4, 6, 6}, AxisSet{2, 3} -> Shape{8, 36}, AxisSet{1} // Shape{2, 4, 6, 6}, AxisSet{2, 3} -> Shape{8, 36}, AxisSet{1}
static void collapse_dims(std::vector<size_t>& shape, static void collapse_dims(std::vector<size_t>& shape,
std::set<size_t> operated_axes, std::set<size_t> operated_axes,
struct CollapsedShape& cshape, struct CollapsedShape& cshape,
bool collapse_operated_axes) bool skip_unit_size = true)
{ {
size_t collapse_size = 1; size_t collapse_size = 1;
bool operated_axes_run = false; bool operated_axes_run = false;
...@@ -58,11 +57,10 @@ static void collapse_dims(std::vector<size_t>& shape, ...@@ -58,11 +57,10 @@ static void collapse_dims(std::vector<size_t>& shape,
for (int output_idx = static_cast<int>(shape.size()) - 1; output_idx >= 0; output_idx--) for (int output_idx = static_cast<int>(shape.size()) - 1; output_idx >= 0; output_idx--)
{ {
auto is_operated_axis = operated_axes.count(output_idx) == 1; auto is_operated_axis = operated_axes.count(output_idx) == 1;
auto end_run = (operated_axes_run != is_operated_axis) || auto end_run = (operated_axes_run != is_operated_axis);
(is_operated_axis && !collapse_operated_axes);
if (collapsing && end_run) if (collapsing && end_run)
{ {
if (collapse_size != 1) if (collapse_size != 1 || !skip_unit_size)
{ {
cshape.fshape.push_back(collapse_size); cshape.fshape.push_back(collapse_size);
fshape_operated_axis.push_back(operated_axes_run); fshape_operated_axis.push_back(operated_axes_run);
...@@ -75,7 +73,7 @@ static void collapse_dims(std::vector<size_t>& shape, ...@@ -75,7 +73,7 @@ static void collapse_dims(std::vector<size_t>& shape,
collapsing = true; collapsing = true;
} }
// Last run // Last run
if (collapse_size != 1) if (collapse_size != 1 || !skip_unit_size)
{ {
cshape.fshape.push_back(collapse_size); cshape.fshape.push_back(collapse_size);
fshape_operated_axis.push_back(operated_axes_run); fshape_operated_axis.push_back(operated_axes_run);
...@@ -106,7 +104,7 @@ static bool collapse_broadcast(std::shared_ptr<Node> n) ...@@ -106,7 +104,7 @@ static bool collapse_broadcast(std::shared_ptr<Node> n)
struct CollapsedShape cshape; struct CollapsedShape cshape;
collapse_dims(output_shape, operated_axes, cshape, true); collapse_dims(output_shape, operated_axes, cshape);
if (cshape.axis_set.size() == 0) if (cshape.axis_set.size() == 0)
{ {
...@@ -155,7 +153,7 @@ static bool collapse_reduction(std::shared_ptr<Node> n) ...@@ -155,7 +153,7 @@ static bool collapse_reduction(std::shared_ptr<Node> n)
struct CollapsedShape cshape; struct CollapsedShape cshape;
collapse_dims(input_shape, operated_axes, cshape, true); collapse_dims(input_shape, operated_axes, cshape);
if (cshape.axis_set.size() == 0) if (cshape.axis_set.size() == 0)
{ {
...@@ -210,8 +208,8 @@ static bool collapse_dot(std::shared_ptr<Node> n) ...@@ -210,8 +208,8 @@ static bool collapse_dot(std::shared_ptr<Node> n)
} }
struct CollapsedShape cshape_A, cshape_B; struct CollapsedShape cshape_A, cshape_B;
collapse_dims(A_shape, operated_axes_A, cshape_A, true); collapse_dims(A_shape, operated_axes_A, cshape_A, false);
collapse_dims(B_shape, operated_axes_B, cshape_B, true); collapse_dims(B_shape, operated_axes_B, cshape_B, false);
if (A_shape != cshape_A.fshape || B_shape != cshape_B.fshape) if (A_shape != cshape_A.fshape || B_shape != cshape_B.fshape)
{ {
......
...@@ -511,3 +511,34 @@ TEST(cpu_test, collapse_dims1) ...@@ -511,3 +511,34 @@ TEST(cpu_test, collapse_dims1)
// with a reshape // with a reshape
EXPECT_EQ(count_ops_of_type<op::Reshape>(cpu_f), 3); EXPECT_EQ(count_ops_of_type<op::Reshape>(cpu_f), 3);
} }
TEST(cpu_test, collapse_dims2)
{
// Collapse dims around a dot where one of the inputs is a scalar
auto make_function = []() -> std::shared_ptr<Function> {
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 3, 1, 1});
auto B = make_shared<op::Parameter>(element::f32, Shape{1, 1});
auto dot = make_shared<op::Dot>(A, B, 1);
return make_shared<Function>(NodeVector{dot}, op::ParameterVector{A, B});
};
auto backend = runtime::Backend::create("CPU");
auto cpu_f = make_function();
auto int_f = make_function();
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment