Unverified Commit 2218cf9f authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

add missing ops to serializer (#351)

parent fe33af85
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp" #include "ngraph/ops/convert.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/cos.hpp" #include "ngraph/ops/cos.hpp"
#include "ngraph/ops/cosh.hpp" #include "ngraph/ops/cosh.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/divide.hpp"
...@@ -38,11 +39,14 @@ ...@@ -38,11 +39,14 @@
#include "ngraph/ops/less.hpp" #include "ngraph/ops/less.hpp"
#include "ngraph/ops/less_eq.hpp" #include "ngraph/ops/less_eq.hpp"
#include "ngraph/ops/log.hpp" #include "ngraph/ops/log.hpp"
#include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/maximum.hpp" #include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/minimum.hpp" #include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp" #include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp" #include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not.hpp"
#include "ngraph/ops/not_equal.hpp" #include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/power.hpp" #include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/remainder.hpp" #include "ngraph/ops/remainder.hpp"
...@@ -360,6 +364,15 @@ static shared_ptr<ngraph::Function> ...@@ -360,6 +364,15 @@ static shared_ptr<ngraph::Function>
auto& target_type = read_element_type(node_js.at("target_type")); auto& target_type = read_element_type(node_js.at("target_type"));
node = make_shared<op::Convert>(args[0], target_type); node = make_shared<op::Convert>(args[0], target_type);
} }
else if (node_op == "Convolution")
{
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_dilation_strides =
node_js.at("window_dilation_strides").get<vector<size_t>>();
node = make_shared<op::Convolution>(
args[0], args[1], window_movement_strides, window_dilation_strides);
}
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
node = make_shared<op::Cos>(args[0]); node = make_shared<op::Cos>(args[0]);
...@@ -418,6 +431,13 @@ static shared_ptr<ngraph::Function> ...@@ -418,6 +431,13 @@ static shared_ptr<ngraph::Function>
{ {
node = make_shared<op::Log>(args[0]); node = make_shared<op::Log>(args[0]);
} }
else if (node_op == "MaxPool")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
node = make_shared<op::MaxPool>(args[0], window_shape, window_movement_strides);
}
else if (node_op == "Maximum") else if (node_op == "Maximum")
{ {
node = make_shared<op::Maximum>(args[0], args[1]); node = make_shared<op::Maximum>(args[0], args[1]);
...@@ -438,6 +458,16 @@ static shared_ptr<ngraph::Function> ...@@ -438,6 +458,16 @@ static shared_ptr<ngraph::Function>
{ {
node = make_shared<op::NotEqual>(args[0], args[1]); node = make_shared<op::NotEqual>(args[0], args[1]);
} }
else if (node_op == "Not")
{
node = make_shared<op::Not>(args[0]);
}
else if (node_op == "OneHot")
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], shape, one_hot_axis);
}
else if (node_op == "Parameter") else if (node_op == "Parameter")
{ {
auto type_node_js = auto type_node_js =
...@@ -513,6 +543,9 @@ static shared_ptr<ngraph::Function> ...@@ -513,6 +543,9 @@ static shared_ptr<ngraph::Function>
{ {
node = make_shared<op::Tanh>(args[0]); node = make_shared<op::Tanh>(args[0]);
} }
// else if (node_op == "XLAGetTupleElement")
// {
// }
else else
{ {
stringstream ss; stringstream ss;
...@@ -600,6 +633,12 @@ static json write(const Node& n) ...@@ -600,6 +633,12 @@ static json write(const Node& n)
auto tmp = dynamic_cast<const op::Convert*>(&n); auto tmp = dynamic_cast<const op::Convert*>(&n);
node["target_type"] = write_element_type(tmp->get_convert_element_type()); node["target_type"] = write_element_type(tmp->get_convert_element_type());
} }
else if (node_op == "Convolution")
{
auto tmp = dynamic_cast<const op::Convolution*>(&n);
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["window_dilation_strides"] = tmp->get_window_dilation_strides();
}
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
} }
...@@ -643,6 +682,12 @@ static json write(const Node& n) ...@@ -643,6 +682,12 @@ static json write(const Node& n)
else if (node_op == "Log") else if (node_op == "Log")
{ {
} }
else if (node_op == "MaxPool")
{
auto tmp = dynamic_cast<const op::MaxPool*>(&n);
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
}
else if (node_op == "Maximum") else if (node_op == "Maximum")
{ {
} }
...@@ -658,6 +703,15 @@ static json write(const Node& n) ...@@ -658,6 +703,15 @@ static json write(const Node& n)
else if (node_op == "NotEqual") else if (node_op == "NotEqual")
{ {
} }
else if (node_op == "Not")
{
}
else if (node_op == "OneHot")
{
auto tmp = dynamic_cast<const op::OneHot*>(&n);
node["shape"] = tmp->get_shape();
node["one_hot_axis"] = tmp->get_one_hot_axis();
}
else if (node_op == "Parameter") else if (node_op == "Parameter")
{ {
auto tmp = dynamic_cast<const op::Parameter*>(&n); auto tmp = dynamic_cast<const op::Parameter*>(&n);
...@@ -722,6 +776,9 @@ static json write(const Node& n) ...@@ -722,6 +776,9 @@ static json write(const Node& n)
else if (node_op == "Tanh") else if (node_op == "Tanh")
{ {
} }
else if (node_op == "XLAGetTupleElement")
{
}
return node; return node;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment