Commit e5d9b540 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

IAT: Collapse dims for Dot ops (#1991)

* Collapse dimensions for inputs to Dot

* Remove eigen kernels for higher dimension dots since they will collapse to cblas_gemm kernels

* Moved collapse dims pass after the fusion passes to prevent interference with fusion patterns. Use cblas_gemm for 2D dot
parent f33317cc
......@@ -132,13 +132,13 @@ namespace ngraph
return;
}
if ((arg0_shape.size() == 3) && (arg1_shape.size() == 3) &&
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 2) &&
reduction_axes_count == 1)
{
std::function<decltype(runtime::cpu::kernel::dot_3d_3d_1rd<float>)> kernel;
std::function<decltype(runtime::cpu::kernel::dot_1d_2d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_3d_3d_1rd);
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_2d_1rd);
auto functor = [&, kernel, arg0_shape, arg1_shape, result_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
......@@ -154,100 +154,37 @@ namespace ngraph
return;
}
if ((arg0_shape.size() == 3) && (arg1_shape.size() == 2) &&
reduction_axes_count == 1)
if (out[0].get_element_type() == element::f32 && (arg0_shape.size() == 2) &&
(arg1_shape.size() == 2) && reduction_axes_count == 1)
{
if (args[0].get_element_type() == element::f32)
{
auto shape_a = args[0].get_shape();
auto shape_b = args[1].get_shape();
const int64_t m = shape_a[1];
const int64_t k = shape_a[2];
const int64_t n = shape_b[1];
// this also works when mat_a is shape (1, m, k)
const int64_t offset_a = m * k;
// we do not offset mat_b
const int64_t offset_b = 0;
const int64_t offset_c = m * n;
const int64_t group_count = 1;
const int64_t group_size = shape_a[0];
auto functor =
[&, offset_a, offset_b, offset_c, m, n, k, group_size, group_count](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
cblas::Transpose transpose = cblas::Transpose::None;
float alpha = 1.0f;
vector<const float*> a;
for (size_t i = 0; i < group_size; i++)
{
a.emplace_back(static_cast<const float*>(arg0_tensor) +
i * offset_a);
}
const float** a_array = a.data();
int64_t lda_array = std::max(int64_t(1), k);
vector<const float*> b;
for (size_t i = 0; i < group_size; i++)
{
b.emplace_back(static_cast<const float*>(arg1_tensor) +
i * offset_b);
}
const float** b_array = b.data();
const int64_t ldb_array = std::max(int64_t(1), n);
float beta = 0.0f;
vector<float*> c;
for (size_t i = 0; i < group_size; i++)
{
c.emplace_back(static_cast<float*>(out_tensor) + i * offset_c);
}
float** c_array = c.data();
const int64_t ldc_array = std::max(int64_t(1), n);
cblas_sgemm_batch(cblas::Layout::RowMajor,
&transpose,
&transpose,
&m,
&n,
&k,
&alpha,
a_array,
&lda_array,
b_array,
&ldb_array,
&beta,
c_array,
&ldc_array,
group_count,
&group_size);
};
functors.emplace_back(functor);
return;
}
else
{
std::function<decltype(runtime::cpu::kernel::dot_3d_2d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_3d_2d_1rd);
auto functor = [&, kernel, arg0_shape, arg1_shape, result_shape](
auto m = arg0_shape[0];
auto n = arg1_shape[1];
auto k = arg0_shape[1];
bool transpose_A = false, transpose_B = false;
auto lda = arg0_shape[1];
auto ldb = arg1_shape[1];
const float beta = 0.0f;
auto functor =
[&, transpose_A, transpose_B, m, n, k, lda, ldb, beta, result_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
arg0_shape,
arg1_shape,
result_shape,
ectx->arena);
cblas::cblas_sgemm(
cblas::Layout::RowMajor,
transpose_A ? cblas::Transpose::Transpose : cblas::Transpose::None,
transpose_B ? cblas::Transpose::Transpose : cblas::Transpose::None,
m,
n,
k,
1.0f,
static_cast<float*>(arg0_tensor),
max(1UL, lda),
static_cast<float*>(arg1_tensor),
max(1UL, ldb),
beta,
static_cast<float*>(out_tensor),
max(1UL, result_shape[1]));
};
functors.emplace_back(functor);
return;
}
functors.emplace_back(functor);
return;
}
std::function<decltype(runtime::cpu::kernel::dot<float>)> kernel;
......
......@@ -1038,7 +1038,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
// pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>();
......
......@@ -128,6 +128,19 @@ namespace ngraph
input0, input1, output, input0_shape, input1_shape, output_shape, arena);
}
template <typename ElementType>
void dot_1d_2d_1rd(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Shape& output_shape,
int arena)
{
dot<ElementType, 1, 2, 1>(
input0, input1, output, input0_shape, input1_shape, output_shape, arena);
}
template <typename ElementType>
void dot_3d_3d_1rd(void* input0,
void* input1,
......
......@@ -22,6 +22,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/product.hpp"
......@@ -191,6 +192,55 @@ static bool collapse_reduction(std::shared_ptr<Node> n)
return replaced;
}
template <typename T>
static bool collapse_dot(std::shared_ptr<Node> n)
{
bool replaced = false;
auto node = std::static_pointer_cast<T>(n).get();
auto A_shape = node->get_argument(0)->get_shape();
auto B_shape = node->get_argument(1)->get_shape();
auto reduction_count = node->get_reduction_axes_count();
AxisSet operated_axes_A, operated_axes_B;
for (size_t i = 0; i < reduction_count; i++)
{
operated_axes_A.insert(A_shape.size() - i - 1);
operated_axes_B.insert(i);
}
struct CollapsedShape cshape_A, cshape_B;
collapse_dims(A_shape, operated_axes_A, cshape_A, true);
collapse_dims(B_shape, operated_axes_B, cshape_B, true);
if (A_shape != cshape_A.fshape || B_shape != cshape_B.fshape)
{
// Reshape A to cshape_A.fshape
AxisVector A_axis_order = ngraph::get_default_order(A_shape);
auto reshape_A = std::make_shared<op::Reshape>(
node->get_argument(0), A_axis_order, Shape(cshape_A.fshape));
// Reshape B to cshape_B.fshape
AxisVector B_axis_order = ngraph::get_default_order(B_shape);
auto reshape_B = std::make_shared<op::Reshape>(
node->get_argument(1), B_axis_order, Shape(cshape_B.fshape));
auto cdot =
std::make_shared<op::Dot>(reshape_A, reshape_B, reduction_count ? 1 : reduction_count);
auto reshape_output = std::make_shared<op::Reshape>(
cdot, ngraph::get_default_order(cdot->get_shape()), node->get_shape());
ngraph::replace_node(n, reshape_output);
NGRAPH_DEBUG << "CollapseDims: Replaced dot " << A_shape << " . " << B_shape
<< " reduction count: " << reduction_count << " with "
<< Shape(cshape_A.fshape) << " . " << Shape(cshape_B.fshape);
replaced = true;
}
return replaced;
}
bool runtime::cpu::pass::CPUCollapseDims::run_on_function(std::shared_ptr<ngraph::Function> f)
{
bool replaced = false;
......@@ -216,6 +266,10 @@ bool runtime::cpu::pass::CPUCollapseDims::run_on_function(std::shared_ptr<ngraph
{
replaced |= collapse_reduction<op::Sum>(n);
}
else if (std::dynamic_pointer_cast<op::Dot>(n))
{
replaced |= collapse_dot<op::Dot>(n);
}
}
return replaced;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment