Commit 7df89837 authored by Tomasz Socha's avatar Tomasz Socha Committed by Scott Cyphers

[FUSED] Add new MatMul fused operator (#3330)

* Move required reshape helpers to builders

* Remove warning from matmul_factory

* Rid off onnx dependency from matmul factory

* Move malmul_factory into builders

* Add implementation of fused MatMul

* Add MatMul to serializer

* Add type_prop tests

* Remove reference.

* Make more metods private

* Use protected instead of private

* Fix compilation issues

* Change construction of matmul_factory in matmul op

* Add MatMul operator into switch of is_supported_impl function

* Change transpose flags from int to bool

* Review Fix I

* Update MatMul op

* Use OutputVector instead of NodeVector in MatmulFactories

* Fix usage of OutputVector

* Convert more shared_ptrs to Outputs

* Fix comments after merge

* Fix comments after merge II

* Fix comments after merge III
parent 22af2395
......@@ -27,6 +27,8 @@ set (SRC
builder/dequantize_builder.cpp
builder/dequantize_builder.hpp
builder/make_constant.hpp
builder/matmul_factory.cpp
builder/matmul_factory.hpp
builder/norm.cpp
builder/norm.hpp
builder/numpy_transpose.cpp
......@@ -329,6 +331,8 @@ set (SRC
op/fused/gru_cell.hpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/matmul.cpp
op/fused/matmul.hpp
op/fused/mvn.cpp
op/fused/mvn.hpp
op/fused/normalize_l2.cpp
......
......@@ -18,17 +18,18 @@
#include <iterator>
#include <memory>
#include "matmul_factory.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "utils/reshape.hpp"
using namespace ngraph::onnx_import::matmul;
using namespace ngraph;
using namespace std;
/// \brief Slice the sub matrix from the input tensor.
///
......@@ -37,59 +38,49 @@ using namespace ngraph::onnx_import::matmul;
///
/// \return The node representing sub matrix.
///
static std::shared_ptr<ngraph::Node> get_sub_matrix(const std::shared_ptr<ngraph::Node>& node,
std::size_t idx)
static Output<Node> get_sub_matrix(const Output<Node>& node, size_t idx)
{
const ngraph::Shape& shape{node->get_shape()};
const Shape& shape{node.get_shape()};
if (shape.size() < 3)
{
return node;
return node.get_node_shared_ptr();
}
// Below bounds defines the sub_matrix through ranges for each input node axis.
ngraph::Coordinate lower_bounds(shape.size());
ngraph::Coordinate upper_bounds = shape;
Coordinate lower_bounds(shape.size());
Coordinate upper_bounds = shape;
// We assume `node` tensor is of rank equal 3, thus we slice the sub-matrix lying in the last
// two dimensions at index `idx` of first axis.
lower_bounds.at(0) = idx;
upper_bounds.at(0) = idx + 1;
auto sub_matrix = std::shared_ptr<ngraph::Node>{
std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds)};
auto sub_matrix = Output<Node>{make_shared<op::Slice>(node, lower_bounds, upper_bounds)};
// Remove first single entry dim.
return ngraph::onnx_import::reshape::squeeze(sub_matrix);
return builder::squeeze(sub_matrix);
}
std::shared_ptr<ngraph::Node> MatmulFactory::get_left()
Output<Node> builder::MatmulFactory::get_left()
{
return m_inputs.at(0);
}
std::shared_ptr<ngraph::Node> MatmulFactory::get_right()
Output<Node> builder::MatmulFactory::get_right()
{
return m_inputs.at(1);
}
ngraph::NodeVector MatmulFactory::make_matmul_op()
NodeVector builder::MatmulFactory::make_matmul_op()
{
auto left = get_left();
auto right = get_right();
std::size_t left_rank{left->get_shape().size()};
std::size_t right_rank{right->get_shape().size()};
if (left_rank == 0 || right_rank == 0)
{
NGRAPH_WARN << (m_onnx_node) << " "
<< "ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard.";
}
size_t left_rank{left.get_shape().size()};
size_t right_rank{right.get_shape().size()};
// First (easy) case that is already internally handled by Ngraph Dot operator.
// Multiply two tensors where both of them has rank lower equal 2.
if (left_rank <= 2 && right_rank <= 2)
{
return NodeVector{make_dot(left, right)};
return {make_dot(left, right).get_node_shared_ptr()};
}
// Second case:
......@@ -98,37 +89,37 @@ ngraph::NodeVector MatmulFactory::make_matmul_op()
// Broadcast input arguments only if both of them are not vectors.
if (left_rank > 1 && right_rank > 1)
{
const NodeVector& broadcasted_nodes =
ngraph::op::numpy_style_broadcast_for_matmul_operation(left, right);
const NodeVector& broadcasted_nodes = op::numpy_style_broadcast_for_matmul_operation(
left.get_node_shared_ptr(), right.get_node_shared_ptr());
left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1);
}
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
// Collapse both tensors _stack of matrices_ axes (all except the last two).
// This will make easier further dot product calculations.
if (left_shape.size() > 3)
{
left = onnx_import::reshape::collapse(left, 0, left_shape.size() - 3);
left = builder::collapse(left, 0, left_shape.size() - 3);
}
if (right_shape.size() > 3)
{
right = onnx_import::reshape::collapse(right, 0, right_shape.size() - 3);
right = builder::collapse(right, 0, right_shape.size() - 3);
}
// Perform multiple small dot products
std::size_t groups = left->get_shape().at(0);
size_t groups = left.get_shape().at(0);
// If we haven't broadcast earlier this means that one of the inputs is a vector,
// thus the number of groups is defined by the shape of the bigger tensor.
if (right->get_shape().size() > left->get_shape().size())
if (right.get_shape().size() > left.get_shape().size())
{
groups = right->get_shape().at(0);
groups = right.get_shape().at(0);
}
NodeVector small_dots(groups);
for (std::size_t g = 0; g < groups; ++g)
for (size_t g = 0; g < groups; ++g)
{
const auto sliced_left = get_sub_matrix(left, g);
const auto sliced_right = get_sub_matrix(right, g);
......@@ -136,11 +127,11 @@ ngraph::NodeVector MatmulFactory::make_matmul_op()
// Expand sub_dot result with single empty outermost axis, in order to
// later concatenate sub_dots at this axis.
small_dots.at(g) = onnx_import::reshape::expand_dims(sub_dot);
small_dots.at(g) = builder::expand_dims(sub_dot);
}
// Concatenate sub_dots on groups axis.
auto result = std::make_shared<ngraph::op::Concat>(small_dots, 0);
auto result = make_shared<op::Concat>(small_dots, 0);
if (left_shape.size() <= 3 && right_shape.size() <= 3)
{
......@@ -150,39 +141,35 @@ ngraph::NodeVector MatmulFactory::make_matmul_op()
else
{
const Shape& shape{result->get_shape()};
Shape result_shape(std::next(std::begin(shape)), std::end(shape));
result_shape.insert(std::begin(result_shape),
std::begin(left_shape),
std::next(std::begin(left_shape), left_shape.size() - 2));
return {std::make_shared<ngraph::op::Reshape>(
result, ngraph::get_default_order(shape.size()), result_shape)};
Shape result_shape(next(begin(shape)), end(shape));
result_shape.insert(
begin(result_shape), begin(left_shape), next(begin(left_shape), left_shape.size() - 2));
return {make_shared<op::Reshape>(result, get_default_order(shape.size()), result_shape)};
}
}
std::shared_ptr<ngraph::Node> MatmulFactory::make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
Output<Node> builder::MatmulFactory::make_dot(const Output<Node>& left, const Output<Node>& right)
{
return std::make_shared<ngraph::op::Dot>(left, right);
return make_shared<op::Dot>(left, right);
}
std::shared_ptr<ngraph::Node> QLinearMatmulFactory::get_right()
Output<Node> builder::QLinearMatmulFactory::get_right()
{
return m_inputs.at(3);
}
std::shared_ptr<ngraph::Node>
QLinearMatmulFactory::make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
Output<Node> builder::QLinearMatmulFactory::make_dot(const Output<Node>& left,
const Output<Node>& right)
{
ngraph::element::Type output_type;
if (left->get_element_type() == ngraph::element::u8 &&
right->get_element_type() == ngraph::element::i8)
if (left.get_element_type() == ngraph::element::u8 &&
right.get_element_type() == ngraph::element::i8)
{
output_type = ngraph::element::i8;
}
else if (left->get_element_type() == ngraph::element::u8 &&
right->get_element_type() == ngraph::element::u8)
else if (left.get_element_type() == ngraph::element::u8 &&
right.get_element_type() == ngraph::element::u8)
{
output_type = ngraph::element::u8;
}
......@@ -202,15 +189,14 @@ std::shared_ptr<ngraph::Node>
ngraph::AxisSet{});
}
std::shared_ptr<ngraph::Node>
MatmulIntegerFactory::make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
Output<Node> builder::MatmulIntegerFactory::make_dot(const Output<Node>& left,
const Output<Node>& right)
{
auto num_inputs = m_inputs.size();
auto scale_one = ngraph::builder::make_constant(ngraph::element::f32, Shape{}, 1);
auto output_zero_point = ngraph::builder::make_constant(ngraph::element::i32, Shape{}, 0);
auto left_zero_point = ngraph::builder::make_constant(left->get_element_type(), Shape{}, 0);
auto right_zero_point = ngraph::builder::make_constant(right->get_element_type(), Shape{}, 0);
auto left_zero_point = ngraph::builder::make_constant(left.get_element_type(), Shape{}, 0);
auto right_zero_point = ngraph::builder::make_constant(right.get_element_type(), Shape{}, 0);
if (num_inputs == 2)
{
return std::make_shared<ngraph::op::QuantizedDot>(left,
......@@ -228,10 +214,10 @@ std::shared_ptr<ngraph::Node>
ngraph::AxisSet{});
}
left_zero_point = m_inputs.at(2);
left_zero_point = m_inputs.at(2).get_node_shared_ptr();
if (num_inputs == 4)
{
right_zero_point = m_inputs.at(3);
right_zero_point = m_inputs.at(3).get_node_shared_ptr();
}
return std::make_shared<ngraph::op::QuantizedDot>(left,
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
namespace ngraph
{
namespace builder
{
/// \brief Factory class which generates an nGraph sub-graph performing MatMul operation.
///
/// This default implementation `MatmulFactory` creates a `MatMul` operation for
/// floating-point data.
/// Subclasses: `QLinearMatmulFactory` and `MatmulIntegerFactory` implement quantized
/// versions.
class MatmulFactory
{
public:
explicit MatmulFactory(const OutputVector& inputs)
: m_inputs(inputs)
{
}
virtual ~MatmulFactory() = default;
/// \brief Create a sub-graph representing an ONNX MatMul operation.
///
/// \return NodeVector containing the sub-graph output node.
virtual NodeVector make_matmul_op();
protected:
/// \return Output representing the left operand.
virtual Output<Node> get_left();
/// \return Output representing the right operand.
virtual Output<Node> get_right();
/// \return Output representing the nGraph Dot operation used to construct MatMul.
virtual Output<Node> make_dot(const Output<Node>& left, const Output<Node>& right);
const OutputVector m_inputs;
};
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX QLinearMatMul
/// operation.
class QLinearMatmulFactory : public MatmulFactory
{
public:
explicit QLinearMatmulFactory(const OutputVector& inputs)
: MatmulFactory(inputs)
{
}
protected:
Output<Node> get_right() override;
Output<Node> make_dot(const Output<Node>& left, const Output<Node>& right) override;
};
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX MatMulInteger
/// operation.
class MatmulIntegerFactory : public MatmulFactory
{
public:
explicit MatmulIntegerFactory(const OutputVector& inputs)
: MatmulFactory(inputs)
{
}
protected:
Output<Node> make_dot(const Output<Node>& left, const Output<Node>& right) override;
};
} // namespace builder
} // namespace ngraph
......@@ -120,3 +120,51 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, const Output<Node>&
// result := DynReshape(value, flattened_dims)
return make_shared<op::DynReshape>(value, flattened_dims);
}
shared_ptr<Node> builder::squeeze(const Output<Node>& value, vector<size_t> axes)
{
if (axes.empty())
{
return value.get_node_shared_ptr();
}
Shape in_shape{value.get_shape()};
for (size_t idx = 0; idx < axes.size(); ++idx)
{
in_shape.at(idx) = 0;
}
Shape output_shape;
for (auto axis : in_shape)
{
if (axis != 0)
{
output_shape.push_back(axis);
}
}
return builder::reshape(value, output_shape);
}
shared_ptr<Node>
builder::collapse(const Output<Node>& value, const size_t start_axis, const size_t end_axis)
{
auto shape = value.get_shape();
size_t collapsed_axis_size = accumulate(next(begin(shape), start_axis),
next(begin(shape), end_axis + 1),
1UL,
multiplies<size_t>());
Shape output_shape{collapsed_axis_size};
output_shape.insert(end(output_shape), next(begin(shape), end_axis + 1), end(shape));
return builder::reshape(value, output_shape);
}
shared_ptr<Node> builder::expand_dims(const Output<Node>& value, size_t axis)
{
Shape output_shape(value.get_shape());
// Add empty axis at specified position.
auto empty_axis_it = begin(output_shape);
advance(empty_axis_it, axis);
output_shape.insert(empty_axis_it, 1);
return make_shared<op::Reshape>(
value, get_default_order(value.get_shape().size()), output_shape);
}
......@@ -55,7 +55,7 @@ namespace ngraph
/// \brief Flatten a value into a 2D matrix, with a static dividing axis.
///
/// \param value The tensor to be flattened.
/// \param axis The axis dividing shape.
/// \param axis The axis dividing shape.
///
/// \return The new value will be a 2D matrix representing the flattened input node.
std::shared_ptr<Node> flatten(const Output<Node>& value, int axis);
......@@ -68,5 +68,40 @@ namespace ngraph
///
/// \return The new value will be a 2D matrix representing the flattened input node.
std::shared_ptr<Node> flatten(const Output<Node>& value, const Output<Node>& axis);
/// \brief Remove empty axes from input tensor.
///
/// \param[in] value The value to be squeezed.
/// \param[in] axes The vector defining indexes of axes to be removed.
///
/// \return The squeezed node.
///
std::shared_ptr<Node> squeeze(const Output<Node>& value,
std::vector<std::size_t> axes = {0});
/// \brief Collapse specified axes into single one.
///
/// \note Collapsed axes create a continuous range starting from outermost axis.
///
/// \param[in] value The value to be reshaped.
/// \param[in] start_axis The start axis index.
/// \param[in] end_axis The end axis (inclusive) index.
///
/// \return The node with collapsed specified axes.
///
std::shared_ptr<Node> collapse(const Output<Node>& value,
const std::size_t start_axis,
const std::size_t end_axis);
/// \brief Expands node tensor shape with empty axis at
/// specified position.
///
/// \param[in] value The value to be expanded.
/// \param[in] axis The position in the expanded axes where the
/// new axis is placed.
///
/// \return The node with added empty axis.
///
std::shared_ptr<Node> expand_dims(const Output<Node>& value, std::size_t axis = 0);
} // namespace builder
} // namespace ngraph
......@@ -204,8 +204,6 @@ add_library(onnx_import STATIC
utils/common.hpp
utils/convpool.cpp
utils/convpool.hpp
utils/matmul_factory.cpp
utils/matmul_factory.hpp
utils/pooling_factory.cpp
utils/pooling_factory.hpp
utils/reduction.cpp
......
......@@ -26,6 +26,7 @@
#include "exceptions.hpp"
#include "lstm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
......@@ -243,12 +244,12 @@ namespace ngraph
const LSTMAttributes& attributes)
: m_X{X} // Since we have forward LSTM we can squeeze `num_directions` axis
// from inputs.
, m_W(reshape::squeeze(W))
, m_R(reshape::squeeze(R))
, m_B(reshape::squeeze(B))
, m_P(reshape::squeeze(P))
, m_initial_h(reshape::squeeze(initial_h))
, m_initial_c(reshape::squeeze(initial_c))
, m_W(builder::squeeze(W))
, m_R(builder::squeeze(R))
, m_B(builder::squeeze(B))
, m_P(builder::squeeze(P))
, m_initial_h(builder::squeeze(initial_h))
, m_initial_c(builder::squeeze(initial_c))
, m_seq_lengths(seq_lengths)
, m_attributes(attributes)
{
......@@ -300,7 +301,7 @@ namespace ngraph
for (auto& in_x : in_seqs)
{
// remove first empty dim, after above split.
in_x = reshape::squeeze(in_x);
in_x = builder::squeeze(in_x);
}
std::int32_t time_step{1};
......@@ -331,7 +332,7 @@ namespace ngraph
// This results in zeroing out values in batches with sequence shorter
// than current time_step.
h_list.push_back(
get_masked_node(reshape::expand_dims(H), time_step, 1));
get_masked_node(builder::expand_dims(H), time_step, 1));
// Reference implementation in ONNX Runtime doesn't mask values of Y_h
// and Y_c outputs, thus here we make sure that only appropriate batches
// (in respect to its sequence length) are updated. Those batches which
......@@ -354,12 +355,12 @@ namespace ngraph
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
Y = reshape::expand_dims(Y, 1);
Y = builder::expand_dims(Y, 1);
// expand H_t and C_t so that it has expected shape:
// [num_directions, batch_size, hidden_size]
auto Y_h = reshape::expand_dims(H_t);
auto Y_c = reshape::expand_dims(C_t);
auto Y_h = builder::expand_dims(H_t);
auto Y_c = builder::expand_dims(C_t);
return {Y, Y_h, Y_c};
}
......
......@@ -15,7 +15,7 @@
//*****************************************************************************
#include "matmul.hpp"
#include "frontend/onnx_import/utils/matmul_factory.hpp"
#include "ngraph/builder/matmul_factory.hpp"
namespace ngraph
{
......@@ -27,7 +27,20 @@ namespace ngraph
{
NodeVector matmul(const Node& node)
{
auto factory = matmul::MatmulFactory(node);
auto ng_inputs = node.get_ng_inputs();
auto factory = builder::MatmulFactory(
(OutputVector(std::begin(ng_inputs), std::end(ng_inputs))));
std::size_t left_rank{ng_inputs.at(0)->get_shape().size()};
std::size_t right_rank{ng_inputs.at(1)->get_shape().size()};
if (left_rank == 0 || right_rank == 0)
{
NGRAPH_WARN
<< (node) << " "
<< "ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard.";
}
return factory.make_matmul_op();
}
} // namespace set_1
......
......@@ -15,7 +15,7 @@
//*****************************************************************************
#include "matmul_integer.hpp"
#include "frontend/onnx_import/utils/matmul_factory.hpp"
#include "ngraph/builder/matmul_factory.hpp"
namespace ngraph
{
......@@ -27,7 +27,20 @@ namespace ngraph
{
NodeVector matmul_integer(const Node& node)
{
auto factory = matmul::MatmulIntegerFactory(node);
auto ng_inputs = node.get_ng_inputs();
auto factory = builder::MatmulIntegerFactory(
OutputVector(std::begin(ng_inputs), std::end(ng_inputs)));
std::size_t left_rank{ng_inputs.at(0)->get_shape().size()};
std::size_t right_rank{ng_inputs.at(1)->get_shape().size()};
if (left_rank == 0 || right_rank == 0)
{
NGRAPH_WARN
<< (node) << " "
<< "ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard.";
}
return factory.make_matmul_op();
}
} // namespace set_1
......
......@@ -15,7 +15,7 @@
//*****************************************************************************
#include "qlinear_matmul.hpp"
#include "frontend/onnx_import/utils/matmul_factory.hpp"
#include "ngraph/builder/matmul_factory.hpp"
namespace ngraph
{
......@@ -27,7 +27,20 @@ namespace ngraph
{
NodeVector qlinear_matmul(const Node& node)
{
auto factory = matmul::QLinearMatmulFactory(node);
auto ng_inputs = node.get_ng_inputs();
auto factory = builder::QLinearMatmulFactory(
(OutputVector(std::begin(ng_inputs), std::end(ng_inputs))));
std::size_t left_rank{ng_inputs.at(0)->get_shape().size()};
std::size_t right_rank{ng_inputs.at(1)->get_shape().size()};
if (left_rank == 0 || right_rank == 0)
{
NGRAPH_WARN
<< (node) << " "
<< "ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard.";
}
return factory.make_matmul_op();
}
} // namespace set_1
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace matmul
{
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX MatMul
/// operation.
///
/// \note
/// The sub-graph is needed to adjust nGraph's Dot operation semantics to semantics
/// expected by ONNX, which are modeled on NumPy's "stacks of arrays" approach.
/// Differences are apparent with matrices of rank > 2.
///
/// This default implementation `MatmulFactory` creates a `MatMul` operation for
/// floating-point data. Subclasses: `QLinearMatmulFactory` and `MatmulIntegerFactory`
/// implement quantized versions.
class MatmulFactory
{
public:
explicit MatmulFactory(const Node& node)
: m_onnx_node(node)
, m_inputs(node.get_ng_inputs())
{
}
virtual ~MatmulFactory() = default;
/// \brief Create a sub-graph representing an ONNX MatMul operation.
///
/// \return NodeVector containing the sub-graph output node.
virtual NodeVector make_matmul_op();
/// \return Node representing the left operand.
virtual std::shared_ptr<ngraph::Node> get_left();
/// \return Node representing the right operand.
virtual std::shared_ptr<ngraph::Node> get_right();
/// \return Node representing the nGraph Dot operation used to construct MatMul.
virtual std::shared_ptr<ngraph::Node>
make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
protected:
const Node& m_onnx_node;
const NodeVector m_inputs;
};
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX
/// QLinearMatMul operation.
class QLinearMatmulFactory : public MatmulFactory
{
public:
explicit QLinearMatmulFactory(const Node& node)
: MatmulFactory(node)
{
}
std::shared_ptr<ngraph::Node> get_right() override;
std::shared_ptr<ngraph::Node>
make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right) override;
};
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX
/// MatMulInteger operation.
class MatmulIntegerFactory : public MatmulFactory
{
public:
explicit MatmulIntegerFactory(const Node& node)
: MatmulFactory(node)
{
}
std::shared_ptr<ngraph::Node>
make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right) override;
};
}
}
}
......@@ -85,60 +85,6 @@ namespace ngraph
return inferred_dims;
}
std::shared_ptr<ngraph::Node> squeeze(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes)
{
if (axes.empty())
{
return node;
}
Shape in_shape{node->get_shape()};
for (std::size_t idx = 0; idx < axes.size(); ++idx)
{
in_shape.at(idx) = 0;
}
Shape output_shape;
for (auto axis : in_shape)
{
if (axis != 0)
{
output_shape.push_back(axis);
}
}
return ngraph::builder::reshape(node, output_shape);
}
std::shared_ptr<ngraph::Node> collapse(const std::shared_ptr<ngraph::Node>& node,
const std::size_t start_axis,
const std::size_t end_axis)
{
auto shape = node->get_shape();
std::size_t collapsed_axis_size =
std::accumulate(std::next(std::begin(shape), start_axis),
std::next(std::begin(shape), end_axis + 1),
1UL,
std::multiplies<std::size_t>());
Shape output_shape{collapsed_axis_size};
output_shape.insert(std::end(output_shape),
std::next(std::begin(shape), end_axis + 1),
std::end(shape));
return ngraph::builder::reshape(node, output_shape);
}
std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node,
std::size_t axis)
{
Shape output_shape(node->get_shape());
// Add empty axis at specified position.
auto empty_axis_it = std::begin(output_shape);
std::advance(empty_axis_it, axis);
output_shape.insert(empty_axis_it, 1);
return std::make_shared<ngraph::op::Reshape>(
node, ngraph::get_default_order(node->get_shape().size()), output_shape);
}
std::shared_ptr<ngraph::Node>
interpret_as_scalar(const std::shared_ptr<ngraph::Node>& node)
{
......
......@@ -50,42 +50,6 @@ namespace ngraph
const std::vector<std::size_t>& input_shape,
const std::vector<std::size_t>& output_shape);
/// \brief Remove empty axes from input tensor.
///
/// \param[in] node The node to be squeezed.
/// \param[in] axes The vector defining indexes of axes to be removed.
///
/// \return The squeezed node.
///
std::shared_ptr<ngraph::Node> squeeze(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes = {0});
/// \brief Collapse specified axes into single one.
///
/// \note Collapsed axes create a continuous range starting from outermost axis.
///
/// \param[in] node The node to be reshaped.
/// \param[in] start_axis The start axis index.
/// \param[in] end_axis The end axis (inclusive) index.
///
/// \return The node with collapsed specified axes.
///
std::shared_ptr<ngraph::Node> collapse(const std::shared_ptr<ngraph::Node>& node,
const std::size_t start_axis,
const std::size_t end_axis);
/// \brief Expands node tensor shape with empty axis at
/// specified position.
///
/// \param[in] node The node to be expanded.
/// \param[in] axis The position in the expanded axes where the
/// new axis is placed.
///
/// \return The node with added empty axis.
///
std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node,
std::size_t axis = 0);
/// \brief Handle a node which represents a scalar value.
///
/// \note Some ONNX nodes, which should provide scalar values are given as
......
......@@ -135,6 +135,7 @@ namespace ngraph
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/prelu.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <numeric>
#include "matmul.hpp"
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp"
using namespace std;
using namespace ngraph;
const string op::MatMul::type_name{"MatMul"};
op::MatMul::MatMul(const Output<Node>& A,
const Output<Node>& B,
const bool& transpose_a,
const bool& transpose_b)
: FusedOp(OutputVector{A, B})
, m_transpose_a{transpose_a}
, m_transpose_b{transpose_b}
{
constructor_validate_and_infer_types();
}
NodeVector op::MatMul::decompose_op() const
{
auto A = input_value(0);
auto B = input_value(1);
// Specification is expecting that A & B have at least 2 dimenstions.
// Missing dimensions are padded with 1.
int a_rank = A.get_shape().size();
if (a_rank < 2)
{
A = a_rank == 0 ? make_shared<op::Reshape>(A, AxisVector{}, Shape{1, 1})
: make_shared<op::Reshape>(A, AxisVector{1}, Shape{1, A.get_shape()[0]});
a_rank = 2;
}
int b_rank = B.get_shape().size();
if (b_rank < 2)
{
B = b_rank == 0 ? make_shared<op::Reshape>(B, AxisVector{}, Shape{1, 1})
: make_shared<op::Reshape>(B, AxisVector{1}, Shape{1, B.get_shape()[0]});
b_rank = 2;
}
if (m_transpose_a)
{
vector<size_t> axes_order(a_rank);
// generate default axes_order.
iota(axes_order.begin(), axes_order.end(), 0);
// transpose the last 2 spatial dims
swap(axes_order[a_rank - 1], axes_order[a_rank - 2]);
A = builder::reorder_axes(A, axes_order);
}
if (m_transpose_b)
{
vector<size_t> axes_order(b_rank);
iota(axes_order.begin(), axes_order.end(), 0);
swap(axes_order[b_rank - 1], axes_order[b_rank - 2]);
B = builder::reorder_axes(B, axes_order);
}
builder::MatmulFactory factory({A, B});
return factory.make_matmul_op();
}
shared_ptr<Node> op::MatMul::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<MatMul>(new_args.at(0), new_args.at(1), m_transpose_a, m_transpose_b);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Operator performing Matrix Multiplication.
class MatMul : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
MatMul() = default;
/// \brief Constructs an ScaleShift operation.
///
/// \param A Matrix A
/// \param B Matrix B
/// \param transpose_a If matrix A should be transposed.
/// \param transpose_b If matrix B should be transposed.
MatMul(const Output<Node>& A,
const Output<Node>& B,
const bool& transpose_a = 0,
const bool& transpose_b = 0);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_transpose_a() const { return m_transpose_a; }
bool get_transpose_b() const { return m_transpose_b; }
private:
const bool m_transpose_a;
const bool m_transpose_b;
};
} // namespace op
} // namespace ngraph
......@@ -38,6 +38,7 @@ NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
......
......@@ -89,6 +89,7 @@
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
......@@ -2072,6 +2073,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::GRUCell:
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MatMul:
case OP_TYPEID::MVN:
case OP_TYPEID::NormalizeL2:
case OP_TYPEID::PRelu:
......@@ -2197,6 +2199,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::GRUCell:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MatMul:
case OP_TYPEID::MVN:
case OP_TYPEID::NormalizeL2:
case OP_TYPEID::PRelu:
......
......@@ -80,6 +80,7 @@
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/prelu.hpp"
......@@ -1402,6 +1403,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
input_forget);
break;
}
case OP_TYPEID::MatMul:
{
bool transpose_a = node_js.at("transpose_a").get<bool>();
bool transpose_b = node_js.at("transpose_b").get<bool>();
node = make_shared<op::MatMul>(args[0], args[1], transpose_a, transpose_b);
break;
}
case OP_TYPEID::Max:
{
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
......@@ -2557,6 +2565,13 @@ json JSONSerializer::serialize_node(const Node& n)
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::MatMul:
{
auto tmp = dynamic_cast<const op::MatMul*>(&n);
node["transpose_a"] = tmp->get_transpose_a();
node["transpose_b"] = tmp->get_transpose_b();
break;
}
case OP_TYPEID::Max:
{
auto tmp = dynamic_cast<const op::Max*>(&n);
......
......@@ -114,6 +114,7 @@ set(SRC
type_prop/index_reduction.cpp
type_prop/lrn.cpp
type_prop/lstm_cell.cpp
type_prop/matmul.cpp
type_prop/max_pool.cpp
type_prop/mvn.cpp
type_prop/normalize.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, matmul_2D_same)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{2, 2});
auto matmul = make_shared<op::MatMul>(A, B);
ASSERT_EQ(matmul->get_element_type(), element::f32);
ASSERT_EQ(matmul->get_shape(), (Shape{2, 2}));
}
TEST(type_prop, matmul_4D_same)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2, 3, 3});
auto B = make_shared<op::Parameter>(element::f32, Shape{2, 2, 3, 3});
auto matmul = make_shared<op::MatMul>(A, B);
ASSERT_EQ(matmul->get_element_type(), element::f32);
ASSERT_EQ(matmul->get_shape(), (Shape{2, 2, 3, 3}));
}
TEST(type_prop, matmul_2D)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6});
auto B = make_shared<op::Parameter>(element::f32, Shape{6, 4});
auto matmul = make_shared<op::MatMul>(A, B);
ASSERT_EQ(matmul->get_element_type(), element::f32);
ASSERT_EQ(matmul->get_shape(), (Shape{3, 4}));
}
TEST(type_prop, matmul_4D)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2, 3, 6});
auto B = make_shared<op::Parameter>(element::f32, Shape{2, 2, 6, 4});
auto matmul = make_shared<op::MatMul>(A, B);
ASSERT_EQ(matmul->get_element_type(), element::f32);
ASSERT_EQ(matmul->get_shape(), (Shape{2, 2, 3, 4}));
}
TEST(type_prop, matmul_2D_transpose_a)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{6, 3});
auto B = make_shared<op::Parameter>(element::f32, Shape{6, 4});
auto matmul = make_shared<op::MatMul>(A, B, 1);
ASSERT_EQ(matmul->get_element_type(), element::f32);
ASSERT_EQ(matmul->get_shape(), (Shape{3, 4}));
}
TEST(type_prop, matmul_4D_transpose_a)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2, 6, 3});
auto B = make_shared<op::Parameter>(element::f32, Shape{2, 2, 6, 4});
auto matmul = make_shared<op::MatMul>(A, B, 1);
ASSERT_EQ(matmul->get_element_type(), element::f32);
ASSERT_EQ(matmul->get_shape(), (Shape{2, 2, 3, 4}));
}
TEST(type_prop, matmul_2D_transpose_b)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6});
auto B = make_shared<op::Parameter>(element::f32, Shape{4, 6});
auto matmul = make_shared<op::MatMul>(A, B, 0, 1);
ASSERT_EQ(matmul->get_element_type(), element::f32);
ASSERT_EQ(matmul->get_shape(), (Shape{3, 4}));
}
TEST(type_prop, matmul_4D_transpose_b)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2, 3, 6});
auto B = make_shared<op::Parameter>(element::f32, Shape{2, 2, 4, 6});
auto matmul = make_shared<op::MatMul>(A, B, 0, 1);
ASSERT_EQ(matmul->get_element_type(), element::f32);
ASSERT_EQ(matmul->get_shape(), (Shape{2, 2, 3, 4}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment