Unverified Commit ba2cbdd6 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Add missing ops to serializer (#1060)

* update serializer for all new ops
parent f6b84d67
...@@ -213,6 +213,7 @@ cpio::Reader::~Reader() ...@@ -213,6 +213,7 @@ cpio::Reader::~Reader()
void cpio::Reader::open(istream& in) void cpio::Reader::open(istream& in)
{ {
m_stream = ∈ m_stream = ∈
m_stream->seekg(0, ios_base::beg);
} }
void cpio::Reader::open(const string& filename) void cpio::Reader::open(const string& filename)
...@@ -280,6 +281,41 @@ void cpio::Reader::read(const string& file_name, void* data, size_t size_in_byte ...@@ -280,6 +281,41 @@ void cpio::Reader::read(const string& file_name, void* data, size_t size_in_byte
} }
} }
bool cpio::is_cpio(const string& path)
{
ifstream in(path, ios_base::binary | ios_base::in);
return is_cpio(in);
}
bool cpio::is_cpio(istream& in)
{
size_t offset = in.tellg();
in.seekg(0, ios_base::beg);
bool rc = false;
uint8_t ch;
in.read(reinterpret_cast<char*>(&ch), 1);
switch (ch)
{
case 0x71: // Big Endian
in.read(reinterpret_cast<char*>(&ch), 1);
if (ch == 0xC7)
{
rc = true;
}
break;
case 0xC7: // Little Endian
in.read(reinterpret_cast<char*>(&ch), 1);
if (ch == 0x71)
{
rc = true;
}
break;
default: break;
}
in.seekg(offset, ios_base::beg);
return rc;
}
const string& cpio::FileInfo::get_name() const const string& cpio::FileInfo::get_name() const
{ {
return m_name; return m_name;
......
...@@ -29,6 +29,9 @@ namespace ngraph ...@@ -29,6 +29,9 @@ namespace ngraph
class FileInfo; class FileInfo;
class Writer; class Writer;
class Reader; class Reader;
bool is_cpio(const std::string&);
bool is_cpio(std::istream&);
} }
} }
......
...@@ -212,18 +212,11 @@ std::string ngraph::serialize(std::shared_ptr<ngraph::Function> func, size_t ind ...@@ -212,18 +212,11 @@ std::string ngraph::serialize(std::shared_ptr<ngraph::Function> func, size_t ind
} }
shared_ptr<ngraph::Function> ngraph::deserialize(istream& in) shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
{
std::stringstream ss;
ss << in.rdbuf();
return deserialize(ss.str());
}
shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
{ {
shared_ptr<Function> rc; shared_ptr<Function> rc;
if (file_util::exists(s)) if (cpio::is_cpio(in))
{ {
cpio::Reader reader(s); cpio::Reader reader(in);
vector<cpio::FileInfo> file_info = reader.get_file_info(); vector<cpio::FileInfo> file_info = reader.get_file_info();
if (file_info.size() > 0) if (file_info.size() > 0)
{ {
...@@ -260,6 +253,25 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s) ...@@ -260,6 +253,25 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
} }
} }
else else
{
// json file?
std::stringstream ss;
ss << in.rdbuf();
rc = deserialize(ss.str());
}
return rc;
}
shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
{
shared_ptr<Function> rc;
if (file_util::exists(s))
{
// s is a file and not a json string
ifstream in(s, ios_base::binary | ios_base::in);
rc = deserialize(in);
}
else
{ {
json js = json::parse(s); json js = json::parse(s);
unordered_map<string, shared_ptr<Function>> function_map; unordered_map<string, shared_ptr<Function>> function_map;
......
...@@ -140,6 +140,7 @@ TEST(serialize, constant) ...@@ -140,6 +140,7 @@ TEST(serialize, constant)
EXPECT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), A->get_vector<float>()); EXPECT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), A->get_vector<float>());
serialize(tmp_file, f); serialize(tmp_file, f);
auto g = deserialize(tmp_file); auto g = deserialize(tmp_file);
ASSERT_NE(g, nullptr);
file_util::remove_file(tmp_file); file_util::remove_file(tmp_file);
bool found = false; bool found = false;
for (shared_ptr<Node> node : g->get_ops()) for (shared_ptr<Node> node : g->get_ops())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment