Unverified Commit 2218cf9f authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

add missing ops to serializer (#351)

parent fe33af85
......@@ -24,6 +24,7 @@
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/cos.hpp"
#include "ngraph/ops/cosh.hpp"
#include "ngraph/ops/divide.hpp"
......@@ -38,11 +39,14 @@
#include "ngraph/ops/less.hpp"
#include "ngraph/ops/less_eq.hpp"
#include "ngraph/ops/log.hpp"
#include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not.hpp"
#include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/remainder.hpp"
......@@ -360,6 +364,15 @@ static shared_ptr<ngraph::Function>
auto& target_type = read_element_type(node_js.at("target_type"));
node = make_shared<op::Convert>(args[0], target_type);
}
else if (node_op == "Convolution")
{
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_dilation_strides =
node_js.at("window_dilation_strides").get<vector<size_t>>();
node = make_shared<op::Convolution>(
args[0], args[1], window_movement_strides, window_dilation_strides);
}
else if (node_op == "Cos")
{
node = make_shared<op::Cos>(args[0]);
......@@ -418,6 +431,13 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::Log>(args[0]);
}
else if (node_op == "MaxPool")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
node = make_shared<op::MaxPool>(args[0], window_shape, window_movement_strides);
}
else if (node_op == "Maximum")
{
node = make_shared<op::Maximum>(args[0], args[1]);
......@@ -438,6 +458,16 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::NotEqual>(args[0], args[1]);
}
else if (node_op == "Not")
{
node = make_shared<op::Not>(args[0]);
}
else if (node_op == "OneHot")
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], shape, one_hot_axis);
}
else if (node_op == "Parameter")
{
auto type_node_js =
......@@ -513,6 +543,9 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::Tanh>(args[0]);
}
// else if (node_op == "XLAGetTupleElement")
// {
// }
else
{
stringstream ss;
......@@ -600,6 +633,12 @@ static json write(const Node& n)
auto tmp = dynamic_cast<const op::Convert*>(&n);
node["target_type"] = write_element_type(tmp->get_convert_element_type());
}
else if (node_op == "Convolution")
{
auto tmp = dynamic_cast<const op::Convolution*>(&n);
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["window_dilation_strides"] = tmp->get_window_dilation_strides();
}
else if (node_op == "Cos")
{
}
......@@ -643,6 +682,12 @@ static json write(const Node& n)
else if (node_op == "Log")
{
}
else if (node_op == "MaxPool")
{
auto tmp = dynamic_cast<const op::MaxPool*>(&n);
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
}
else if (node_op == "Maximum")
{
}
......@@ -658,6 +703,15 @@ static json write(const Node& n)
else if (node_op == "NotEqual")
{
}
else if (node_op == "Not")
{
}
else if (node_op == "OneHot")
{
auto tmp = dynamic_cast<const op::OneHot*>(&n);
node["shape"] = tmp->get_shape();
node["one_hot_axis"] = tmp->get_one_hot_axis();
}
else if (node_op == "Parameter")
{
auto tmp = dynamic_cast<const op::Parameter*>(&n);
......@@ -722,6 +776,9 @@ static json write(const Node& n)
else if (node_op == "Tanh")
{
}
else if (node_op == "XLAGetTupleElement")
{
}
return node;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment