Commit e3b442aa authored by baojun's avatar baojun Committed by Scott Cyphers

set output type for dynshape (#4072)

parent 72493caf
...@@ -13,18 +13,19 @@ ...@@ -13,18 +13,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/frontend/fluid/operators/matmul.hpp"
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include "ngraph/builder/matmul_factory.hpp" #include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/frontend/fluid/operators/matmul.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::fluid;
constexpr NodeTypeInfo fluid::MatMul::type_info; constexpr NodeTypeInfo MatMul::type_info;
fluid::MatMul::MatMul(const Output<Node>& A, MatMul::MatMul(const Output<Node>& A,
const Output<Node>& B, const Output<Node>& B,
const bool& transpose_a, const bool& transpose_a,
const bool& transpose_b) const bool& transpose_b)
...@@ -39,6 +40,7 @@ template <class Input> ...@@ -39,6 +40,7 @@ template <class Input>
void DecomposeLogic(Input& input, bool transpose, bool reverse = false) void DecomposeLogic(Input& input, bool transpose, bool reverse = false)
{ {
auto rank = input.get_shape().size(); auto rank = input.get_shape().size();
if (rank < 2) if (rank < 2)
{ {
if (rank) if (rank)
...@@ -60,6 +62,7 @@ void DecomposeLogic(Input& input, bool transpose, bool reverse = false) ...@@ -60,6 +62,7 @@ void DecomposeLogic(Input& input, bool transpose, bool reverse = false)
} }
rank = 2; rank = 2;
} }
if (transpose) if (transpose)
{ {
vector<size_t> axes_order(rank); vector<size_t> axes_order(rank);
...@@ -75,48 +78,59 @@ inline NodeVector remove_1(std::shared_ptr<ngraph::Node> input_node) ...@@ -75,48 +78,59 @@ inline NodeVector remove_1(std::shared_ptr<ngraph::Node> input_node)
AxisVector axis(input_shape.size()); AxisVector axis(input_shape.size());
iota(axis.begin(), axis.end(), 0); iota(axis.begin(), axis.end(), 0);
Shape shape(input_shape.begin(), input_shape.end()); Shape shape(input_shape.begin(), input_shape.end());
auto b_remove = std::remove(shape.begin(), shape.end(), 1); auto b_remove = std::remove(shape.begin(), shape.end(), 1);
shape.erase(b_remove, shape.end()); shape.erase(b_remove, shape.end());
Output<Node> node(input_node); Output<Node> node(input_node);
auto reshape = make_shared<op::Reshape>(node, axis, shape); auto reshape = make_shared<op::Reshape>(node, axis, shape);
NodeVector final_vector{reshape}; NodeVector final_vector{reshape};
return final_vector; return final_vector;
} }
void fluid::MatMul::pre_validate_and_infer_types() void MatMul::pre_validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
PartialShape pshape_A = get_input_partial_shape(0);
PartialShape pshape_B = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(), input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ", "Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type, input_element_type,
")."); ").");
if (is_dynamic())
if (pshape_A.is_dynamic() || pshape_B.is_dynamic())
{ {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); set_output_type(0, input_element_type, PartialShape::dynamic());
} }
} }
NodeVector fluid::MatMul::decompose_op() const NodeVector MatMul::decompose_op() const
{ {
auto A = input_value(0); auto A = input_value(0);
auto B = input_value(1); auto B = input_value(1);
DecomposeLogic(A, m_transpose_a); DecomposeLogic(A, m_transpose_a);
DecomposeLogic(B, m_transpose_b, true); DecomposeLogic(B, m_transpose_b, true);
builder::MatmulFactory factory({A, B}); builder::MatmulFactory factory({A, B});
auto node_vector_matmul = factory.make_matmul_op(); auto node_vector_matmul = factory.make_matmul_op();
auto first_item_node_vector = node_vector_matmul[0]; auto first_item_node_vector = node_vector_matmul[0];
auto b = first_item_node_vector->get_shape().begin(); auto b = first_item_node_vector->get_shape().begin();
auto e = first_item_node_vector->get_shape().end(); auto e = first_item_node_vector->get_shape().end();
auto it = std::find(b, e, 1); auto it = std::find(b, e, 1);
if (it != e) if (it != e)
{ {
node_vector_matmul = remove_1(first_item_node_vector); node_vector_matmul = remove_1(first_item_node_vector);
} }
return node_vector_matmul; return node_vector_matmul;
} }
shared_ptr<Node> fluid::MatMul::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> MatMul::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<MatMul>(new_args.at(0), new_args.at(1), m_transpose_a, m_transpose_b); return make_shared<MatMul>(new_args.at(0), new_args.at(1), m_transpose_a, m_transpose_b);
......
...@@ -20,18 +20,21 @@ ...@@ -20,18 +20,21 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp" #include "ngraph/op/util/fused_op.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph namespace ngraph
{ {
namespace fluid namespace fluid
{ {
/// \brief Operator performing Matrix Multiplication. /// \brief Operator performing Matrix Multiplication.
class NGRAPH_API MatMul : public ngraph::op::util::FusedOp class NGRAPH_API MatMul : public op::util::FusedOp
{ {
public: public:
static constexpr NodeTypeInfo type_info{"MatMul", 0}; static constexpr NodeTypeInfo type_info{"MatMul", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
MatMul() = default; MatMul() = default;
/// \brief Constructs an ScaleShift operation. /// \brief Constructs a MatMul operation.
/// ///
/// \param A Matrix A /// \param A Matrix A
/// \param B Matrix B /// \param B Matrix B
...@@ -43,10 +46,10 @@ namespace ngraph ...@@ -43,10 +46,10 @@ namespace ngraph
const bool& transpose_b = 0); const bool& transpose_b = 0);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override; void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
copy_with_new_args(const NodeVector& new_args) const override;
bool get_transpose_a() const { return m_transpose_a; } bool get_transpose_a() const { return m_transpose_a; }
bool get_transpose_b() const { return m_transpose_b; } bool get_transpose_b() const { return m_transpose_b; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment