Commit 8491030d authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

add safety checks for mkldnn-assignement for quantized inner product (#2815)

* - add safety checks for mkldnn-assignement for quantized inner product

* - add asserts for unsupported data types in builder & emitter code of Quantized Dot
parent d07e38e0
...@@ -36,6 +36,12 @@ namespace ngraph ...@@ -36,6 +36,12 @@ namespace ngraph
{ {
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::u8)
{
throw ngraph_error(
"Unsupported data types for QuantizedDot MKLDNN kernel.");
}
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name()); auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
......
...@@ -2494,6 +2494,13 @@ namespace ngraph ...@@ -2494,6 +2494,13 @@ namespace ngraph
{ {
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::u8)
{
throw ngraph_error(
"Unsupported data types for QuantizedDot MKLDNN kernel.");
}
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto qip_index = mkldnn_emitter->build_inner_product<ngraph::op::QuantizedDot>( auto qip_index = mkldnn_emitter->build_inner_product<ngraph::op::QuantizedDot>(
node, args, out); node, args, out);
......
...@@ -781,15 +781,23 @@ namespace ngraph ...@@ -781,15 +781,23 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedDotBias) void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedDotBias)
{
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::i8)
{ {
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node); runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
} }
}
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedDot) void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedDot)
{
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::i8)
{ {
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node); runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
} }
}
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Dequantize) void CPUAssignment::ASSIGN_DECL(ngraph::op::Dequantize)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment