Commit 6a0101a2 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Add Quantized MatMul OP (#2715)

* Add Quantized MatMul OP

* Rename qmatmul -> quantized_matmul

* Use auto everywhere
parent f70353ab
......@@ -129,6 +129,7 @@ add_library(onnx_import STATIC
op/quant_conv.hpp
op/quantize_linear.cpp
op/quantize_linear.hpp
op/quantized_matmul.hpp
op/reciprocal.cpp
op/reciprocal.hpp
op/reduce.cpp
......
......@@ -25,6 +25,7 @@
#include "ngraph/log.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/shape.hpp"
......@@ -68,11 +69,26 @@ namespace ngraph
{
namespace set_1
{
NodeVector matmul(const Node& node)
NodeVector make_matmul_op(const Node& node, bool quantized)
{
const NodeVector& ng_inputs{node.get_ng_inputs()};
auto left = std::shared_ptr<ngraph::Node>{ng_inputs.at(0)};
auto right = std::shared_ptr<ngraph::Node>{ng_inputs.at(1)};
auto right = std::shared_ptr<ngraph::Node>{};
auto scale = std::shared_ptr<ngraph::Node>{};
if (quantized)
{
NGRAPH_WARN
<< "[" << node.get_name()
<< "] Zero point different from 0 is not supported. Assuming Zero "
"point is 0";
right = ng_inputs.at(3);
scale = ng_inputs.at(6);
}
else
{
right = ng_inputs.at(1);
}
std::size_t left_rank{left->get_shape().size()};
std::size_t right_rank{right->get_shape().size()};
......@@ -89,8 +105,19 @@ namespace ngraph
// Multiply two tensors where both of them has rank lower equal 2.
if (left_rank <= 2 && right_rank <= 2)
{
return {
std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
if (quantized)
{
right = std::make_shared<ngraph::op::Reshape>(
right,
AxisVector{1, 0},
Shape(right->get_shape().rbegin(), right->get_shape().rend()));
return {std::make_shared<ngraph::op::QuantizedDot>(left, right, scale)};
}
else
{
return {std::make_shared<ngraph::op::Dot>(left, right)};
}
}
// Second case:
......@@ -132,9 +159,24 @@ namespace ngraph
for (std::size_t g = 0; g < groups; ++g)
{
const auto& sliced_left = get_sub_matrix(left, g);
const auto& sliced_right = get_sub_matrix(right, g);
auto sub_dot = std::make_shared<ngraph::op::Dot>(sliced_left, sliced_right);
auto sliced_right = get_sub_matrix(right, g);
auto sub_dot = std::shared_ptr<ngraph::Node>{};
if (quantized)
{
sliced_right = std::make_shared<ngraph::op::Reshape>(
sliced_right,
AxisVector{1, 0},
Shape(sliced_right->get_shape().rbegin(),
sliced_right->get_shape().rend()));
sub_dot = std::make_shared<ngraph::op::QuantizedDot>(
sliced_left, sliced_right, scale);
}
else
{
sub_dot = std::make_shared<ngraph::op::Dot>(sliced_left, sliced_right);
}
// Expand sub_dot result with single empty outermost axis, in order to
// later concatenate sub_dots at this axis.
......@@ -162,6 +204,7 @@ namespace ngraph
}
}
NodeVector matmul(const Node& node) { return make_matmul_op(node, false); }
} // namespace set_1
} //namespace op
......
......@@ -27,6 +27,7 @@ namespace ngraph
{
namespace set_1
{
NodeVector make_matmul_op(const Node& node, bool quantized);
NodeVector matmul(const Node& node);
} // namespace set_1
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/frontend/onnx_import/op/matmul.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector quantized_matmul(const Node& node) { return make_matmul_op(node, true); }
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -82,6 +82,7 @@
#include "op/prelu.hpp"
#include "op/quant_conv.hpp"
#include "op/quantize_linear.hpp"
#include "op/quantized_matmul.hpp"
#include "op/reciprocal.hpp"
#include "op/reduce.hpp"
#include "op/relu.hpp"
......@@ -283,6 +284,7 @@ namespace ngraph
REGISTER_OPERATOR("Pow", 1, pow);
REGISTER_OPERATOR("PRelu", 1, prelu);
REGISTER_OPERATOR("QLinearConv", 1, quant_conv);
REGISTER_OPERATOR("QLinearMatMul", 1, quantized_matmul);
REGISTER_OPERATOR("QuantizeLinear", 1, quantize_linear);
REGISTER_OPERATOR("Reciprocal", 1, reciprocal);
REGISTER_OPERATOR("ReduceLogSum", 1, reduce_log_sum);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment