Commit e6267708 authored by Nishant Patel's avatar Nishant Patel Committed by Robert Kimball

Add support for Dequantize op via mkldnn for IA backend (codegen + DEX) (#1565)

* Add support for Dequantize op via mkldnn for IA backend (codegen + DEX)

* Remove unused variable

* Static cast target range
parent 00a76f3b
...@@ -36,6 +36,7 @@ set(SRC ...@@ -36,6 +36,7 @@ set(SRC
builder/convert.cpp builder/convert.cpp
builder/convert_layout.cpp builder/convert_layout.cpp
builder/convolution.cpp builder/convolution.cpp
builder/dequantize.cpp
builder/dot.cpp builder/dot.cpp
builder/function_call.cpp builder/function_call.cpp
builder/lstm.cpp builder/lstm.cpp
...@@ -77,6 +78,7 @@ set(SRC ...@@ -77,6 +78,7 @@ set(SRC
op/conv_bias.cpp op/conv_bias.cpp
op/conv_relu.cpp op/conv_relu.cpp
op/convert_layout.cpp op/convert_layout.cpp
op/dequantize.cpp
op/loop_kernel.cpp op/loop_kernel.cpp
op/lstm.cpp op/lstm.cpp
op/matmul_bias.cpp op/matmul_bias.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/dequantize.hpp"
#include <vector>
#include "ngraph/op/constant.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Dequantize)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t dequantize_index =
mkldnn_emitter->build_dequantization(node, input_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(dequantize_index);
auto functor = [&, dequantize_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, dequantize_index);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("unsupported parameters for DequantizeOp via DEX");
}
}
REGISTER_OP_BUILDER(Dequantize);
}
}
}
...@@ -102,6 +102,7 @@ ...@@ -102,6 +102,7 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp" #include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/dequantize.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp" #include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp" #include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
...@@ -2773,6 +2774,32 @@ namespace ngraph ...@@ -2773,6 +2774,32 @@ namespace ngraph
} }
} }
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Dequantize)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t dequantize_index =
mkldnn_emitter->build_dequantization(node, input_data_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(dequantize_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(dequantize_index) << ");\n";
}
else
{
throw ngraph_error("unsupported parameters for DequantizeOp");
}
}
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ConvolutionBackpropFilters) void CPU_Emitter::EMITTER_DECL(ngraph::op::ConvolutionBackpropFilters)
{ {
......
...@@ -142,6 +142,7 @@ ...@@ -142,6 +142,7 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp" #include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/dequantize.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp" #include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp" #include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
...@@ -292,6 +293,7 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -292,6 +293,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Ceiling), &runtime::cpu::CPU_Emitter::emit<op::Ceiling>}, {TI(ngraph::op::Ceiling), &runtime::cpu::CPU_Emitter::emit<op::Ceiling>},
{TI(ngraph::op::Sqrt), &runtime::cpu::CPU_Emitter::emit<op::Sqrt>}, {TI(ngraph::op::Sqrt), &runtime::cpu::CPU_Emitter::emit<op::Sqrt>},
{TI(ngraph::op::Convolution), &runtime::cpu::CPU_Emitter::emit<op::Convolution>}, {TI(ngraph::op::Convolution), &runtime::cpu::CPU_Emitter::emit<op::Convolution>},
{TI(ngraph::op::Dequantize), &runtime::cpu::CPU_Emitter::emit<op::Dequantize>},
{TI(ngraph::op::ConvolutionBackpropFilters), {TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::CPU_Emitter::emit<op::ConvolutionBackpropFilters>}, &runtime::cpu::CPU_Emitter::emit<op::ConvolutionBackpropFilters>},
{TI(ngraph::op::ConvolutionBackpropData), {TI(ngraph::op::ConvolutionBackpropData),
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
#include "mkldnn_emitter.hpp" #include "mkldnn_emitter.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/dequantize.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
using namespace ngraph::runtime::cpu; using namespace ngraph::runtime::cpu;
...@@ -119,6 +121,32 @@ size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc) ...@@ -119,6 +121,32 @@ size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
return index; return index;
} }
size_t MKLDNNEmitter::build_dequantization(const ngraph::Node* node,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc)
{
auto dequantize = static_cast<const ngraph::op::Dequantize*>(node);
auto min_const_op = std::static_pointer_cast<ngraph::op::Constant>(dequantize->get_argument(1));
auto max_const_op = std::static_pointer_cast<ngraph::op::Constant>(dequantize->get_argument(2));
float min_range = *(static_cast<float const*>(min_const_op->get_data_ptr()));
float max_range = *(static_cast<float const*>(max_const_op->get_data_ptr()));
const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
bool is_signed = (dequantize->get_dequantize_et()).is_signed();
const float target_range =
static_cast<float>((is_signed ? std::pow(2, 7) : std::pow(2, 8)) - 1);
const float scale_factor = max_abs / target_range;
std::vector<float> scales;
scales.push_back(scale_factor);
mkldnn::primitive_attr attr;
attr.set_output_scales(0, scales);
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
size_t dequantize_index = 0;
dequantize_index = this->build_quantize_reorder(input_desc, result_desc, attr);
return dequantize_index;
}
mkldnn::memory::format MKLDNNEmitter::query_convolution_forward_weight_format( mkldnn::memory::format MKLDNNEmitter::query_convolution_forward_weight_format(
const mkldnn::memory::desc& input_data_desc, const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc_any, const mkldnn::memory::desc& weights_desc_any,
...@@ -638,6 +666,23 @@ size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc, ...@@ -638,6 +666,23 @@ size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc,
return primitive_index; return primitive_index;
} }
size_t MKLDNNEmitter::build_quantize_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
const mkldnn::primitive_attr attr)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
auto reorder_desc =
mkldnn::reorder::primitive_desc({input_desc, mkldnn_utils::global_cpu_engine},
{result_desc, mkldnn_utils::global_cpu_engine},
attr);
size_t primitive_index = insert_primitive(new mkldnn::reorder(
reorder_desc, *m_mkldnn_primitives[input_index], *m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_lrn_forward(const mkldnn::memory::desc& input_desc, size_t MKLDNNEmitter::build_lrn_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& result_desc,
float alpha, float alpha,
......
...@@ -504,6 +504,14 @@ namespace ngraph ...@@ -504,6 +504,14 @@ namespace ngraph
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& result_desc,
float alpha); float alpha);
size_t build_quantize_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
const mkldnn::primitive_attr attr);
size_t build_dequantization(const ngraph::Node* node,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
private: private:
std::vector<mkldnn::primitive*> m_mkldnn_primitives; std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams; std::vector<mkldnn::stream> m_mkldnn_streams;
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/dequantize.hpp"
#include "ngraph/op/constant.hpp"
ngraph::op::Dequantize::Dequantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> min,
std::shared_ptr<Node> max,
const element::Type& type)
: Op("Dequantize", check_single_output_args({input, min, max}))
, m_element_type(type)
{
constructor_validate_and_infer_types();
if (input->get_element_type() != element::u8 && input->get_element_type() != element::i8)
{
throw ngraph_error("Dequantization supported only for i8/u8!");
}
if (min->get_element_type() != min->get_element_type())
{
throw ngraph_error("Min's element type isn't equal to max's!");
}
if (min->get_shape().size() != 0)
{
throw ngraph_error("Min is not a scalar!");
}
if (max->get_shape().size() != 0)
{
throw ngraph_error("Max is not a scalar!");
}
if (!(std::dynamic_pointer_cast<op::Constant>(min) &&
std::dynamic_pointer_cast<op::Constant>(max)))
{
throw ngraph_error("Min and max have to be constants!");
}
set_output_type(0, element::f32, input->get_shape());
}
std::shared_ptr<ngraph::Node>
ngraph::op::Dequantize::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Dequantize>(
new_args.at(0), new_args.at(1), new_args.at(2), m_element_type);
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class Dequantize : public Op
{
public:
Dequantize(std::shared_ptr<Node> input,
std::shared_ptr<Node> min,
std::shared_ptr<Node> max,
const element::Type& type);
const element::Type& get_dequantize_et() const { return m_element_type; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
const element::Type m_element_type;
};
}
}
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include "ngraph/runtime/cpu/op/bounded_relu.hpp" #include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp" #include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/dequantize.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp" #include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp" #include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
...@@ -659,6 +660,20 @@ namespace ngraph ...@@ -659,6 +660,20 @@ namespace ngraph
bounded_relu->set_op_annotations(op_annotations); bounded_relu->set_op_annotations(op_annotations);
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Dequantize)
{
if (node->get_input_element_type(0) == element::u8 ||
node->get_input_element_type(0) == element::i8)
{
auto dequantize = static_cast<op::Dequantize*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
dequantize->set_op_annotations(op_annotations);
}
}
} }
} }
} }
...@@ -712,6 +727,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{ ...@@ -712,6 +727,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Lstm>}, {TI(ngraph::op::Lstm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Rnn>}, {TI(ngraph::op::Rnn), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Rnn>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Softmax>}, {TI(ngraph::op::Softmax), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Softmax>},
{TI(ngraph::op::Dequantize),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Dequantize>},
}; };
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph( bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
...@@ -49,6 +49,7 @@ ...@@ -49,6 +49,7 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp" #include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/dequantize.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp" #include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp" #include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
...@@ -1538,6 +1539,20 @@ namespace ngraph ...@@ -1538,6 +1539,20 @@ namespace ngraph
set_native_layouts(external_function, node); set_native_layouts(external_function, node);
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Dequantize)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
//TODO : Propogate Layout
set_native_layouts(external_function, node);
}
else
{
throw ngraph_error("Dequantized op is only supported in MKLDNN for now.");
}
}
} }
} }
} }
...@@ -1593,6 +1608,7 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -1593,6 +1608,7 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>}, {TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Softmax>}, {TI(ngraph::op::Softmax), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Softmax>},
{TI(ngraph::op::BoundedRelu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BoundedRelu>}, {TI(ngraph::op::BoundedRelu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BoundedRelu>},
{TI(ngraph::op::Dequantize), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Dequantize>},
}; };
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
...@@ -58,7 +58,7 @@ if (NGRAPH_INTERPRETER_ENABLE) ...@@ -58,7 +58,7 @@ if (NGRAPH_INTERPRETER_ENABLE)
endif() endif()
if (NGRAPH_CPU_ENABLE) if (NGRAPH_CPU_ENABLE)
set(SRC ${SRC} core_fusion.cpp) set(SRC ${SRC} core_fusion.cpp quantize_cpu.cpp)
endif() endif()
add_subdirectory(models) add_subdirectory(models)
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <string>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/runtime/cpu/op/dequantize.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
template <typename T>
void DequantizeTest(int input, float min, float max, float expected_output)
{
vector<T> a_data = {static_cast<T>(input)};
Shape shape_a{1};
auto A = make_shared<op::Parameter>(element::from<T>(), shape_a);
auto B = op::Constant::create(element::f32, Shape{}, {min});
auto C = op::Constant::create(element::f32, Shape{}, {max});
auto r = make_shared<op::Dequantize>(A, B, C, element::from<T>());
auto f = make_shared<Function>(r, op::ParameterVector{A});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::from<T>(), Shape{});
copy_data(a, a_data);
auto result = backend->create_tensor(element::f32, Shape{});
backend->call(f, {result}, {a});
EXPECT_EQ((vector<float>{expected_output}), read_vector<float>(result));
}
TEST(quantize_cpu, dequantize_from_uint8)
{
DequantizeTest<uint8_t>(255, 100.0f, 300.0f, 300.0);
}
TEST(quantize_cpu, dequantize_from_uint8_smallrange)
{
DequantizeTest<uint8_t>(255, -2.0f, 5.0f, 5.0);
}
TEST(quantize_cpu, dequantize_from_int8_smallrange)
{
DequantizeTest<int8_t>(-127, -2.0f, 1.0f, -2.0);
}
TEST(quantize_cpu, dequantize_from_int8)
{
DequantizeTest<int8_t>(42, -1.0f, 300.0f, static_cast<float>(99.212601));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment