Commit 6bca3efd authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Refactoring MMB (#1224)

* rank3xrank2 cpu_emitter version 1

* refactoring matmulbias

* add comment
parent 5927bbe4
......@@ -39,8 +39,8 @@ namespace ngraph
const ngraph::op::MatmulBias* mm = static_cast<const ngraph::op::MatmulBias*>(node);
const auto& arg0_shape = mm->get_arg0_shape();
const auto& arg1_shape = mm->get_arg1_shape();
const auto& arg0_shape = mm->get_a_shape();
const auto& arg1_shape = mm->get_b_shape();
const auto& arg2_shape = node->get_shape();
auto m = arg0_shape[0];
......@@ -51,14 +51,14 @@ namespace ngraph
auto lda = arg0_shape[1];
auto ldb = arg1_shape[1];
if (mm->get_is_arg0_transposed())
if (mm->get_is_a_transposed())
{
transpose_A = true;
m = arg0_shape[1];
k = arg0_shape[0];
}
if (mm->get_is_arg1_transposed())
if (mm->get_is_b_transposed())
{
transpose_B = true;
n = arg1_shape[0];
......
This diff is collapsed.
......@@ -32,6 +32,8 @@ namespace ngraph
bool get_is_a_transposed() const { return m_transpose_a; }
bool get_is_b_transposed() const { return m_transpose_b; }
Shape get_a_shape() const { return get_argument(0)->get_shape(); }
Shape get_b_shape() const { return get_argument(1)->get_shape(); }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -35,10 +35,10 @@ namespace ngraph
bool transpose_x,
AxisSet axes = AxisSet{});
bool get_is_arg0_transposed() const { return m_transpose_w; }
bool get_is_arg1_transposed() const { return m_transpose_x; }
Shape get_arg0_shape() const { return m_shape_w; }
Shape get_arg1_shape() const { return m_shape_x; }
bool get_is_a_transposed() const { return m_transpose_w; }
bool get_is_b_transposed() const { return m_transpose_x; }
Shape get_a_shape() const { return m_shape_w; }
Shape get_b_shape() const { return m_shape_x; }
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -153,10 +153,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
auto mmb = std::make_shared<op::MatmulBias>(pattern_map[W],
pattern_map[x],
m_bias,
m_matmul->get_arg0_shape(),
m_matmul->get_arg1_shape(),
m_matmul->get_is_arg0_transposed(),
m_matmul->get_is_arg1_transposed(),
m_matmul->get_a_shape(),
m_matmul->get_b_shape(),
m_matmul->get_is_a_transposed(),
m_matmul->get_is_b_transposed(),
m_broadcast->get_broadcast_axes());
ngraph::replace_node(m.get_match_root(), mmb);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment