Commit 859a0fed authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Generic Dot & builders for QLinearMatmul (#2857)

* Add builders for QLinearMatmul for onnxruntime

* Generic Dot

* Change the onnx bridge and address PR feedback

* Fix date

* Fix CI failure

* Change varibale filter to input1

* const & reference

* update branch

* Comment

* Introduce QuantizedMatmul

* change pattern

* QDot tests passing

* style

* PR feedback

* Fix pattern

* style
parent 127508ee
......@@ -33,8 +33,8 @@ set (SRC
builder/quantization.hpp
builder/quantization/quantized_linear_convolution.cpp
builder/quantization/quantized_linear_convolution.hpp
builder/quantization/quantized_linear_dot.cpp
builder/quantization/quantized_linear_dot.hpp
builder/quantization/quantized_linear_matmul.cpp
builder/quantization/quantized_linear_matmul.hpp
builder/quantization_util.hpp
builder/reduce_ops.cpp
builder/reduce_ops.hpp
......
//*****************************************************************************
// Copyright 2018-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/builder/quantization/quantized_linear_matmul.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/quantization.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/type/element_type.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace builder
{
namespace quantization
{
// TODO: this code is falling back to fp32 dot
// 1) add support in reference kernel for zero point
shared_ptr<Node> QuantizedLinearMatmul(const shared_ptr<Node>& input0,
const shared_ptr<Node>& input1,
const shared_ptr<Node>& input0_scale,
const shared_ptr<Node>& input0_zero_point,
const shared_ptr<Node>& input1_scale,
const shared_ptr<Node>& input1_zero_point,
const shared_ptr<Node>& output_scale,
const shared_ptr<Node>& output_zero_point)
{
auto input0_zero = dynamic_pointer_cast<ngraph::op::Constant>(input0_zero_point);
auto input1_zero = dynamic_pointer_cast<ngraph::op::Constant>(input1_zero_point);
auto output_zero = dynamic_pointer_cast<ngraph::op::Constant>(output_zero_point);
// Check if zero point is constant and zero
if (input0_zero != nullptr && input1_zero != nullptr && output_zero != nullptr &&
ngraph::is_zero(input0_zero) && ngraph::is_zero(input1_zero) &&
ngraph::is_zero(output_zero))
{
auto requantization_scale = (input0_scale * input1_scale) / output_scale;
return make_shared<op::QuantizedDot>(input0, input1, requantization_scale);
}
else
{
AxisSet axes;
auto dq_input0 = make_shared<op::Dequantize>(input0,
input0_scale,
input0_zero_point,
input0_scale->get_element_type(),
axes);
auto dq_input1 = make_shared<op::Dequantize>(input1,
input1_scale,
input1_zero_point,
input1_scale->get_element_type(),
axes);
auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1);
auto q_dot = make_shared<op::Quantize>(
dot,
output_scale,
output_zero_point,
output_zero_point->get_element_type(),
axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
return q_dot;
}
}
shared_ptr<Node> QuantizedLinearMatmulInteger(const shared_ptr<Node>& input0,
const shared_ptr<Node>& input1)
{
auto output_scale = make_constant(element::f32, Shape{}, 1);
return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false);
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
// Copyright 2018-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -25,8 +25,18 @@ namespace ngraph
{
namespace quantization
{
std::shared_ptr<Node> QuantizedDotInteger(std::shared_ptr<Node> input,
std::shared_ptr<Node> filter);
std::shared_ptr<Node>
QuantizedLinearMatmul(const std::shared_ptr<Node>& input0,
const std::shared_ptr<Node>& input1,
const std::shared_ptr<Node>& input0_scale,
const std::shared_ptr<Node>& input0_zero_point,
const std::shared_ptr<Node>& input1_scale,
const std::shared_ptr<Node>& input1_zero_point,
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point);
std::shared_ptr<Node> QuantizedLinearMatmulInteger(const std::shared_ptr<Node>& input0,
const std::shared_ptr<Node>& input1);
}
}
}
......@@ -21,11 +21,11 @@
#include "exceptions.hpp"
#include "matmul.hpp"
#include "ngraph/builder/quantization/quantized_linear_matmul.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
......@@ -79,10 +79,6 @@ namespace ngraph
auto scale = std::shared_ptr<ngraph::Node>{};
if (quantized)
{
NGRAPH_WARN
<< "[" << node.get_name()
<< "] Zero point different from 0 is not supported. Assuming Zero "
"point is 0";
right = ng_inputs.at(3);
scale = ng_inputs.at(6);
}
......@@ -109,12 +105,15 @@ namespace ngraph
{
if (quantized)
{
right = std::make_shared<ngraph::op::Reshape>(
return {ngraph::builder::quantization::QuantizedLinearMatmul(
left,
right,
AxisVector{1, 0},
Shape(right->get_shape().rbegin(), right->get_shape().rend()));
return {std::make_shared<ngraph::op::QuantizedDot>(left, right, scale)};
ng_inputs.at(1),
ng_inputs.at(2),
ng_inputs.at(4),
ng_inputs.at(5),
ng_inputs.at(6),
ng_inputs.at(7))};
}
else
{
......@@ -167,13 +166,15 @@ namespace ngraph
if (quantized)
{
sliced_right = std::make_shared<ngraph::op::Reshape>(
sub_dot = ngraph::builder::quantization::QuantizedLinearMatmul(
sliced_left,
sliced_right,
AxisVector{1, 0},
Shape(sliced_right->get_shape().rbegin(),
sliced_right->get_shape().rend()));
sub_dot = std::make_shared<ngraph::op::QuantizedDot>(
sliced_left, sliced_right, scale);
ng_inputs.at(1),
ng_inputs.at(2),
ng_inputs.at(4),
ng_inputs.at(5),
ng_inputs.at(6),
ng_inputs.at(7));
}
else
{
......
......@@ -37,15 +37,19 @@ op::QuantizedDot::QuantizedDot(const shared_ptr<Node>& data,
auto& data_shape = data->get_shape();
auto& weights_shape = weights->get_shape();
// QuantizedDot does [n, ic] * [oc, ic] = [n, oc]
// QuantizedDot does [m ,n] * [n, k] = [m, k]
NODE_VALIDATION_CHECK(this,
data_shape.size() == 2 && weights_shape.size() == 2 &&
data_shape[1] == weights_shape[1],
data_shape[1] == weights_shape[0],
"only valid tensors of rank 2 supported. data shape ",
data_shape,
" weights shape ",
weights_shape);
auto output_et = requantize ? (with_relu ? element::u8 : element::i8) : element::i32;
set_output_type(0, output_et, Shape{data_shape[0], weights_shape[0]});
if (data->get_element_type() == element::u8 && weights->get_element_type() == element::u8)
{
output_et = element::u8;
}
set_output_type(0, output_et, Shape{data_shape[0], weights_shape[1]});
}
......@@ -68,6 +68,7 @@ set(SRC
builder/quantized_concat.cpp
builder/quantized_dot.cpp
builder/quantized_max_pool.cpp
builder/quantized_matmul.cpp
builder/reshape.cpp
builder/reverse.cpp
builder/reverse_sequence.cpp
......@@ -104,6 +105,7 @@ set(SRC
op/lstm.cpp
op/matmul_bias.cpp
op/max_pool_with_indices.cpp
op/quantized_matmul.cpp
op/rnn.cpp
op/sigmoid_mul.cpp
op/update_slice.cpp
......
......@@ -230,9 +230,10 @@ namespace ngraph
return;
}
std::function<decltype(runtime::cpu::kernel::dot_ref<float>)> kernel;
std::function<decltype(runtime::cpu::kernel::dot_ref<float, float, float>)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_ref);
SELECT_KERNEL_3ARGS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_ref);
auto functor = [&,
kernel,
......@@ -250,7 +251,8 @@ namespace ngraph
arg0_shape,
arg1_shape,
result_shape,
reduction_axes_count);
reduction_axes_count,
1.0f); // Requantization scale (1 for non quant dot)
};
functors.emplace_back(functor);
}
......
......@@ -19,6 +19,7 @@
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/cpu/kernel/dot.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
......@@ -114,70 +115,51 @@ namespace ngraph
template <>
void Builder::BUILDER_DECL(ngraph::op::QuantizedDot)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& functors = external_function->get_functors();
auto arg0_buffer_index =
external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index =
external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& functors = external_function->get_functors();
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto scales_size = shape_size(args[2].get_shape());
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto ip_desc =
mkldnn_emitter->get_inner_product_forward_desc<ngraph::op::QuantizedDot>(
node);
auto ip_attr =
mkldnn_emitter->get_inner_product_forward_attr<ngraph::op::QuantizedDot>(
node);
size_t ip_index = mkldnn_emitter->inner_product_forward_init(false);
auto& deps = mkldnn_emitter->get_primitive_deps(ip_index);
auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index = external_function->get_buffer_index(args[2].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto functor = [&,
scales_size,
ip_desc,
ip_attr,
deps,
ip_index,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
out0_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) mutable {
if (ctx->first_iteration)
{
vector<float> dyn_scales;
dyn_scales.assign(
static_cast<float*>(ctx->buffer_data[arg2_buffer_index]),
static_cast<float*>(ctx->buffer_data[arg2_buffer_index]) +
scales_size);
ip_attr.set_output_scales(0, dyn_scales);
mkldnn_emitter->build_inner_product_forward<false>(
ctx->mkldnn_primitives,
ip_desc,
ip_attr,
executor::global_cpu_engine,
deps,
ip_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[arg1_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, ip_index);
};
functors.emplace_back(functor);
}
else
if (shape_size(args[2].get_shape()) != 1)
{
throw ngraph_error("unsupported parameters for QuantizedDot via DEX");
throw ngraph_error("Scale size should be 1 for QuantizedDot");
}
std::function<decltype(
runtime::cpu::kernel::dot_ref<uint8_t, uint8_t, uint8_t, int32_t>)>
kernel;
kernel = runtime::cpu::kernel::dot_ref<uint8_t, uint8_t, uint8_t, int32_t>;
auto functor = [&,
kernel,
arg0_shape,
arg1_shape,
result_shape,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
out0_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
float dyn_scale = *(static_cast<float*>(ctx->buffer_data[arg2_buffer_index]));
kernel(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[out0_buffer_index],
arg0_shape,
arg1_shape,
result_shape,
1,
dyn_scale);
};
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(QuantizedDotBias);
REGISTER_OP_BUILDER(QuantizedDot);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/op/quantized_matmul.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/cpu/kernel/dot.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::QuantizedMatmul)
{
auto& functors = external_function->get_functors();
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index = external_function->get_buffer_index(args[2].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto scales_size = shape_size(args[2].get_shape());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto ip_desc =
mkldnn_emitter->get_inner_product_forward_desc<ngraph::op::QuantizedMatmul>(
node);
auto ip_attr =
mkldnn_emitter->get_inner_product_forward_attr<ngraph::op::QuantizedMatmul>(
node);
size_t ip_index = mkldnn_emitter->inner_product_forward_init(false);
auto& deps = mkldnn_emitter->get_primitive_deps(ip_index);
auto functor = [&,
scales_size,
ip_desc,
ip_attr,
deps,
ip_index,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
out0_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) mutable {
if (ctx->first_iteration)
{
vector<float> dyn_scales;
dyn_scales.assign(
static_cast<float*>(ctx->buffer_data[arg2_buffer_index]),
static_cast<float*>(ctx->buffer_data[arg2_buffer_index]) +
scales_size);
ip_attr.set_output_scales(0, dyn_scales);
mkldnn_emitter->build_inner_product_forward<false>(
ctx->mkldnn_primitives,
ip_desc,
ip_attr,
executor::global_cpu_engine,
deps,
ip_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[arg1_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, ip_index);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("Unsupported QuantizedMatmul");
}
}
REGISTER_OP_BUILDER(QuantizedMatmul);
}
}
}
......@@ -172,6 +172,7 @@
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantized_matmul.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
......@@ -372,6 +373,7 @@ static const runtime::cpu::OpMap dispatcher{
&runtime::cpu::CPU_Emitter::emit<op::QuantizedConvolutionBiasSignedAdd>},
{TI(ngraph::op::QuantizedDotBias), &runtime::cpu::CPU_Emitter::emit<op::QuantizedDotBias>},
{TI(ngraph::op::QuantizedDot), &runtime::cpu::CPU_Emitter::emit<op::QuantizedDot>},
{TI(ngraph::op::QuantizedMatmul), &runtime::cpu::CPU_Emitter::emit<op::QuantizedMatmul>},
{TI(ngraph::op::ConvolutionRelu), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionRelu>},
{TI(ngraph::op::QuantizedConvolution),
&runtime::cpu::CPU_Emitter::emit<op::QuantizedConvolution>},
......
......@@ -20,6 +20,7 @@
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/shape.hpp"
......@@ -167,22 +168,28 @@ namespace ngraph
input0, input1, output, input0_shape, input1_shape, output_shape, arena);
}
template <typename ElementType>
template <typename INPUT0,
typename INPUT1,
typename OUTPUT,
typename ACCUMULATION =
typename ngraph::runtime::reference::widen<OUTPUT>::type>
void dot_ref(void* arg0,
void* arg1,
void* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& out_shape,
size_t reduction_axes_count)
size_t reduction_axes_count,
const float requant_scale)
{
reference::dot(static_cast<const ElementType*>(arg0),
static_cast<const ElementType*>(arg1),
static_cast<ElementType*>(out),
reference::dot(static_cast<const INPUT0*>(arg0),
static_cast<const INPUT1*>(arg1),
static_cast<OUTPUT*>(out),
arg0_shape,
arg1_shape,
out_shape,
reduction_axes_count);
reduction_axes_count,
requant_scale);
}
}
}
......
......@@ -54,6 +54,7 @@
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/quantized_matmul.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
......@@ -835,7 +836,7 @@ namespace ngraph
{
size_t index = 0;
if (std::is_same<OP, ngraph::op::QuantizedConvolution>() ||
std::is_same<OP, ngraph::op::QuantizedDot>() ||
std::is_same<OP, ngraph::op::QuantizedMatmul>() ||
std::is_same<OP, ngraph::op::QuantizedConvolutionRelu>())
{
index = 2;
......@@ -909,7 +910,7 @@ namespace ngraph
template <typename OP>
bool is_quantized_inner_product()
{
if (std::is_same<OP, ngraph::op::QuantizedDot>() ||
if (std::is_same<OP, ngraph::op::QuantizedMatmul>() ||
std::is_same<OP, ngraph::op::QuantizedDotBias>())
{
return true;
......
//*****************************************************************************
// Copyright 2018-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <memory>
#include <utility>
#include "ngraph/shape.hpp"
#include "quantized_matmul.hpp"
using namespace std;
using namespace ngraph;
op::QuantizedMatmul::QuantizedMatmul(const shared_ptr<Node>& data,
const shared_ptr<Node>& weights,
const shared_ptr<Node>& scale,
bool requantize,
bool with_relu)
: Op("QuantizedMatmul", check_single_output_args({data, weights, scale}))
, m_requantize(requantize)
, m_with_relu(with_relu)
{
constructor_validate_and_infer_types();
auto& data_shape = data->get_shape();
auto& weights_shape = weights->get_shape();
// QuantizedMatmul does [n, ic] * [oc, ic] = [n, oc]
NODE_VALIDATION_CHECK(this,
data_shape.size() == 2 && weights_shape.size() == 2 &&
data_shape[1] == weights_shape[1],
"only valid tensors of rank 2 supported. data shape ",
data_shape,
" weights shape ",
weights_shape);
auto output_et = requantize ? (with_relu ? element::u8 : element::i8) : element::i32;
set_output_type(0, output_et, Shape{data_shape[0], weights_shape[0]});
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
// Copyright 2018-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,25 +14,37 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/builder/quantization/quantized_linear_dot.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/quantization.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#pragma once
using namespace std;
using namespace ngraph;
#include <utility>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace builder
namespace op
{
namespace quantization
class QuantizedMatmul : public Op
{
shared_ptr<Node> QuantizedDotInteger(shared_ptr<Node> input, shared_ptr<Node> filter)
public:
QuantizedMatmul(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& weights,
const std::shared_ptr<Node>& scale,
bool requantize = true,
bool with_relu = false);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override
{
auto output_scale = make_constant(element::f32, Shape{}, 1);
return make_shared<op::QuantizedDot>(input, filter, output_scale, false, false);
check_new_args_count(this, new_args);
return std::make_shared<QuantizedMatmul>(
new_args.at(0), new_args.at(1), new_args.at(2), m_requantize, m_with_relu);
}
}
}
}
bool with_relu() const { return m_with_relu; }
bool requantize() const { return m_requantize; }
protected:
bool m_requantize;
bool m_with_relu;
};
} // namespace op
} // namespace ngraph
......@@ -37,7 +37,6 @@
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
......@@ -61,6 +60,7 @@
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantized_matmul.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
......@@ -790,7 +790,7 @@ namespace ngraph
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedDot)
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedMatmul)
{
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::i8)
......@@ -989,8 +989,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Dequantize>},
{TI(ngraph::op::QuantizedConcat),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedConcat>},
{TI(ngraph::op::QuantizedDot),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedDot>},
{TI(ngraph::op::QuantizedMatmul),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedMatmul>},
{TI(ngraph::op::QuantizedDotBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedDotBias>},
{TI(ngraph::op::GetOutputElement),
......
......@@ -41,6 +41,7 @@
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/group_conv.hpp"
......@@ -75,6 +76,7 @@
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/quantized_matmul.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
......@@ -2519,3 +2521,43 @@ void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_qconvb_add()
auto m = std::make_shared<pattern::Matcher>(prelu, "CPUQuantFusion.QConvBiasSignedAdd");
this->add_matcher(m, callback);
}
// Convert a QuantizedDot which takes [m,n]*[n,k] to
// QuantizedMatmul which reorders input1 and does [m,n]*[k,n]
// which is what mkldnn wants
void ngraph::runtime::cpu::pass::CPUQuantFusion::construct_quantized_matmul()
{
Shape shape_input0{2, 3};
Shape shape_input1{3, 4};
auto input0 = std::make_shared<pattern::op::Label>(element::u8, shape_input0);
auto input1 = std::make_shared<pattern::op::Label>(element::i8, shape_input1);
auto scale = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto q_dot = std::make_shared<ngraph::op::QuantizedDot>(input0, input1, scale);
auto callback = [input0, input1, scale](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for Qdot against node = " << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto qdot = std::dynamic_pointer_cast<ngraph::op::QuantizedDot>(m.get_match_root());
auto input_0 = pattern_map[input0];
auto input_1 = pattern_map[input1];
auto scale_new = pattern_map[scale];
if (input_0->get_element_type() == element::u8 &&
input_1->get_element_type() == element::u8)
{
return false;
}
auto reshape_input1 = std::make_shared<op::Reshape>(
input_1, AxisVector{0, 1}, Shape{input_1->get_shape()[1], input_1->get_shape()[0]});
auto qmatmul = std::make_shared<ngraph::op::QuantizedMatmul>(
input_0, reshape_input1, scale_new, qdot->requantize(), qdot->with_relu());
ngraph::replace_node(m.get_match_root(), qmatmul);
return true;
};
auto m = std::make_shared<pattern::Matcher>(q_dot, "CPUQuantFusion.QDot");
this->add_matcher(m, callback);
}
......@@ -124,6 +124,7 @@ public:
construct_qconcat();
construct_qconvb_add();
construct_dq_q();
construct_quantized_matmul();
}
private:
......@@ -133,4 +134,5 @@ private:
void construct_qconcat();
void construct_dq_q();
void construct_qconvb_add();
void construct_quantized_matmul();
};
......@@ -39,7 +39,6 @@
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
......@@ -69,6 +68,7 @@
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/quantized_matmul.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
using namespace std;
......@@ -738,13 +738,13 @@ namespace ngraph
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::QuantizedDot)
void CPULayout::LAYOUT_DECL(ngraph::op::QuantizedMatmul)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
InnerProductLayout<ngraph::op::QuantizedDot, false>(node, i_mds, o_mds);
InnerProductLayout<ngraph::op::QuantizedMatmul, false>(node, i_mds, o_mds);
auto scale_input_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 2, false, memory::format::x);
......@@ -2394,8 +2394,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedConcat>},
{TI(ngraph::op::QuantizedDotBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedDotBias>},
{TI(ngraph::op::QuantizedDot),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedDot>},
{TI(ngraph::op::QuantizedMatmul),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedMatmul>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
......@@ -19,6 +19,9 @@
#include <cmath>
#include <utility>
#include <cfenv>
#include <functional>
#include "convolution.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
......@@ -28,15 +31,21 @@ namespace ngraph
{
namespace reference
{
template <typename T>
void dot(const T* arg0,
const T* arg1,
T* out,
template <typename INPUT0,
typename INPUT1,
typename OUTPUT,
typename ACCUMULATION = typename widen<OUTPUT>::type>
void dot(const INPUT0* arg0,
const INPUT1* arg1,
OUTPUT* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& out_shape,
size_t reduction_axes_count)
size_t reduction_axes_count,
const float requant_scale = 1.0f)
{
auto old_mode = std::fegetround();
std::fesetround(FE_TONEAREST);
// Get the sizes of the dot axes. It's easiest to pull them from arg1 because they're
// right up front.
Shape dot_axis_sizes(reduction_axes_count);
......@@ -84,7 +93,7 @@ namespace ngraph
arg1_projected_coord.begin(), arg1_projected_coord.end(), out_coord_it);
// Zero out to start the sum.
T sum = 0;
ACCUMULATION sum = 0;
size_t out_index = output_transform.index(out_coord);
......@@ -113,8 +122,9 @@ namespace ngraph
}
// Write the sum back.
out[out_index] = sum;
out[out_index] = static_cast<OUTPUT>(sum * requant_scale);
}
std::fesetround(old_mode);
}
}
}
......
......@@ -23,7 +23,7 @@
#include "gtest/gtest.h"
#include "ngraph/builder/quantization.hpp"
#include "ngraph/builder/quantization/quantized_linear_convolution.hpp"
#include "ngraph/builder/quantization/quantized_linear_dot.hpp"
#include "ngraph/builder/quantization/quantized_linear_matmul.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/pass/constant_folding.hpp"
......@@ -1221,13 +1221,13 @@ TEST(builder, scaled_QDotInteger)
{
Shape shape_a{1, 2}; // input shape
vector<uint8_t> a_data = {2, 3};
Shape shape_b{3, 2}; // filter shape
Shape shape_b{2, 3}; // filter shape
vector<int8_t> b_data = {0, 1, 2, 3, 4, 5};
auto A = make_shared<op::Parameter>(element::u8, shape_a);
auto B = make_shared<op::Parameter>(element::i8, shape_b);
Shape shape_r{1, 3}; // output shape
auto QD = ngraph::builder::quantization::QuantizedDotInteger(A, B);
auto QD = ngraph::builder::quantization::QuantizedLinearMatmulInteger(A, B);
auto f = make_shared<Function>(NodeVector{QD}, ParameterVector{A, B});
constant_fold(f);
auto backend = runtime::Backend::create("CPU");
......@@ -1469,3 +1469,41 @@ TEST(builder, scaled_QC_u8u8)
39 * 2} /*{1, 28, -3, 16, -7, -14, 3, -7, -3}*/),
read_vector<uint8_t>(result));
}
TEST(builder, scaled_QDot_u8u8)
{
Shape shape_a{1, 2}; // input shape
vector<uint8_t> a_data = {2, 3};
Shape shape_b{2, 3}; // filter shape
vector<uint8_t> b_data = {0, 2, 4, 1, 3, 5};
auto A = make_shared<op::Parameter>(element::u8, shape_a);
auto B = make_shared<op::Parameter>(element::u8, shape_b);
auto input_scale = op::Constant::create(element::f32, Shape{}, {2});
auto input_zero_point = op::Constant::create(element::u8, Shape{}, {0});
auto filter_scale = op::Constant::create(element::f32, Shape{}, {1});
auto filter_zero_point = op::Constant::create(element::u8, Shape{}, {0});
auto output_scale = op::Constant::create(element::f32, Shape{}, {2});
auto output_zero_point = op::Constant::create(element::u8, Shape{}, {0});
Shape shape_r{1, 3}; // output shape
auto QD = ngraph::builder::quantization::QuantizedLinearMatmul(A,
B,
input_scale,
input_zero_point,
filter_scale,
filter_zero_point,
output_scale,
output_zero_point);
auto f = make_shared<Function>(NodeVector{QD}, ParameterVector{A, B});
constant_fold(f);
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::u8, shape_a);
copy_data(a, a_data);
auto b = backend->create_tensor(element::u8, shape_b);
copy_data(b, b_data);
auto result = backend->create_tensor(element::u8, shape_r);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_EQ((vector<uint8_t>{3, 13, 23}), read_vector<uint8_t>(result));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment