Commit 9d66d9a7 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

serialize logic for reverse_sequence (#1125)

* serialize logic for reverse_sequence

* Added serializer support for Softmax
parent 22e783ff
......@@ -72,6 +72,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/sign.hpp"
......@@ -806,6 +807,13 @@ static shared_ptr<ngraph::Function>
auto reversed_axes = node_js.at("reversed_axes").get<set<size_t>>();
node = make_shared<op::Reverse>(args[0], reversed_axes);
}
else if (node_op == "ReverseSequence")
{
auto batch_axis = node_js.at("batch_axis").get<size_t>();
auto sequence_axis = node_js.at("sequence_axis").get<size_t>();
node =
make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
}
else if (node_op == "Select")
{
node = make_shared<op::Select>(args[0], args[1], args[2]);
......@@ -850,8 +858,8 @@ static shared_ptr<ngraph::Function>
}
else if (node_op == "Softmax")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Softmax>(args[0], reduction_axes);
auto softmax_axes = node_js.at("softmax_axes").get<set<size_t>>();
node = make_shared<op::Softmax>(args[0], softmax_axes);
}
else if (node_op == "Sqrt")
{
......@@ -1238,6 +1246,12 @@ static json write(const Node& n, bool binary_constant_data)
auto tmp = dynamic_cast<const op::Reverse*>(&n);
node["reversed_axes"] = tmp->get_reversed_axes();
}
else if (node_op == "ReverseSequence")
{
auto tmp = dynamic_cast<const op::ReverseSequence*>(&n);
node["batch_axis"] = tmp->get_batch_axis();
node["sequence_axis"] = tmp->get_sequence_axis();
}
else if (node_op == "Select")
{
}
......@@ -1276,6 +1290,11 @@ static json write(const Node& n, bool binary_constant_data)
auto tmp = dynamic_cast<const op::Sum*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
}
else if (node_op == "Softmax")
{
auto tmp = dynamic_cast<const op::Softmax*>(&n);
node["softmax_axes"] = tmp->get_axes();
}
else if (node_op == "Tan")
{
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment