Commit d9a9ae69 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Optimize Constant for deserialization (#4208)

* Move non-templated constructor implementation to the source file

* Optimize constant constructor for uniform constant

* Cleanup

* Much faster deserialize constant

* Adding unit tests

* Unit tests

* Update unit test

* Cleanup

* style

* Cleanup nbench output

* Fix specializations
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 8b246c5d
......@@ -64,41 +64,227 @@ op::Constant::Constant(const element::Type& type,
", expected ",
shape_size(m_shape),
".");
if (values.size())
constructor_validate_and_infer_types();
if (values.size() == 1 && shape_size(m_shape) != 1)
{
if (type.is_integral())
// broadcast single value
switch (m_element_type)
{
if (type.is_signed())
{
std::vector<int64_t> dvalues = parse_string<int64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
dvalues = std::vector<int64_t>(shape_size(m_shape), dvalues[0]);
}
write_values(dvalues);
}
else
case element::Type_t::boolean:
{
bool value = stoi(values[0]) != 0;
bool* target = m_data->get_ptr<bool>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::bf16:
{
bfloat16 value = parse_string<float>(values[0]);
bfloat16* target = m_data->get_ptr<bfloat16>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::f16:
{
float16 value = parse_string<float>(values[0]);
float16* target = m_data->get_ptr<float16>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::f32:
{
float value = parse_string<float>(values[0]);
float* target = m_data->get_ptr<float>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::f64:
{
double value = parse_string<double>(values[0]);
double* target = m_data->get_ptr<double>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::i8:
{
int8_t value = parse_string<int64_t>(values[0]);
int8_t* target = m_data->get_ptr<int8_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::i16:
{
int16_t value = parse_string<int64_t>(values[0]);
int16_t* target = m_data->get_ptr<int16_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::i32:
{
int32_t value = parse_string<int64_t>(values[0]);
int32_t* target = m_data->get_ptr<int32_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::i64:
{
int64_t value = parse_string<int64_t>(values[0]);
int64_t* target = m_data->get_ptr<int64_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::u8:
{
uint8_t value = parse_string<uint64_t>(values[0]);
uint8_t* target = m_data->get_ptr<uint8_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::u16:
{
uint16_t value = parse_string<uint64_t>(values[0]);
uint16_t* target = m_data->get_ptr<uint16_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::u32:
{
uint32_t value = parse_string<uint64_t>(values[0]);
uint32_t* target = m_data->get_ptr<uint32_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::u64:
{
uint64_t value = parse_string<uint64_t>(values[0]);
uint64_t* target = m_data->get_ptr<uint64_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::undefined:
{
throw std::runtime_error("deserialize unsupported type undefined");
}
case element::Type_t::dynamic:
{
throw std::runtime_error("deserialize unsupported type dynamic");
}
case element::Type_t::u1: { throw std::runtime_error("deserialize unsupported type u1");
}
}
m_all_elements_bitwise_identical = true;
}
else
{
switch (m_element_type)
{
case element::Type_t::boolean:
{
vector<uint8_t> value = parse_string<uint8_t>(values);
uint8_t* target = m_data->get_ptr<uint8_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::bf16:
{
vector<float> value = parse_string<float>(values);
bfloat16* target = m_data->get_ptr<bfloat16>();
for (size_t i = 0; i < value.size(); i++)
{
std::vector<uint64_t> dvalues = parse_string<uint64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
dvalues = std::vector<uint64_t>(shape_size(m_shape), dvalues[0]);
}
write_values(dvalues);
target[i] = value[i];
}
break;
}
else
case element::Type_t::f16:
{
std::vector<double> dvalues = parse_string<double>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
vector<float> value = parse_string<float>(values);
float16* target = m_data->get_ptr<float16>();
for (size_t i = 0; i < value.size(); i++)
{
dvalues = std::vector<double>(shape_size(m_shape), dvalues[0]);
target[i] = value[i];
}
write_values(dvalues);
break;
}
case element::Type_t::f32:
{
vector<float> value = parse_string<float>(values);
float* target = m_data->get_ptr<float>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::f64:
{
vector<double> value = parse_string<double>(values);
double* target = m_data->get_ptr<double>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::i8:
{
vector<int8_t> value = parse_string<int8_t>(values);
int8_t* target = m_data->get_ptr<int8_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::i16:
{
vector<int16_t> value = parse_string<int16_t>(values);
int16_t* target = m_data->get_ptr<int16_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::i32:
{
vector<int32_t> value = parse_string<int32_t>(values);
int32_t* target = m_data->get_ptr<int32_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::i64:
{
vector<int64_t> value = parse_string<int64_t>(values);
int64_t* target = m_data->get_ptr<int64_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::u8:
{
vector<uint8_t> value = parse_string<uint8_t>(values);
uint8_t* target = m_data->get_ptr<uint8_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::u16:
{
vector<uint16_t> value = parse_string<uint16_t>(values);
uint16_t* target = m_data->get_ptr<uint16_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::u32:
{
vector<uint32_t> value = parse_string<uint32_t>(values);
uint32_t* target = m_data->get_ptr<uint32_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::u64:
{
vector<uint64_t> value = parse_string<uint64_t>(values);
uint64_t* target = m_data->get_ptr<uint64_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::undefined:
throw std::runtime_error("deserialize unsupported type undefined");
case element::Type_t::dynamic:
throw std::runtime_error("deserialize unsupported type dynamic");
case element::Type_t::u1: throw std::runtime_error("deserialize unsupported type u1");
}
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
op::Constant::Constant(const element::Type& type, const Shape& shape, const void* data)
......
......@@ -87,9 +87,7 @@ namespace ngraph
Shape shape,
const std::vector<std::string>& values);
/// \brief Constructs a tensor constant with the same initialization value copied
/// across the tensor. This constructor is to support deserialization of
/// constants.
/// \brief Constructs a tensor constant with the supplied data
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
......
......@@ -401,6 +401,36 @@ namespace ngraph
}
return result;
}
template <>
int8_t parse_string<int8_t>(const std::string& s)
{
char* err;
int8_t result = strtol(s.c_str(), &err, 10);
// Check that (1) parsing succeeded and (2) the entire string was used.
if (*err != 0)
{
throw std::runtime_error("Could not parse literal '" + s + "'");
}
return result;
}
template <>
uint8_t parse_string<uint8_t>(const std::string& s)
{
char* err;
uint8_t result = strtol(s.c_str(), &err, 10);
// Check that (1) parsing succeeded and (2) the entire string was used.
if (*err != 0)
{
throw std::runtime_error("Could not parse literal '" + s + "'");
}
return result;
}
}
std::ostream& operator<<(std::ostream& os, const ngraph::NodeVector& nv)
......
......@@ -163,6 +163,14 @@ namespace ngraph
template <>
double parse_string<double>(const std::string& s);
/// template specializations for int8_t and uint8_t to handle the fact that default
/// implementation ends up treating values as characters so that the number "0" turns into
/// the parsed value 48, which is it's ASCII value
template <>
int8_t parse_string<int8_t>(const std::string& s);
template <>
uint8_t parse_string<uint8_t>(const std::string& s);
/// Parses a list of strings containing literals of the underlying type.
template <typename T>
std::vector<T> parse_string(const std::vector<std::string>& ss)
......
......@@ -431,7 +431,13 @@ OPTIONS
if (!backend.empty())
{
cout << "\n---- Benchmark ----\n";
stopwatch t1;
t1.start();
shared_ptr<Function> f = deserialize(model);
stringstream ss;
ss.imbue(locale(""));
ss << t1.get_milliseconds();
cout << "deserialize took " << ss.str() << "ms\n";
vector<runtime::PerformanceCounter> perf_data;
if (double_buffer)
{
......
......@@ -51,6 +51,7 @@ set(SRC
build_graph.cpp
builder_autobroadcast.cpp
check.cpp
constant.cpp
constant_folding.cpp
concat_fusion.cpp
control_dependencies.cpp
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <gtest/gtest.h>
#include "ngraph/ngraph.hpp"
using namespace ngraph;
using namespace std;
//
// boolean
//
TEST(constant, boolean_string)
{
Shape shape{4};
op::Constant c(element::boolean, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<char>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const char* p = c.get_data_ptr<char>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, boolean_string_broadcast)
{
Shape shape{4};
op::Constant c(element::boolean, shape, vector<string>{"1"});
auto v = c.get_vector<char>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const char* p = c.get_data_ptr<char>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, boolean_vector)
{
Shape shape{4};
op::Constant c(element::boolean, shape, vector<char>{1, 0, 1, 0});
auto v = c.get_vector<char>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const char* p = c.get_data_ptr<char>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, boolean_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::boolean, shape, vector<char>{1});
auto v = c.get_vector<char>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const char* p = c.get_data_ptr<char>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// float
//
TEST(constant, float_string)
{
Shape shape{4};
op::Constant c(element::f32, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<float>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const float* p = c.get_data_ptr<float>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, float_string_broadcast)
{
Shape shape{4};
op::Constant c(element::f32, shape, vector<string>{"1"});
auto v = c.get_vector<float>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const float* p = c.get_data_ptr<float>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, float_vector)
{
Shape shape{4};
op::Constant c(element::f32, shape, vector<float>{1, 0, 1, 0});
auto v = c.get_vector<float>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const float* p = c.get_data_ptr<float>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, float_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::f32, shape, vector<float>{1});
auto v = c.get_vector<float>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const float* p = c.get_data_ptr<float>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// double
//
TEST(constant, double_string)
{
Shape shape{4};
op::Constant c(element::f64, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<double>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const double* p = c.get_data_ptr<double>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, double_string_broadcast)
{
Shape shape{4};
op::Constant c(element::f64, shape, vector<string>{"1"});
auto v = c.get_vector<double>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const double* p = c.get_data_ptr<double>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, double_vector)
{
Shape shape{4};
op::Constant c(element::f64, shape, vector<double>{1, 0, 1, 0});
auto v = c.get_vector<double>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const double* p = c.get_data_ptr<double>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, double_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::f64, shape, vector<double>{1});
auto v = c.get_vector<double>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const double* p = c.get_data_ptr<double>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// int8
//
TEST(constant, int8_string)
{
Shape shape{4};
op::Constant c(element::i8, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<int8_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const int8_t* p = c.get_data_ptr<int8_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, int8_string_broadcast)
{
Shape shape{4};
op::Constant c(element::i8, shape, vector<string>{"1"});
auto v = c.get_vector<int8_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const int8_t* p = c.get_data_ptr<int8_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, int8_vector)
{
Shape shape{4};
op::Constant c(element::i8, shape, vector<int8_t>{1, 0, 1, 0});
auto v = c.get_vector<int8_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const int8_t* p = c.get_data_ptr<int8_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, int8_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::i8, shape, vector<int8_t>{1});
auto v = c.get_vector<int8_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const int8_t* p = c.get_data_ptr<int8_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// int16
//
TEST(constant, int16_string)
{
Shape shape{4};
op::Constant c(element::i16, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<int16_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const int16_t* p = c.get_data_ptr<int16_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, int16_string_broadcast)
{
Shape shape{4};
op::Constant c(element::i16, shape, vector<string>{"1"});
auto v = c.get_vector<int16_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const int16_t* p = c.get_data_ptr<int16_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, int16_vector)
{
Shape shape{4};
op::Constant c(element::i16, shape, vector<int16_t>{1, 0, 1, 0});
auto v = c.get_vector<int16_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const int16_t* p = c.get_data_ptr<int16_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, int16_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::i16, shape, vector<int16_t>{1});
auto v = c.get_vector<int16_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const int16_t* p = c.get_data_ptr<int16_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// int32
//
TEST(constant, int32_string)
{
Shape shape{4};
op::Constant c(element::i32, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<int32_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const int32_t* p = c.get_data_ptr<int32_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, int32_string_broadcast)
{
Shape shape{4};
op::Constant c(element::i32, shape, vector<string>{"1"});
auto v = c.get_vector<int32_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const int32_t* p = c.get_data_ptr<int32_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, int32_vector)
{
Shape shape{4};
op::Constant c(element::i32, shape, vector<int32_t>{1, 0, 1, 0});
auto v = c.get_vector<int32_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const int32_t* p = c.get_data_ptr<int32_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, int32_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::i32, shape, vector<int32_t>{1});
auto v = c.get_vector<int32_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const int32_t* p = c.get_data_ptr<int32_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// int64
//
TEST(constant, int64_string)
{
Shape shape{4};
op::Constant c(element::i64, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<int64_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const int64_t* p = c.get_data_ptr<int64_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, int64_string_broadcast)
{
Shape shape{4};
op::Constant c(element::i64, shape, vector<string>{"1"});
auto v = c.get_vector<int64_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const int64_t* p = c.get_data_ptr<int64_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, int64_vector)
{
Shape shape{4};
op::Constant c(element::i64, shape, vector<int64_t>{1, 0, 1, 0});
auto v = c.get_vector<int64_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const int64_t* p = c.get_data_ptr<int64_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, int64_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::i64, shape, vector<int64_t>{1});
auto v = c.get_vector<int64_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const int64_t* p = c.get_data_ptr<int64_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// uint8
//
TEST(constant, uint8_string)
{
Shape shape{4};
op::Constant c(element::u8, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<uint8_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const uint8_t* p = c.get_data_ptr<uint8_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, uint8_string_broadcast)
{
Shape shape{4};
op::Constant c(element::u8, shape, vector<string>{"1"});
auto v = c.get_vector<uint8_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const uint8_t* p = c.get_data_ptr<uint8_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, uint8_vector)
{
Shape shape{4};
op::Constant c(element::u8, shape, vector<uint8_t>{1, 0, 1, 0});
auto v = c.get_vector<uint8_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const uint8_t* p = c.get_data_ptr<uint8_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, uint8_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::u8, shape, vector<uint8_t>{1});
auto v = c.get_vector<uint8_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const uint8_t* p = c.get_data_ptr<uint8_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// uint16
//
TEST(constant, uint16_string)
{
Shape shape{4};
op::Constant c(element::u16, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<uint16_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const uint16_t* p = c.get_data_ptr<uint16_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, uint16_string_broadcast)
{
Shape shape{4};
op::Constant c(element::u16, shape, vector<string>{"1"});
auto v = c.get_vector<uint16_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const uint16_t* p = c.get_data_ptr<uint16_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, uint16_vector)
{
Shape shape{4};
op::Constant c(element::u16, shape, vector<uint16_t>{1, 0, 1, 0});
auto v = c.get_vector<uint16_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const uint16_t* p = c.get_data_ptr<uint16_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, uint16_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::u16, shape, vector<uint16_t>{1});
auto v = c.get_vector<uint16_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const uint16_t* p = c.get_data_ptr<uint16_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// uint32
//
TEST(constant, uint32_string)
{
Shape shape{4};
op::Constant c(element::u32, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<uint32_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const uint32_t* p = c.get_data_ptr<uint32_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, uint32_string_broadcast)
{
Shape shape{4};
op::Constant c(element::u32, shape, vector<string>{"1"});
auto v = c.get_vector<uint32_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const uint32_t* p = c.get_data_ptr<uint32_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, uint32_vector)
{
Shape shape{4};
op::Constant c(element::u32, shape, vector<uint32_t>{1, 0, 1, 0});
auto v = c.get_vector<uint32_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const uint32_t* p = c.get_data_ptr<uint32_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, uint32_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::u32, shape, vector<uint32_t>{1});
auto v = c.get_vector<uint32_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const uint32_t* p = c.get_data_ptr<uint32_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// uint64
//
TEST(constant, uint64_string)
{
Shape shape{4};
op::Constant c(element::u64, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<uint64_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const uint64_t* p = c.get_data_ptr<uint64_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, uint64_string_broadcast)
{
Shape shape{4};
op::Constant c(element::u64, shape, vector<string>{"1"});
auto v = c.get_vector<uint64_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const uint64_t* p = c.get_data_ptr<uint64_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
TEST(constant, uint64_vector)
{
Shape shape{4};
op::Constant c(element::u64, shape, vector<uint64_t>{1, 0, 1, 0});
auto v = c.get_vector<uint64_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 0);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 0);
const uint64_t* p = c.get_data_ptr<uint64_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 0);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 0);
}
TEST(constant, uint64_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::u64, shape, vector<uint64_t>{1});
auto v = c.get_vector<uint64_t>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], 1);
EXPECT_EQ(v[1], 1);
EXPECT_EQ(v[2], 1);
EXPECT_EQ(v[3], 1);
const uint64_t* p = c.get_data_ptr<uint64_t>();
EXPECT_EQ(p[0], 1);
EXPECT_EQ(p[1], 1);
EXPECT_EQ(p[2], 1);
EXPECT_EQ(p[3], 1);
}
//
// bfloat16
//
TEST(constant, bfloat16_string)
{
Shape shape{4};
op::Constant c(element::bf16, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<bfloat16>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], bfloat16(1));
EXPECT_EQ(v[1], bfloat16(0));
EXPECT_EQ(v[2], bfloat16(1));
EXPECT_EQ(v[3], bfloat16(0));
const bfloat16* p = c.get_data_ptr<bfloat16>();
EXPECT_EQ(p[0], bfloat16(1));
EXPECT_EQ(p[1], bfloat16(0));
EXPECT_EQ(p[2], bfloat16(1));
EXPECT_EQ(p[3], bfloat16(0));
}
TEST(constant, bfloat16_string_broadcast)
{
Shape shape{4};
op::Constant c(element::bf16, shape, vector<string>{"1"});
auto v = c.get_vector<bfloat16>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], bfloat16(1));
EXPECT_EQ(v[1], bfloat16(1));
EXPECT_EQ(v[2], bfloat16(1));
EXPECT_EQ(v[3], bfloat16(1));
const bfloat16* p = c.get_data_ptr<bfloat16>();
EXPECT_EQ(p[0], bfloat16(1));
EXPECT_EQ(p[1], bfloat16(1));
EXPECT_EQ(p[2], bfloat16(1));
EXPECT_EQ(p[3], bfloat16(1));
}
TEST(constant, bfloat16_vector)
{
Shape shape{4};
op::Constant c(element::bf16, shape, vector<bfloat16>{1, 0, 1, 0});
auto v = c.get_vector<bfloat16>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], bfloat16(1));
EXPECT_EQ(v[1], bfloat16(0));
EXPECT_EQ(v[2], bfloat16(1));
EXPECT_EQ(v[3], bfloat16(0));
const bfloat16* p = c.get_data_ptr<bfloat16>();
EXPECT_EQ(p[0], bfloat16(1));
EXPECT_EQ(p[1], bfloat16(0));
EXPECT_EQ(p[2], bfloat16(1));
EXPECT_EQ(p[3], bfloat16(0));
}
TEST(constant, bfloat16_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::bf16, shape, vector<bfloat16>{1});
auto v = c.get_vector<bfloat16>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], bfloat16(1));
EXPECT_EQ(v[1], bfloat16(1));
EXPECT_EQ(v[2], bfloat16(1));
EXPECT_EQ(v[3], bfloat16(1));
const bfloat16* p = c.get_data_ptr<bfloat16>();
EXPECT_EQ(p[0], bfloat16(1));
EXPECT_EQ(p[1], bfloat16(1));
EXPECT_EQ(p[2], bfloat16(1));
EXPECT_EQ(p[3], bfloat16(1));
}
//
// float16
//
TEST(constant, float16_string)
{
Shape shape{4};
op::Constant c(element::f16, shape, vector<string>{"1", "0", "1", "0"});
auto v = c.get_vector<float16>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], float16(1));
EXPECT_EQ(v[1], float16(0));
EXPECT_EQ(v[2], float16(1));
EXPECT_EQ(v[3], float16(0));
const float16* p = c.get_data_ptr<float16>();
EXPECT_EQ(p[0], float16(1));
EXPECT_EQ(p[1], float16(0));
EXPECT_EQ(p[2], float16(1));
EXPECT_EQ(p[3], float16(0));
}
TEST(constant, float16_string_broadcast)
{
Shape shape{4};
op::Constant c(element::f16, shape, vector<string>{"1"});
auto v = c.get_vector<float16>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], float16(1));
EXPECT_EQ(v[1], float16(1));
EXPECT_EQ(v[2], float16(1));
EXPECT_EQ(v[3], float16(1));
const float16* p = c.get_data_ptr<float16>();
EXPECT_EQ(p[0], float16(1));
EXPECT_EQ(p[1], float16(1));
EXPECT_EQ(p[2], float16(1));
EXPECT_EQ(p[3], float16(1));
}
TEST(constant, float16_vector)
{
Shape shape{4};
op::Constant c(element::f16, shape, vector<float16>{1, 0, 1, 0});
auto v = c.get_vector<float16>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], float16(1));
EXPECT_EQ(v[1], float16(0));
EXPECT_EQ(v[2], float16(1));
EXPECT_EQ(v[3], float16(0));
const float16* p = c.get_data_ptr<float16>();
EXPECT_EQ(p[0], float16(1));
EXPECT_EQ(p[1], float16(0));
EXPECT_EQ(p[2], float16(1));
EXPECT_EQ(p[3], float16(0));
}
TEST(constant, float16_vector_broadcast)
{
Shape shape{4};
op::Constant c(element::f16, shape, vector<float16>{1});
auto v = c.get_vector<float16>();
ASSERT_EQ(v.size(), shape_size(shape));
EXPECT_EQ(v[0], float16(1));
EXPECT_EQ(v[1], float16(1));
EXPECT_EQ(v[2], float16(1));
EXPECT_EQ(v[3], float16(1));
const float16* p = c.get_data_ptr<float16>();
EXPECT_EQ(p[0], float16(1));
EXPECT_EQ(p[1], float16(1));
EXPECT_EQ(p[2], float16(1));
EXPECT_EQ(p[3], float16(1));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment