Unverified Commit 28602f31 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Yet another serialization option (#619)

* Add cpio file read/write class and unit tests

add reserializer

Add unit test for serialize constants to cpio file. Fix bug in serializer if function has no parameters.
parent 9dea9576
......@@ -117,6 +117,7 @@ set (SRC
util.cpp
graph_util.cpp
placement.cpp
cpio.cpp
)
message(STATUS ${CMAKE_CURRENT_SOURCE_DIR}/ops)
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/cpio.hpp"
#include "ngraph/log.hpp"
using namespace ngraph;
using namespace std;
static uint16_t read_u16(istream& stream, bool big_endian = false)
{
uint8_t ch[2];
uint16_t rc;
stream.read(reinterpret_cast<char*>(&ch[0]), 2);
if (big_endian)
{
rc = static_cast<uint16_t>((ch[0] << 8) + ch[1]);
}
else
{
rc = static_cast<uint16_t>((ch[1] << 8) + ch[0]);
}
return rc;
}
static uint32_t read_u32(istream& stream, bool big_endian = false)
{
uint32_t rc;
uint16_t sh[2];
sh[0] = read_u16(stream, big_endian);
sh[1] = read_u16(stream, big_endian);
rc = (sh[0] << 16) + sh[1];
return rc;
}
static void write_u16(ostream& stream, uint16_t value)
{
const char* p = reinterpret_cast<const char*>(&value);
stream.write(p, 2);
}
static void write_u32(ostream& stream, uint32_t value)
{
uint16_t* v = reinterpret_cast<uint16_t*>(&value);
write_u16(stream, v[1]);
write_u16(stream, v[0]);
}
cpio::Header cpio::Header::read(istream& stream)
{
uint8_t ch;
stream.read(reinterpret_cast<char*>(&ch), 1);
Header rc;
switch (ch)
{
case 0x71: // Big Endian
stream.read(reinterpret_cast<char*>(&ch), 1);
if (ch != 0xC7)
{
throw runtime_error("CPIO magic error");
}
rc.dev = read_u16(stream, true);
rc.ino = read_u16(stream, true);
rc.mode = read_u16(stream, true);
rc.uid = read_u16(stream, true);
rc.gid = read_u16(stream, true);
rc.nlink = read_u16(stream, true);
rc.rdev = read_u16(stream, true);
rc.mtime = read_u32(stream, true);
rc.namesize = read_u16(stream, true);
rc.filesize = read_u32(stream, true);
break;
case 0xC7: // Little Endian
stream.read(reinterpret_cast<char*>(&ch), 1);
if (ch != 0x71)
{
throw runtime_error("CPIO magic error");
}
rc.dev = read_u16(stream);
rc.ino = read_u16(stream);
rc.mode = read_u16(stream);
rc.uid = read_u16(stream);
rc.gid = read_u16(stream);
rc.nlink = read_u16(stream);
rc.rdev = read_u16(stream);
rc.mtime = read_u32(stream);
rc.namesize = read_u16(stream);
rc.filesize = read_u32(stream);
break;
case '0': throw runtime_error("CPIO ASCII unsupported");
default: throw runtime_error("CPIO invalid file");
}
return rc;
}
void cpio::Header::write(ostream& stream, const string& name, uint32_t size)
{
// namesize includes the null string terminator so + 1
uint16_t namesize = static_cast<uint16_t>(name.size()) + 1;
write_u16(stream, 0x71C7); // magic
write_u16(stream, 0); // dev
write_u16(stream, 0); // ino
write_u16(stream, 0); // mode
write_u16(stream, 0); // uid
write_u16(stream, 0); // gid
write_u16(stream, 0); // nlink
write_u16(stream, 0); // rdev
write_u32(stream, 0); // mtime
write_u16(stream, namesize); // namesize
write_u32(stream, size); // filesize
stream.write(name.c_str(), namesize + (namesize % 2));
}
cpio::Writer::Writer()
: m_stream(nullptr)
{
}
cpio::Writer::Writer(ostream& out)
: Writer()
{
open(out);
}
cpio::Writer::Writer(const string& filename)
: Writer()
{
open(filename);
}
cpio::Writer::~Writer()
{
close();
}
void cpio::Writer::open(ostream& out)
{
m_stream = &out;
}
void cpio::Writer::open(const string& filename)
{
m_stream = &m_my_stream;
m_my_stream.open(filename, ios_base::binary | ios_base::out);
}
void cpio::Writer::close()
{
write("TRAILER!!!", nullptr, 0);
if (m_my_stream.is_open())
{
m_my_stream.close();
}
}
void cpio::Writer::write(const string& record_name, const void* data, uint32_t size_in_bytes)
{
if (m_stream)
{
Header::write(*m_stream, record_name, size_in_bytes);
m_stream->write(static_cast<const char*>(data), size_in_bytes);
if (size_in_bytes % 2)
{
char ch = 0;
m_stream->write(&ch, 1);
}
}
else
{
throw runtime_error("cpio writer output not set");
}
}
cpio::Reader::Reader()
: m_stream(nullptr)
{
}
cpio::Reader::Reader(istream& in)
: Reader()
{
open(in);
}
cpio::Reader::Reader(const string& filename)
: Reader()
{
open(filename);
}
cpio::Reader::~Reader()
{
}
void cpio::Reader::open(istream& in)
{
m_stream = &in;
}
void cpio::Reader::open(const string& filename)
{
m_stream = &m_my_stream;
m_my_stream.open(filename, ios_base::binary | ios_base::in);
}
void cpio::Reader::close()
{
if (m_my_stream.is_open())
{
m_my_stream.close();
}
}
const vector<cpio::FileInfo>& cpio::Reader::get_file_info()
{
if (m_file_info.empty())
{
while (*m_stream)
{
Header header = Header::read(*m_stream);
auto buffer = new char[header.namesize];
m_stream->read(buffer, header.namesize);
// namesize includes the null string terminator so -1
string file_name = string(buffer, header.namesize - 1);
delete[] buffer;
// skip any pad characters
if (header.namesize % 2)
{
m_stream->seekg(1, ios_base::cur);
}
if (file_name == "TRAILER!!!")
{
break;
}
size_t offset = m_stream->tellg();
m_file_info.emplace_back(file_name, header.filesize, offset);
m_stream->seekg((header.filesize % 2) + header.filesize, ios_base::cur);
}
}
return m_file_info;
}
void cpio::Reader::read(const string& file_name, void* data, size_t size_in_bytes)
{
for (const FileInfo& info : get_file_info())
{
if (info.get_name() == file_name)
{
if (size_in_bytes != info.get_size())
{
throw runtime_error("Buffer size does not match file size");
}
m_stream->seekg(info.get_offset(), ios_base::beg);
m_stream->read(reinterpret_cast<char*>(data), size_in_bytes);
break;
}
}
}
const string& cpio::FileInfo::get_name() const
{
return m_name;
}
size_t cpio::FileInfo::get_size() const
{
return m_size;
}
size_t cpio::FileInfo::get_offset() const
{
return m_offset;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
namespace ngraph
{
namespace cpio
{
class Header;
class FileInfo;
class Writer;
class Reader;
}
}
class ngraph::cpio::Header
{
public:
uint16_t magic;
uint16_t dev;
uint16_t ino;
uint16_t mode;
uint16_t uid;
uint16_t gid;
uint16_t nlink;
uint16_t rdev;
uint32_t mtime;
uint16_t namesize;
uint32_t filesize;
static Header read(std::istream&);
static void write(std::ostream&, const std::string& name, uint32_t size);
private:
};
class ngraph::cpio::FileInfo
{
public:
FileInfo(const std::string& name, size_t size, size_t offset)
: m_name(name)
, m_size(size)
, m_offset(offset)
{
}
const std::string& get_name() const;
size_t get_size() const;
size_t get_offset() const;
private:
std::string m_name;
size_t m_size;
size_t m_offset;
};
class ngraph::cpio::Writer
{
public:
Writer();
Writer(std::ostream& out);
Writer(const std::string& filename);
~Writer();
void open(std::ostream& out);
void open(const std::string& filename);
void close();
void write(const std::string& file_name, const void* data, uint32_t size_in_bytes);
private:
std::ostream* m_stream;
std::ofstream m_my_stream;
};
class ngraph::cpio::Reader
{
public:
Reader();
Reader(std::istream& in);
Reader(const std::string& filename);
~Reader();
void open(std::istream& in);
void open(const std::string& filename);
void close();
const std::vector<FileInfo>& get_file_info();
void read(const std::string& file_name, void* data, size_t size_in_bytes);
private:
std::istream* m_stream;
std::ifstream m_my_stream;
std::vector<cpio::FileInfo> m_file_info;
};
......@@ -36,7 +36,8 @@ namespace ngraph
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param values A vector of literals for initializing the tensor constant. The size of values must match the size of the shape.
/// \param values A vector of literals for initializing the tensor constant. The size
/// of values must match the size of the shape.
template <typename T>
Constant(const element::Type& type, Shape shape, const std::vector<T>& values)
: Node("Constant", {})
......@@ -84,8 +85,8 @@ namespace ngraph
write_values(dvalues);
}
/// \brief Constructs a tensor constant with the same initialization value copied across the tensor.
/// This constructor is mainly to support deserialization of constants.
/// \brief Constructs a tensor constant with the same initialization value copied across
// the tensor. This constructor is to support deserialization of constants.
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
......
......@@ -14,7 +14,11 @@
* limitations under the License.
*******************************************************************************/
#include "ngraph/serializer.hpp"
#include <fstream>
#include <functional>
#include "ngraph/cpio.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ops/abs.hpp"
#include "ngraph/ops/acos.hpp"
......@@ -78,12 +82,14 @@
#include "ngraph/ops/sum.hpp"
#include "ngraph/ops/tan.hpp"
#include "ngraph/ops/tanh.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
using namespace ngraph;
using namespace std;
using json = nlohmann::json;
using const_data_callback_t = shared_ptr<Node>(const string&, const element::Type&, const Shape&);
template <typename T>
T get_or_default(nlohmann::json& j, const std::string& key, const T& default_value)
......@@ -101,10 +107,12 @@ T get_or_default(nlohmann::json& j, const std::string& key, const T& default_val
}
static std::shared_ptr<ngraph::Function>
read_function(const json&, std::unordered_map<std::string, std::shared_ptr<Function>>&);
read_function(const json&,
std::unordered_map<std::string, std::shared_ptr<Function>>&,
function<const_data_callback_t>);
static json write(const ngraph::Function&);
static json write(const ngraph::Node&);
static json write(const ngraph::Function&, bool binary_constant_data);
static json write(const ngraph::Node&, bool binary_constant_data);
static json write_element_type(const ngraph::element::Type& n)
{
......@@ -144,12 +152,40 @@ static element::Type read_element_type(const json& j)
return element::Type(bitwidth, is_real, is_signed, c_type_string);
}
string ngraph::serialize(shared_ptr<ngraph::Function> func, size_t indent)
void ngraph::serialize(const string& path, shared_ptr<ngraph::Function> func, size_t indent)
{
ofstream out(path);
serialize(out, func, indent);
}
void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t indent)
{
string j = serialize(func, indent, true);
cpio::Writer writer(out);
writer.write(func->get_name(), j.c_str(), static_cast<uint32_t>(j.size()));
traverse_functions(func, [&](shared_ptr<ngraph::Function> f) {
traverse_nodes(const_cast<Function*>(f.get()), [&](shared_ptr<Node> node) {
if (auto c = dynamic_pointer_cast<op::Constant>(node))
{
uint32_t size = static_cast<uint32_t>(shape_size(c->get_output_shape(0)) *
c->get_output_element_type(0).size());
writer.write(c->get_name(), c->get_data_ptr(), size);
}
});
});
writer.close();
}
string
ngraph::serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data)
{
json j;
vector<json> functions;
traverse_functions(func,
[&](shared_ptr<ngraph::Function> f) { functions.push_back(write(*f)); });
traverse_functions(func, [&](shared_ptr<ngraph::Function> f) {
functions.push_back(write(*f, binary_constant_data));
});
for (auto it = functions.rbegin(); it != functions.rend(); it++)
{
j.push_back(*it);
......@@ -176,27 +212,71 @@ shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
{
json js = json::parse(s);
shared_ptr<Function> rc;
unordered_map<string, shared_ptr<Function>> function_map;
for (json func : js)
if (file_util::exists(s))
{
cpio::Reader reader(s);
vector<cpio::FileInfo> file_info = reader.get_file_info();
if (file_info.size() > 0)
{
// The first file is the model
uint32_t size = static_cast<uint32_t>(file_info[0].get_size());
char* data = new char[size];
reader.read(file_info[0].get_name(), data, size);
string jstr(data, size);
delete[] data;
json js = json::parse(jstr);
unordered_map<string, shared_ptr<Function>> function_map;
for (json func : js)
{
shared_ptr<Function> f = read_function(
func,
function_map,
[&](const string& const_name, const element::Type& et, const Shape& shape) {
shared_ptr<Node> const_node;
for (const cpio::FileInfo& info : file_info)
{
if (info.get_name() == const_name)
{
void* const_data = malloc(info.get_size());
reader.read(const_name, const_data, info.get_size());
const_node = make_shared<op::Constant>(et, shape, const_data);
free(const_data);
break;
}
}
return const_node;
});
rc = f;
}
}
}
else
{
shared_ptr<Function> f = read_function(func, function_map);
rc = f;
json js = json::parse(s);
unordered_map<string, shared_ptr<Function>> function_map;
for (json func : js)
{
shared_ptr<Function> f = read_function(func, function_map, nullptr);
rc = f;
}
}
return rc;
}
static json write(const Function& f)
static json write(const Function& f, bool binary_constant_data)
{
json function;
function["name"] = f.get_name();
vector<string> parameter_list;
for (auto param : f.get_parameters())
{
function["parameters"].push_back(param->get_name());
parameter_list.push_back(param->get_name());
}
function["parameters"] = parameter_list;
// TODO Functions can return multiple results
for (size_t i = 0; i < f.get_output_size(); ++i)
{
......@@ -239,14 +319,16 @@ static json write(const Function& f)
json nodes;
for (shared_ptr<Node> node : result_list)
{
nodes.push_back(write(*node));
nodes.push_back(write(*node, binary_constant_data));
}
function["ops"] = nodes;
return function;
}
static shared_ptr<ngraph::Function>
read_function(const json& func_js, unordered_map<string, shared_ptr<Function>>& function_map)
read_function(const json& func_js,
unordered_map<string, shared_ptr<Function>>& function_map,
function<const_data_callback_t> const_data_callback)
{
shared_ptr<ngraph::Function> rc;
......@@ -357,8 +439,15 @@ static shared_ptr<ngraph::Function>
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
auto value = node_js.at("value").get<vector<string>>();
node = make_shared<op::Constant>(element_type, shape, value);
try
{
auto value = node_js.at("value").get<vector<string>>();
node = make_shared<op::Constant>(element_type, shape, value);
}
catch (...)
{
node = const_data_callback(node_name, element_type, shape);
}
}
else if (node_op == "Convert")
{
......@@ -779,7 +868,7 @@ static shared_ptr<ngraph::Function>
return rc;
}
static json write(const Node& n)
static json write(const Node& n, bool binary_constant_data)
{
json node;
node["name"] = n.get_name();
......@@ -875,7 +964,10 @@ static json write(const Node& n)
else if (node_op == "Constant")
{
auto tmp = dynamic_cast<const op::Constant*>(&n);
node["value"] = tmp->get_value_strings();
if (!binary_constant_data)
{
node["value"] = tmp->get_value_strings();
}
node["shape"] = tmp->get_shape();
node["element_type"] = write_element_type(tmp->get_element_type());
}
......
......@@ -23,7 +23,12 @@
namespace ngraph
{
std::string serialize(std::shared_ptr<ngraph::Function>, size_t indent = 0);
std::string serialize(std::shared_ptr<ngraph::Function>,
size_t indent = 0,
bool binary_constant_data = false);
void serialize(const std::string& path, std::shared_ptr<ngraph::Function>, size_t indent = 0);
void serialize(std::ostream& out, std::shared_ptr<ngraph::Function>, size_t indent = 0);
std::shared_ptr<ngraph::Function> deserialize(std::istream&);
std::shared_ptr<ngraph::Function> deserialize(const std::string&);
}
......@@ -15,3 +15,4 @@
# ******************************************************************************
add_subdirectory(nbench)
add_subdirectory(reserialize)
# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
if(MKLDNN_INCLUDE_DIR)
link_directories(${MKLDNN_LIB_DIR})
endif()
if (NGRAPH_CPU_ENABLE)
set (SRC
reserialize.cpp
)
add_executable(reserialize ${SRC})
add_dependencies(reserialize ngraph)
target_link_libraries(reserialize ngraph)
endif()
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
// tool to benchmark any ngraph json model with given backend.
// compile and run with:
// g++ ./nbench.cpp -std=c++11 -I$HOME/ngraph_dist/include -L$HOME/ngraph_dist/lib -lngraph -o nbench
// env LD_LIBRARY_PATH=$HOME/ngraph_dist/lib env NGRAPH_INTERPRETER_EMIT_TIMING=1 ./nbench
// sample models are under ../../test/models
#include <fstream>
#include <iostream>
#include <string>
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
using namespace std;
void help()
{
cout << R"###(
DESCRIPTION
Reserialize a serialized model
SYNOPSIS
reserialize [-i|--input <input file>] [-o|--output <output file>]
OPTIONS
-i or --input input serialized model
-o or --output output serialized model
)###";
}
int main(int argc, char** argv)
{
string input;
string output;
for (size_t i = 1; i < argc; i++)
{
string arg = argv[i];
if (arg == "-o" || arg == "--output")
{
output = argv[++i];
}
else if (arg == "-i" || arg == "--input")
{
input = argv[++i];
}
else if (arg == "-h" || arg == "--help")
{
help();
return 0;
}
}
ifstream f(input);
if (f)
{
ngraph::stopwatch timer;
timer.start();
shared_ptr<ngraph::Function> function = ngraph::deserialize(f);
timer.stop();
cout << "deserialize took " << timer.get_milliseconds() << "ms\n";
timer.start();
ngraph::serialize(output, function, 2);
timer.stop();
cout << "serialize took " << timer.get_milliseconds() << "ms\n";
}
else
{
cout << "failed to open '" << input << "' for input\n";
return 2;
}
return 0;
}
......@@ -32,6 +32,7 @@ set (SRC
build_graph.cpp
copy.cpp
core_fusion.cpp
cpio.cpp
eigen.cpp
element_type.cpp
file_util.cpp
......@@ -59,6 +60,7 @@ set (SRC
)
add_subdirectory(models)
add_subdirectory(files)
#================================================================================================
# To auto generate a suite of unit tests for a backend add a line like this
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <gtest/gtest.h>
#include "ngraph/cpio.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/log.hpp"
using namespace ngraph;
using namespace std;
TEST(cpio, read)
{
const string test_file = file_util::path_join(TEST_FILES, "test.cpio");
cpio::Reader reader(test_file);
auto file_info = reader.get_file_info();
ASSERT_EQ(3, file_info.size());
EXPECT_STREQ(file_info[0].get_name().c_str(), "test1.txt");
EXPECT_STREQ(file_info[1].get_name().c_str(), "test2.txt");
EXPECT_STREQ(file_info[2].get_name().c_str(), "test3.txt");
EXPECT_EQ(file_info[0].get_size(), 5);
EXPECT_EQ(file_info[1].get_size(), 14);
EXPECT_EQ(file_info[2].get_size(), 44);
{
int index = 0;
char* data = static_cast<char*>(malloc(file_info[index].get_size()));
reader.read(file_info[index].get_name(), data, file_info[index].get_size());
string content = string(data, file_info[index].get_size());
EXPECT_STREQ(content.c_str(), "12345");
}
{
int index = 1;
char* data = static_cast<char*>(malloc(file_info[index].get_size()));
reader.read(file_info[index].get_name(), data, file_info[index].get_size());
string content = string(data, file_info[index].get_size());
EXPECT_STREQ(content.c_str(), "this is a test");
}
{
int index = 2;
char* data = static_cast<char*>(malloc(file_info[index].get_size()));
reader.read(file_info[index].get_name(), data, file_info[index].get_size());
string content = string(data, file_info[index].get_size());
EXPECT_STREQ(content.c_str(), "the quick brown fox jumped over the lazy dog");
}
}
TEST(cpio, write)
{
const string test_file = "test1.cpio";
string s1 = "this is a test";
string s2 = "the quick brown fox jumps over the lazy dog";
{
cpio::Writer writer(test_file);
{
writer.write("file1.txt", s1.data(), static_cast<uint32_t>(s1.size()));
}
{
writer.write("file.txt", s2.data(), static_cast<uint32_t>(s2.size()));
}
}
{
cpio::Reader reader(test_file);
auto file_info = reader.get_file_info();
ASSERT_EQ(2, file_info.size());
EXPECT_STREQ(file_info[0].get_name().c_str(), "file1.txt");
EXPECT_STREQ(file_info[1].get_name().c_str(), "file.txt");
EXPECT_EQ(file_info[0].get_size(), 14);
EXPECT_EQ(file_info[1].get_size(), 43);
{
int index = 0;
char* data = static_cast<char*>(malloc(file_info[index].get_size()));
reader.read(file_info[index].get_name(), data, file_info[index].get_size());
string content = string(data, file_info[index].get_size());
EXPECT_STREQ(content.c_str(), s1.c_str());
}
{
int index = 1;
char* data = static_cast<char*>(malloc(file_info[index].get_size()));
reader.read(file_info[index].get_name(), data, file_info[index].get_size());
string content = string(data, file_info[index].get_size());
EXPECT_STREQ(content.c_str(), s2.c_str());
}
}
}
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTEST_FILES=\\\"${CMAKE_CURRENT_SOURCE_DIR}\\\"" PARENT_SCOPE)
......@@ -133,6 +133,31 @@ TEST(serialize, default_value)
EXPECT_EQ(x3, 3);
}
TEST(serialize, constant)
{
const string tmp_file = "serialize_constant.cpio";
Shape shape{2, 2, 2};
auto A = op::Constant::create(element::f32, shape, {1, 2, 3, 4, 5, 6, 7, 8});
auto f = make_shared<Function>(A, op::ParameterVector{});
EXPECT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), A->get_vector<float>());
serialize(tmp_file, f);
auto g = deserialize(tmp_file);
file_util::remove_file(tmp_file);
bool found = false;
for (shared_ptr<Node> node : g->get_ops())
{
shared_ptr<op::Constant> c = dynamic_pointer_cast<op::Constant>(node);
if (c)
{
found = true;
EXPECT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), c->get_vector<float>());
break;
}
}
EXPECT_TRUE(found);
}
TEST(benchmark, serialize)
{
stopwatch timer;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment