Commit 6bca3efd authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Refactoring MMB (#1224)

* rank3xrank2 cpu_emitter version 1

* refactoring matmulbias

* add comment
parent 5927bbe4
...@@ -39,8 +39,8 @@ namespace ngraph ...@@ -39,8 +39,8 @@ namespace ngraph
const ngraph::op::MatmulBias* mm = static_cast<const ngraph::op::MatmulBias*>(node); const ngraph::op::MatmulBias* mm = static_cast<const ngraph::op::MatmulBias*>(node);
const auto& arg0_shape = mm->get_arg0_shape(); const auto& arg0_shape = mm->get_a_shape();
const auto& arg1_shape = mm->get_arg1_shape(); const auto& arg1_shape = mm->get_b_shape();
const auto& arg2_shape = node->get_shape(); const auto& arg2_shape = node->get_shape();
auto m = arg0_shape[0]; auto m = arg0_shape[0];
...@@ -51,14 +51,14 @@ namespace ngraph ...@@ -51,14 +51,14 @@ namespace ngraph
auto lda = arg0_shape[1]; auto lda = arg0_shape[1];
auto ldb = arg1_shape[1]; auto ldb = arg1_shape[1];
if (mm->get_is_arg0_transposed()) if (mm->get_is_a_transposed())
{ {
transpose_A = true; transpose_A = true;
m = arg0_shape[1]; m = arg0_shape[1];
k = arg0_shape[0]; k = arg0_shape[0];
} }
if (mm->get_is_arg1_transposed()) if (mm->get_is_b_transposed())
{ {
transpose_B = true; transpose_B = true;
n = arg1_shape[0]; n = arg1_shape[0];
......
This diff is collapsed.
...@@ -32,6 +32,8 @@ namespace ngraph ...@@ -32,6 +32,8 @@ namespace ngraph
bool get_is_a_transposed() const { return m_transpose_a; } bool get_is_a_transposed() const { return m_transpose_a; }
bool get_is_b_transposed() const { return m_transpose_b; } bool get_is_b_transposed() const { return m_transpose_b; }
Shape get_a_shape() const { return get_argument(0)->get_shape(); }
Shape get_b_shape() const { return get_argument(1)->get_shape(); }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -35,10 +35,10 @@ namespace ngraph ...@@ -35,10 +35,10 @@ namespace ngraph
bool transpose_x, bool transpose_x,
AxisSet axes = AxisSet{}); AxisSet axes = AxisSet{});
bool get_is_arg0_transposed() const { return m_transpose_w; } bool get_is_a_transposed() const { return m_transpose_w; }
bool get_is_arg1_transposed() const { return m_transpose_x; } bool get_is_b_transposed() const { return m_transpose_x; }
Shape get_arg0_shape() const { return m_shape_w; } Shape get_a_shape() const { return m_shape_w; }
Shape get_arg1_shape() const { return m_shape_x; } Shape get_b_shape() const { return m_shape_x; }
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; } const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -153,10 +153,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias() ...@@ -153,10 +153,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
auto mmb = std::make_shared<op::MatmulBias>(pattern_map[W], auto mmb = std::make_shared<op::MatmulBias>(pattern_map[W],
pattern_map[x], pattern_map[x],
m_bias, m_bias,
m_matmul->get_arg0_shape(), m_matmul->get_a_shape(),
m_matmul->get_arg1_shape(), m_matmul->get_b_shape(),
m_matmul->get_is_arg0_transposed(), m_matmul->get_is_a_transposed(),
m_matmul->get_is_arg1_transposed(), m_matmul->get_is_b_transposed(),
m_broadcast->get_broadcast_axes()); m_broadcast->get_broadcast_axes());
ngraph::replace_node(m.get_match_root(), mmb); ngraph::replace_node(m.get_match_root(), mmb);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment