Commit 2e6d6a60 authored by Ashok Emani's avatar Ashok Emani Committed by Scott Cyphers

add serialize_types and deserialize_types API (#3795)

* add serialize_attrs and deserialize_attrs API

* add doc comments

* update function signature
parent 42a3a0a4
......@@ -458,6 +458,33 @@ std::string ngraph::serialize(std::shared_ptr<ngraph::Function> func, size_t ind
return ::serialize(func, indent, false);
}
std::string
ngraph::serialize_types(const std::vector<std::pair<PartialShape, element::Type>>& types)
{
json attrs = json::array();
for (const auto& n : types)
{
json j;
j["shape"] = write_partial_shape(n.first);
j["type"] = write_element_type(n.second);
attrs.push_back(j);
}
return attrs.dump();
}
std::vector<std::pair<PartialShape, element::Type>>
ngraph::deserialize_types(const std::string& str)
{
std::vector<std::pair<PartialShape, element::Type>> outs;
json js = json::parse(str);
for (auto& j : js)
{
auto s = read_partial_shape(j["shape"]);
auto t = read_element_type(j["type"]);
outs.emplace_back(s, t);
}
return outs;
}
shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
{
shared_ptr<Function> rc;
......
......@@ -30,6 +30,13 @@ namespace ngraph
/// indent level specified.
std::string serialize(std::shared_ptr<ngraph::Function> func, size_t indent = 0);
/// \brief Serialize given vector of shapes/types
/// \param types The vector of shape/types to serialize
std::string serialize_types(const std::vector<std::pair<PartialShape, element::Type>>& types);
/// \brief Deerialize a string into vector of shapes/types
/// \param str The serialized string to deseriailze
std::vector<std::pair<PartialShape, element::Type>> deserialize_types(const std::string& str);
/// \brief Serialize a Function to a json file
/// \param path The path to the output file
/// \param func The Function to serialize
......
......@@ -96,6 +96,42 @@ TEST(serialize, main)
EXPECT_EQ((vector<float>{50, 72, 98, 128}), read_vector<float>(result));
}
TEST(serialize, main_attrs)
{
// First create "f(A,B,C) = (A+B)*C".
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C}, "f");
std::vector<std::pair<PartialShape, element::Type>> types;
auto results = f->get_results();
for (auto& n : results)
types.emplace_back(n->get_output_partial_shape(0), n->output(0).get_element_type());
auto s_types = serialize_types(types);
for (auto& attrs : deserialize_types(s_types))
{
EXPECT_EQ(size_t(attrs.first.rank()), shape.size());
EXPECT_EQ(size_t(attrs.first[0]), 2);
EXPECT_EQ(size_t(attrs.first[1]), 2);
EXPECT_EQ(element::f32, attrs.second);
}
auto params = f->get_parameters();
types.clear();
for (auto& n : params)
types.emplace_back(n->get_output_partial_shape(0), n->output(0).get_element_type());
s_types = serialize_types(types);
for (auto& attrs : deserialize_types(s_types))
{
EXPECT_EQ(size_t(attrs.first.rank()), shape.size());
EXPECT_EQ(size_t(attrs.first[0]), 2);
EXPECT_EQ(size_t(attrs.first[1]), 2);
EXPECT_EQ(element::f32, attrs.second);
}
}
TEST(serialize, friendly_name)
{
// First create "f(A,B,C) = (A+B)*C".
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment