Unverified Commit c89b1a84 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

init version (#298)

clean-up, formatting fixes

cleanup2

clean up 3

tests pass

remove printf

switching to the old version of TupleType::==
parent 2a2525f7
......@@ -141,6 +141,69 @@ static const element::Type& read_element_type(const json& j)
return to_ref(element::Type(bitwidth, is_real, is_signed, c_type_string));
}
static json write_value_type(std::shared_ptr<const ValueType> vt)
{
if (auto tt = std::dynamic_pointer_cast<const ngraph::TupleType>(vt))
{
json j;
for (auto e : tt->get_element_types())
{
j.push_back(write_value_type(e));
}
return j;
}
if (auto tvt = std::dynamic_pointer_cast<const ngraph::TensorViewType>(vt))
{
json j;
j["element_type"] = write_element_type(tvt->get_element_type());
j["shape"] = tvt->get_shape();
return j;
}
throw "UNREACHABLE!";
}
std::shared_ptr<const ValueType>
extract_type_shape(const json& j, const char* type, const char* sshape)
{
const auto& et = read_element_type(j.at(type));
auto shape = j.count(sshape)
? j.at(sshape).get<vector<size_t>>()
: Shape{} /*HACK, so we could call read_value_type uniformly @ each callsite*/;
auto tvt = make_shared<TensorViewType>(et, shape);
return tvt;
}
std::shared_ptr<const ValueType> read_value_type_(const json& j)
{
if (j.is_array())
{
std::vector<std::shared_ptr<const ValueType>> vts;
for (auto& e : j)
{
vts.push_back(read_value_type_(e));
}
return make_shared<const TupleType>(vts);
}
else
{
return extract_type_shape(j, "element_type", "shape");
}
}
std::shared_ptr<const ValueType>
read_value_type(const json& j, const char* type, const char* sshape)
{
if (j.count("value_type")) //new serialization format supporting tuples
{
return read_value_type_(j.at("value_type"));
}
return extract_type_shape(j, type, sshape);
}
string ngraph::serialize(shared_ptr<ngraph::Function> func)
{
json j;
......@@ -177,8 +240,7 @@ json write(const Function& f)
{
json function;
function["name"] = f.get_name();
function["result_type"] = write_element_type(f.get_result_type()->get_element_type());
function["result_shape"] = f.get_result_type()->get_shape();
function["value_type"] = write_value_type(f.get_result_type());
for (auto param : f.get_parameters())
{
function["parameters"].push_back(param->get_name());
......@@ -235,14 +297,13 @@ shared_ptr<ngraph::Function>
string func_name = func_js.at("name").get<string>();
vector<string> func_result = func_js.at("result").get<vector<string>>();
vector<string> func_parameters = func_js.at("parameters").get<vector<string>>();
const element::Type& result_type = read_element_type(func_js.at("result_type"));
vector<size_t> result_shape = func_js.at("result_shape").get<vector<size_t>>();
const auto& rvt = read_value_type(func_js, "result_type", "result_shape");
unordered_map<string, shared_ptr<Node>> node_map;
for (json node_js : func_js.at("ops"))
{
string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>();
const element::Type& node_etype = read_element_type(node_js.at("element_type"));
auto nvt = read_value_type(node_js, "element_type", "shape");
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> node_outputs = node_js.at("outputs").get<vector<string>>();
shared_ptr<Node> node;
......@@ -298,7 +359,7 @@ shared_ptr<ngraph::Function>
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto value = node_js.at("value").get<vector<string>>();
node = make_shared<op::Constant>(node_etype, shape, value);
node = make_shared<op::Constant>(nvt->get_element_type(), shape, value);
}
else if (node_op == "Convert")
{
......@@ -386,7 +447,7 @@ shared_ptr<ngraph::Function>
else if (node_op == "Parameter")
{
auto shape = node_js.at("shape");
node = make_shared<op::Parameter>(node_etype, shape);
node = make_shared<op::Parameter>(nvt->get_element_type(), shape);
}
else if (node_op == "Power")
{
......@@ -466,8 +527,8 @@ shared_ptr<ngraph::Function>
{
params.push_back(dynamic_pointer_cast<op::Parameter>(node_map.at(param_name)));
}
auto rt = make_shared<TensorViewType>(result_type, result_shape);
rc = make_shared<Function>(result, rt, params, func_name);
rc = make_shared<Function>(result, rvt, params, func_name);
function_map[func_name] = rc;
return rc;
......@@ -478,7 +539,7 @@ json write(const Node& n)
json node;
node["name"] = n.get_name();
node["op"] = n.description();
node["element_type"] = write_element_type(n.get_element_type());
node["value_type"] = write_value_type(n.get_value_type());
json inputs = json::array();
json outputs = json::array();
for (const descriptor::Input& input : n.get_inputs())
......
File mode changed from 100644 to 100755
......@@ -32,6 +32,27 @@ static void copy_data(shared_ptr<runtime::TensorView> tv, const vector<T>& data)
tv->write(data.data(), 0, data_size);
}
TEST(serialize, tuple)
{
auto shape = Shape{2, 2};
auto tensor_view_type = make_shared<TensorViewType>(element::Int64::element_type(), shape);
auto A = make_shared<op::Parameter>(tensor_view_type);
auto B = make_shared<op::Parameter>(tensor_view_type);
auto C = make_shared<op::Parameter>(tensor_view_type);
auto f = make_shared<Function>(make_shared<op::Tuple>(Nodes{(A + B), (A - B), (C * A)}),
op::Parameters{A, B, C});
string js = serialize(f);
{
ofstream f("serialize_function_tuple.js");
f << js;
}
istringstream in(js);
shared_ptr<Function> sfunc = deserialize(in);
}
TEST(serialize, main)
{
// First create "f(A,B,C) = (A+B)*C".
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment