Commit 55d11bb4 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Generalize MatMulBias (2nd attempt) (#597)

* generalize matmulbias

fixes

disable logging

* unit-test failures
parent 5c7e9844
...@@ -150,3 +150,50 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -150,3 +150,50 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape2, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(reshape2, callback);
this->add_matcher(m); this->add_matcher(m);
} }
void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
{
//dot(A,B).T = dot (B.T, A.T)
auto dot_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Dot>(n));
};
auto pdot = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = std::make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.match_root()->get_name();
std::shared_ptr<Node> nn;
auto mtranspose = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
//this also checks the rank
if (mtranspose->get_input_order() != AxisVector{1, 0})
{
NGRAPH_DEBUG << "Reshape isn't transpose. "
<< vector_to_string(mtranspose->get_input_order());
return nn;
}
auto mdot = mtranspose->get_input_op(0);
if (mdot->get_shape().size() != 2)
{
NGRAPH_DEBUG << "Dot has the wrong shape. " << vector_to_string(mdot->get_shape());
return nn;
}
auto arg0 = mdot->get_input_op(0);
auto reshape0_shape = Shape{arg0->get_shape().at(1), arg0->get_shape().at(0)};
auto reshape0 = std::make_shared<op::Reshape>(arg0, AxisVector{1, 0}, reshape0_shape);
auto arg1 = mdot->get_input_op(1);
auto reshape1_shape = Shape{arg1->get_shape().at(1), arg1->get_shape().at(0)};
auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
auto tdot = std::shared_ptr<Node>(new op::Dot(reshape1, reshape0));
return tdot;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(preshape, callback);
this->add_matcher(m);
}
/******************************************************************************* /*******************************************************************************
* Copyright 2018 Intel Corporation * Copyright 2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#pragma once #pragma once
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class ReshapeElimination; class ReshapeElimination;
} }
} }
class ngraph::pass::ReshapeElimination : public ngraph::pass::GraphRewrite class ngraph::pass::ReshapeElimination : public ngraph::pass::GraphRewrite
{ {
public: public:
ReshapeElimination() ReshapeElimination()
: GraphRewrite() : GraphRewrite()
{ {
construct_identity_reshape_pattern(); construct_dot_transpose_pattern();
construct_reshapex2_pattern(); construct_identity_reshape_pattern();
} construct_reshapex2_pattern();
}
private:
void construct_identity_reshape_pattern(); private:
void construct_reshapex2_pattern(); void construct_dot_transpose_pattern();
}; void construct_identity_reshape_pattern();
void construct_reshapex2_pattern();
};
...@@ -240,7 +240,7 @@ namespace ngraph ...@@ -240,7 +240,7 @@ namespace ngraph
const Shape& arg0_shape = cg->get_arg0_shape(); //W const Shape& arg0_shape = cg->get_arg0_shape(); //W
const Shape& arg1_shape = cg->get_arg1_shape(); //x const Shape& arg1_shape = cg->get_arg1_shape(); //x
const Shape& arg2_shape = args[2].get_shape(); //bias (C) const Shape& arg2_shape = node->get_shape(); //bias (C)
static const char* ctranspose = "cblas::Transpose::Transpose, "; static const char* ctranspose = "cblas::Transpose::Transpose, ";
static const char* cnotranspose = "cblas::Transpose::None, "; static const char* cnotranspose = "cblas::Transpose::None, ";
...@@ -270,16 +270,23 @@ namespace ngraph ...@@ -270,16 +270,23 @@ namespace ngraph
writer << "{ // " << node->get_name() << "\n"; writer << "{ // " << node->get_name() << "\n";
writer.indent++; writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[2].get_name() << ", " const char* cbeta = "0.0f";
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
if (args.size() > 2)
{
writer << "memcpy(" << out[0].get_name() << ", " << args[2].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
cbeta = "1.0f";
}
writer << "cblas::cblas_sgemm(" writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << tranpose_a << tranpose_b << m << ", " << n << "cblas::Layout::RowMajor, " << tranpose_a << tranpose_b << m << ", " << n
<< ", " << k << ",\n" << ", " << k << ",\n"
<< " 1.0f, " << args[0].get_name() << ", " << max(1UL, lda) << ", " << " 1.0f, " << args[0].get_name() << ", " << max(1UL, lda) << ", "
<< args[1].get_name() << ", " << max(1UL, ldb) << ", 1.0f,\n" << args[1].get_name() << ", " << max(1UL, ldb) << ", " << cbeta << ",\n"
<< " " << out[0].get_name() << ", " << max(1UL, arg2_shape[1]) << " " << out[0].get_name() << ", " << max(1UL, arg2_shape[1])
<< ");\n"; << ");\n";
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
......
/******************************************************************************* /*******************************************************************************
* Copyright 2018 Intel Corporation * Copyright 2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "matmul_bias.hpp" #include "matmul_bias.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
ngraph::op::MatmulBias::copy_with_new_args(const NodeVector& new_args) const ngraph::op::MatmulBias::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) if (new_args.size() != 2 && new_args.size() != 3)
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return std::make_shared<MatmulBias>(new_args.at(0),
new_args.at(1), return std::make_shared<MatmulBias>(new_args.at(0),
new_args.at(1), new_args.at(1),
m_shape_w, new_args.size() == 3 ? new_args.at(2) : nullptr,
m_shape_x, m_shape_w,
m_transpose_w, m_shape_x,
m_transpose_x); m_transpose_w,
} m_transpose_x);
}
ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W,
std::shared_ptr<ngraph::Node> x, ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W,
std::shared_ptr<ngraph::Node> b, std::shared_ptr<ngraph::Node> x,
Shape shape_w, std::shared_ptr<ngraph::Node> b,
Shape shape_x, Shape shape_w,
bool transpose_w, Shape shape_x,
bool transpose_x) bool transpose_w,
: RequiresTensorViewArgs("MatMulBias", {W, x, b}) bool transpose_x)
, m_shape_w(shape_w) : RequiresTensorViewArgs("MatMulBias",
, m_shape_x(shape_x) b == nullptr ? std::vector<std::shared_ptr<Node>>{W, x}
, m_transpose_w(transpose_w) : std::vector<std::shared_ptr<Node>>{W, x, b})
, m_transpose_x(transpose_x) , m_shape_w(shape_w)
, m_shape_x(shape_x)
{ , m_transpose_w(transpose_w)
if (shape_w.size() != 2) , m_transpose_x(transpose_x)
{
NGRAPH_DEBUG << "W shape = " << vector_to_string(shape_w); {
throw ngraph_error("W.shape.rank != 2 while creating MatmulBias"); if (shape_w.size() != 2)
} {
NGRAPH_DEBUG << "W shape = " << vector_to_string(shape_w);
if (shape_x.size() != 2) throw ngraph_error("W.shape.rank != 2 while creating MatmulBias");
{ }
NGRAPH_DEBUG << "x shape = " << vector_to_string(shape_x);
throw ngraph_error("x.shape.rank != 2 while creating MatmulBias"); if (shape_x.size() != 2)
} {
NGRAPH_DEBUG << "x shape = " << vector_to_string(shape_x);
size_t dot_dimension_w = (transpose_w) ? 0 : 1; throw ngraph_error("x.shape.rank != 2 while creating MatmulBias");
size_t dot_dimension_x = (transpose_x) ? 1 : 0; }
NGRAPH_DEBUG << "dot_dimension_w = " << dot_dimension_w size_t dot_dimension_w = (transpose_w) ? 0 : 1;
<< " , dot_dimension_x = " << dot_dimension_x; size_t dot_dimension_x = (transpose_x) ? 1 : 0;
NGRAPH_DEBUG << "W shape = " << vector_to_string(shape_w)
<< " , x shape = " << vector_to_string(shape_x); NGRAPH_DEBUG << "dot_dimension_w = " << dot_dimension_w
<< " , dot_dimension_x = " << dot_dimension_x;
if (shape_w.at(dot_dimension_w) != shape_x.at(dot_dimension_x)) NGRAPH_DEBUG << "W shape = " << vector_to_string(shape_w)
{ << " , x shape = " << vector_to_string(shape_x);
throw ngraph_error("product dimensions are not equal while creating MatmulBias");
} if (shape_w.at(dot_dimension_w) != shape_x.at(dot_dimension_x))
{
Shape dot_shape{shape_w.at(1 - dot_dimension_w), shape_x.at(1 - dot_dimension_x)}; throw ngraph_error("product dimensions are not equal while creating MatmulBias");
NGRAPH_DEBUG << "dot_shape shape = " << vector_to_string(dot_shape) }
<< " , b shape = " << vector_to_string(b->get_shape());
Shape dot_shape{shape_w.at(1 - dot_dimension_w), shape_x.at(1 - dot_dimension_x)};
add_output(W->get_element_type(), dot_shape); NGRAPH_DEBUG << "dot_shape shape = " << vector_to_string(dot_shape);
}
if (b)
{
NGRAPH_DEBUG << "b shape = " << vector_to_string(b->get_shape());
}
add_output(W->get_element_type(), dot_shape);
}
/******************************************************************************* /*******************************************************************************
* Copyright 2017-2018 Intel Corporation * Copyright 2017-2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "cpu_fusion.hpp" #include "cpu_fusion.hpp"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <unordered_set> #include <unordered_set>
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/batch_norm.hpp" #include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp" #include "ngraph/ops/dot.hpp"
#include "ngraph/ops/multiply.hpp" #include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/sqrt.hpp" #include "ngraph/ops/sqrt.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp" #include "ngraph/ops/sum.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp" #include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape, static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
std::shared_ptr<ngraph::Node> arg, std::shared_ptr<ngraph::Node> arg,
bool& transpose_w, bool& transpose_w,
ngraph::Shape& shape_w) ngraph::Shape& shape_w)
{ {
auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape); auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape);
if (!r_w) if (!r_w)
{ {
return true; //nth to do; reshape isn't a reshape if (arg->get_shape().size() != 2)
} {
NGRAPH_DEBUG << arg->get_name() << " 's rank != 2 "
if (r_w->get_shape().size() != 2) << ngraph::vector_to_string(arg->get_shape());
{ return false;
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix" }
<< ngraph::vector_to_string(r_w->get_shape()); return true; //nth to do; reshape isn't a reshape
return false; }
}
if (r_w->get_shape().size() != 2)
auto io = r_w->get_input_order(); {
if (r_w->get_shape().size() != arg->get_shape().size()) //reshape NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix"
{ << ngraph::vector_to_string(r_w->get_shape());
ngraph::AxisVector dio(io.size()); return false;
std::iota(begin(dio), end(dio), 0); }
if (io != dio) //we can't reshape and transpose at the same time auto io = r_w->get_input_order();
{ if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order " {
<< ngraph::vector_to_string(io); ngraph::AxisVector dio(io.size());
NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape()); std::iota(begin(dio), end(dio), 0);
NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape());
return false; if (io != dio) //we can't reshape and transpose at the same time
} {
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order "
shape_w = r_w->get_shape(); << ngraph::vector_to_string(io);
} NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape());
else NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape());
{ return false;
if (io == ngraph::AxisVector{1, 0}) }
{
transpose_w = true; shape_w = r_w->get_shape();
} }
//otherwise no-op reshape else
} {
if (io == ngraph::AxisVector{1, 0})
return true; {
} transpose_w = true;
}
template <typename T> //otherwise no-op reshape
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order) }
{
if (input.size() != order.size()) return true;
{ }
throw "input and order sizes don't match!";
} template <typename T>
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
std::vector<T> output(input.size()); {
if (input.size() != order.size())
for (size_t i = 0; i < order.size(); i++) {
{ throw "input and order sizes don't match!";
output[i] = input.at(order.at(i)); }
}
std::vector<T> output(input.size());
return output;
} for (size_t i = 0; i < order.size(); i++)
{
void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern() output[i] = input.at(order.at(i));
{ }
Shape shape_w{2, 4};
Shape shape_x{4, 1}; return output;
Shape shape_b{1}; }
Shape shape_dot{2, 1};
void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern()
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w); {
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x); Shape shape_w{2, 4};
Shape shape_x{4, 1};
auto reshape_pred = [](std::shared_ptr<Node> n) { Shape shape_b{1};
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n)); auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
}; auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred); auto pmmb = std::make_shared<op::MatmulBias>(
W, x, nullptr, W->get_shape(), x->get_shape(), false, false);
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x); auto pbroadcast = std::make_shared<op::Broadcast>(b, pmmb->get_shape(), AxisSet{0});
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b); auto padd = pmmb + pbroadcast;
auto pbroadcast = std::make_shared<op::Broadcast>(b, shape_dot, AxisSet{0});
auto padd = pdot + pbroadcast; ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmulbias_pattern against node = "
ngraph::pattern::gr_callback_fn callback = [W, x, b](pattern::Matcher& m) { << m.match_root()->get_name();
NGRAPH_DEBUG << "In callback for construct_gemm_pattern against node = "
<< m.match_root()->get_name(); auto mpattern = m.match_root(); //add
auto pattern_map = m.get_pattern_map(); auto m_matmul = mpattern->get_input_op(0);
std::shared_ptr<Node> nn = nullptr; auto m_broadcast = mpattern->get_input_op(1);
auto pattern_map = m.get_pattern_map();
auto mpattern = m.match_root();
if (mpattern->get_element_type() != element::f32) return m_matmul->copy_with_new_args(
{ NodeVector{pattern_map[W], pattern_map[x], m_broadcast});
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!"; };
return nn;
} auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
this->add_matcher(m);
auto dot = mpattern->get_input_op(0); }
if (dot->get_shape().size() != 2)
{ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!"; {
return nn; Shape shape_w{2, 4};
} Shape shape_x{4, 1};
Shape shape_b{1};
bool transpose_w = false; Shape shape_dot{2, 1};
Shape shape_arg0{pattern_map[W]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0)) auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
{ auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
return nn;
} auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
bool transpose_x = false; };
Shape shape_arg1{pattern_map[x]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1)) auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
{ auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
return nn;
} auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W], ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) {
pattern_map[x], NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = "
mpattern->get_input_op(1), << m.match_root()->get_name();
shape_arg0, auto pattern_map = m.get_pattern_map();
shape_arg1, std::shared_ptr<Node> nn;
transpose_w,
transpose_x)); auto mpattern = m.match_root();
return cg; auto dot = m.match_root();
};
if (mpattern->get_element_type() != element::f32)
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback); {
this->add_matcher(m); NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
} return nn;
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
{ if (dot->get_shape().size() != 2)
// construct varaiance {
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2}); NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!";
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3}); return nn;
auto input_sq = std::make_shared<op::Multiply>(input, input); }
auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input); if (shape_size(dot->get_shape()) == 0)
auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0}); {
auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N); NGRAPH_DEBUG << "dot has a zero dimension";
auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq); return nn;
auto variance = std::make_shared<op::Divide>(xmu, N); }
auto variance_label =
std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance}); bool transpose_w = false;
auto variance_with_broadcast = Shape shape_arg0{pattern_map[W]->get_shape()};
std::make_shared<op::Broadcast>(variance_label, Shape{2, 3}, AxisSet{0}); if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0))
{
// construct mean return nn;
auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0}); }
auto mean = std::make_shared<op::Divide>(sum_input1, N);
auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean}); bool transpose_x = false;
auto mean_with_broadcast = std::make_shared<op::Broadcast>(mean_label, Shape{2, 3}, AxisSet{0}); Shape shape_arg1{pattern_map[x]->get_shape()};
auto input_diff_mean = std::make_shared<op::Subtract>(input, mean_with_broadcast); if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1))
{
// Eps return nn;
auto eps_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3}); }
auto eps_with_broadcast = std::make_shared<op::Broadcast>(eps_label, Shape{2, 3}, AxisSet{0});
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W],
auto add1 = std::make_shared<op::Add>(eps_with_broadcast, variance_with_broadcast); pattern_map[x],
auto sqrt_variance_eps = std::make_shared<op::Sqrt>(add1); nullptr,
auto divide_mean_variance = std::make_shared<op::Divide>(input_diff_mean, sqrt_variance_eps); shape_arg0,
shape_arg1,
//Gamma transpose_w,
auto gamma_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3}); transpose_x));
auto gamma_with_broadcast = return cg;
std::make_shared<op::Broadcast>(gamma_label, Shape{2, 3}, AxisSet{0}); };
auto multiply_gamma =
std::make_shared<op::Multiply>(gamma_with_broadcast, divide_mean_variance); auto m = std::make_shared<ngraph::pattern::Matcher>(pdot, callback);
this->add_matcher(m);
//Beta }
auto beta_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto beta_with_broadcast = std::make_shared<op::Broadcast>(beta_label, Shape{2, 3}, AxisSet{0}); void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
{
auto add_beta = std::make_shared<op::Add>(beta_with_broadcast, multiply_gamma); // construct varaiance
// This completes fprop bn pattern auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
//Define a call back that needs to called once the DFG matches the pattern auto input_sq = std::make_shared<op::Multiply>(input, input);
ngraph::pattern::gr_callback_fn callback = auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
[variance_label, mean_label, input, eps_label, gamma_label, beta_label]( auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
pattern::Matcher& m) { auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against " auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
<< m.match_root()->get_name(); auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
auto variance = std::make_shared<op::Divide>(xmu, N);
std::shared_ptr<Node> nn = nullptr; auto variance_label =
//TODO - add assert's based on the matched node std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
auto pattern_map = m.get_pattern_map(); auto variance_with_broadcast =
NGRAPH_DEBUG << "Input: " << pattern_map[input]->get_name() << " " std::make_shared<op::Broadcast>(variance_label, Shape{2, 3}, AxisSet{0});
<< pattern_map[input]->get_shape().size();
NGRAPH_DEBUG << "Variance: " << pattern_map[variance_label]->get_name() << " " // construct mean
<< pattern_map[variance_label]->get_shape().size(); auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
NGRAPH_DEBUG << "Mean: " << pattern_map[mean_label]->get_name() << " " auto mean = std::make_shared<op::Divide>(sum_input1, N);
<< pattern_map[mean_label]->get_shape().size(); auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
NGRAPH_DEBUG << "eps: " << pattern_map[eps_label]->get_name() << " " auto mean_with_broadcast = std::make_shared<op::Broadcast>(mean_label, Shape{2, 3}, AxisSet{0});
<< pattern_map[eps_label]->get_shape().size(); auto input_diff_mean = std::make_shared<op::Subtract>(input, mean_with_broadcast);
NGRAPH_DEBUG << "gamma: " << pattern_map[gamma_label]->get_name() << " "
<< pattern_map[gamma_label]->get_shape().size(); // Eps
NGRAPH_DEBUG << "beta: " << pattern_map[beta_label]->get_name() << " " auto eps_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
<< pattern_map[beta_label]->get_shape().size(); auto eps_with_broadcast = std::make_shared<op::Broadcast>(eps_label, Shape{2, 3}, AxisSet{0});
// dont fuse if the inout doesnt have 4dims auto add1 = std::make_shared<op::Add>(eps_with_broadcast, variance_with_broadcast);
if (pattern_map[input]->get_shape().size() != 4) auto sqrt_variance_eps = std::make_shared<op::Sqrt>(add1);
{ auto divide_mean_variance = std::make_shared<op::Divide>(input_diff_mean, sqrt_variance_eps);
NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing";
return nn; //Gamma
} auto gamma_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
Shape bn_output_shape{m.match_root()->get_shape()}; auto gamma_with_broadcast =
Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()}; std::make_shared<op::Broadcast>(gamma_label, Shape{2, 3}, AxisSet{0});
Shape m_bn_variance_shape{pattern_map[variance_label]->get_shape()}; auto multiply_gamma =
std::make_shared<op::Multiply>(gamma_with_broadcast, divide_mean_variance);
// get epsilon value
auto eps_ptr = std::dynamic_pointer_cast<op::Constant>(pattern_map[eps_label]); //Beta
double epsilon = *(reinterpret_cast<const double*>(eps_ptr->get_data_ptr())); auto beta_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto bn_node = std::shared_ptr<Node>(new op::BatchNorm(epsilon, auto beta_with_broadcast = std::make_shared<op::Broadcast>(beta_label, Shape{2, 3}, AxisSet{0});
pattern_map[gamma_label],
pattern_map[beta_label], auto add_beta = std::make_shared<op::Add>(beta_with_broadcast, multiply_gamma);
pattern_map[input], // This completes fprop bn pattern
pattern_map[mean_label],
pattern_map[variance_label])); //Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback =
return bn_node; [variance_label, mean_label, input, eps_label, gamma_label, beta_label](
}; pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against "
auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback); << m.match_root()->get_name();
this->add_matcher(m);
} std::shared_ptr<Node> nn = nullptr;
//TODO - add assert's based on the matched node
auto pattern_map = m.get_pattern_map();
NGRAPH_DEBUG << "Input: " << pattern_map[input]->get_name() << " "
<< pattern_map[input]->get_shape().size();
NGRAPH_DEBUG << "Variance: " << pattern_map[variance_label]->get_name() << " "
<< pattern_map[variance_label]->get_shape().size();
NGRAPH_DEBUG << "Mean: " << pattern_map[mean_label]->get_name() << " "
<< pattern_map[mean_label]->get_shape().size();
NGRAPH_DEBUG << "eps: " << pattern_map[eps_label]->get_name() << " "
<< pattern_map[eps_label]->get_shape().size();
NGRAPH_DEBUG << "gamma: " << pattern_map[gamma_label]->get_name() << " "
<< pattern_map[gamma_label]->get_shape().size();
NGRAPH_DEBUG << "beta: " << pattern_map[beta_label]->get_name() << " "
<< pattern_map[beta_label]->get_shape().size();
// dont fuse if the inout doesnt have 4dims
if (pattern_map[input]->get_shape().size() != 4)
{
NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing";
return nn;
}
Shape bn_output_shape{m.match_root()->get_shape()};
Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()};
Shape m_bn_variance_shape{pattern_map[variance_label]->get_shape()};
// get epsilon value
auto eps_ptr = std::dynamic_pointer_cast<op::Constant>(pattern_map[eps_label]);
double epsilon = *(reinterpret_cast<const double*>(eps_ptr->get_data_ptr()));
auto bn_node = std::shared_ptr<Node>(new op::BatchNorm(epsilon,
pattern_map[gamma_label],
pattern_map[beta_label],
pattern_map[input],
pattern_map[mean_label],
pattern_map[variance_label]));
return bn_node;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback);
this->add_matcher(m);
}
...@@ -38,11 +38,13 @@ public: ...@@ -38,11 +38,13 @@ public:
CPUFusion() CPUFusion()
: GraphRewrite() : GraphRewrite()
{ {
construct_gemm_pattern(); construct_matmul_pattern();
construct_matmulbias_pattern();
construct_fprop_bn(); construct_fprop_bn();
} }
private: private:
void construct_gemm_pattern(); void construct_matmul_pattern();
void construct_matmulbias_pattern();
void construct_fprop_bn(); void construct_fprop_bn();
}; };
...@@ -133,6 +133,42 @@ TEST(cpu_fusion, gemm_cpu) ...@@ -133,6 +133,42 @@ TEST(cpu_fusion, gemm_cpu)
ASSERT_TRUE(read_vector<float>(result) == expected); ASSERT_TRUE(read_vector<float>(result) == expected);
} }
TEST(cpu_fusion, gemm_cpu_no_bias)
{
auto shapeA = Shape{3, 2};
auto shapeB = Shape{2, 3};
auto shapeC = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeB);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, Shape{2, 3});
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, Shape{3, 2});
auto cg =
make_shared<op::MatmulBias>(A, B, nullptr, A->get_shape(), B->get_shape(), true, true);
auto f = make_shared<Function>(cg, op::ParameterVector{A, B});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shapeA);
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shapeB);
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, shapeC);
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{3.0f, 3.0f, 3.0f, 9.0f, 9.0f, 9.0f};
copy_data(a, dataA);
copy_data(b, dataB);
cf->call({a, b}, {result});
vector<float> expected{9, 27, 36, 108};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, cpu_fusion_pass_basic) TEST(cpu_fusion, cpu_fusion_pass_basic)
{ {
Shape shape{}; Shape shape{};
...@@ -154,6 +190,50 @@ TEST(cpu_fusion, cpu_fusion_pass_basic) ...@@ -154,6 +190,50 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr); ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
} }
TEST(cpu_fusion, cpu_fusion_pass_matmul_bias)
{
Shape shape_w{2, 4};
Shape shape_x{4, 1};
Shape shape_b{1};
auto W = make_shared<op::Parameter>(element::f32, shape_w);
auto x = make_shared<op::Parameter>(element::f32, shape_x);
auto b = make_shared<op::Parameter>(element::f32, shape_b);
auto mmb = std::make_shared<op::MatmulBias>(
W, x, nullptr, W->get_shape(), x->get_shape(), false, false);
auto broadcast = std::make_shared<op::Broadcast>(b, mmb->get_shape(), AxisSet{0});
auto add = mmb + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::ParameterVector{W, x, b});
pass_manager.run_passes(func);
auto gmm = graph->get_input_op(0);
ASSERT_TRUE(std::dynamic_pointer_cast<op::MatmulBias>(gmm));
ASSERT_EQ(gmm->get_input_op(2), broadcast);
}
TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias)
{
Shape shape_w{4, 2};
Shape shape_x{1, 4};
auto W = make_shared<op::Parameter>(element::f32, shape_w);
auto x = make_shared<op::Parameter>(element::f32, shape_x);
auto reshape_w = std::make_shared<op::Reshape>(W, AxisVector{1, 0}, Shape{2, 4});
auto reshape_x = std::make_shared<op::Reshape>(x, AxisVector{1, 0}, Shape{4, 1});
auto re_dot = make_shared<op::Dot>(reshape_w, reshape_x);
auto graph = make_shared<op::Abs>(re_dot);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::ParameterVector{W, x});
pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmb, 1);
}
TEST(cpu_fusion, gemm_mlp) TEST(cpu_fusion, gemm_mlp)
{ {
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json"); const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json");
...@@ -163,8 +243,8 @@ TEST(cpu_fusion, gemm_mlp) ...@@ -163,8 +243,8 @@ TEST(cpu_fusion, gemm_mlp)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func); size_t mmb = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3); ASSERT_EQ(mmb, 3);
} }
//TODO: Move this test to backend_test.in.cpp once we have the INTERPRETER //TODO: Move this test to backend_test.in.cpp once we have the INTERPRETER
......
/******************************************************************************* /*******************************************************************************
* Copyright 2018 Intel Corporation * Copyright 2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include <algorithm> #include <algorithm>
#include <cstdio> #include <cstdio>
#include <iostream> #include <iostream>
#include <list> #include <list>
#include <memory> #include <memory>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/ops/sum.hpp" #include "ngraph/ops/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/reshape_elimination.hpp" #include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "util/matcher.hpp" #include "util/matcher.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
TEST(reshape_elimination, remove_reshape) TEST(reshape_elimination, remove_reshape)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::ReshapeElimination>(); pass_manager.register_pass<pass::ReshapeElimination>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop.json"); const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string); stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss); shared_ptr<Function> func = ngraph::deserialize(ss);
size_t count_before = count_ops_of_type<op::Reshape>(func); size_t count_before = count_ops_of_type<op::Reshape>(func);
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t count_after = count_ops_of_type<op::Reshape>(func); size_t count_after = count_ops_of_type<op::Reshape>(func);
ASSERT_TRUE(count_after < count_before); ASSERT_TRUE(count_after < count_before);
} }
TEST(reshape_elimination, remove_tranpose) TEST(reshape_elimination, remove_tranpose)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::ReshapeElimination>(); pass_manager.register_pass<pass::ReshapeElimination>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/tranpose.json"); const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/tranpose.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string); stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss); shared_ptr<Function> func = ngraph::deserialize(ss);
size_t count_before = count_ops_of_type<op::Reshape>(func); size_t count_before = count_ops_of_type<op::Reshape>(func);
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t count_after = count_ops_of_type<op::Reshape>(func); size_t count_after = count_ops_of_type<op::Reshape>(func);
ASSERT_TRUE(count_after < count_before); ASSERT_TRUE(count_after < count_before);
} }
TEST(reshape_elimination, bn_bprop_rewrite) TEST(reshape_elimination, bn_bprop_rewrite)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::ReshapeElimination>(); pass_manager.register_pass<pass::ReshapeElimination>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_bprop.json"); const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_bprop.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string); stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss); shared_ptr<Function> func = ngraph::deserialize(ss);
size_t count_before = count_ops_of_type<op::Reshape>(func); size_t count_before = count_ops_of_type<op::Reshape>(func);
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t count_after = count_ops_of_type<op::Reshape>(func); size_t count_after = count_ops_of_type<op::Reshape>(func);
ASSERT_TRUE(count_after < count_before); ASSERT_TRUE(count_after < count_before);
} }
TEST(reshape_elimination, dot_transpose_to_dot_w_transpose_args)
{
Shape shape_w{2, 4};
Shape shape_x{4, 1};
auto W = make_shared<op::Parameter>(element::f32, shape_w);
auto x = make_shared<op::Parameter>(element::f32, shape_x);
auto dot = make_shared<op::Dot>(W, x);
auto reshape_dot = std::make_shared<op::Reshape>(dot, AxisVector{1, 0}, Shape{1, 2});
auto graph = make_shared<op::Abs>(reshape_dot);
pass::Manager pass_manager;
pass_manager.register_pass<pass::ReshapeElimination>();
auto func = make_shared<Function>(graph, op::ParameterVector{W, x});
pass_manager.run_passes(func);
auto gdot = graph->get_input_op(0);
ASSERT_TRUE(std::dynamic_pointer_cast<op::Dot>(gdot));
ASSERT_TRUE(std::dynamic_pointer_cast<op::Reshape>(gdot->get_input_op(0)));
ASSERT_TRUE(std::dynamic_pointer_cast<op::Reshape>(gdot->get_input_op(1)));
ASSERT_EQ(gdot->get_input_op(0)->get_input_op(0), x);
ASSERT_EQ(gdot->get_input_op(1)->get_input_op(0), W);
ASSERT_EQ(gdot->get_shape(), (Shape{1, 2}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment