Commit 55d11bb4 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Generalize MatMulBias (2nd attempt) (#597)

* generalize matmulbias

fixes

disable logging

* unit-test failures
parent 5c7e9844
......@@ -150,3 +150,50 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape2, callback);
this->add_matcher(m);
}
void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
{
//dot(A,B).T = dot (B.T, A.T)
auto dot_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Dot>(n));
};
auto pdot = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = std::make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.match_root()->get_name();
std::shared_ptr<Node> nn;
auto mtranspose = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
//this also checks the rank
if (mtranspose->get_input_order() != AxisVector{1, 0})
{
NGRAPH_DEBUG << "Reshape isn't transpose. "
<< vector_to_string(mtranspose->get_input_order());
return nn;
}
auto mdot = mtranspose->get_input_op(0);
if (mdot->get_shape().size() != 2)
{
NGRAPH_DEBUG << "Dot has the wrong shape. " << vector_to_string(mdot->get_shape());
return nn;
}
auto arg0 = mdot->get_input_op(0);
auto reshape0_shape = Shape{arg0->get_shape().at(1), arg0->get_shape().at(0)};
auto reshape0 = std::make_shared<op::Reshape>(arg0, AxisVector{1, 0}, reshape0_shape);
auto arg1 = mdot->get_input_op(1);
auto reshape1_shape = Shape{arg1->get_shape().at(1), arg1->get_shape().at(0)};
auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
auto tdot = std::shared_ptr<Node>(new op::Dot(reshape1, reshape0));
return tdot;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(preshape, callback);
this->add_matcher(m);
}
......@@ -32,11 +32,13 @@ public:
ReshapeElimination()
: GraphRewrite()
{
construct_dot_transpose_pattern();
construct_identity_reshape_pattern();
construct_reshapex2_pattern();
}
private:
void construct_dot_transpose_pattern();
void construct_identity_reshape_pattern();
void construct_reshapex2_pattern();
};
......@@ -240,7 +240,7 @@ namespace ngraph
const Shape& arg0_shape = cg->get_arg0_shape(); //W
const Shape& arg1_shape = cg->get_arg1_shape(); //x
const Shape& arg2_shape = args[2].get_shape(); //bias (C)
const Shape& arg2_shape = node->get_shape(); //bias (C)
static const char* ctranspose = "cblas::Transpose::Transpose, ";
static const char* cnotranspose = "cblas::Transpose::None, ";
......@@ -270,16 +270,23 @@ namespace ngraph
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
const char* cbeta = "0.0f";
if (args.size() > 2)
{
writer << "memcpy(" << out[0].get_name() << ", " << args[2].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
cbeta = "1.0f";
}
writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << tranpose_a << tranpose_b << m << ", " << n
<< ", " << k << ",\n"
<< " 1.0f, " << args[0].get_name() << ", " << max(1UL, lda) << ", "
<< args[1].get_name() << ", " << max(1UL, ldb) << ", 1.0f,\n"
<< args[1].get_name() << ", " << max(1UL, ldb) << ", " << cbeta << ",\n"
<< " " << out[0].get_name() << ", " << max(1UL, arg2_shape[1])
<< ");\n";
writer.indent--;
writer << "}\n";
}
......
......@@ -21,13 +21,14 @@
std::shared_ptr<ngraph::Node>
ngraph::op::MatmulBias::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
if (new_args.size() != 2 && new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<MatmulBias>(new_args.at(0),
new_args.at(1),
new_args.at(1),
new_args.size() == 3 ? new_args.at(2) : nullptr,
m_shape_w,
m_shape_x,
m_transpose_w,
......@@ -41,7 +42,9 @@ ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W,
Shape shape_x,
bool transpose_w,
bool transpose_x)
: RequiresTensorViewArgs("MatMulBias", {W, x, b})
: RequiresTensorViewArgs("MatMulBias",
b == nullptr ? std::vector<std::shared_ptr<Node>>{W, x}
: std::vector<std::shared_ptr<Node>>{W, x, b})
, m_shape_w(shape_w)
, m_shape_x(shape_x)
, m_transpose_w(transpose_w)
......@@ -74,8 +77,12 @@ ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W,
}
Shape dot_shape{shape_w.at(1 - dot_dimension_w), shape_x.at(1 - dot_dimension_x)};
NGRAPH_DEBUG << "dot_shape shape = " << vector_to_string(dot_shape)
<< " , b shape = " << vector_to_string(b->get_shape());
NGRAPH_DEBUG << "dot_shape shape = " << vector_to_string(dot_shape);
if (b)
{
NGRAPH_DEBUG << "b shape = " << vector_to_string(b->get_shape());
}
add_output(W->get_element_type(), dot_shape);
}
......@@ -49,6 +49,12 @@ static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
if (!r_w)
{
if (arg->get_shape().size() != 2)
{
NGRAPH_DEBUG << arg->get_name() << " 's rank != 2 "
<< ngraph::vector_to_string(arg->get_shape());
return false;
}
return true; //nth to do; reshape isn't a reshape
}
......@@ -106,7 +112,38 @@ static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector
return output;
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern()
{
Shape shape_w{2, 4};
Shape shape_x{4, 1};
Shape shape_b{1};
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto pmmb = std::make_shared<op::MatmulBias>(
W, x, nullptr, W->get_shape(), x->get_shape(), false, false);
auto pbroadcast = std::make_shared<op::Broadcast>(b, pmmb->get_shape(), AxisSet{0});
auto padd = pmmb + pbroadcast;
ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmulbias_pattern against node = "
<< m.match_root()->get_name();
auto mpattern = m.match_root(); //add
auto m_matmul = mpattern->get_input_op(0);
auto m_broadcast = mpattern->get_input_op(1);
auto pattern_map = m.get_pattern_map();
return m_matmul->copy_with_new_args(
NodeVector{pattern_map[W], pattern_map[x], m_broadcast});
};
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
{
Shape shape_w{2, 4};
Shape shape_x{4, 1};
......@@ -124,30 +161,34 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto pbroadcast = std::make_shared<op::Broadcast>(b, shape_dot, AxisSet{0});
auto padd = pdot + pbroadcast;
ngraph::pattern::gr_callback_fn callback = [W, x, b](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_gemm_pattern against node = "
ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn = nullptr;
std::shared_ptr<Node> nn;
auto mpattern = m.match_root();
auto dot = m.match_root();
if (mpattern->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
return nn;
}
auto dot = mpattern->get_input_op(0);
if (dot->get_shape().size() != 2)
{
NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!";
return nn;
}
if (shape_size(dot->get_shape()) == 0)
{
NGRAPH_DEBUG << "dot has a zero dimension";
return nn;
}
bool transpose_w = false;
Shape shape_arg0{pattern_map[W]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0))
......@@ -164,7 +205,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W],
pattern_map[x],
mpattern->get_input_op(1),
nullptr,
shape_arg0,
shape_arg1,
transpose_w,
......@@ -172,7 +213,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
return cg;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
auto m = std::make_shared<ngraph::pattern::Matcher>(pdot, callback);
this->add_matcher(m);
}
......
......@@ -38,11 +38,13 @@ public:
CPUFusion()
: GraphRewrite()
{
construct_gemm_pattern();
construct_matmul_pattern();
construct_matmulbias_pattern();
construct_fprop_bn();
}
private:
void construct_gemm_pattern();
void construct_matmul_pattern();
void construct_matmulbias_pattern();
void construct_fprop_bn();
};
......@@ -133,6 +133,42 @@ TEST(cpu_fusion, gemm_cpu)
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, gemm_cpu_no_bias)
{
auto shapeA = Shape{3, 2};
auto shapeB = Shape{2, 3};
auto shapeC = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeB);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, Shape{2, 3});
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, Shape{3, 2});
auto cg =
make_shared<op::MatmulBias>(A, B, nullptr, A->get_shape(), B->get_shape(), true, true);
auto f = make_shared<Function>(cg, op::ParameterVector{A, B});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shapeA);
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shapeB);
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, shapeC);
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{3.0f, 3.0f, 3.0f, 9.0f, 9.0f, 9.0f};
copy_data(a, dataA);
copy_data(b, dataB);
cf->call({a, b}, {result});
vector<float> expected{9, 27, 36, 108};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, cpu_fusion_pass_basic)
{
Shape shape{};
......@@ -154,6 +190,50 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
}
TEST(cpu_fusion, cpu_fusion_pass_matmul_bias)
{
Shape shape_w{2, 4};
Shape shape_x{4, 1};
Shape shape_b{1};
auto W = make_shared<op::Parameter>(element::f32, shape_w);
auto x = make_shared<op::Parameter>(element::f32, shape_x);
auto b = make_shared<op::Parameter>(element::f32, shape_b);
auto mmb = std::make_shared<op::MatmulBias>(
W, x, nullptr, W->get_shape(), x->get_shape(), false, false);
auto broadcast = std::make_shared<op::Broadcast>(b, mmb->get_shape(), AxisSet{0});
auto add = mmb + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::ParameterVector{W, x, b});
pass_manager.run_passes(func);
auto gmm = graph->get_input_op(0);
ASSERT_TRUE(std::dynamic_pointer_cast<op::MatmulBias>(gmm));
ASSERT_EQ(gmm->get_input_op(2), broadcast);
}
TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias)
{
Shape shape_w{4, 2};
Shape shape_x{1, 4};
auto W = make_shared<op::Parameter>(element::f32, shape_w);
auto x = make_shared<op::Parameter>(element::f32, shape_x);
auto reshape_w = std::make_shared<op::Reshape>(W, AxisVector{1, 0}, Shape{2, 4});
auto reshape_x = std::make_shared<op::Reshape>(x, AxisVector{1, 0}, Shape{4, 1});
auto re_dot = make_shared<op::Dot>(reshape_w, reshape_x);
auto graph = make_shared<op::Abs>(re_dot);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::ParameterVector{W, x});
pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmb, 1);
}
TEST(cpu_fusion, gemm_mlp)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json");
......@@ -163,8 +243,8 @@ TEST(cpu_fusion, gemm_mlp)
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3);
size_t mmb = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmb, 3);
}
//TODO: Move this test to backend_test.in.cpp once we have the INTERPRETER
......
......@@ -82,3 +82,27 @@ TEST(reshape_elimination, bn_bprop_rewrite)
size_t count_after = count_ops_of_type<op::Reshape>(func);
ASSERT_TRUE(count_after < count_before);
}
TEST(reshape_elimination, dot_transpose_to_dot_w_transpose_args)
{
Shape shape_w{2, 4};
Shape shape_x{4, 1};
auto W = make_shared<op::Parameter>(element::f32, shape_w);
auto x = make_shared<op::Parameter>(element::f32, shape_x);
auto dot = make_shared<op::Dot>(W, x);
auto reshape_dot = std::make_shared<op::Reshape>(dot, AxisVector{1, 0}, Shape{1, 2});
auto graph = make_shared<op::Abs>(reshape_dot);
pass::Manager pass_manager;
pass_manager.register_pass<pass::ReshapeElimination>();
auto func = make_shared<Function>(graph, op::ParameterVector{W, x});
pass_manager.run_passes(func);
auto gdot = graph->get_input_op(0);
ASSERT_TRUE(std::dynamic_pointer_cast<op::Dot>(gdot));
ASSERT_TRUE(std::dynamic_pointer_cast<op::Reshape>(gdot->get_input_op(0)));
ASSERT_TRUE(std::dynamic_pointer_cast<op::Reshape>(gdot->get_input_op(1)));
ASSERT_EQ(gdot->get_input_op(0)->get_input_op(0), x);
ASSERT_EQ(gdot->get_input_op(1)->get_input_op(0), W);
ASSERT_EQ(gdot->get_shape(), (Shape{1, 2}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment