Unverified Commit ff8a2008 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Add ReplaceSlice serialization (#339)

parent 69a6fb09
......@@ -45,6 +45,7 @@
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/sign.hpp"
......@@ -68,7 +69,6 @@ static std::shared_ptr<ngraph::Function>
static json write(const ngraph::Function&);
static json write(const ngraph::Node&);
static json write(const ngraph::element::Type&);
// This stupidity is caused by the fact that we do not pass element types
// by value but by reference even though they can be compared. There is no reason to pass
......@@ -250,9 +250,15 @@ string ngraph::serialize(shared_ptr<ngraph::Function> func, size_t indent)
shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
{
json js = json::array();
std::stringstream ss;
ss << in.rdbuf();
return deserialize(ss.str());
}
shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
{
json js = json::parse(s);
shared_ptr<Function> rc;
in >> js;
unordered_map<string, shared_ptr<Function>> function_map;
for (json func : js)
{
......@@ -492,6 +498,14 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::Remainder>(args[0], args[1]);
}
else if (node_op == "ReplaceSlice")
{
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::ReplaceSlice>(
args[0], args[1], lower_bounds, upper_bounds, strides);
}
else if (node_op == "Reshape")
{
auto input_order = node_js.at("input_order").get<vector<size_t>>();
......@@ -699,6 +713,13 @@ static json write(const Node& n)
else if (node_op == "Remainder")
{
}
else if (node_op == "ReplaceSlice")
{
auto tmp = dynamic_cast<const op::ReplaceSlice*>(&n);
node["lower_bounds"] = tmp->get_lower_bounds();
node["upper_bounds"] = tmp->get_upper_bounds();
node["strides"] = tmp->get_strides();
}
else if (node_op == "Reshape")
{
auto tmp = dynamic_cast<const op::Reshape*>(&n);
......
......@@ -25,4 +25,5 @@ namespace ngraph
{
std::string serialize(std::shared_ptr<ngraph::Function>, size_t indent = 0);
std::shared_ptr<ngraph::Function> deserialize(std::istream&);
std::shared_ptr<ngraph::Function> deserialize(const std::string&);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment