Commit b9ff5d1f authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Add QuantizedConcat (#2060)

* Add QuantizedConcat

* Remove unused variables and add check for size of mins and maxes vector

* Resolve conflicts

* Merged with master and addressed some PR feedback

* Avoid float comparison

* make min/max vector, add dequant/quanti

* fix dequant/quant scales

* fix CI build issue
parent 26bba737
......@@ -70,6 +70,7 @@ set (SRC
op/experimental/quantized_conv_relu.cpp
op/experimental/quantized_conv.cpp
op/experimental/quantized_max_pool.cpp
op/experimental/quantized_concat.cpp
op/experimental/shape_of.cpp
op/floor.cpp
op/get_output_element.cpp
......
......@@ -18,8 +18,12 @@
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/quantization.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/reshape.hpp"
#include "quantization_util.hpp"
using namespace std;
......@@ -88,6 +92,43 @@ namespace ngraph
return make_shared<op::Dequantize>(input, scale, zero, real_type, axes);
}
std::shared_ptr<Node> ScaledQuantizedConcat(const NodeVector& args,
size_t concatenation_axis,
const NodeVector& mins,
const NodeVector& maxs)
{
quantization_util::check_concat(args, mins, maxs);
auto quant_type = args[0]->get_element_type();
// output scale
auto min = std::make_shared<op::Min>(std::make_shared<op::Concat>(mins, 0),
ngraph::AxisSet{0});
auto max = std::make_shared<op::Max>(std::make_shared<op::Concat>(maxs, 0),
ngraph::AxisSet{0});
auto out_scale = quantization_util::get_scale(min, max, quant_type);
NodeVector rescaled_args(args.size());
for (size_t i = 0; i < args.size(); ++i)
{
auto q_type = args[i]->get_element_type();
auto in_scale = std::make_shared<ngraph::op::Reshape>(
quantization_util::get_scale(mins[i], maxs[i], q_type), AxisVector{0}, Shape{});
auto zero = make_constant(q_type, in_scale->get_shape(), 0);
rescaled_args[i] =
make_shared<op::Dequantize>(args[i], in_scale, zero, element::f32, AxisSet{});
rescaled_args[i] =
make_shared<op::Quantize>(rescaled_args[i],
out_scale,
zero,
q_type,
AxisSet{},
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
}
return make_shared<op::QuantizedConcat>(rescaled_args, concatenation_axis);
}
std::shared_ptr<Node> ScaledQuantizedAvgPool(std::shared_ptr<Node> input,
const Shape& window_shape,
const Strides& window_movement_strides,
......
......@@ -20,6 +20,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
......@@ -43,6 +44,11 @@ namespace ngraph
const ngraph::element::Type& type,
const ngraph::AxisSet& axes);
std::shared_ptr<Node> ScaledQuantizedConcat(const NodeVector& args,
size_t concatenation_axis,
const NodeVector& mins,
const NodeVector& maxes);
std::shared_ptr<Node> ScaledQuantizedAvgPool(std::shared_ptr<Node> input,
const Shape& window_shape,
const Strides& window_movement_strides,
......
......@@ -249,6 +249,33 @@ namespace ngraph
return max_abs_range / target_range;
}
void
check_concat(const NodeVector& args, const NodeVector& mins, const NodeVector& maxs)
{
auto size = args.size();
if (size != mins.size() || size != maxs.size())
{
throw ngraph_error("Min and Max node vectors must be of same length");
}
for (size_t i = 0; i < size; i++)
{
auto min = mins[i];
auto max = maxs[i];
auto type = min->get_element_type();
if (type != max->get_element_type())
{
throw ngraph_error("check_concat: min and max must have same type");
}
if (min->get_shape() != Shape{1} || min->get_shape() != Shape{1})
{
throw ngraph_error("check_concat: min and max must have same shape " +
vector_to_string(min->get_shape()) +
vector_to_string(max->get_shape()));
}
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cassert>
#include <memory>
#include "ngraph/op/concat.hpp"
#include "ngraph/op/slice.hpp"
#include "quantized_concat.hpp"
using namespace std;
using namespace ngraph;
op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenation_axis)
: Op("QuantizedConcat", check_single_output_args(args))
, m_concatenation_axis(concatenation_axis)
{
constructor_validate_and_infer_types();
}
void op::QuantizedConcat::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
PartialShape inputs_shape_scheme{PartialShape::dynamic()};
element::Type inputs_et{element::dynamic};
Dimension concatenation_axis_output_dim{0};
for (auto i = 0; i < get_inputs().size(); i++)
{
PartialShape this_input_shape = get_input_partial_shape(i);
Dimension this_input_rank = this_input_shape.rank();
if (this_input_rank.is_static())
{
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < size_t(this_input_rank))
<< "QuantizedConcatenation axis (" << m_concatenation_axis
<< ") is out of bounds for "
<< "argument " << i << ", which has shape " << this_input_shape << ".";
concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic();
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape))
<< "Argument shapes are inconsistent; they must have the same rank, and must have "
<< "equal dimension everywhere except on the concatenation axis (axis "
<< m_concatenation_axis << ").";
NODE_VALIDATION_ASSERT(
this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
}
else
{
concatenation_axis_output_dim += Dimension::dynamic();
}
}
PartialShape concatenated_shape = inputs_shape_scheme;
if (concatenated_shape.rank().is_static())
{
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_dim;
}
set_output_type(0, inputs_et, concatenated_shape);
}
shared_ptr<Node> op::QuantizedConcat::copy_with_new_args(const NodeVector& new_args) const
{
return make_shared<QuantizedConcat>(new_args, m_concatenation_axis);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief QuantizedConcatenation operation.
class QuantizedConcat : public Op
{
public:
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
QuantizedConcat(const NodeVector& args, size_t concatenation_axis);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; }
protected:
const size_t m_concatenation_axis;
};
}
}
......@@ -42,6 +42,8 @@ set(SRC
builder/concat.cpp
builder/convert.cpp
builder/convert_layout.cpp
builder/quantized_conv.cpp
builder/quantized_concat.cpp
builder/convolution.cpp
builder/dot.cpp
builder/embedding_lookup.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/concat.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::QuantizedConcat)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& functors = external_function->get_functors();
vector<reference_wrapper<void*>> arg_tensors;
for (auto& arg : args)
{
if (shape_size(arg.get_shape()))
{
arg_tensors.emplace_back(
external_function->get_tensor_data(arg.get_name()));
}
}
auto nargs = args.size();
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
std::vector<mkldnn::memory::desc> inputs_data_desc;
for (size_t i = 0; i < args.size(); i++)
{
inputs_data_desc.push_back(mkldnn_utils::get_input_mkldnn_md(node, i));
}
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t concat_dim = (static_cast<const ngraph::op::QuantizedConcat*>(node))
->get_concatenation_axis();
auto concat_index =
mkldnn_emitter->build_concat(inputs_data_desc, result_desc, concat_dim);
auto& deps = mkldnn_emitter->get_primitive_deps(concat_index);
auto functor = [&, arg_tensors, nargs, concat_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
for (size_t i = 0; i < nargs; i++)
{
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[i], arg_tensors[i]);
}
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[nargs], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, concat_index);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("unsupported parameters for QuantizedConcat via DEX");
}
}
REGISTER_OP_BUILDER(QuantizedConcat);
}
}
}
......@@ -52,6 +52,7 @@
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
......@@ -1071,7 +1072,7 @@ namespace ngraph
size_t concat_index = 0;
size_t concat_dim =
(dynamic_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis();
(static_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis();
concat_index =
mkldnn_emitter->build_concat(inputs_data_desc, result_desc, concat_dim);
auto& deps = mkldnn_emitter->get_primitive_deps(concat_index);
......@@ -1090,7 +1091,7 @@ namespace ngraph
else
{
auto axis =
(dynamic_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis();
(static_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis();
std::vector<std::string> arg_names;
std::vector<Shape> arg_shapes;
......@@ -4040,6 +4041,44 @@ namespace ngraph
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::QuantizedConcat)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
std::vector<mkldnn::memory::desc> inputs_data_desc;
for (size_t i = 0; i < args.size(); i++)
{
inputs_data_desc.push_back(mkldnn_utils::get_input_mkldnn_md(node, i));
}
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t concat_index = 0;
size_t concat_dim = (static_cast<const ngraph::op::QuantizedConcat*>(node))
->get_concatenation_axis();
concat_index =
mkldnn_emitter->build_concat(inputs_data_desc, result_desc, concat_dim);
auto& deps = mkldnn_emitter->get_primitive_deps(concat_index);
size_t i;
for (i = 0; i < args.size(); i++)
{
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[i])
<< ", " << args[i].get_name() << ");\n";
}
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[i])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(concat_index) << ");\n";
}
else
{
throw ngraph_error("unsupported parameters for QuantizedConcat via DEX");
}
}
#undef TI
}
}
......
......@@ -68,6 +68,7 @@
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
......@@ -409,6 +410,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Dequantize), &runtime::cpu::CPU_Emitter::emit<ngraph::op::Dequantize>},
{TI(ngraph::op::GroupConvolutionBias),
&runtime::cpu::CPU_Emitter::emit<op::GroupConvolutionBias>},
{TI(ngraph::op::QuantizedConcat), &runtime::cpu::CPU_Emitter::emit<op::QuantizedConcat>},
};
static void
......
......@@ -34,6 +34,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
......@@ -120,6 +121,39 @@ namespace ngraph
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConcat)
{
auto quantized_concat = static_cast<op::QuantizedConcat*>(node);
if ((node->get_input_element_type(0) == element::i8 ||
node->get_input_element_type(0) == element::u8) &&
((node->get_input_shape(0)).size() == 4 ||
(node->get_input_shape(0)).size() == 2))
{
// MKLDNN seems to throw an exception when given tensors with 0-length
// dimensions, so don't assign it in such cases.
bool any_zero = false;
for (size_t i = 0; i < node->get_input_size(); i++)
{
if (shape_size(node->get_input_shape(i)) == 0)
{
any_zero = true;
break;
}
}
if (!any_zero)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantized_concat->set_op_annotations(op_annotations);
}
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Convolution)
{
......@@ -895,6 +929,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Quantize), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Quantize>},
{TI(ngraph::op::Dequantize),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Dequantize>},
{TI(ngraph::op::QuantizedConcat),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::QuantizedConcat>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::GetOutputElement>},
};
......
......@@ -35,6 +35,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
......@@ -1899,44 +1900,65 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Concat)
template <typename T>
void ConcatLayout(std::shared_ptr<ngraph::Node> node,
vector<memory::desc>& i_mds,
vector<memory::desc>& o_mds)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
auto concat = static_cast<const T*>(node.get());
size_t concat_dim = concat->get_concatenation_axis();
auto result_desc = mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, true, memory::format::any);
std::vector<mkldnn::memory::primitive_desc> inputs_pd;
for (size_t i = 0; i < node->get_input_size(); i++)
{
auto concat = static_cast<const ngraph::op::Concat*>(node.get());
size_t concat_dim = concat->get_concatenation_axis();
auto result_desc = mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, true, memory::format::any);
std::vector<mkldnn::memory::primitive_desc> inputs_pd;
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), i);
inputs_pd.push_back(
mkldnn::memory::primitive_desc(input_md, executor::global_cpu_engine));
}
try
{
auto prim_desc = concat::primitive_desc(
result_desc, static_cast<int>(concat_dim), inputs_pd);
for (size_t i = 0; i < node->get_input_size(); i++)
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), i);
inputs_pd.push_back(mkldnn::memory::primitive_desc(
input_md, executor::global_cpu_engine));
i_mds.push_back(inputs_pd[i].desc());
}
try
{
auto prim_desc = concat::primitive_desc(
result_desc, static_cast<int>(concat_dim), inputs_pd);
o_mds.push_back(prim_desc.dst_primitive_desc().desc());
}
catch (const mkldnn::error& e)
{
throw ngraph_error(e.message);
}
}
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
for (size_t i = 0; i < node->get_input_size(); i++)
{
i_mds.push_back(inputs_pd[i].desc());
}
o_mds.push_back(prim_desc.dst_primitive_desc().desc());
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
catch (const mkldnn::error& e)
{
throw ngraph_error(e.message);
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Concat)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
ConcatLayout<ngraph::op::Concat>(node, i_mds, o_mds);
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::QuantizedConcat)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
ConcatLayout<ngraph::op::QuantizedConcat>(node, i_mds, o_mds);
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
......@@ -2096,6 +2118,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedConvolutionBiasSignedAdd>},
{TI(ngraph::op::GroupConvolutionBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GroupConvolutionBias>},
{TI(ngraph::op::QuantizedConcat),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::QuantizedConcat>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
......@@ -1036,3 +1036,100 @@ TEST(builder, dynamic_scaled_DQ)
EXPECT_EQ((vector<float>{0.13725491, 0.59215689, 0.60392159, 0.8588236}),
read_vector<float>(result2));
}
TEST(builder, scaled_quantize_concat_unsigned)
{
Shape shape_a{2, 2};
auto A = make_shared<op::Parameter>(element::u8, shape_a);
auto An = make_shared<op::Parameter>(element::f32, Shape{1});
auto Ax = make_shared<op::Parameter>(element::f32, Shape{1});
Shape shape_b{3, 2};
auto B = make_shared<op::Parameter>(element::u8, shape_b);
auto Bn = make_shared<op::Parameter>(element::f32, Shape{1});
auto Bx = make_shared<op::Parameter>(element::f32, Shape{1});
Shape shape_c{3, 2};
auto C = make_shared<op::Parameter>(element::u8, shape_c);
auto Cn = make_shared<op::Parameter>(element::f32, Shape{1});
auto Cx = make_shared<op::Parameter>(element::f32, Shape{1});
Shape shape_r{8, 2};
auto QConcat = ngraph::builder::ScaledQuantizedConcat(
NodeVector{A, B, C}, 0, NodeVector{An, Bn, Cn}, NodeVector{Ax, Bx, Cx});
auto f = make_shared<Function>(NodeVector{QConcat},
ParameterVector{A, B, C, An, Bn, Cn, Ax, Bx, Cx});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::u8, shape_a);
copy_data(a, vector<uint8_t>{2, 4, 8, 16});
auto b = backend->create_tensor(element::u8, shape_b);
copy_data(b, vector<uint8_t>{2, 2, 4, 8, 16, 15});
auto c = backend->create_tensor(element::u8, shape_c);
copy_data(c, vector<uint8_t>{2, 3, 5, 7, 11, 16});
// min/max vectors
auto an = backend->create_tensor(element::f32, Shape{1});
copy_data(an, vector<float>{2.0});
auto ax = backend->create_tensor(element::f32, Shape{1});
copy_data(ax, vector<float>{16.0});
auto bn = backend->create_tensor(element::f32, Shape{1});
copy_data(bn, vector<float>{2.0});
auto bx = backend->create_tensor(element::f32, Shape{1});
copy_data(bx, vector<float>{16.0});
auto cn = backend->create_tensor(element::f32, Shape{1});
copy_data(cn, vector<float>{2.0});
auto cx = backend->create_tensor(element::f32, Shape{1});
copy_data(cx, vector<float>{16.0});
// result
auto result = backend->create_tensor(element::u8, shape_r);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b, c, an, bn, cn, ax, bx, cx});
EXPECT_EQ((vector<uint8_t>{2, 4, 8, 16, 2, 2, 4, 8, 16, 15, 2, 3, 5, 7, 11, 16}),
read_vector<uint8_t>(result));
}
TEST(builder, scaled_quantize_concat_signed)
{
Shape shape_a{2, 2};
auto A = make_shared<op::Parameter>(element::i8, shape_a);
auto An = make_shared<op::Parameter>(element::f32, Shape{1});
auto Ax = make_shared<op::Parameter>(element::f32, Shape{1});
Shape shape_b{3, 2};
auto B = make_shared<op::Parameter>(element::i8, shape_b);
auto Bn = make_shared<op::Parameter>(element::f32, Shape{1});
auto Bx = make_shared<op::Parameter>(element::f32, Shape{1});
Shape shape_c{3, 2};
auto C = make_shared<op::Parameter>(element::i8, shape_c);
auto Cn = make_shared<op::Parameter>(element::f32, Shape{1});
auto Cx = make_shared<op::Parameter>(element::f32, Shape{1});
Shape shape_r{8, 2};
auto QConcat = ngraph::builder::ScaledQuantizedConcat(
NodeVector{A, B, C}, 0, NodeVector{An, Bn, Cn}, NodeVector{Ax, Bx, Cx});
auto f = make_shared<Function>(NodeVector{QConcat},
ParameterVector{A, B, C, An, Bn, Cn, Ax, Bx, Cx});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i8, shape_a);
copy_data(a, vector<int8_t>{-2, 4, 8, 16});
auto b = backend->create_tensor(element::i8, shape_b);
copy_data(b, vector<int8_t>{-2, 2, 4, 8, 16, 15});
auto c = backend->create_tensor(element::i8, shape_c);
copy_data(c, vector<int8_t>{-2, 3, 5, 7, 11, 16});
// min/max vectors
auto an = backend->create_tensor(element::f32, Shape{1});
copy_data(an, vector<float>{2.0});
auto ax = backend->create_tensor(element::f32, Shape{1});
copy_data(ax, vector<float>{16.0});
auto bn = backend->create_tensor(element::f32, Shape{1});
copy_data(bn, vector<float>{2.0});
auto bx = backend->create_tensor(element::f32, Shape{1});
copy_data(bx, vector<float>{16.0});
auto cn = backend->create_tensor(element::f32, Shape{1});
copy_data(cn, vector<float>{2.0});
auto cx = backend->create_tensor(element::f32, Shape{1});
copy_data(cx, vector<float>{16.0});
// result
auto result = backend->create_tensor(element::i8, shape_r);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b, c, an, bn, cn, ax, bx, cx});
EXPECT_EQ((vector<int8_t>{-2, 4, 8, 16, -2, 2, 4, 8, 16, 15, -2, 3, 5, 7, 11, 16}),
read_vector<int8_t>(result));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment