Commit e3b442aa authored by baojun's avatar baojun Committed by Scott Cyphers

set output type for dynshape (#4072)

parent 72493caf
......@@ -13,18 +13,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/frontend/fluid/operators/matmul.hpp"
#include <memory>
#include <numeric>
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/frontend/fluid/operators/matmul.hpp"
#include "ngraph/op/reshape.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::fluid;
constexpr NodeTypeInfo fluid::MatMul::type_info;
fluid::MatMul::MatMul(const Output<Node>& A,
constexpr NodeTypeInfo MatMul::type_info;
MatMul::MatMul(const Output<Node>& A,
const Output<Node>& B,
const bool& transpose_a,
const bool& transpose_b)
......@@ -39,6 +40,7 @@ template <class Input>
void DecomposeLogic(Input& input, bool transpose, bool reverse = false)
{
auto rank = input.get_shape().size();
if (rank < 2)
{
if (rank)
......@@ -60,6 +62,7 @@ void DecomposeLogic(Input& input, bool transpose, bool reverse = false)
}
rank = 2;
}
if (transpose)
{
vector<size_t> axes_order(rank);
......@@ -75,48 +78,59 @@ inline NodeVector remove_1(std::shared_ptr<ngraph::Node> input_node)
AxisVector axis(input_shape.size());
iota(axis.begin(), axis.end(), 0);
Shape shape(input_shape.begin(), input_shape.end());
auto b_remove = std::remove(shape.begin(), shape.end(), 1);
shape.erase(b_remove, shape.end());
Output<Node> node(input_node);
auto reshape = make_shared<op::Reshape>(node, axis, shape);
NodeVector final_vector{reshape};
return final_vector;
}
void fluid::MatMul::pre_validate_and_infer_types()
void MatMul::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
PartialShape pshape_A = get_input_partial_shape(0);
PartialShape pshape_B = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
if (is_dynamic())
if (pshape_A.is_dynamic() || pshape_B.is_dynamic())
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
set_output_type(0, input_element_type, PartialShape::dynamic());
}
}
NodeVector fluid::MatMul::decompose_op() const
NodeVector MatMul::decompose_op() const
{
auto A = input_value(0);
auto B = input_value(1);
DecomposeLogic(A, m_transpose_a);
DecomposeLogic(B, m_transpose_b, true);
builder::MatmulFactory factory({A, B});
auto node_vector_matmul = factory.make_matmul_op();
auto first_item_node_vector = node_vector_matmul[0];
auto b = first_item_node_vector->get_shape().begin();
auto e = first_item_node_vector->get_shape().end();
auto it = std::find(b, e, 1);
if (it != e)
{
node_vector_matmul = remove_1(first_item_node_vector);
}
return node_vector_matmul;
}
shared_ptr<Node> fluid::MatMul::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> MatMul::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<MatMul>(new_args.at(0), new_args.at(1), m_transpose_a, m_transpose_b);
......
......@@ -20,18 +20,21 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace fluid
{
/// \brief Operator performing Matrix Multiplication.
class NGRAPH_API MatMul : public ngraph::op::util::FusedOp
class NGRAPH_API MatMul : public op::util::FusedOp
{
public:
static constexpr NodeTypeInfo type_info{"MatMul", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
MatMul() = default;
/// \brief Constructs an ScaleShift operation.
/// \brief Constructs a MatMul operation.
///
/// \param A Matrix A
/// \param B Matrix B
......@@ -43,10 +46,10 @@ namespace ngraph
const bool& transpose_b = 0);
virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool get_transpose_a() const { return m_transpose_a; }
bool get_transpose_b() const { return m_transpose_b; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment