Unverified Commit 0f2a22e7 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Serializer emits simple element_type (#331)

* cleanup

* cleanup

* update serializer to emit small, simple element_type. backwards compatible.

* allow for selecting indenting when serializing
parent 752396cd
......@@ -63,17 +63,17 @@ using namespace ngraph;
using namespace std;
using json = nlohmann::json;
std::shared_ptr<ngraph::Function>
static std::shared_ptr<ngraph::Function>
read_function(const json&, std::unordered_map<std::string, std::shared_ptr<Function>>&);
json write(const ngraph::Function&);
json write(const ngraph::Node&);
json write(const ngraph::element::Type&);
static json write(const ngraph::Function&);
static json write(const ngraph::Node&);
static json write(const ngraph::element::Type&);
// This stupidity is caused by the fact that we do not pass element types
// by value but by reference even though they can be compared. There is no reason to pass
// them by reference EVERYWERE but here we are...
const element::Type& to_ref(const element::Type& t)
static const element::Type& to_ref(const element::Type& t)
{
if (t == element::boolean)
{
......@@ -125,87 +125,107 @@ const element::Type& to_ref(const element::Type& t)
static json write_element_type(const ngraph::element::Type& n)
{
json j;
j["bitwidth"] = n.bitwidth();
j["is_real"] = n.is_real();
j["is_signed"] = n.is_signed();
j["c_type_string"] = n.c_type_string();
j = n.c_type_string();
return j;
}
static const element::Type& read_element_type(const json& j)
{
size_t bitwidth = j.at("bitwidth").get<size_t>();
bool is_real = j.at("is_real").get<bool>();
bool is_signed = j.at("is_signed").get<bool>();
string c_type_string = j.at("c_type_string").get<string>();
size_t bitwidth = 0;
bool is_real;
bool is_signed;
string c_type_string;
if (j.is_object())
{
bitwidth = j.at("bitwidth").get<size_t>();
is_real = j.at("is_real").get<bool>();
is_signed = j.at("is_signed").get<bool>();
c_type_string = j.at("c_type_string").get<string>();
}
else
{
string c_type = j.get<string>();
for (const element::Type* t : element::Type::get_known_types())
{
if (t->c_type_string() == c_type)
{
bitwidth = t->bitwidth();
is_real = t->is_real();
is_signed = t->is_signed();
c_type_string = t->c_type_string();
break;
}
}
}
return to_ref(element::Type(bitwidth, is_real, is_signed, c_type_string));
}
static json write_value_type(std::shared_ptr<const ValueType> vt)
{
json j;
if (auto tt = std::dynamic_pointer_cast<const ngraph::TupleType>(vt))
{
json j;
for (auto e : tt->get_element_types())
{
j.push_back(write_value_type(e));
}
return j;
}
if (auto tvt = std::dynamic_pointer_cast<const ngraph::TensorViewType>(vt))
else if (auto tvt = std::dynamic_pointer_cast<const ngraph::TensorViewType>(vt))
{
json j;
j["element_type"] = write_element_type(tvt->get_element_type());
j["shape"] = tvt->get_shape();
return j;
}
throw "UNREACHABLE!";
return j;
}
std::shared_ptr<const ValueType>
extract_type_shape(const json& j, const char* type, const char* sshape)
static std::shared_ptr<const ValueType>
extract_type_shape(const json& j, const string& type, const string& sshape)
{
const auto& et = read_element_type(j.at(type));
auto shape = j.count(sshape)
? j.at(sshape).get<vector<size_t>>()
: Shape{} /*HACK, so we could call read_value_type uniformly @ each callsite*/;
const element::Type& et = read_element_type(j.at(type));
Shape shape =
j.count(sshape)
? j.at(sshape).get<vector<size_t>>()
: Shape{} /*HACK, so we could call read_value_type uniformly @ each callsite*/;
auto tvt = make_shared<TensorViewType>(et, shape);
return tvt;
}
std::shared_ptr<const ValueType> read_value_type_(const json& j)
static std::shared_ptr<const ValueType> read_value_type(const json& j)
{
std::shared_ptr<const ValueType> rc;
if (j.is_array())
{
std::vector<std::shared_ptr<const ValueType>> vts;
for (auto& e : j)
{
vts.push_back(read_value_type_(e));
vts.push_back(read_value_type(e));
}
return make_shared<const TupleType>(vts);
rc = make_shared<const TupleType>(vts);
}
else
{
return extract_type_shape(j, "element_type", "shape");
rc = extract_type_shape(j, "element_type", "shape");
}
return rc;
}
std::shared_ptr<const ValueType>
read_value_type(const json& j, const char* type, const char* sshape)
static std::shared_ptr<const ValueType>
read_value_type(const json& j, const string& type, const string& sshape)
{
std::shared_ptr<const ValueType> rc;
if (j.count("value_type")) //new serialization format supporting tuples
{
return read_value_type_(j.at("value_type"));
rc = read_value_type(j.at("value_type"));
}
return extract_type_shape(j, type, sshape);
else
{
rc = extract_type_shape(j, type, sshape);
}
return rc;
}
string ngraph::serialize(shared_ptr<ngraph::Function> func)
string ngraph::serialize(shared_ptr<ngraph::Function> func, size_t indent)
{
json j;
vector<json> functions;
......@@ -216,7 +236,16 @@ string ngraph::serialize(shared_ptr<ngraph::Function> func)
j.push_back(*it);
}
return j.dump();
string rc;
if (indent == 0)
{
rc = j.dump();
}
else
{
rc = j.dump(indent);
}
return rc;
}
shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
......@@ -237,7 +266,7 @@ shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
return rc;
}
json write(const Function& f)
static json write(const Function& f)
{
json function;
function["name"] = f.get_name();
......@@ -290,7 +319,7 @@ json write(const Function& f)
return function;
}
shared_ptr<ngraph::Function>
static shared_ptr<ngraph::Function>
read_function(const json& func_js, unordered_map<string, shared_ptr<Function>>& function_map)
{
shared_ptr<ngraph::Function> rc;
......@@ -535,7 +564,7 @@ shared_ptr<ngraph::Function>
return rc;
}
json write(const Node& n)
static json write(const Node& n)
{
json node;
node["name"] = n.get_name();
......
......@@ -23,6 +23,6 @@
namespace ngraph
{
std::string serialize(std::shared_ptr<ngraph::Function>);
std::string serialize(std::shared_ptr<ngraph::Function>, size_t indent = 0);
std::shared_ptr<ngraph::Function> deserialize(std::istream&);
}
......@@ -34,6 +34,22 @@ const element::Type element::u16(16, false, false, "uint16_t");
const element::Type element::u32(32, false, false, "uint32_t");
const element::Type element::u64(64, false, false, "uint64_t");
std::vector<const element::Type*> element::Type::get_known_types()
{
std::vector<const element::Type*> rc = {&element::boolean,
&element::f32,
&element::f64,
&element::i8,
&element::i16,
&element::i32,
&element::i64,
&element::u8,
&element::u16,
&element::u32,
&element::u64};
return rc;
}
element::Type::Type()
: m_bitwidth{0}
, m_is_real{0}
......
......@@ -64,6 +64,7 @@ namespace ngraph
bool operator!=(const Type& other) const { return !(*this == other); }
bool operator<(const Type& other) const;
friend std::ostream& operator<<(std::ostream&, const Type&);
static std::vector<const Type*> get_known_types();
/// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; }
......
......@@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/json.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/serializer.hpp"
......@@ -47,7 +48,7 @@ TEST(serialize, tuple)
auto f = make_shared<XLAFunction>(
make_shared<op::XLATuple>(Nodes{(A + B), (A - B), (C * A)}), ttt, op::Parameters{A, B, C});
string js = serialize(f);
string js = serialize(f, 4);
{
ofstream f("serialize_function_tuple.js");
f << js;
......@@ -89,7 +90,7 @@ TEST(serialize, main)
op::Parameters{X1, Y1, Z1},
"h");
string js = serialize(h);
string js = serialize(h, 4);
{
ofstream f("serialize_function.js");
......@@ -122,3 +123,16 @@ TEST(serialize, main)
cf->call({x, z, y}, {result});
EXPECT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector<float>());
}
TEST(serialize, existing_models)
{
vector<string> models = {"mxnet/mnist_mlp_forward.json", "mxnet/10_bucket_LSTM.json"};
for (const string& model : models)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, model);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> f = ngraph::deserialize(ss);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment