Unverified Commit ff8a2008 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Add ReplaceSlice serialization (#339)

parent 69a6fb09
...@@ -45,6 +45,7 @@ ...@@ -45,6 +45,7 @@
#include "ngraph/ops/power.hpp" #include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/remainder.hpp" #include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/select.hpp" #include "ngraph/ops/select.hpp"
#include "ngraph/ops/sign.hpp" #include "ngraph/ops/sign.hpp"
...@@ -68,7 +69,6 @@ static std::shared_ptr<ngraph::Function> ...@@ -68,7 +69,6 @@ static std::shared_ptr<ngraph::Function>
static json write(const ngraph::Function&); static json write(const ngraph::Function&);
static json write(const ngraph::Node&); static json write(const ngraph::Node&);
static json write(const ngraph::element::Type&);
// This stupidity is caused by the fact that we do not pass element types // This stupidity is caused by the fact that we do not pass element types
// by value but by reference even though they can be compared. There is no reason to pass // by value but by reference even though they can be compared. There is no reason to pass
...@@ -250,9 +250,15 @@ string ngraph::serialize(shared_ptr<ngraph::Function> func, size_t indent) ...@@ -250,9 +250,15 @@ string ngraph::serialize(shared_ptr<ngraph::Function> func, size_t indent)
shared_ptr<ngraph::Function> ngraph::deserialize(istream& in) shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
{ {
json js = json::array(); std::stringstream ss;
ss << in.rdbuf();
return deserialize(ss.str());
}
shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
{
json js = json::parse(s);
shared_ptr<Function> rc; shared_ptr<Function> rc;
in >> js;
unordered_map<string, shared_ptr<Function>> function_map; unordered_map<string, shared_ptr<Function>> function_map;
for (json func : js) for (json func : js)
{ {
...@@ -492,6 +498,14 @@ static shared_ptr<ngraph::Function> ...@@ -492,6 +498,14 @@ static shared_ptr<ngraph::Function>
{ {
node = make_shared<op::Remainder>(args[0], args[1]); node = make_shared<op::Remainder>(args[0], args[1]);
} }
else if (node_op == "ReplaceSlice")
{
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::ReplaceSlice>(
args[0], args[1], lower_bounds, upper_bounds, strides);
}
else if (node_op == "Reshape") else if (node_op == "Reshape")
{ {
auto input_order = node_js.at("input_order").get<vector<size_t>>(); auto input_order = node_js.at("input_order").get<vector<size_t>>();
...@@ -699,6 +713,13 @@ static json write(const Node& n) ...@@ -699,6 +713,13 @@ static json write(const Node& n)
else if (node_op == "Remainder") else if (node_op == "Remainder")
{ {
} }
else if (node_op == "ReplaceSlice")
{
auto tmp = dynamic_cast<const op::ReplaceSlice*>(&n);
node["lower_bounds"] = tmp->get_lower_bounds();
node["upper_bounds"] = tmp->get_upper_bounds();
node["strides"] = tmp->get_strides();
}
else if (node_op == "Reshape") else if (node_op == "Reshape")
{ {
auto tmp = dynamic_cast<const op::Reshape*>(&n); auto tmp = dynamic_cast<const op::Reshape*>(&n);
......
...@@ -25,4 +25,5 @@ namespace ngraph ...@@ -25,4 +25,5 @@ namespace ngraph
{ {
std::string serialize(std::shared_ptr<ngraph::Function>, size_t indent = 0); std::string serialize(std::shared_ptr<ngraph::Function>, size_t indent = 0);
std::shared_ptr<ngraph::Function> deserialize(std::istream&); std::shared_ptr<ngraph::Function> deserialize(std::istream&);
std::shared_ptr<ngraph::Function> deserialize(const std::string&);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment