Commit d4a12feb authored by baojun's avatar baojun Committed by Sang Ik Lee

Update fluid matmul fprop (#4099)

* update matmul fprop

* clean up
parent d2c6c27e
......@@ -16,8 +16,6 @@
#include <memory>
#include <numeric>
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/frontend/fluid/operators/matmul.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/dot.hpp"
......@@ -29,131 +27,6 @@
using namespace std;
using namespace ngraph::fluid;
constexpr NodeTypeInfo MatMul::type_info;
MatMul::MatMul(const Output<Node>& A,
const Output<Node>& B,
const bool transpose_a,
const bool transpose_b)
: FusedOp(OutputVector{A, B})
, m_transpose_a{transpose_a}
, m_transpose_b{transpose_b}
{
constructor_validate_and_infer_types();
}
void decompose_logic(Output<Node>& input, bool transpose, bool reverse = false)
{
auto rank = input.get_shape().size();
if (rank < 2)
{
if (rank)
{
if (reverse)
{
input =
make_shared<op::Reshape>(input, AxisVector{0}, Shape{input.get_shape()[0], 1});
}
else
{
input =
make_shared<op::Reshape>(input, AxisVector{0}, Shape{1, input.get_shape()[0]});
}
}
else
{
input = make_shared<op::Reshape>(input, AxisVector{}, Shape{1, 1});
}
rank = 2;
}
if (transpose)
{
vector<size_t> axes_order(rank);
iota(axes_order.begin(), axes_order.end(), 0);
swap(axes_order[rank - 1], axes_order[rank - 2]);
input = builder::reorder_axes(input, axes_order);
}
}
NodeVector remove_1(shared_ptr<Node> input_node)
{
auto input_shape = input_node->get_shape();
AxisVector axis(input_shape.size());
iota(axis.begin(), axis.end(), 0);
Shape shape(input_shape.begin(), input_shape.end());
auto b_remove = remove(shape.begin(), shape.end(), 1);
shape.erase(b_remove, shape.end());
Output<Node> node(input_node);
auto reshape = make_shared<op::Reshape>(node, axis, shape);
NodeVector final_vector{reshape};
return final_vector;
}
void MatMul::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
PartialShape pshape_A = get_input_partial_shape(0);
PartialShape pshape_B = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
if (pshape_A.is_dynamic() || pshape_B.is_dynamic())
{
set_output_type(0, input_element_type, PartialShape::dynamic());
}
}
NodeVector MatMul::decompose_op() const
{
auto A = input_value(0);
auto B = input_value(1);
decompose_logic(A, m_transpose_a);
decompose_logic(B, m_transpose_b, true);
builder::MatmulFactory factory({A, B});
auto node_vector_matmul = factory.make_matmul_op();
auto first_item_node_vector = node_vector_matmul[0];
auto b = first_item_node_vector->get_shape().begin();
auto e = first_item_node_vector->get_shape().end();
auto it = find(b, e, 1);
if (it != e)
{
node_vector_matmul = remove_1(first_item_node_vector);
}
return node_vector_matmul;
}
shared_ptr<Node> MatMul::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<MatMul>(new_args.at(0), new_args.at(1), m_transpose_a, m_transpose_b);
}
constexpr NodeTypeInfo MatMulGrad::type_info;
MatMulGrad::MatMulGrad(const Output<Node>& A,
const Output<Node>& B,
const Output<Node>& Out,
const bool transpose_a,
const bool transpose_b)
: FusedOp(OutputVector{A, B, Out})
, m_transpose_a{transpose_a}
, m_transpose_b{transpose_b}
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> broadcast_to_3d(const shared_ptr<Node>& input, size_t axis0)
{
auto shape = input->get_shape();
......@@ -249,6 +122,116 @@ shared_ptr<Node> reshape_to_original(shared_ptr<Node> input, const Shape& shape)
return make_shared<op::Reshape>(input, get_default_order(input_shape), shape);
}
constexpr NodeTypeInfo MatMul::type_info;
MatMul::MatMul(const Output<Node>& A,
const Output<Node>& B,
const bool transpose_a,
const bool transpose_b)
: FusedOp(OutputVector{A, B})
, m_transpose_a{transpose_a}
, m_transpose_b{transpose_b}
{
constructor_validate_and_infer_types();
}
void MatMul::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
PartialShape pshape_A = get_input_partial_shape(0);
PartialShape pshape_B = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
if (pshape_A.is_dynamic() || pshape_B.is_dynamic())
{
set_output_type(0, input_element_type, PartialShape::dynamic());
}
}
NodeVector MatMul::decompose_op() const
{
auto x = input_value(0).get_node_shared_ptr();
auto y = input_value(1).get_node_shared_ptr();
auto x_shape = x->get_shape();
auto y_shape = y->get_shape();
size_t nx = x_shape.size();
size_t ny = y_shape.size();
x = transpose_and_flatten3d(x, m_transpose_a, true);
y = transpose_and_flatten3d(y, m_transpose_b, false);
auto y_shape3 = y->get_shape();
auto x_shape3 = x->get_shape();
shared_ptr<Node> out;
Shape out_shape;
if (nx > 2 || ny > 2)
{
Shape out_shape = x_shape;
if (nx != 3)
{
x = broadcast_to_3d(x, y_shape3[0]);
out_shape = y_shape;
}
if (ny != 3)
{
y = broadcast_to_3d(y, x_shape3[0]);
out_shape = x_shape;
}
auto nout = out_shape.size();
auto out3 = make_shared<op::BatchMatMul>(x, y);
auto out3_shape = out3->get_shape();
out_shape[nout - 1] = out3_shape[2];
out_shape[nout - 2] = out3_shape[1];
out = make_shared<op::Reshape>(out3, AxisVector{0, 1, 2}, out_shape);
}
else
{
out = make_shared<op::Dot>(x, y);
}
out_shape = out->get_shape();
auto axis_vector = get_default_order(out_shape);
for (size_t i = out_shape.size() - 1; i > 0; i--)
{
if (out_shape[i] == 1)
{
out_shape.erase(out_shape.begin() + i);
}
}
auto out_reshaped = make_shared<op::Reshape>(out, axis_vector, out_shape);
return {out_reshaped};
}
shared_ptr<Node> MatMul::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<MatMul>(new_args.at(0), new_args.at(1), m_transpose_a, m_transpose_b);
}
constexpr NodeTypeInfo MatMulGrad::type_info;
MatMulGrad::MatMulGrad(const Output<Node>& A,
const Output<Node>& B,
const Output<Node>& Out,
const bool transpose_a,
const bool transpose_b)
: FusedOp(OutputVector{A, B, Out})
, m_transpose_a{transpose_a}
, m_transpose_b{transpose_b}
{
constructor_validate_and_infer_types();
}
void MatMulGrad::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment