Commit 1a6f8487 authored by Ashok Emani's avatar Ashok Emani Committed by Scott Cyphers

QuantizedDot and QuantizedDotBias ops CPU impl (#2592)

* QuantizedDot and QuantizedDotBias ops CPU impl

* add builders and unittests

* fix CI issue

* fix GPU emitter build

* Remove duplicate line.
parent 60ca608c
......@@ -140,6 +140,10 @@ set (SRC
op/experimental/quantized_max_pool.hpp
op/experimental/shape_of.cpp
op/experimental/shape_of.hpp
op/experimental/quantized_dot.cpp
op/experimental/quantized_dot.hpp
op/experimental/quantized_dot_bias.cpp
op/experimental/quantized_dot_bias.hpp
op/floor.cpp
op/floor.hpp
op/get_output_element.cpp
......
......@@ -399,5 +399,69 @@ namespace ngraph
with_relu);
return make_shared<op::Convert>(qconv, element::u8);
}
}
}
std::shared_ptr<Node> ScaledQuantizedDotBias(std::shared_ptr<Node> input,
std::shared_ptr<Node> filters,
std::shared_ptr<Node> bias,
std::shared_ptr<Node> min_input,
std::shared_ptr<Node> max_input,
std::shared_ptr<Node> min_filter,
std::shared_ptr<Node> max_filter,
std::shared_ptr<Node> min_freezed_output,
std::shared_ptr<Node> max_freezed_output,
const bool requantize,
const bool with_relu)
{
auto requantization_scale =
quantization_util::get_dot_scale(min_input,
max_input,
min_filter,
max_filter,
min_freezed_output,
max_freezed_output,
input->get_element_type(),
with_relu ? element::u8 : element::i8,
requantize);
if (bias->get_element_type() != element::i32)
{
auto zero = make_constant(element::i32, min_input->get_shape(), 0);
AxisSet quantization_axes;
auto bias_scale =
quantization_util::get_bias_scale(min_input, max_input, min_filter, max_filter);
op::Quantize::RoundMode round_mode =
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;
bias = make_shared<op::Quantize>(
bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
}
return make_shared<op::QuantizedDotBias>(
input, filters, bias, requantization_scale, requantize, with_relu);
}
std::shared_ptr<Node> ScaledQuantizedDot(std::shared_ptr<Node> input,
std::shared_ptr<Node> filters,
std::shared_ptr<Node> min_input,
std::shared_ptr<Node> max_input,
std::shared_ptr<Node> min_filter,
std::shared_ptr<Node> max_filter,
std::shared_ptr<Node> min_freezed_output,
std::shared_ptr<Node> max_freezed_output,
const bool requantize,
const bool with_relu)
{
auto requantization_scale =
quantization_util::get_dot_scale(min_input,
max_input,
min_filter,
max_filter,
min_freezed_output,
max_freezed_output,
input->get_element_type(),
with_relu ? element::u8 : element::i8,
requantize);
return make_shared<op::QuantizedDot>(
input, filters, requantization_scale, requantize, with_relu);
}
} // namespace builder
} // namespace ngraph
......@@ -24,6 +24,8 @@
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/quantize.hpp"
......@@ -151,5 +153,29 @@ namespace ngraph
std::shared_ptr<Node> min_freezed_output_conv_2,
std::shared_ptr<Node> max_freezed_output_conv_2,
const bool with_relu);
}
}
std::shared_ptr<Node> ScaledQuantizedDotBias(std::shared_ptr<Node> input,
std::shared_ptr<Node> filters,
std::shared_ptr<Node> bias,
std::shared_ptr<Node> min_input,
std::shared_ptr<Node> max_input,
std::shared_ptr<Node> min_filter,
std::shared_ptr<Node> max_filter,
std::shared_ptr<Node> min_freezed_output,
std::shared_ptr<Node> max_freezed_output,
const bool requantize = true,
const bool with_relu = false);
std::shared_ptr<Node> ScaledQuantizedDot(std::shared_ptr<Node> input,
std::shared_ptr<Node> filters,
std::shared_ptr<Node> min_input,
std::shared_ptr<Node> max_input,
std::shared_ptr<Node> min_filter,
std::shared_ptr<Node> max_filter,
std::shared_ptr<Node> min_freezed_output,
std::shared_ptr<Node> max_freezed_output,
const bool requantize = true,
const bool with_relu = false);
} // namespace builder
} // namespace ngraph
......@@ -276,6 +276,47 @@ namespace ngraph
}
}
}
}
}
}
std::shared_ptr<Node> get_dot_scale(std::shared_ptr<Node> min_input,
std::shared_ptr<Node> max_input,
std::shared_ptr<Node> min_filter,
std::shared_ptr<Node> max_filter,
std::shared_ptr<Node> min_freezed_output,
std::shared_ptr<Node> max_freezed_output,
const ngraph::element::Type& input_type,
const ngraph::element::Type& output_type,
const bool requantize = true)
{
auto type = min_input->get_element_type();
if (type != max_input->get_element_type() ||
type != min_filter->get_element_type() ||
type != max_filter->get_element_type() ||
type != min_freezed_output->get_element_type() ||
type != max_freezed_output->get_element_type())
{
throw ngraph_error("get_dot_scale: min and max must have same type");
}
auto shape = min_input->get_shape();
if (shape != max_input->get_shape() || shape != min_filter->get_shape() ||
shape != max_filter->get_shape() || shape != min_freezed_output->get_shape() ||
shape != max_freezed_output->get_shape())
{
throw ngraph_error("get_dot_scale: min and max must have same shape");
}
auto data_scale = get_scale(min_input, max_input, input_type);
auto weight_scale = get_scale(min_filter, max_filter, element::i8);
auto out_scale = get_scale(min_freezed_output, max_freezed_output, output_type);
if (requantize)
{
return data_scale * weight_scale / out_scale;
}
else
{
return data_scale * weight_scale;
}
}
} // namespace quantization_util
} // namespace builder
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <memory>
#include <utility>
#include "ngraph/shape.hpp"
#include "quantized_dot.hpp"
using namespace std;
using namespace ngraph;
op::QuantizedDot::QuantizedDot(const shared_ptr<Node>& data,
const shared_ptr<Node>& weights,
const shared_ptr<Node>& scale,
bool requantize,
bool with_relu)
: Op("QuantizedDot", check_single_output_args({data, weights, scale}))
, m_requantize(requantize)
, m_with_relu(with_relu)
{
constructor_validate_and_infer_types();
auto& data_shape = data->get_shape();
auto& weights_shape = weights->get_shape();
NODE_VALIDATION_CHECK(this,
data_shape.size() == 2 && weights_shape.size() == 2 &&
data_shape[1] == weights_shape[1],
"only valid tensors of rank 2 supported. data shape ",
data_shape,
" weights shape ",
weights_shape);
auto output_et = requantize ? (with_relu ? element::u8 : element::i8) : element::f32;
set_output_type(0, output_et, Shape{data_shape[0], weights_shape[0]});
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <utility>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class QuantizedDot : public Op
{
public:
QuantizedDot(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& weights,
const std::shared_ptr<Node>& scale,
bool requantize = true,
bool with_relu = false);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override
{
check_new_args_count(this, new_args);
return std::make_shared<QuantizedDot>(
new_args.at(0), new_args.at(1), new_args.at(2), m_requantize, m_with_relu);
}
bool with_relu() const { return m_with_relu; }
bool requantize() const { return m_requantize; }
protected:
bool m_requantize;
bool m_with_relu;
};
} // namespace op
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <memory>
#include <utility>
#include "ngraph/shape.hpp"
#include "quantized_dot_bias.hpp"
using namespace std;
using namespace ngraph;
op::QuantizedDotBias::QuantizedDotBias(const shared_ptr<Node>& data,
const shared_ptr<Node>& weights,
const shared_ptr<Node>& bias,
const shared_ptr<Node>& scale,
bool requantize,
bool with_relu)
: Op("QuantizedDotBias", check_single_output_args({data, weights, bias, scale}))
, m_requantize(requantize)
, m_with_relu(with_relu)
{
constructor_validate_and_infer_types();
auto& data_shape = data->get_shape();
auto& weights_shape = weights->get_shape();
auto& bias_shape = bias->get_shape();
NODE_VALIDATION_CHECK(this,
data_shape.size() == 2 && weights_shape.size() == 2 &&
data_shape[1] == weights_shape[1],
"only valid tensors of rank 2 supported. data ",
data_shape,
" weights ",
weights_shape);
NODE_VALIDATION_CHECK(this,
bias_shape.size() == 1 && bias_shape[0] == weights_shape[0],
"invalid bias ",
bias_shape);
auto output_et = requantize ? (with_relu ? element::u8 : element::i8) : element::f32;
set_output_type(0, output_et, Shape{data_shape[0], weights_shape[0]});
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <utility>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class QuantizedDotBias : public Op
{
public:
QuantizedDotBias(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& weights,
const std::shared_ptr<Node>& bias,
const std::shared_ptr<Node>& scale,
bool requantize = true,
bool with_relu = false);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override
{
check_new_args_count(this, new_args);
return std::make_shared<QuantizedDotBias>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
m_requantize,
m_with_relu);
}
bool with_relu() const { return m_with_relu; }
bool requantize() const { return m_requantize; }
protected:
bool m_requantize;
bool m_with_relu;
};
} // namespace op
} // namespace ngraph
......@@ -115,6 +115,8 @@ NGRAPH_OP(QuantizedConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionBiasSignedAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedConvolution, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedMaxPool, ngraph::op)
NGRAPH_OP(Relu, ngraph::op)
NGRAPH_OP(ReluBackprop, ngraph::op)
......
......@@ -42,8 +42,6 @@ set(SRC
builder/concat.cpp
builder/convert.cpp
builder/convert_layout.cpp
builder/quantized_conv.cpp
builder/quantized_concat.cpp
builder/convolution.cpp
builder/dot.cpp
builder/embedding_lookup.cpp
......@@ -63,6 +61,8 @@ set(SRC
builder/quantization.cpp
builder/quantized_avg_pool.cpp
builder/quantized_conv.cpp
builder/quantized_concat.cpp
builder/quantized_dot.cpp
builder/quantized_max_pool.cpp
builder/reshape.cpp
builder/reverse.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::QuantizedDotBias)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& functors = external_function->get_functors();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& arg3_tensor = external_function->get_tensor_data(args[3].get_name());
auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto scales_size = shape_size(args[3].get_shape());
auto ip_desc =
mkldnn_emitter
->get_inner_product_forward_desc<ngraph::op::QuantizedDotBias>(node);
auto ip_attr =
mkldnn_emitter
->get_inner_product_forward_attr<ngraph::op::QuantizedDotBias>(node);
size_t ip_index = mkldnn_emitter->inner_product_forward_init(true);
auto& deps = mkldnn_emitter->get_primitive_deps(ip_index);
auto functor = [&, scales_size, ip_desc, ip_attr, deps, ip_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) mutable {
if (ctx->first_iteration)
{
vector<float> dyn_scales;
dyn_scales.assign(static_cast<float*>(arg3_tensor),
static_cast<float*>(arg3_tensor) + scales_size);
ip_attr.set_output_scales(0, dyn_scales);
mkldnn_emitter->build_inner_product_forward<true>(
ip_desc, ip_attr, executor::global_cpu_engine, ip_index);
}
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], arg1_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], arg2_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[3], out0_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, ip_index);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("unsupported parameters for QuantizedDotBias via DEX");
}
}
template <>
void Builder::BUILDER_DECL(ngraph::op::QuantizedDot)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& functors = external_function->get_functors();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto scales_size = shape_size(args[2].get_shape());
auto ip_desc =
mkldnn_emitter->get_inner_product_forward_desc<ngraph::op::QuantizedDot>(
node);
auto ip_attr =
mkldnn_emitter->get_inner_product_forward_attr<ngraph::op::QuantizedDot>(
node);
size_t ip_index = mkldnn_emitter->inner_product_forward_init(false);
auto& deps = mkldnn_emitter->get_primitive_deps(ip_index);
auto functor = [&, scales_size, ip_desc, ip_attr, deps, ip_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) mutable {
if (ctx->first_iteration)
{
vector<float> dyn_scales;
dyn_scales.assign(static_cast<float*>(arg2_tensor),
static_cast<float*>(arg2_tensor) + scales_size);
ip_attr.set_output_scales(0, dyn_scales);
mkldnn_emitter->build_inner_product_forward<false>(
ip_desc, ip_attr, executor::global_cpu_engine, ip_index);
}
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], arg1_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out0_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, ip_index);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("unsupported parameters for QuantizedDot via DEX");
}
}
REGISTER_OP_BUILDER(QuantizedDotBias);
REGISTER_OP_BUILDER(QuantizedDot);
}
}
}
......@@ -55,6 +55,8 @@
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -2498,6 +2500,60 @@ namespace ngraph
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::QuantizedDotBias)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto qip_index =
mkldnn_emitter->build_inner_product<ngraph::op::QuantizedDotBias>(
node, args, out);
auto& deps = mkldnn_emitter->get_primitive_deps(qip_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(qip_index) << ");\n";
}
else
{
throw ngraph_error("QuantizedDotBias is only supported with MKLDNN kernel.");
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::QuantizedDot)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto qip_index = mkldnn_emitter->build_inner_product<ngraph::op::QuantizedDot>(
node, args, out);
auto& deps = mkldnn_emitter->get_primitive_deps(qip_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(qip_index) << ");\n";
}
else
{
throw ngraph_error("unsupported parameters for QuantizedDot");
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ConvolutionBias)
{
......
......@@ -72,6 +72,8 @@
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -353,6 +355,8 @@ static const runtime::cpu::OpMap dispatcher{
&runtime::cpu::CPU_Emitter::emit<op::QuantizedConvolutionBiasAdd>},
{TI(ngraph::op::QuantizedConvolutionBiasSignedAdd),
&runtime::cpu::CPU_Emitter::emit<op::QuantizedConvolutionBiasSignedAdd>},
{TI(ngraph::op::QuantizedDotBias), &runtime::cpu::CPU_Emitter::emit<op::QuantizedDotBias>},
{TI(ngraph::op::QuantizedDot), &runtime::cpu::CPU_Emitter::emit<op::QuantizedDot>},
{TI(ngraph::op::ConvolutionRelu), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionRelu>},
{TI(ngraph::op::QuantizedConvolution),
&runtime::cpu::CPU_Emitter::emit<op::QuantizedConvolution>},
......
......@@ -1988,6 +1988,24 @@ size_t MKLDNNEmitter::convolution_forward_init(bool with_bias)
return m_mkldnn_primitives.size() - 1;
}
size_t MKLDNNEmitter::inner_product_forward_init(bool with_bias)
{
size_t size = m_mkldnn_primitives.size();
if (with_bias)
{
// Inputs, Weights, Bias, Results, inner_product
m_mkldnn_primitives.resize(size + 5, nullptr);
m_primitive_deps[m_mkldnn_primitives.size() - 1] = {size, size + 1, size + 2, size + 3};
}
else
{
// Inputs, Weights, Results, inner_product
m_mkldnn_primitives.resize(size + 4, nullptr);
m_primitive_deps[m_mkldnn_primitives.size() - 1] = {size, size + 1, size + 2};
}
return m_mkldnn_primitives.size() - 1;
}
size_t MKLDNNEmitter::reserve_primitive_space(size_t count, bool new_workspace)
{
size_t size = m_mkldnn_primitives.size();
......@@ -2002,3 +2020,76 @@ size_t MKLDNNEmitter::reserve_primitive_space(size_t count, bool new_workspace)
}
return m_mkldnn_primitives.size() - 1;
}
size_t MKLDNNEmitter::build_quantized_inner_product_forward(
const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& result_desc,
const float scale,
const mkldnn::post_ops& pops)
{
size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t bias_index = build_memory_primitive(bias_desc);
size_t result_index = build_memory_primitive(result_desc);
std::vector<float> output_scale;
output_scale.push_back(scale);
// mkldnn inner_product attr
mkldnn::primitive_attr ip_attr;
ip_attr.set_post_ops(pops);
/* Specify the rounding mode */
ip_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
/* Specify the scales array and corresponding mask */
ip_attr.set_output_scales(0, output_scale);
// mkldnn inner_product
size_t ip_index =
insert_primitive(new mkldnn::inner_product_forward({{
mkldnn::prop_kind::forward_scoring,
input_data_desc,
weights_desc,
bias_desc,
result_desc,
},
ip_attr,
executor::global_cpu_engine},
*m_mkldnn_primitives[input_data_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[bias_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[ip_index] = {input_data_index, weights_index, bias_index, result_index};
return ip_index;
}
size_t MKLDNNEmitter::build_quantized_inner_product_forward(
const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& result_desc,
const float scale,
const mkldnn::post_ops& pops)
{
size_t input_data_index = build_memory_primitive(input_data_desc);
size_t weights_index = build_memory_primitive(weights_desc);
size_t result_index = build_memory_primitive(result_desc);
std::vector<float> output_scale;
output_scale.push_back(scale);
// mkldnn inner_product attr
mkldnn::primitive_attr ip_attr;
ip_attr.set_post_ops(pops);
/* Specify the rounding mode */
ip_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
/* Specify the scales array and corresponding mask */
ip_attr.set_output_scales(0, output_scale);
// mkldnn inner_product
size_t ip_index = insert_primitive(new mkldnn::inner_product_forward(
{{
mkldnn::prop_kind::forward_scoring, input_data_desc, weights_desc, result_desc,
},
ip_attr,
executor::global_cpu_engine},
*m_mkldnn_primitives[input_data_index],
*m_mkldnn_primitives[weights_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[ip_index] = {input_data_index, weights_index, result_index};
return ip_index;
}
This diff is collapsed.
......@@ -38,6 +38,8 @@
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
......@@ -741,6 +743,18 @@ namespace ngraph
quantized_conv_bias->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedDotBias)
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedDot)
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Dequantize)
{
......@@ -931,6 +945,10 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Dequantize>},
{TI(ngraph::op::QuantizedConcat),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedConcat>},
{TI(ngraph::op::QuantizedDot),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedDot>},
{TI(ngraph::op::QuantizedDotBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedDotBias>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::GetOutputElement>},
};
......
......@@ -39,6 +39,8 @@
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
......@@ -453,6 +455,82 @@ namespace ngraph
o_mds.push_back(prim_desc.dst_primitive_desc().desc());
}
template <typename T, bool use_bias>
void InnerProductLayout(std::shared_ptr<ngraph::Node> node,
vector<memory::desc>& i_mds,
vector<memory::desc>& o_mds)
{
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
auto result_shape = node->get_output_shape(0);
memory::data_type et =
mkldnn_utils::get_mkldnn_data_type(node->get_input_element_type(0));
memory::data_type et_weights = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(1));
memory::data_type et_result = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_output_element_type(0));
engine cpu_engine(engine::cpu, 0);
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_arg1_shape(arg1_shape.begin(), arg1_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
const memory::desc input_data_desc(mkldnn_arg0_shape, et, memory::format::any);
const memory::desc weights_desc(
mkldnn_arg1_shape, et_weights, memory::format::any);
const memory::desc result_desc(
mkldnn_result_shape, et_result, memory::format::any);
std::unique_ptr<inner_product_forward::desc> fwd_desc{nullptr};
if (use_bias)
{
memory::data_type et_bias =
mkldnn_utils::get_mkldnn_data_type(node->get_input_element_type(2));
auto arg2_shape = node->get_input_shape(2);
memory::dims mkldnn_arg2_shape(arg2_shape.begin(), arg2_shape.end());
const memory::desc bias_desc(
mkldnn_arg2_shape, et_bias, memory::format::any);
try
{
fwd_desc.reset(new inner_product_forward::desc(prop_kind::forward,
input_data_desc,
weights_desc,
bias_desc, // with bias
result_desc));
}
catch (const mkldnn::error& e)
{
throw ngraph_error(
"setting layouts on inner_product failed with MKLDNN error: " +
e.message);
}
}
else
{
try
{
fwd_desc.reset(new inner_product_forward::desc(
prop_kind::forward, input_data_desc, weights_desc, result_desc));
}
catch (const mkldnn::error& e)
{
throw ngraph_error(
"setting layouts on inner_product failed with MKLDNN error: " +
e.message);
}
}
inner_product_forward::primitive_desc prim_desc(*fwd_desc, cpu_engine);
i_mds.push_back(prim_desc.src_primitive_desc().desc());
i_mds.push_back(prim_desc.weights_primitive_desc().desc());
if (use_bias)
{
i_mds.push_back(prim_desc.bias_primitive_desc().desc());
}
o_mds.push_back(prim_desc.dst_primitive_desc().desc());
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::QuantizedConvolution)
{
......@@ -628,6 +706,52 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::QuantizedDotBias)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
InnerProductLayout<ngraph::op::QuantizedDotBias, true>(node, i_mds, o_mds);
auto scale_input_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 3, false, memory::format::x);
i_mds.push_back(scale_input_md);
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::QuantizedDot)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
InnerProductLayout<ngraph::op::QuantizedDot, false>(node, i_mds, o_mds);
auto scale_input_md = mkldnn_utils::create_default_mkldnn_md(
node.get(), 2, false, memory::format::x);
i_mds.push_back(scale_input_md);
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ConvolutionRelu)
{
......@@ -2166,6 +2290,10 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GroupConvolutionBias>},
{TI(ngraph::op::QuantizedConcat),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedConcat>},
{TI(ngraph::op::QuantizedDotBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedDotBias>},
{TI(ngraph::op::QuantizedDot),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedDot>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
......@@ -1007,6 +1007,8 @@ private:
case OP_TYPEID::QuantizedConvolutionRelu:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::QuantizedDotBias:
case OP_TYPEID::QuantizedDot:
{
throw unsupported_op("Unsupported op '" + node.description() + "'.");
}
......
......@@ -62,6 +62,8 @@
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/floor.hpp"
......@@ -921,6 +923,16 @@ std::string runtime::gpu::GPU_Emitter::emit_QuantizedConvolutionRelu(EMIT_ARGS)
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_QuantizedDot(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_QuantizedDotBias(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_QuantizedMaxPool(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
......
......@@ -2003,6 +2003,8 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::QuantizedConvolutionBiasSignedAdd:
case OP_TYPEID::QuantizedConvolutionRelu:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedDot:
case OP_TYPEID::QuantizedDotBias:
case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::GenerateMask:
......
......@@ -1019,6 +1019,8 @@ private:
case OP_TYPEID::QuantizedConvolutionRelu:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::QuantizedDotBias:
case OP_TYPEID::QuantizedDot:
{
throw unsupported_op("Unsupported op '" + node.description() +
"' in Interpreter back end.");
......
......@@ -52,6 +52,8 @@
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/floor.hpp"
......@@ -1030,6 +1032,10 @@ static shared_ptr<ngraph::Function>
data_dilation_strides.get<std::vector<size_t>>());
break;
}
case OP_TYPEID::QuantizedDotBias: { break;
}
case OP_TYPEID::QuantizedDot: { break;
}
case OP_TYPEID::QuantizedMaxPool:
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
......@@ -1645,6 +1651,10 @@ static json write(const Node& n, bool binary_constant_data)
node["data_dilation_strides"] = tmp->get_data_dilation_strides();
break;
}
case OP_TYPEID::QuantizedDotBias: { break;
}
case OP_TYPEID::QuantizedDot: { break;
}
case OP_TYPEID::QuantizedMaxPool:
{
auto tmp = dynamic_cast<const op::QuantizedMaxPool*>(&n);
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment