Commit 7e310e20 authored by Nishant Patel's avatar Nishant Patel Committed by Robert Kimball

Support dynamic scales for Qconv's and Dequantize (#2171)

* Support dynamic scales for Qconv's and Dequantize

* Remove constant folding

* add additional dynamic_quantize unittest

* add another mxnet quantize unittest

* add additional dynamic_dequantize tests

* fix shape error

* add dynamic signed_quantize unittest

* Pass correct scale

* Refactoring

* Added dynamic scale support for QCBA and QCBSA

* Refactor to create MKLDNN primitives on the first iteration

* remove stray code

* unused variables

* remove extraneous line
parent c153ea8a
......@@ -45,6 +45,7 @@ namespace ngraph
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
bool with_relu() const { return true; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -52,16 +52,57 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t dequantize_index =
mkldnn_emitter->build_dequantization(node, input_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(dequantize_index);
functor = [&, dequantize_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, dequantize_index);
};
functors.emplace_back(functor);
auto scale_const_op = std::dynamic_pointer_cast<ngraph::op::Constant>(
dequantize->get_argument(1));
std::vector<float> scales;
if (scale_const_op == nullptr)
{
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto scales_size = shape_size(args[1].get_shape());
size_t dequantize_index =
mkldnn_emitter->build_dequantization(node, input_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(dequantize_index);
functor = [&, input_desc, result_desc, scales_size, dequantize_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
// Create MKLDNN reorder primitive during the first iteration.
// Assumes the scales dont change for the duration of the graph
if (ctx->first_iteration)
{
mkldnn::primitive_attr attr;
vector<float> dyn_scales;
dyn_scales.assign(static_cast<float*>(arg1_tensor),
static_cast<float*>(arg1_tensor) + scales_size);
attr.set_output_scales(0, dyn_scales);
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
auto reorder_desc = mkldnn::reorder::primitive_desc(
{input_desc, executor::global_cpu_engine},
{result_desc, executor::global_cpu_engine},
attr);
*ctx->mkldnn_primitives[dequantize_index] =
mkldnn::reorder(reorder_desc,
*ctx->mkldnn_primitives[deps[0]],
*ctx->mkldnn_primitives[deps[1]]);
}
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, dequantize_index);
};
functors.emplace_back(functor);
}
else
{
size_t dequantize_index =
mkldnn_emitter->build_dequantization(node, input_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(dequantize_index);
functor = [&, dequantize_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, dequantize_index);
};
functors.emplace_back(functor);
}
}
else
{
......@@ -223,6 +264,7 @@ namespace ngraph
vector<float> dyn_scales;
dyn_scales.assign(static_cast<float*>(arg1_tensor),
static_cast<float*>(arg1_tensor) + scales_size);
dyn_scales[0] = 1.0 / dyn_scales[0];
attr.set_output_scales(0, dyn_scales);
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
auto reorder_desc = mkldnn::reorder::primitive_desc(
......
......@@ -148,10 +148,14 @@ size_t MKLDNNEmitter::build_dequantization(const ngraph::Node* node,
{
auto dequantize = static_cast<const ngraph::op::Dequantize*>(node);
auto scale_const_op =
std::static_pointer_cast<ngraph::op::Constant>(dequantize->get_argument(1));
float scale = *(static_cast<float const*>(scale_const_op->get_data_ptr()));
std::dynamic_pointer_cast<ngraph::op::Constant>(dequantize->get_argument(1));
std::vector<float> scale = {1.0f};
if (scale_const_op != nullptr)
{
scale = scale_const_op->get_vector<float>();
}
std::vector<float> scales;
scales.push_back(scale);
scales.push_back(scale[0]);
size_t dequantize_index = 0;
dequantize_index = this->build_quantize_reorder(input_desc, result_desc, scales);
return dequantize_index;
......@@ -1203,3 +1207,21 @@ size_t MKLDNNEmitter::build_bounded_relu(const mkldnn::memory::desc& input_desc,
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::convolution_forward_init(bool with_bias)
{
size_t size = m_mkldnn_primitives.size();
if (with_bias)
{
// Inputs, Weights, Bias, Results, Conv
m_mkldnn_primitives.resize(size + 5, nullptr);
m_primitive_deps[m_mkldnn_primitives.size() - 1] = {size, size + 1, size + 2, size + 3};
}
else
{
// Inputs, Weights, Results, Conv
m_mkldnn_primitives.resize(size + 4, nullptr);
m_primitive_deps[m_mkldnn_primitives.size() - 1] = {size, size + 1, size + 2};
}
return m_mkldnn_primitives.size() - 1;
}
This diff is collapsed.
......@@ -45,6 +45,7 @@ namespace ngraph
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
bool with_relu() const { return true; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment