Unverified Commit ba2cbdd6 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Add missing ops to serializer (#1060)

* update serializer for all new ops
parent f6b84d67
......@@ -213,6 +213,7 @@ cpio::Reader::~Reader()
void cpio::Reader::open(istream& in)
{
m_stream = ∈
m_stream->seekg(0, ios_base::beg);
}
void cpio::Reader::open(const string& filename)
......@@ -280,6 +281,41 @@ void cpio::Reader::read(const string& file_name, void* data, size_t size_in_byte
}
}
bool cpio::is_cpio(const string& path)
{
ifstream in(path, ios_base::binary | ios_base::in);
return is_cpio(in);
}
bool cpio::is_cpio(istream& in)
{
size_t offset = in.tellg();
in.seekg(0, ios_base::beg);
bool rc = false;
uint8_t ch;
in.read(reinterpret_cast<char*>(&ch), 1);
switch (ch)
{
case 0x71: // Big Endian
in.read(reinterpret_cast<char*>(&ch), 1);
if (ch == 0xC7)
{
rc = true;
}
break;
case 0xC7: // Little Endian
in.read(reinterpret_cast<char*>(&ch), 1);
if (ch == 0x71)
{
rc = true;
}
break;
default: break;
}
in.seekg(offset, ios_base::beg);
return rc;
}
const string& cpio::FileInfo::get_name() const
{
return m_name;
......
......@@ -29,6 +29,9 @@ namespace ngraph
class FileInfo;
class Writer;
class Reader;
bool is_cpio(const std::string&);
bool is_cpio(std::istream&);
}
}
......
......@@ -212,18 +212,11 @@ std::string ngraph::serialize(std::shared_ptr<ngraph::Function> func, size_t ind
}
shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
{
std::stringstream ss;
ss << in.rdbuf();
return deserialize(ss.str());
}
shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
{
shared_ptr<Function> rc;
if (file_util::exists(s))
if (cpio::is_cpio(in))
{
cpio::Reader reader(s);
cpio::Reader reader(in);
vector<cpio::FileInfo> file_info = reader.get_file_info();
if (file_info.size() > 0)
{
......@@ -260,6 +253,25 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
}
}
else
{
// json file?
std::stringstream ss;
ss << in.rdbuf();
rc = deserialize(ss.str());
}
return rc;
}
shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
{
shared_ptr<Function> rc;
if (file_util::exists(s))
{
// s is a file and not a json string
ifstream in(s, ios_base::binary | ios_base::in);
rc = deserialize(in);
}
else
{
json js = json::parse(s);
unordered_map<string, shared_ptr<Function>> function_map;
......
......@@ -140,6 +140,7 @@ TEST(serialize, constant)
EXPECT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), A->get_vector<float>());
serialize(tmp_file, f);
auto g = deserialize(tmp_file);
ASSERT_NE(g, nullptr);
file_util::remove_file(tmp_file);
bool found = false;
for (shared_ptr<Node> node : g->get_ops())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment