Commit e51c5824 authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[ONNX] Add MatMulInteger op (#3011)

* Unit tests for MatMulInteger

* Add ONNX MatMulInteger op

* Add QuantizedLinearMatmulInteger builder

* Additional unit test

* Exclude tests on nVidia GPU backend

* Add 4D test case

* Enable >2D MatMulInteger

* Refactoring to MatMulFactory - step 1

* Refactoring to MatMulFactory - step 2

* Remove `using namespace ngraph` to make `Node` unambiguous.

* Disable quantized ops tests on GPU backend

* Remove unused `includes`

* Remove redundant dynamic_pointer_cast

* Remove redundant `move`

* Add const correctness

* Code review comments

* Style apply

* Add documentation

* Use more complex shapes in tests
parent 34f6c57c
......@@ -48,14 +48,9 @@ namespace ngraph
const shared_ptr<Node>& output_scale,
const shared_ptr<Node>& output_zero_point)
{
auto input0_zero = dynamic_pointer_cast<ngraph::op::Constant>(input0_zero_point);
auto input1_zero = dynamic_pointer_cast<ngraph::op::Constant>(input1_zero_point);
auto output_zero = dynamic_pointer_cast<ngraph::op::Constant>(output_zero_point);
// Check if zero point is constant and zero
if (input0_zero != nullptr && input1_zero != nullptr && output_zero != nullptr &&
ngraph::is_zero(input0_zero) && ngraph::is_zero(input1_zero) &&
ngraph::is_zero(output_zero))
if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point) &&
ngraph::is_zero(output_zero_point))
{
auto requantization_scale = (input0_scale * input1_scale) / output_scale;
return make_shared<op::QuantizedDot>(input0, input1, requantization_scale);
......@@ -77,14 +72,13 @@ namespace ngraph
axes);
auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1);
auto q_dot = make_shared<op::Quantize>(
return make_shared<op::Quantize>(
dot,
output_scale,
output_zero_point,
output_zero_point->get_element_type(),
axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
return move(q_dot);
}
}
......@@ -94,6 +88,51 @@ namespace ngraph
auto output_scale = make_constant(element::f32, Shape{}, 1);
return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false);
}
shared_ptr<Node>
QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0,
const std::shared_ptr<Node>& input1,
const std::shared_ptr<Node>& input0_zero_point,
const std::shared_ptr<Node>& input1_zero_point)
{
// Check if zero points are constant and zero
if (ngraph::is_zero(input0_zero_point) && ngraph::is_zero(input1_zero_point))
{
return QuantizedLinearMatmulInteger(input0, input1);
}
else
{
// Fall back to performing matmul on dequantized floating-point values
const auto input0_scale = make_constant(element::f32, Shape{}, 1);
const auto input1_scale = make_constant(element::f32, Shape{}, 1);
const auto output_scale = make_constant(element::f32, Shape{}, 1);
const auto output_zero_point = make_constant(element::i32, Shape{}, 0);
const AxisSet axes;
const auto dq_input0 =
make_shared<op::Dequantize>(input0,
input0_scale,
input0_zero_point,
input0_scale->get_element_type(),
axes);
const auto dq_input1 =
make_shared<op::Dequantize>(input1,
input1_scale,
input1_zero_point,
input1_scale->get_element_type(),
axes);
const auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1);
return make_shared<op::Quantize>(
dot,
output_scale,
output_zero_point,
output_zero_point->get_element_type(),
axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
}
}
}
}
}
......@@ -37,6 +37,12 @@ namespace ngraph
std::shared_ptr<Node> QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0,
const std::shared_ptr<Node>& input1);
std::shared_ptr<Node>
QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0,
const std::shared_ptr<Node>& input1,
const std::shared_ptr<Node>& input0_zero_point,
const std::shared_ptr<Node>& input1_zero_point);
}
}
}
......@@ -112,6 +112,8 @@ add_library(onnx_import STATIC
op/lstm.hpp
op/matmul.cpp
op/matmul.hpp
op/matmul_integer.cpp
op/matmul_integer.hpp
op/max_pool.cpp
op/max_pool.hpp
op/max.hpp
......@@ -129,11 +131,12 @@ add_library(onnx_import STATIC
op/pow.hpp
op/prelu.cpp
op/prelu.hpp
op/qlinear_matmul.cpp
op/qlinear_matmul.hpp
op/quant_conv.cpp
op/quant_conv.hpp
op/quantize_linear.cpp
op/quantize_linear.hpp
op/quantized_matmul.hpp
op/reciprocal.cpp
op/reciprocal.hpp
op/reduce.cpp
......@@ -185,6 +188,8 @@ add_library(onnx_import STATIC
utils/common.hpp
utils/convpool.cpp
utils/convpool.hpp
utils/matmul_factory.cpp
utils/matmul_factory.hpp
utils/reduction.cpp
utils/reduction.hpp
utils/reshape.cpp
......
......@@ -14,53 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <iterator>
#include <memory>
#include <vector>
#include "exceptions.hpp"
#include "matmul.hpp"
#include "ngraph/builder/quantization/quantized_linear_matmul.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include "utils/reshape.hpp"
/// \brief Slice the sub matrix from the input tensor.
///
/// \param[in] node The input tensor. Must be at most of rank 3.
/// \param[in] idx The index on the first axis, at which to slice sub-matrix.
///
/// \return The node representing sub matrix.
///
static std::shared_ptr<ngraph::Node> get_sub_matrix(const std::shared_ptr<ngraph::Node>& node,
std::size_t idx)
{
const ngraph::Shape& shape{node->get_shape()};
if (shape.size() < 3)
{
return node;
}
// Below bounds defines the sub_matrix through ranges for each input node axis.
ngraph::Coordinate lower_bounds(shape.size());
ngraph::Coordinate upper_bounds = shape;
// We assume `node` tensor is of rank equal 3, thus we slice the sub-matrix lying in the last
// two dimensions at index `idx` of first axis.
lower_bounds.at(0) = idx;
upper_bounds.at(0) = idx + 1;
auto sub_matrix = std::shared_ptr<ngraph::Node>{
std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds)};
// Remove first single entry dim.
return ngraph::onnx_import::reshape::squeeze(sub_matrix);
}
#include "frontend/onnx_import/utils/matmul_factory.hpp"
namespace ngraph
{
......@@ -70,143 +25,11 @@ namespace ngraph
{
namespace set_1
{
NodeVector make_matmul_op(const Node& node, bool quantized)
NodeVector matmul(const Node& node)
{
const NodeVector& ng_inputs{node.get_ng_inputs()};
auto left = std::shared_ptr<ngraph::Node>{ng_inputs.at(0)};
auto right = std::shared_ptr<ngraph::Node>{};
auto scale = std::shared_ptr<ngraph::Node>{};
if (quantized)
{
right = ng_inputs.at(3);
scale = ng_inputs.at(6);
}
else
{
right = ng_inputs.at(1);
}
std::size_t left_rank{left->get_shape().size()};
std::size_t right_rank{right->get_shape().size()};
if (left_rank == 0 || right_rank == 0)
{
NGRAPH_WARN
<< (node) << " "
<< "ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard.";
}
// First (easy) case that is already internally handled by Ngraph Dot operator.
// Multiply two tensors where both of them has rank lower equal 2.
if (left_rank <= 2 && right_rank <= 2)
{
if (quantized)
{
return {ngraph::builder::quantization::QuantizedLinearMatmul(
left,
right,
ng_inputs.at(1),
ng_inputs.at(2),
ng_inputs.at(4),
ng_inputs.at(5),
ng_inputs.at(6),
ng_inputs.at(7))};
}
else
{
return {std::make_shared<ngraph::op::Dot>(left, right)};
}
}
// Second case:
// Multiply two tensors where at least one of them is rank greater equal 3.
// Broadcast input arguments only if both of them are not vectors.
if (left_rank > 1 && right_rank > 1)
{
const NodeVector& broadcasted_nodes =
ngraph::op::numpy_style_broadcast_for_matmul_operation(left, right);
left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1);
}
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
// Collapse both tensors _stack of matrices_ axes (all except the last two).
// This will make easier further dot product calculations.
if (left_shape.size() > 3)
{
left = reshape::collapse(left, 0, left_shape.size() - 3);
}
if (right_shape.size() > 3)
{
right = reshape::collapse(right, 0, right_shape.size() - 3);
}
// Perform multiple small dot products
std::size_t groups = left->get_shape().at(0);
// If we haven't broadcast earlier this means that one of the inputs is a vector,
// thus the number of groups is defined by the shape of the bigger tensor.
if (right->get_shape().size() > left->get_shape().size())
{
groups = right->get_shape().at(0);
}
NodeVector small_dots(groups);
for (std::size_t g = 0; g < groups; ++g)
{
const auto& sliced_left = get_sub_matrix(left, g);
auto sliced_right = get_sub_matrix(right, g);
auto sub_dot = std::shared_ptr<ngraph::Node>{};
if (quantized)
{
sub_dot = ngraph::builder::quantization::QuantizedLinearMatmul(
sliced_left,
sliced_right,
ng_inputs.at(1),
ng_inputs.at(2),
ng_inputs.at(4),
ng_inputs.at(5),
ng_inputs.at(6),
ng_inputs.at(7));
}
else
{
sub_dot = std::make_shared<ngraph::op::Dot>(sliced_left, sliced_right);
}
// Expand sub_dot result with single empty outermost axis, in order to
// later concatenate sub_dots at this axis.
small_dots.at(g) = reshape::expand_dims(sub_dot);
}
// Concatenate sub_dots on groups axis.
auto result = std::make_shared<ngraph::op::Concat>(small_dots, 0);
if (left_shape.size() <= 3 && right_shape.size() <= 3)
{
return {result};
}
// Expand result _stack of matrices_ axes to get expected result shape.
else
{
const Shape& shape{result->get_shape()};
Shape result_shape(std::next(std::begin(shape)), std::end(shape));
result_shape.insert(
std::begin(result_shape),
std::begin(left_shape),
std::next(std::begin(left_shape), left_shape.size() - 2));
return {std::make_shared<ngraph::op::Reshape>(
result, ngraph::get_default_order(shape.size()), result_shape)};
}
auto factory = matmul::MatmulFactory(node);
return factory.make_matmul_op();
}
NodeVector matmul(const Node& node) { return make_matmul_op(node, false); }
} // namespace set_1
} //namespace op
......
......@@ -27,7 +27,6 @@ namespace ngraph
{
namespace set_1
{
NodeVector make_matmul_op(const Node& node, bool quantized);
NodeVector matmul(const Node& node);
} // namespace set_1
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "matmul_integer.hpp"
#include "frontend/onnx_import/utils/matmul_factory.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector matmul_integer(const Node& node)
{
auto factory = matmul::MatmulIntegerFactory(node);
return factory.make_matmul_op();
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Performs ONNX MatMulInteger operation.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing Ngraph nodes producing output of quantized ONNX matrix
/// multiplication operation.
NodeVector matmul_integer(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "qlinear_matmul.hpp"
#include "frontend/onnx_import/utils/matmul_factory.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector qlinear_matmul(const Node& node)
{
auto factory = matmul::QLinearMatmulFactory(node);
return factory.make_matmul_op();
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -17,7 +17,6 @@
#pragma once
#include "core/node.hpp"
#include "ngraph/frontend/onnx_import/op/matmul.hpp"
#include "ngraph/node.hpp"
namespace ngraph
......@@ -28,7 +27,7 @@ namespace ngraph
{
namespace set_1
{
NodeVector quantized_matmul(const Node& node) { return make_matmul_op(node, true); }
NodeVector qlinear_matmul(const Node& node);
} // namespace set_1
} //namespace op
......
......@@ -70,6 +70,7 @@
#include "op/lrn.hpp"
#include "op/lstm.hpp"
#include "op/matmul.hpp"
#include "op/matmul_integer.hpp"
#include "op/max.hpp"
#include "op/max_pool.hpp"
#include "op/mean.hpp"
......@@ -83,9 +84,9 @@
#include "op/pad.hpp"
#include "op/pow.hpp"
#include "op/prelu.hpp"
#include "op/qlinear_matmul.hpp"
#include "op/quant_conv.hpp"
#include "op/quantize_linear.hpp"
#include "op/quantized_matmul.hpp"
#include "op/reciprocal.hpp"
#include "op/reduce.hpp"
#include "op/relu.hpp"
......@@ -273,6 +274,7 @@ namespace ngraph
REGISTER_OPERATOR("LRN", 1, lrn);
REGISTER_OPERATOR("LSTM", 1, lstm);
REGISTER_OPERATOR("MatMul", 1, matmul);
REGISTER_OPERATOR("MatMulInteger", 1, matmul_integer);
REGISTER_OPERATOR("MaxPool", 1, max_pool);
REGISTER_OPERATOR("Max", 1, max);
REGISTER_OPERATOR("Max", 8, max);
......@@ -290,7 +292,7 @@ namespace ngraph
REGISTER_OPERATOR("Pow", 1, pow);
REGISTER_OPERATOR("PRelu", 1, prelu);
REGISTER_OPERATOR("QLinearConv", 1, quant_conv);
REGISTER_OPERATOR("QLinearMatMul", 1, quantized_matmul);
REGISTER_OPERATOR("QLinearMatMul", 1, qlinear_matmul);
REGISTER_OPERATOR("QuantizeLinear", 1, quantize_linear);
REGISTER_OPERATOR("Reciprocal", 1, reciprocal);
REGISTER_OPERATOR("ReduceLogSum", 1, reduce_log_sum);
......
//*****************************************************************************
// Copyright 2018-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <iterator>
#include <memory>
#include <vector>
#include "matmul_factory.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/quantization/quantized_linear_matmul.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "utils/reshape.hpp"
using namespace ngraph::onnx_import::matmul;
/// \brief Slice the sub matrix from the input tensor.
///
/// \param[in] node The input tensor. Must be at most of rank 3.
/// \param[in] idx The index on the first axis, at which to slice sub-matrix.
///
/// \return The node representing sub matrix.
///
static std::shared_ptr<ngraph::Node> get_sub_matrix(const std::shared_ptr<ngraph::Node>& node,
std::size_t idx)
{
const ngraph::Shape& shape{node->get_shape()};
if (shape.size() < 3)
{
return node;
}
// Below bounds defines the sub_matrix through ranges for each input node axis.
ngraph::Coordinate lower_bounds(shape.size());
ngraph::Coordinate upper_bounds = shape;
// We assume `node` tensor is of rank equal 3, thus we slice the sub-matrix lying in the last
// two dimensions at index `idx` of first axis.
lower_bounds.at(0) = idx;
upper_bounds.at(0) = idx + 1;
auto sub_matrix = std::shared_ptr<ngraph::Node>{
std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds)};
// Remove first single entry dim.
return ngraph::onnx_import::reshape::squeeze(sub_matrix);
}
std::shared_ptr<ngraph::Node> MatmulFactory::get_left()
{
return m_inputs.at(0);
}
std::shared_ptr<ngraph::Node> MatmulFactory::get_right()
{
return m_inputs.at(1);
}
ngraph::NodeVector MatmulFactory::make_matmul_op()
{
auto left = get_left();
auto right = get_right();
std::size_t left_rank{left->get_shape().size()};
std::size_t right_rank{right->get_shape().size()};
if (left_rank == 0 || right_rank == 0)
{
NGRAPH_WARN << (m_onnx_node) << " "
<< "ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard.";
}
// First (easy) case that is already internally handled by Ngraph Dot operator.
// Multiply two tensors where both of them has rank lower equal 2.
if (left_rank <= 2 && right_rank <= 2)
{
return NodeVector{make_dot(left, right)};
}
// Second case:
// Multiply two tensors where at least one of them is rank greater equal 3.
// Broadcast input arguments only if both of them are not vectors.
if (left_rank > 1 && right_rank > 1)
{
const NodeVector& broadcasted_nodes =
ngraph::op::numpy_style_broadcast_for_matmul_operation(left, right);
left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1);
}
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
// Collapse both tensors _stack of matrices_ axes (all except the last two).
// This will make easier further dot product calculations.
if (left_shape.size() > 3)
{
left = onnx_import::reshape::collapse(left, 0, left_shape.size() - 3);
}
if (right_shape.size() > 3)
{
right = onnx_import::reshape::collapse(right, 0, right_shape.size() - 3);
}
// Perform multiple small dot products
std::size_t groups = left->get_shape().at(0);
// If we haven't broadcast earlier this means that one of the inputs is a vector,
// thus the number of groups is defined by the shape of the bigger tensor.
if (right->get_shape().size() > left->get_shape().size())
{
groups = right->get_shape().at(0);
}
NodeVector small_dots(groups);
for (std::size_t g = 0; g < groups; ++g)
{
const auto sliced_left = get_sub_matrix(left, g);
const auto sliced_right = get_sub_matrix(right, g);
auto sub_dot = make_dot(sliced_left, sliced_right);
// Expand sub_dot result with single empty outermost axis, in order to
// later concatenate sub_dots at this axis.
small_dots.at(g) = onnx_import::reshape::expand_dims(sub_dot);
}
// Concatenate sub_dots on groups axis.
auto result = std::make_shared<ngraph::op::Concat>(small_dots, 0);
if (left_shape.size() <= 3 && right_shape.size() <= 3)
{
return {result};
}
// Expand result _stack of matrices_ axes to get expected result shape.
else
{
const Shape& shape{result->get_shape()};
Shape result_shape(std::next(std::begin(shape)), std::end(shape));
result_shape.insert(std::begin(result_shape),
std::begin(left_shape),
std::next(std::begin(left_shape), left_shape.size() - 2));
return {std::make_shared<ngraph::op::Reshape>(
result, ngraph::get_default_order(shape.size()), result_shape)};
}
}
std::shared_ptr<ngraph::Node> MatmulFactory::make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
{
return std::make_shared<ngraph::op::Dot>(left, right);
}
std::shared_ptr<ngraph::Node> QLinearMatmulFactory::get_right()
{
return m_inputs.at(3);
}
std::shared_ptr<ngraph::Node>
QLinearMatmulFactory::make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
{
return ngraph::builder::quantization::QuantizedLinearMatmul(left,
right,
m_inputs.at(1),
m_inputs.at(2),
m_inputs.at(4),
m_inputs.at(5),
m_inputs.at(6),
m_inputs.at(7));
}
std::shared_ptr<ngraph::Node>
MatmulIntegerFactory::make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
{
auto num_inputs = m_inputs.size();
if (num_inputs == 2)
{
return ngraph::builder::quantization::QuantizedLinearMatmulInteger(left, right);
}
auto left_zero_point = m_inputs.at(2);
auto right_zero_point = ngraph::builder::make_constant(right->get_element_type(), Shape{}, 0);
if (num_inputs == 4)
{
right_zero_point = m_inputs.at(3);
}
return ngraph::builder::quantization::QuantizedLinearMatmulInteger(
left, right, left_zero_point, right_zero_point);
}
//*****************************************************************************
// Copyright 2018-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace matmul
{
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX MatMul operation.
///
/// \note
/// The sub-graph is needed to adjust nGraph's Dot operation semantics to semantics
/// expected by ONNX, which are modeled on NumPy's "stacks of arrays" approach.
/// Differences are apparent with matrices of rank > 2.
///
/// This default implementation `MatmulFactory` creates a `MatMul` operation for floating-point data.
/// Subclasses: `QLinearMatmulFactory` and `MatmulIntegerFactory` implement quantized versions.
class MatmulFactory
{
public:
explicit MatmulFactory(const Node& node)
: m_onnx_node(node)
, m_inputs(node.get_ng_inputs())
{
}
virtual ~MatmulFactory() = default;
/// \brief Create a sub-graph representing an ONNX MatMul operation.
///
/// \return NodeVector containing the sub-graph output node.
virtual NodeVector make_matmul_op();
/// \return Node representing the left operand.
virtual std::shared_ptr<ngraph::Node> get_left();
/// \return Node representing the right operand.
virtual std::shared_ptr<ngraph::Node> get_right();
/// \return Node representing the nGraph Dot operation used to construct MatMul.
virtual std::shared_ptr<ngraph::Node>
make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
protected:
const Node& m_onnx_node;
const NodeVector m_inputs;
};
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX QLinearMatMul operation.
class QLinearMatmulFactory : public MatmulFactory
{
public:
explicit QLinearMatmulFactory(const Node& node)
: MatmulFactory(node)
{
}
std::shared_ptr<ngraph::Node> get_right() override;
std::shared_ptr<ngraph::Node>
make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right) override;
};
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX MatMulInteger operation.
class MatmulIntegerFactory : public MatmulFactory
{
public:
explicit MatmulIntegerFactory(const Node& node)
: MatmulFactory(node)
{
}
std::shared_ptr<ngraph::Node>
make_dot(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right) override;
};
}
}
}
......@@ -130,6 +130,13 @@ model_quant_conv_linear_2d
model_quant_conv_linear_3d
model_qlinear_matmul
model_qlinear_matmul_3d
model_matmul_integer
model_matmul_integer_zero_point_zero
model_matmul_integer_no_zero_point
model_matmul_integer_scalar
model_matmul_integer_4d
model_matmul_integer_4d_zero_point
model_matmul_integer_4d_no_zero_point
# This should be implemented
create_tensor_2_input
......
......@@ -108,3 +108,7 @@ fake_quantize_with_clip_across_channels
model_dequantize_linear_1d_zero_scale_int8
model_dequantize_linear_1d_zero_scale_int8_4d
model_quant_conv_linear
model_matmul_integer
model_matmul_integer_no_zero_point
model_matmul_integer_zero_point_zero
model_matmul_integer_scalar
......@@ -2,7 +2,8 @@
model_quant_conv_linear
model_qlinear_matmul
model_qlinear_matmul_3d
model_matmul_integer_no_zero_point
model_matmul_integer_4d_no_zero_point
fake_quantize
fake_quantize_with_clip
fake_quantize_with_clip_across_channels
ir_version: 5
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "a"
input: "b"
input: "a_zero_point"
input: "b_zero_point"
output: "y"
name: "node1"
op_type: "MatMulInteger"
doc_string: "MatMulInteger"
domain: ""
}
name: "test"
input {
name: "a"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "b"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "a_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
input {
name: "b_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
domain: ""
version: 10
}
ir_version: 5
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "a"
input: "b"
input: "a_zero_point"
input: "b_zero_point"
output: "y"
name: "node1"
op_type: "MatMulInteger"
doc_string: "MatMulInteger"
domain: ""
}
name: "test"
input {
name: "a"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "b"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 4
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "a_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
input {
name: "b_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
domain: ""
version: 10
}
ir_version: 5
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "a"
input: "b"
output: "y"
name: "node1"
op_type: "MatMulInteger"
doc_string: "MatMulInteger"
domain: ""
}
name: "test"
input {
name: "a"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "b"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 4
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
domain: ""
version: 10
}
ir_version: 5
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "a"
input: "b"
output: "y"
name: "node1"
op_type: "MatMulInteger"
doc_string: "MatMulInteger"
domain: ""
}
name: "test"
input {
name: "a"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "b"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
domain: ""
version: 10
}
ir_version: 5
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "a"
input: "b"
input: "a_zero_point"
input: "b_zero_point"
output: "y"
name: "node1"
op_type: "MatMulInteger"
doc_string: "MatMulInteger"
domain: ""
}
name: "test"
input {
name: "a"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "b"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "a_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
input {
name: "b_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
}
opset_import {
domain: ""
version: 10
}
......@@ -338,3 +338,167 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_qlinear_matmul_3d)
std::vector<uint8_t>{168, 115, 255, 1, 66, 151, 168, 115, 255, 1, 66, 151}); // T3
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_matmul_integer)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_integer.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<uint8_t>{11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0}); // a
test_case.add_input(std::vector<uint8_t>{1, 4, 2, 5, 3, 6}); // b
test_case.add_input(std::vector<uint8_t>{12}); // a_zero_point
test_case.add_input(std::vector<uint8_t>{0}); // b_zero_point
test_case.add_expected_output(
{4, 2}, std::vector<int32_t>{-38, -83, -44, -98, -50, -113, -56, -128}); // y
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_matmul_integer_zero_point_zero)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_integer.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<uint8_t>{11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0}); // a
test_case.add_input(std::vector<uint8_t>{1, 4, 2, 5, 3, 6}); // b
test_case.add_input(std::vector<uint8_t>{0}); // a_zero_point
test_case.add_input(std::vector<uint8_t>{0}); // b_zero_point
test_case.add_expected_output({4, 2},
std::vector<int32_t>{34, 97, 28, 82, 22, 67, 16, 52}); // y
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_matmul_integer_no_zero_point)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_integer_no_zero_point.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<uint8_t>{11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0}); // a
test_case.add_input(std::vector<uint8_t>{1, 4, 2, 5, 3, 6}); // b
test_case.add_expected_output({4, 2},
std::vector<int32_t>{34, 97, 28, 82, 22, 67, 16, 52}); // y
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_matmul_integer_scalar)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_integer_scalar.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<uint8_t>{11}); // a
test_case.add_input(std::vector<uint8_t>{13}); // b
test_case.add_input(std::vector<uint8_t>{12}); // a_zero_point
test_case.add_input(std::vector<uint8_t>{12}); // b_zero_point
test_case.add_expected_output({1, 1}, std::vector<int32_t>{-1}); // y
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_matmul_integer_4d)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_integer_4d.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); // a
test_case.add_input(std::vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); // b
test_case.add_input(std::vector<uint8_t>{0}); // a_zero_point
test_case.add_input(std::vector<uint8_t>{0}); // b_zero_point
test_case.add_expected_output<int32_t>(Shape{1, 2, 3, 3},
{42,
48,
54,
114,
136,
158,
186,
224,
262,
906,
960,
1014,
1170,
1240,
1310,
1434,
1520,
1606}); // y
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_matmul_integer_4d_zero_point)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_integer_4d.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); // a
test_case.add_input(std::vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); // b
test_case.add_input(std::vector<uint8_t>{1}); // a_zero_point
test_case.add_input(std::vector<uint8_t>{1}); // b_zero_point
test_case.add_expected_output<int32_t>(Shape{1, 2, 3, 3},
{22,
24,
26,
78,
96,
114,
134,
168,
202,
790,
840,
890,
1038,
1104,
1170,
1286,
1368,
1450}); // y
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_matmul_integer_4d_no_zero_point)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_integer_4d_no_zero_point.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); // a
test_case.add_input(std::vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); // b
test_case.add_expected_output<int32_t>(Shape{1, 2, 3, 3},
{42,
48,
54,
114,
136,
158,
186,
224,
262,
906,
960,
1014,
1170,
1240,
1310,
1434,
1520,
1606}); // y
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment