Commit 43172068 authored by Ashok Emani's avatar Ashok Emani Committed by Scott Cyphers

add set_parameters, set_results and serialize types (#4168)

Co-authored-by: 's avatarasemx <998264+asemx@users.noreply.github.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 602660c2
......@@ -114,6 +114,16 @@ size_t runtime::Executable::get_preferred_pipeline_depth() const
return 2;
}
void runtime::Executable::set_parameters(const ngraph::ParameterVector& params)
{
m_parameters = params;
}
void runtime::Executable::set_results(const ngraph::ResultVector& results)
{
m_results = results;
}
void runtime::Executable::set_parameters_and_results(const Function& func)
{
m_parameters = func.get_parameters();
......
......@@ -73,6 +73,14 @@ public:
/// \returns preferred pipeline_depth
virtual size_t get_preferred_pipeline_depth() const;
/// \brief Set the input Parameters
/// \param params ngraph::ParameterVector of all input parameters
void set_parameters(const ngraph::ParameterVector& params);
/// \brief Set the output Results
/// \param results ngraph::ResultVector of all output results
void set_results(const ngraph::ResultVector& results);
/// \brief Save this compiled Executable to an output stream.
/// Saved stream may be read with Backend::load
virtual void save(std::ostream& output_stream);
......
......@@ -359,6 +359,33 @@ std::string ngraph::serialize(std::shared_ptr<ngraph::Function> func, size_t ind
return ::serialize(func, indent, false);
}
std::string
ngraph::serialize_types(const std::vector<std::pair<PartialShape, element::Type>>& types)
{
json attrs = json::array();
for (const auto& n : types)
{
json j;
j["shape"] = write_partial_shape(n.first);
j["type"] = write_element_type(n.second);
attrs.push_back(j);
}
return attrs.dump();
}
std::vector<std::pair<PartialShape, element::Type>>
ngraph::deserialize_types(const std::string& str)
{
std::vector<std::pair<PartialShape, element::Type>> outs;
json js = json::parse(str);
for (auto& j : js)
{
auto s = read_partial_shape(j["shape"]);
auto t = read_element_type(j["type"]);
outs.emplace_back(s, t);
}
return outs;
}
shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
{
shared_ptr<Function> rc;
......
......@@ -30,6 +30,13 @@ namespace ngraph
/// indent level specified.
std::string serialize(std::shared_ptr<ngraph::Function> func, size_t indent = 0);
/// \brief Serialize given vector of shapes/types
/// \param types The vector of shape/types to serialize
std::string serialize_types(const std::vector<std::pair<PartialShape, element::Type>>& types);
/// \brief Deerialize a string into vector of shapes/types
/// \param str The serialized string to deseriailze
std::vector<std::pair<PartialShape, element::Type>> deserialize_types(const std::string& str);
/// \brief Serialize a Function to a json file
/// \param path The path to the output file
/// \param func The Function to serialize
......
......@@ -96,6 +96,42 @@ TEST(serialize, main)
EXPECT_EQ((vector<float>{50, 72, 98, 128}), read_vector<float>(result));
}
TEST(serialize, types)
{
// First create "f(A,B,C) = (A+B)*C".
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C}, "f");
std::vector<std::pair<PartialShape, element::Type>> types;
auto results = f->get_results();
for (auto& n : results)
types.emplace_back(n->get_output_partial_shape(0), n->output(0).get_element_type());
auto s_types = serialize_types(types);
for (auto& attrs : deserialize_types(s_types))
{
EXPECT_EQ(size_t(attrs.first.rank()), shape.size());
EXPECT_EQ(size_t(attrs.first[0]), 2);
EXPECT_EQ(size_t(attrs.first[1]), 2);
EXPECT_EQ(element::f32, attrs.second);
}
auto params = f->get_parameters();
types.clear();
for (auto& n : params)
types.emplace_back(n->get_output_partial_shape(0), n->output(0).get_element_type());
s_types = serialize_types(types);
for (auto& attrs : deserialize_types(s_types))
{
EXPECT_EQ(size_t(attrs.first.rank()), shape.size());
EXPECT_EQ(size_t(attrs.first[0]), 2);
EXPECT_EQ(size_t(attrs.first[1]), 2);
EXPECT_EQ(element::f32, attrs.second);
}
}
TEST(serialize, friendly_name)
{
// First create "f(A,B,C) = (A+B)*C".
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment