Unverified Commit a517ba09 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Change serializer to use op_tbl (#1612)

* wip

* wip

* serializer update to use op_tbl

* fix deserialize

* remove Remainder
parent 51bcb92d
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
include(ExternalProject) include(ExternalProject)
message(STATUS "Fetching LLVM from llvm.org")
find_package(ZLIB REQUIRED) find_package(ZLIB REQUIRED)
# Override default LLVM binaries # Override default LLVM binaries
......
...@@ -70,7 +70,6 @@ ...@@ -70,7 +70,6 @@
#include "ngraph/op/reduce.hpp" #include "ngraph/op/reduce.hpp"
#include "ngraph/op/reduce_window.hpp" #include "ngraph/op/reduce_window.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/remainder.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
...@@ -100,6 +99,31 @@ using namespace std; ...@@ -100,6 +99,31 @@ using namespace std;
using json = nlohmann::json; using json = nlohmann::json;
using const_data_callback_t = shared_ptr<Node>(const string&, const element::Type&, const Shape&); using const_data_callback_t = shared_ptr<Node>(const string&, const element::Type&, const Shape&);
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
#define NGRAPH_OP(a) a,
enum class OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
static OP_TYPEID get_typeid(const string& s)
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {"Abs", OP_TYPEID::Abs},
// {"Acos", OP_TYPEID::Acos},
// ...
#define NGRAPH_OP(a) {#a, OP_TYPEID::a},
static const unordered_map<string, OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
return typeid_map.at(s);
}
template <typename T> template <typename T>
T get_or_default(nlohmann::json& j, const std::string& key, const T& default_value) T get_or_default(nlohmann::json& j, const std::string& key, const T& default_value)
{ {
...@@ -343,48 +367,62 @@ static shared_ptr<ngraph::Function> ...@@ -343,48 +367,62 @@ static shared_ptr<ngraph::Function>
{ {
args.push_back(node_map.at(name)); args.push_back(node_map.at(name));
} }
#pragma GCC diagnostic push
if (node_op == "Abs") #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wimplicit-fallthrough"
switch (get_typeid(node_op))
{
case OP_TYPEID::Abs:
{ {
node = make_shared<op::Abs>(args[0]); node = make_shared<op::Abs>(args[0]);
break;
} }
else if (node_op == "Acos") case OP_TYPEID::Acos:
{ {
node = make_shared<op::Acos>(args[0]); node = make_shared<op::Acos>(args[0]);
break;
} }
else if (node_op == "Add") case OP_TYPEID::Add:
{ {
node = make_shared<op::Add>(args[0], args[1]); node = make_shared<op::Add>(args[0], args[1]);
break;
} }
else if (node_op == "AllReduce") case OP_TYPEID::AllReduce:
{ {
node = make_shared<op::AllReduce>(args[0]); node = make_shared<op::AllReduce>(args[0]);
break;
} }
else if (node_op == "And") case OP_TYPEID::And:
{ {
node = make_shared<op::And>(args[0], args[1]); node = make_shared<op::And>(args[0], args[1]);
break;
} }
else if (node_op == "ArgMin") case OP_TYPEID::ArgMin:
{ {
auto axis = node_js.at("axis").get<size_t>(); auto axis = node_js.at("axis").get<size_t>();
auto target_type = read_element_type(node_js.at("index_element_type")); auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::ArgMin>(args[0], axis, target_type); node = make_shared<op::ArgMin>(args[0], axis, target_type);
break;
} }
else if (node_op == "ArgMax") case OP_TYPEID::ArgMax:
{ {
auto axis = node_js.at("axis").get<size_t>(); auto axis = node_js.at("axis").get<size_t>();
auto target_type = read_element_type(node_js.at("index_element_type")); auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::ArgMax>(args[0], axis, target_type); node = make_shared<op::ArgMax>(args[0], axis, target_type);
break;
} }
else if (node_op == "Asin") case OP_TYPEID::Asin:
{ {
node = make_shared<op::Asin>(args[0]); node = make_shared<op::Asin>(args[0]);
break;
} }
else if (node_op == "Atan") case OP_TYPEID::Atan:
{ {
node = make_shared<op::Atan>(args[0]); node = make_shared<op::Atan>(args[0]);
break;
} }
else if (node_op == "AvgPool") case OP_TYPEID::AvgPool:
{ {
auto window_shape = node_js.at("window_shape").get<vector<size_t>>(); auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides = auto window_movement_strides =
...@@ -399,8 +437,9 @@ static shared_ptr<ngraph::Function> ...@@ -399,8 +437,9 @@ static shared_ptr<ngraph::Function>
padding_below, padding_below,
padding_above, padding_above,
include_padding_in_avg_computation); include_padding_in_avg_computation);
break;
} }
else if (node_op == "AvgPoolBackprop") case OP_TYPEID::AvgPoolBackprop:
{ {
auto forward_arg_shape = node_js.at("forward_arg_shape").get<vector<size_t>>(); auto forward_arg_shape = node_js.at("forward_arg_shape").get<vector<size_t>>();
auto window_shape = node_js.at("window_shape").get<vector<size_t>>(); auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
...@@ -417,8 +456,9 @@ static shared_ptr<ngraph::Function> ...@@ -417,8 +456,9 @@ static shared_ptr<ngraph::Function>
padding_below, padding_below,
padding_above, padding_above,
include_padding_in_avg_computation); include_padding_in_avg_computation);
break;
} }
else if (node_op == "BatchNorm") case OP_TYPEID::BatchNorm:
{ {
auto epsilon = node_js.at("eps").get<double>(); auto epsilon = node_js.at("eps").get<double>();
bool training = get_or_default<bool>(node_js, "training", true); bool training = get_or_default<bool>(node_js, "training", true);
...@@ -436,29 +476,34 @@ static shared_ptr<ngraph::Function> ...@@ -436,29 +476,34 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::BatchNorm>( node = make_shared<op::BatchNorm>(
epsilon, args[0], args[1], args[2], args[3], args[4]); epsilon, args[0], args[1], args[2], args[3], args[4]);
} }
break;
} }
else if (node_op == "BatchNormBackprop") case OP_TYPEID::BatchNormBackprop:
{ {
auto epsilon = node_js.at("eps").get<double>(); auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNormBackprop>( node = make_shared<op::BatchNormBackprop>(
epsilon, args[0], args[1], args[2], args[3], args[4], args[5]); epsilon, args[0], args[1], args[2], args[3], args[4], args[5]);
break;
} }
else if (node_op == "Broadcast") case OP_TYPEID::Broadcast:
{ {
auto shape = node_js.at("shape").get<vector<size_t>>(); auto shape = node_js.at("shape").get<vector<size_t>>();
auto axes = node_js.at("axes").get<set<size_t>>(); auto axes = node_js.at("axes").get<set<size_t>>();
node = make_shared<op::Broadcast>(args[0], shape, axes); node = make_shared<op::Broadcast>(args[0], shape, axes);
break;
} }
else if (node_op == "Ceiling") case OP_TYPEID::Ceiling:
{ {
node = make_shared<op::Ceiling>(args[0]); node = make_shared<op::Ceiling>(args[0]);
break;
} }
else if (node_op == "Concat") case OP_TYPEID::Concat:
{ {
auto axis = node_js.at("axis").get<size_t>(); auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::Concat>(args, axis); node = make_shared<op::Concat>(args, axis);
break;
} }
else if (node_op == "Constant") case OP_TYPEID::Constant:
{ {
auto type_node_js = auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js; node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
...@@ -473,13 +518,15 @@ static shared_ptr<ngraph::Function> ...@@ -473,13 +518,15 @@ static shared_ptr<ngraph::Function>
{ {
node = const_data_callback(node_name, element_type, shape); node = const_data_callback(node_name, element_type, shape);
} }
break;
} }
else if (node_op == "Convert") case OP_TYPEID::Convert:
{ {
auto target_type = read_element_type(node_js.at("target_type")); auto target_type = read_element_type(node_js.at("target_type"));
node = make_shared<op::Convert>(args[0], target_type); node = make_shared<op::Convert>(args[0], target_type);
break;
} }
else if (node_op == "Convolution") case OP_TYPEID::Convolution:
{ {
auto window_movement_strides = auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>(); node_js.at("window_movement_strides").get<vector<size_t>>();
...@@ -516,8 +563,9 @@ static shared_ptr<ngraph::Function> ...@@ -516,8 +563,9 @@ static shared_ptr<ngraph::Function>
padding_above, padding_above,
data_dilation_strides_maybe.get<std::vector<size_t>>()); data_dilation_strides_maybe.get<std::vector<size_t>>());
} }
break;
} }
else if (node_op == "ConvolutionBackpropData") case OP_TYPEID::ConvolutionBackpropData:
{ {
auto data_batch_shape = node_js.at("data_batch_shape").get<vector<size_t>>(); auto data_batch_shape = node_js.at("data_batch_shape").get<vector<size_t>>();
auto window_movement_strides_forward = auto window_movement_strides_forward =
...@@ -538,8 +586,9 @@ static shared_ptr<ngraph::Function> ...@@ -538,8 +586,9 @@ static shared_ptr<ngraph::Function>
padding_below_forward, padding_below_forward,
padding_above_forward, padding_above_forward,
data_dilation_strides_forward); data_dilation_strides_forward);
break;
} }
else if (node_op == "ConvolutionBackpropFilters") case OP_TYPEID::ConvolutionBackpropFilters:
{ {
auto filters_shape = node_js.at("filters_shape").get<vector<size_t>>(); auto filters_shape = node_js.at("filters_shape").get<vector<size_t>>();
auto window_movement_strides_forward = auto window_movement_strides_forward =
...@@ -560,20 +609,24 @@ static shared_ptr<ngraph::Function> ...@@ -560,20 +609,24 @@ static shared_ptr<ngraph::Function>
padding_below_forward, padding_below_forward,
padding_above_forward, padding_above_forward,
data_dilation_strides_forward); data_dilation_strides_forward);
break;
} }
else if (node_op == "Cos") case OP_TYPEID::Cos:
{ {
node = make_shared<op::Cos>(args[0]); node = make_shared<op::Cos>(args[0]);
break;
} }
else if (node_op == "Cosh") case OP_TYPEID::Cosh:
{ {
node = make_shared<op::Cosh>(args[0]); node = make_shared<op::Cosh>(args[0]);
break;
} }
else if (node_op == "Divide") case OP_TYPEID::Divide:
{ {
node = make_shared<op::Divide>(args[0], args[1]); node = make_shared<op::Divide>(args[0], args[1]);
break;
} }
else if (node_op == "Dot") case OP_TYPEID::Dot:
{ {
// For backwards compatibility, reduction_axes_count is optional. // For backwards compatibility, reduction_axes_count is optional.
auto obj = node_js["reduction_axes_count"]; auto obj = node_js["reduction_axes_count"];
...@@ -586,63 +639,76 @@ static shared_ptr<ngraph::Function> ...@@ -586,63 +639,76 @@ static shared_ptr<ngraph::Function>
size_t reduction_axes_count = obj.get<size_t>(); size_t reduction_axes_count = obj.get<size_t>();
node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count); node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count);
} }
break;
} }
else if (node_op == "Equal") case OP_TYPEID::Equal:
{ {
node = make_shared<op::Equal>(args[0], args[1]); node = make_shared<op::Equal>(args[0], args[1]);
break;
} }
else if (node_op == "Exp") case OP_TYPEID::Exp:
{ {
node = make_shared<op::Exp>(args[0]); node = make_shared<op::Exp>(args[0]);
break;
} }
else if (node_op == "Floor") case OP_TYPEID::Floor:
{ {
node = make_shared<op::Floor>(args[0]); node = make_shared<op::Floor>(args[0]);
break;
} }
else if (node_op == "FunctionCall") case OP_TYPEID::FunctionCall:
{ {
string function_name = node_js.at("function").get<string>(); string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name); shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::FunctionCall>(f_ptr, args); node = make_shared<op::FunctionCall>(f_ptr, args);
break;
} }
else if (node_op == "GetOutputElement") case OP_TYPEID::GetOutputElement:
{ {
node = make_shared<op::GetOutputElement>(args[0], node_js.at("n").get<size_t>()); node = make_shared<op::GetOutputElement>(args[0], node_js.at("n").get<size_t>());
break;
} }
else if (node_op == "Greater") case OP_TYPEID::Greater:
{ {
node = make_shared<op::Greater>(args[0], args[1]); node = make_shared<op::Greater>(args[0], args[1]);
break;
} }
else if (node_op == "GreaterEq") case OP_TYPEID::GreaterEq:
{ {
node = make_shared<op::GreaterEq>(args[0], args[1]); node = make_shared<op::GreaterEq>(args[0], args[1]);
break;
} }
else if (node_op == "Less") case OP_TYPEID::Less:
{ {
node = make_shared<op::Less>(args[0], args[1]); node = make_shared<op::Less>(args[0], args[1]);
break;
} }
else if (node_op == "LessEq") case OP_TYPEID::LessEq:
{ {
node = make_shared<op::LessEq>(args[0], args[1]); node = make_shared<op::LessEq>(args[0], args[1]);
break;
} }
else if (node_op == "Log") case OP_TYPEID::Log:
{ {
node = make_shared<op::Log>(args[0]); node = make_shared<op::Log>(args[0]);
break;
} }
else if (node_op == "LRN") case OP_TYPEID::LRN:
{ {
auto alpha = node_js.at("alpha").get<double>(); auto alpha = node_js.at("alpha").get<double>();
auto beta = node_js.at("beta").get<double>(); auto beta = node_js.at("beta").get<double>();
auto bias = node_js.at("bias").get<double>(); auto bias = node_js.at("bias").get<double>();
auto nsize = node_js.at("nsize").get<size_t>(); auto nsize = node_js.at("nsize").get<size_t>();
node = make_shared<op::LRN>(args[0], alpha, beta, bias, nsize); node = make_shared<op::LRN>(args[0], alpha, beta, bias, nsize);
break;
} }
else if (node_op == "Max") case OP_TYPEID::Max:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Max>(args[0], reduction_axes); node = make_shared<op::Max>(args[0], reduction_axes);
break;
} }
else if (node_op == "MaxPool") case OP_TYPEID::MaxPool:
{ {
auto window_shape = node_js.at("window_shape").get<vector<size_t>>(); auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides = auto window_movement_strides =
...@@ -675,8 +741,9 @@ static shared_ptr<ngraph::Function> ...@@ -675,8 +741,9 @@ static shared_ptr<ngraph::Function>
{ {
node = make_shared<op::MaxPool>(args[0], window_shape, window_movement_strides); node = make_shared<op::MaxPool>(args[0], window_shape, window_movement_strides);
} }
break;
} }
else if (node_op == "MaxPoolBackprop") case OP_TYPEID::MaxPoolBackprop:
{ {
auto window_shape = node_js.at("window_shape").get<vector<size_t>>(); auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides = auto window_movement_strides =
...@@ -689,55 +756,66 @@ static shared_ptr<ngraph::Function> ...@@ -689,55 +756,66 @@ static shared_ptr<ngraph::Function>
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above); padding_above);
break;
} }
else if (node_op == "Maximum") case OP_TYPEID::Maximum:
{ {
node = make_shared<op::Maximum>(args[0], args[1]); node = make_shared<op::Maximum>(args[0], args[1]);
break;
} }
else if (node_op == "Min") case OP_TYPEID::Min:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Min>(args[0], reduction_axes); node = make_shared<op::Min>(args[0], reduction_axes);
break;
} }
else if (node_op == "Minimum") case OP_TYPEID::Minimum:
{ {
node = make_shared<op::Minimum>(args[0], args[1]); node = make_shared<op::Minimum>(args[0], args[1]);
break;
} }
else if (node_op == "Multiply") case OP_TYPEID::Multiply:
{ {
node = make_shared<op::Multiply>(args[0], args[1]); node = make_shared<op::Multiply>(args[0], args[1]);
break;
} }
else if (node_op == "Negative") case OP_TYPEID::Negative:
{ {
node = make_shared<op::Negative>(args[0]); node = make_shared<op::Negative>(args[0]);
break;
} }
else if (node_op == "NotEqual") case OP_TYPEID::NotEqual:
{ {
node = make_shared<op::NotEqual>(args[0], args[1]); node = make_shared<op::NotEqual>(args[0], args[1]);
break;
} }
else if (node_op == "Not") case OP_TYPEID::Not:
{ {
node = make_shared<op::Not>(args[0]); node = make_shared<op::Not>(args[0]);
break;
} }
else if (node_op == "OneHot") case OP_TYPEID::OneHot:
{ {
auto shape = node_js.at("shape").get<vector<size_t>>(); auto shape = node_js.at("shape").get<vector<size_t>>();
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>(); auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], shape, one_hot_axis); node = make_shared<op::OneHot>(args[0], shape, one_hot_axis);
break;
} }
else if (node_op == "Or") case OP_TYPEID::Or:
{ {
node = make_shared<op::Or>(args[0], args[1]); node = make_shared<op::Or>(args[0], args[1]);
break;
} }
else if (node_op == "Pad") case OP_TYPEID::Pad:
{ {
auto padding_below = node_js.at("padding_below").get<vector<size_t>>(); auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>(); auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto padding_interior = node_js.at("padding_interior").get<vector<size_t>>(); auto padding_interior = node_js.at("padding_interior").get<vector<size_t>>();
node = make_shared<op::Pad>( node = make_shared<op::Pad>(
args[0], args[1], padding_below, padding_above, padding_interior); args[0], args[1], padding_below, padding_above, padding_interior);
break;
} }
else if (node_op == "Parameter") case OP_TYPEID::Parameter:
{ {
auto type_node_js = auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js; node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
...@@ -745,24 +823,28 @@ static shared_ptr<ngraph::Function> ...@@ -745,24 +823,28 @@ static shared_ptr<ngraph::Function>
auto shape = type_node_js.at("shape"); auto shape = type_node_js.at("shape");
auto cacheable = get_or_default<bool>(node_js, "cacheable", false); auto cacheable = get_or_default<bool>(node_js, "cacheable", false);
node = make_shared<op::Parameter>(element_type, shape, cacheable); node = make_shared<op::Parameter>(element_type, shape, cacheable);
break;
} }
else if (node_op == "Power") case OP_TYPEID::Power:
{ {
node = make_shared<op::Power>(args[0], args[1]); node = make_shared<op::Power>(args[0], args[1]);
break;
} }
else if (node_op == "Product") case OP_TYPEID::Product:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Product>(args[0], reduction_axes); node = make_shared<op::Product>(args[0], reduction_axes);
break;
} }
else if (node_op == "Reduce") case OP_TYPEID::Reduce:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
string function_name = node_js.at("function").get<string>(); string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name); shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::Reduce>(args[0], args[1], f_ptr, reduction_axes); node = make_shared<op::Reduce>(args[0], args[1], f_ptr, reduction_axes);
break;
} }
else if (node_op == "ReduceWindow") case OP_TYPEID::ReduceWindow:
{ {
auto window_shape = node_js.at("window_shape").get<vector<size_t>>(); auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides = auto window_movement_strides =
...@@ -771,54 +853,59 @@ static shared_ptr<ngraph::Function> ...@@ -771,54 +853,59 @@ static shared_ptr<ngraph::Function>
shared_ptr<Function> f_ptr = function_map.at(function_name); shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::ReduceWindow>( node = make_shared<op::ReduceWindow>(
args[0], args[1], f_ptr, window_shape, window_movement_strides); args[0], args[1], f_ptr, window_shape, window_movement_strides);
break;
} }
else if (node_op == "Remainder") case OP_TYPEID::Relu:
{
node = make_shared<op::Remainder>(args[0], args[1]);
}
else if (node_op == "Relu")
{ {
node = make_shared<op::Relu>(args[0]); node = make_shared<op::Relu>(args[0]);
break;
} }
else if (node_op == "ReluBackprop") case OP_TYPEID::ReluBackprop:
{ {
node = make_shared<op::ReluBackprop>(args[0], args[1]); node = make_shared<op::ReluBackprop>(args[0], args[1]);
break;
} }
else if (node_op == "ReplaceSlice") case OP_TYPEID::ReplaceSlice:
{ {
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>(); auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>(); auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>(); auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::ReplaceSlice>( node = make_shared<op::ReplaceSlice>(
args[0], args[1], lower_bounds, upper_bounds, strides); args[0], args[1], lower_bounds, upper_bounds, strides);
break;
} }
else if (node_op == "Reshape") case OP_TYPEID::Reshape:
{ {
auto input_order = node_js.at("input_order").get<vector<size_t>>(); auto input_order = node_js.at("input_order").get<vector<size_t>>();
auto output_shape = node_js.at("output_shape").get<vector<size_t>>(); auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::Reshape>(args[0], input_order, output_shape); node = make_shared<op::Reshape>(args[0], input_order, output_shape);
break;
} }
else if (node_op == "Result") case OP_TYPEID::Result:
{ {
node = make_shared<op::Result>(args[0]); node = make_shared<op::Result>(args[0]);
break;
} }
else if (node_op == "Reverse") case OP_TYPEID::Reverse:
{ {
auto reversed_axes = node_js.at("reversed_axes").get<set<size_t>>(); auto reversed_axes = node_js.at("reversed_axes").get<set<size_t>>();
node = make_shared<op::Reverse>(args[0], reversed_axes); node = make_shared<op::Reverse>(args[0], reversed_axes);
break;
} }
else if (node_op == "ReverseSequence") case OP_TYPEID::ReverseSequence:
{ {
auto batch_axis = node_js.at("batch_axis").get<size_t>(); auto batch_axis = node_js.at("batch_axis").get<size_t>();
auto sequence_axis = node_js.at("sequence_axis").get<size_t>(); auto sequence_axis = node_js.at("sequence_axis").get<size_t>();
node = node =
make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis); make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
break;
} }
else if (node_op == "Select") case OP_TYPEID::Select:
{ {
node = make_shared<op::Select>(args[0], args[1], args[2]); node = make_shared<op::Select>(args[0], args[1], args[2]);
break;
} }
else if (node_op == "SelectAndScatter") case OP_TYPEID::SelectAndScatter:
{ {
string selection_function_name = node_js.at("selection_function").get<string>(); string selection_function_name = node_js.at("selection_function").get<string>();
shared_ptr<Function> selection_f_ptr = function_map.at(selection_function_name); shared_ptr<Function> selection_f_ptr = function_map.at(selection_function_name);
...@@ -836,78 +923,95 @@ static shared_ptr<ngraph::Function> ...@@ -836,78 +923,95 @@ static shared_ptr<ngraph::Function>
scatter_f_ptr, scatter_f_ptr,
window_shape, window_shape,
window_movement_strides); window_movement_strides);
break;
} }
else if (node_op == "Sigmoid") case OP_TYPEID::Sigmoid:
{ {
node = make_shared<op::Sigmoid>(args[0]); node = make_shared<op::Sigmoid>(args[0]);
break;
} }
else if (node_op == "SigmoidBackprop") case OP_TYPEID::SigmoidBackprop:
{ {
node = make_shared<op::SigmoidBackprop>(args[0], args[1]); node = make_shared<op::SigmoidBackprop>(args[0], args[1]);
break;
} }
else if (node_op == "Sign") case OP_TYPEID::Sign:
{ {
node = make_shared<op::Sign>(args[0]); node = make_shared<op::Sign>(args[0]);
break;
} }
else if (node_op == "Sin") case OP_TYPEID::Sin:
{ {
node = make_shared<op::Sin>(args[0]); node = make_shared<op::Sin>(args[0]);
break;
} }
else if (node_op == "Sinh") case OP_TYPEID::Sinh:
{ {
node = make_shared<op::Sinh>(args[0]); node = make_shared<op::Sinh>(args[0]);
break;
} }
else if (node_op == "Slice") case OP_TYPEID::Slice:
{ {
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>(); auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>(); auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>(); auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::Slice>(args[0], lower_bounds, upper_bounds, strides); node = make_shared<op::Slice>(args[0], lower_bounds, upper_bounds, strides);
break;
} }
else if (node_op == "Softmax") case OP_TYPEID::Softmax:
{ {
auto softmax_axes = node_js.at("softmax_axes").get<set<size_t>>(); auto softmax_axes = node_js.at("softmax_axes").get<set<size_t>>();
node = make_shared<op::Softmax>(args[0], softmax_axes); node = make_shared<op::Softmax>(args[0], softmax_axes);
break;
} }
else if (node_op == "Sqrt") case OP_TYPEID::Sqrt:
{ {
node = make_shared<op::Sqrt>(args[0]); node = make_shared<op::Sqrt>(args[0]);
break;
} }
else if (node_op == "Subtract") case OP_TYPEID::Subtract:
{ {
node = make_shared<op::Subtract>(args[0], args[1]); node = make_shared<op::Subtract>(args[0], args[1]);
break;
} }
else if (node_op == "Sum") case OP_TYPEID::Sum:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Sum>(args[0], reduction_axes); node = make_shared<op::Sum>(args[0], reduction_axes);
break;
} }
else if (node_op == "Tan") case OP_TYPEID::Tan:
{ {
node = make_shared<op::Tan>(args[0]); node = make_shared<op::Tan>(args[0]);
break;
} }
else if (node_op == "Tanh") case OP_TYPEID::Tanh:
{ {
node = make_shared<op::Tanh>(args[0]); node = make_shared<op::Tanh>(args[0]);
break;
} }
else if (node_op == "TopK") case OP_TYPEID::TopK:
{ {
auto top_k_axis = node_js.at("top_k_axis").get<size_t>(); auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
auto k = node_js.at("k").get<size_t>(); auto k = node_js.at("k").get<size_t>();
auto compute_max = node_js.at("compute_max").get<bool>(); auto compute_max = node_js.at("compute_max").get<bool>();
auto target_type = read_element_type(node_js.at("index_element_type")); auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max); node = make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max);
break;
} }
else if (node_op == "StopGradient") case OP_TYPEID::StopGradient:
{ {
node = make_shared<op::StopGradient>(args[0]); node = make_shared<op::StopGradient>(args[0]);
break;
} }
else default:
{ {
stringstream ss; stringstream ss;
ss << "unsupported op " << node_op; ss << "unsupported op " << node_op;
throw runtime_error(ss.str()); throw runtime_error(ss.str());
} }
}
#pragma GCC diagnostic pop
for (const string& name : control_deps_inputs) for (const string& name : control_deps_inputs)
{ {
...@@ -1011,37 +1115,41 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1011,37 +1115,41 @@ static json write(const Node& n, bool binary_constant_data)
} }
string node_op = n.description(); string node_op = n.description();
if (node_op == "Abs") #pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wimplicit-fallthrough"
switch (get_typeid(node_op))
{ {
case OP_TYPEID::Abs: { break;
} }
else if (node_op == "Acos") case OP_TYPEID::Acos: { break;
{
} }
else if (node_op == "Add") case OP_TYPEID::Add: { break;
{
} }
else if (node_op == "ArgMin") case OP_TYPEID::ArgMin:
{ {
auto tmp = dynamic_cast<const op::ArgMin*>(&n); auto tmp = dynamic_cast<const op::ArgMin*>(&n);
node["axis"] = tmp->get_reduction_axis(); node["axis"] = tmp->get_reduction_axis();
node["index_element_type"] = write_element_type(tmp->get_element_type()); node["index_element_type"] = write_element_type(tmp->get_element_type());
break;
} }
else if (node_op == "ArgMax") case OP_TYPEID::ArgMax:
{ {
auto tmp = dynamic_cast<const op::ArgMax*>(&n); auto tmp = dynamic_cast<const op::ArgMax*>(&n);
node["axis"] = tmp->get_reduction_axis(); node["axis"] = tmp->get_reduction_axis();
node["index_element_type"] = write_element_type(tmp->get_element_type()); node["index_element_type"] = write_element_type(tmp->get_element_type());
break;
} }
else if (node_op == "AllReduce") case OP_TYPEID::AllReduce: { break;
{
} }
else if (node_op == "Asin") case OP_TYPEID::And: { break;
{
} }
else if (node_op == "Atan") case OP_TYPEID::Asin: { break;
{
} }
else if (node_op == "AvgPool") case OP_TYPEID::Atan: { break;
}
case OP_TYPEID::AvgPool:
{ {
auto tmp = dynamic_cast<const op::AvgPool*>(&n); auto tmp = dynamic_cast<const op::AvgPool*>(&n);
node["window_shape"] = tmp->get_window_shape(); node["window_shape"] = tmp->get_window_shape();
...@@ -1049,8 +1157,9 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1049,8 +1157,9 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation(); node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
break;
} }
else if (node_op == "AvgPoolBackprop") case OP_TYPEID::AvgPoolBackprop:
{ {
auto tmp = dynamic_cast<const op::AvgPoolBackprop*>(&n); auto tmp = dynamic_cast<const op::AvgPoolBackprop*>(&n);
node["forward_arg_shape"] = tmp->get_forward_arg_shape(); node["forward_arg_shape"] = tmp->get_forward_arg_shape();
...@@ -1059,33 +1168,37 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1059,33 +1168,37 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation(); node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
break;
} }
else if (node_op == "BatchNorm") case OP_TYPEID::BatchNorm:
{ {
auto tmp = dynamic_cast<const op::BatchNorm*>(&n); auto tmp = dynamic_cast<const op::BatchNorm*>(&n);
node["eps"] = tmp->get_eps_value(); node["eps"] = tmp->get_eps_value();
node["training"] = tmp->get_training_flag(); node["training"] = tmp->get_training_flag();
break;
} }
else if (node_op == "BatchNormBackprop") case OP_TYPEID::BatchNormBackprop:
{ {
auto tmp = dynamic_cast<const op::BatchNormBackprop*>(&n); auto tmp = dynamic_cast<const op::BatchNormBackprop*>(&n);
node["eps"] = tmp->get_eps_value(); node["eps"] = tmp->get_eps_value();
break;
} }
else if (node_op == "Broadcast") case OP_TYPEID::Broadcast:
{ {
auto tmp = dynamic_cast<const op::Broadcast*>(&n); auto tmp = dynamic_cast<const op::Broadcast*>(&n);
node["axes"] = tmp->get_broadcast_axes(); node["axes"] = tmp->get_broadcast_axes();
node["shape"] = tmp->get_broadcast_shape(); node["shape"] = tmp->get_broadcast_shape();
break;
} }
else if (node_op == "Ceiling") case OP_TYPEID::Ceiling: { break;
{
} }
else if (node_op == "Concat") case OP_TYPEID::Concat:
{ {
auto tmp = dynamic_cast<const op::Concat*>(&n); auto tmp = dynamic_cast<const op::Concat*>(&n);
node["axis"] = tmp->get_concatenation_axis(); node["axis"] = tmp->get_concatenation_axis();
break;
} }
else if (node_op == "Constant") case OP_TYPEID::Constant:
{ {
auto tmp = dynamic_cast<const op::Constant*>(&n); auto tmp = dynamic_cast<const op::Constant*>(&n);
if (!binary_constant_data) if (!binary_constant_data)
...@@ -1094,13 +1207,15 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1094,13 +1207,15 @@ static json write(const Node& n, bool binary_constant_data)
} }
node["shape"] = tmp->get_shape(); node["shape"] = tmp->get_shape();
node["element_type"] = write_element_type(tmp->get_element_type()); node["element_type"] = write_element_type(tmp->get_element_type());
break;
} }
else if (node_op == "Convert") case OP_TYPEID::Convert:
{ {
auto tmp = dynamic_cast<const op::Convert*>(&n); auto tmp = dynamic_cast<const op::Convert*>(&n);
node["target_type"] = write_element_type(tmp->get_convert_element_type()); node["target_type"] = write_element_type(tmp->get_convert_element_type());
break;
} }
else if (node_op == "Convolution") case OP_TYPEID::Convolution:
{ {
auto tmp = dynamic_cast<const op::Convolution*>(&n); auto tmp = dynamic_cast<const op::Convolution*>(&n);
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
...@@ -1108,8 +1223,9 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1108,8 +1223,9 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides(); node["data_dilation_strides"] = tmp->get_data_dilation_strides();
break;
} }
else if (node_op == "ConvolutionBackpropData") case OP_TYPEID::ConvolutionBackpropData:
{ {
auto tmp = dynamic_cast<const op::ConvolutionBackpropData*>(&n); auto tmp = dynamic_cast<const op::ConvolutionBackpropData*>(&n);
node["data_batch_shape"] = tmp->get_data_batch_shape(); node["data_batch_shape"] = tmp->get_data_batch_shape();
...@@ -1118,8 +1234,9 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1118,8 +1234,9 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below_forward"] = tmp->get_padding_below_forward(); node["padding_below_forward"] = tmp->get_padding_below_forward();
node["padding_above_forward"] = tmp->get_padding_above_forward(); node["padding_above_forward"] = tmp->get_padding_above_forward();
node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward(); node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward();
break;
} }
else if (node_op == "ConvolutionBackpropFilters") case OP_TYPEID::ConvolutionBackpropFilters:
{ {
auto tmp = dynamic_cast<const op::ConvolutionBackpropFilters*>(&n); auto tmp = dynamic_cast<const op::ConvolutionBackpropFilters*>(&n);
node["filters_shape"] = tmp->get_filters_shape(); node["filters_shape"] = tmp->get_filters_shape();
...@@ -1128,246 +1245,242 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1128,246 +1245,242 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below_forward"] = tmp->get_padding_below_forward(); node["padding_below_forward"] = tmp->get_padding_below_forward();
node["padding_above_forward"] = tmp->get_padding_above_forward(); node["padding_above_forward"] = tmp->get_padding_above_forward();
node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward(); node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward();
break;
} }
else if (node_op == "Cos") case OP_TYPEID::Cos: { break;
{
} }
else if (node_op == "Cosh") case OP_TYPEID::Cosh: { break;
{
} }
else if (node_op == "Divide") case OP_TYPEID::Divide: { break;
{
} }
else if (node_op == "Dot") case OP_TYPEID::Dot:
{ {
auto tmp = dynamic_cast<const op::Dot*>(&n); auto tmp = dynamic_cast<const op::Dot*>(&n);
node["reduction_axes_count"] = tmp->get_reduction_axes_count(); node["reduction_axes_count"] = tmp->get_reduction_axes_count();
break;
} }
else if (node_op == "Equal") case OP_TYPEID::Equal: { break;
{
} }
else if (node_op == "Exp") case OP_TYPEID::Exp: { break;
{
} }
else if (node_op == "Floor") case OP_TYPEID::Floor: { break;
{
} }
else if (node_op == "FunctionCall") case OP_TYPEID::FunctionCall:
{ {
node["function"] = n.get_functions()[0]->get_name(); node["function"] = n.get_functions()[0]->get_name();
break;
} }
else if (node_op == "GetOutputElement") case OP_TYPEID::GetOutputElement:
{ {
auto tmp = dynamic_cast<const op::GetOutputElement*>(&n); auto tmp = dynamic_cast<const op::GetOutputElement*>(&n);
node["n"] = tmp->get_n(); node["n"] = tmp->get_n();
break;
} }
else if (node_op == "Greater") case OP_TYPEID::Greater: { break;
{
} }
else if (node_op == "GreaterEq") case OP_TYPEID::GreaterEq: { break;
{
} }
else if (node_op == "Less") case OP_TYPEID::Less: { break;
{
} }
else if (node_op == "LessEq") case OP_TYPEID::LessEq: { break;
{
} }
else if (node_op == "Log") case OP_TYPEID::Log: { break;
{
} }
else if (node_op == "LRN") case OP_TYPEID::LRN:
{ {
auto tmp = dynamic_cast<const op::LRN*>(&n); auto tmp = dynamic_cast<const op::LRN*>(&n);
node["alpha"] = tmp->get_alpha(); node["alpha"] = tmp->get_alpha();
node["beta"] = tmp->get_beta(); node["beta"] = tmp->get_beta();
node["bias"] = tmp->get_bias(); node["bias"] = tmp->get_bias();
node["nsize"] = tmp->get_nsize(); node["nsize"] = tmp->get_nsize();
break;
} }
else if (node_op == "Max") case OP_TYPEID::Max:
{ {
auto tmp = dynamic_cast<const op::Max*>(&n); auto tmp = dynamic_cast<const op::Max*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = tmp->get_reduction_axes();
break;
} }
else if (node_op == "MaxPool") case OP_TYPEID::MaxPool:
{ {
auto tmp = dynamic_cast<const op::MaxPool*>(&n); auto tmp = dynamic_cast<const op::MaxPool*>(&n);
node["window_shape"] = tmp->get_window_shape(); node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
break;
} }
else if (node_op == "MaxPoolBackprop") case OP_TYPEID::MaxPoolBackprop:
{ {
auto tmp = dynamic_cast<const op::MaxPoolBackprop*>(&n); auto tmp = dynamic_cast<const op::MaxPoolBackprop*>(&n);
node["window_shape"] = tmp->get_window_shape(); node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
break;
} }
else if (node_op == "Maximum") case OP_TYPEID::Maximum: { break;
{
} }
else if (node_op == "Min") case OP_TYPEID::Min:
{ {
auto tmp = dynamic_cast<const op::Min*>(&n); auto tmp = dynamic_cast<const op::Min*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = tmp->get_reduction_axes();
break;
} }
else if (node_op == "Minimum") case OP_TYPEID::Minimum: { break;
{
} }
else if (node_op == "Multiply") case OP_TYPEID::Multiply: { break;
{
} }
else if (node_op == "Negative") case OP_TYPEID::Negative: { break;
{
} }
else if (node_op == "NotEqual") case OP_TYPEID::NotEqual: { break;
{
} }
else if (node_op == "Not") case OP_TYPEID::Not: { break;
{
} }
else if (node_op == "OneHot") case OP_TYPEID::OneHot:
{ {
auto tmp = dynamic_cast<const op::OneHot*>(&n); auto tmp = dynamic_cast<const op::OneHot*>(&n);
node["shape"] = tmp->get_shape(); node["shape"] = tmp->get_shape();
node["one_hot_axis"] = tmp->get_one_hot_axis(); node["one_hot_axis"] = tmp->get_one_hot_axis();
break;
}
case OP_TYPEID::Or: { break;
} }
else if (node_op == "Pad") case OP_TYPEID::Pad:
{ {
auto tmp = dynamic_cast<const op::Pad*>(&n); auto tmp = dynamic_cast<const op::Pad*>(&n);
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["padding_interior"] = tmp->get_padding_interior(); node["padding_interior"] = tmp->get_padding_interior();
break;
} }
else if (node_op == "Parameter") case OP_TYPEID::Parameter:
{ {
auto tmp = dynamic_cast<const op::Parameter*>(&n); auto tmp = dynamic_cast<const op::Parameter*>(&n);
node["shape"] = tmp->get_shape(); node["shape"] = tmp->get_shape();
node["cacheable"] = tmp->get_cacheable(); node["cacheable"] = tmp->get_cacheable();
node["element_type"] = write_element_type(tmp->get_element_type()); node["element_type"] = write_element_type(tmp->get_element_type());
break;
} }
else if (node_op == "Product") case OP_TYPEID::Product:
{ {
auto tmp = dynamic_cast<const op::Product*>(&n); auto tmp = dynamic_cast<const op::Product*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = tmp->get_reduction_axes();
break;
} }
else if (node_op == "Power") case OP_TYPEID::Power: { break;
{
} }
else if (node_op == "Reduce") case OP_TYPEID::Reduce:
{ {
auto tmp = dynamic_cast<const op::Reduce*>(&n); auto tmp = dynamic_cast<const op::Reduce*>(&n);
node["function"] = tmp->get_functions()[0]->get_name(); node["function"] = tmp->get_functions()[0]->get_name();
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = tmp->get_reduction_axes();
break;
} }
else if (node_op == "ReduceWindow") case OP_TYPEID::ReduceWindow:
{ {
auto tmp = dynamic_cast<const op::ReduceWindow*>(&n); auto tmp = dynamic_cast<const op::ReduceWindow*>(&n);
node["function"] = tmp->get_functions()[0]->get_name(); node["function"] = tmp->get_functions()[0]->get_name();
node["window_shape"] = tmp->get_window_shape(); node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
break;
} }
else if (node_op == "Relu") case OP_TYPEID::Relu: { break;
{
}
else if (node_op == "ReluBackprop")
{
} }
else if (node_op == "Remainder") case OP_TYPEID::ReluBackprop: { break;
{
} }
else if (node_op == "ReplaceSlice") case OP_TYPEID::ReplaceSlice:
{ {
auto tmp = dynamic_cast<const op::ReplaceSlice*>(&n); auto tmp = dynamic_cast<const op::ReplaceSlice*>(&n);
node["lower_bounds"] = tmp->get_lower_bounds(); node["lower_bounds"] = tmp->get_lower_bounds();
node["upper_bounds"] = tmp->get_upper_bounds(); node["upper_bounds"] = tmp->get_upper_bounds();
node["strides"] = tmp->get_strides(); node["strides"] = tmp->get_strides();
break;
} }
else if (node_op == "Reshape") case OP_TYPEID::Reshape:
{ {
auto tmp = dynamic_cast<const op::Reshape*>(&n); auto tmp = dynamic_cast<const op::Reshape*>(&n);
node["input_order"] = tmp->get_input_order(); node["input_order"] = tmp->get_input_order();
node["output_shape"] = tmp->get_output_shape(); node["output_shape"] = tmp->get_output_shape();
break;
} }
else if (node_op == "Result") case OP_TYPEID::Result: { break;
{
} }
else if (node_op == "Reverse") case OP_TYPEID::Reverse:
{ {
auto tmp = dynamic_cast<const op::Reverse*>(&n); auto tmp = dynamic_cast<const op::Reverse*>(&n);
node["reversed_axes"] = tmp->get_reversed_axes(); node["reversed_axes"] = tmp->get_reversed_axes();
break;
} }
else if (node_op == "ReverseSequence") case OP_TYPEID::ReverseSequence:
{ {
auto tmp = dynamic_cast<const op::ReverseSequence*>(&n); auto tmp = dynamic_cast<const op::ReverseSequence*>(&n);
node["batch_axis"] = tmp->get_batch_axis(); node["batch_axis"] = tmp->get_batch_axis();
node["sequence_axis"] = tmp->get_sequence_axis(); node["sequence_axis"] = tmp->get_sequence_axis();
break;
} }
else if (node_op == "Select") case OP_TYPEID::Select: { break;
{
} }
else if (node_op == "SelectAndScatter") case OP_TYPEID::SelectAndScatter:
{ {
auto tmp = dynamic_cast<const op::SelectAndScatter*>(&n); auto tmp = dynamic_cast<const op::SelectAndScatter*>(&n);
node["selection_function"] = tmp->get_functions()[0]->get_name(); node["selection_function"] = tmp->get_functions()[0]->get_name();
node["scatter_function"] = tmp->get_functions()[1]->get_name(); node["scatter_function"] = tmp->get_functions()[1]->get_name();
node["window_shape"] = tmp->get_window_shape(); node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
break;
} }
else if (node_op == "Sigmoid") case OP_TYPEID::Sigmoid: { break;
{
} }
else if (node_op == "SigmoidBackprop") case OP_TYPEID::SigmoidBackprop: { break;
{
} }
else if (node_op == "Sign") case OP_TYPEID::Sign: { break;
{
} }
else if (node_op == "Sin") case OP_TYPEID::Sin: { break;
{
} }
else if (node_op == "Sinh") case OP_TYPEID::Sinh: { break;
{
} }
else if (node_op == "Slice") case OP_TYPEID::Slice:
{ {
auto tmp = dynamic_cast<const op::Slice*>(&n); auto tmp = dynamic_cast<const op::Slice*>(&n);
node["lower_bounds"] = tmp->get_lower_bounds(); node["lower_bounds"] = tmp->get_lower_bounds();
node["upper_bounds"] = tmp->get_upper_bounds(); node["upper_bounds"] = tmp->get_upper_bounds();
node["strides"] = tmp->get_strides(); node["strides"] = tmp->get_strides();
break;
} }
else if (node_op == "Sqrt") case OP_TYPEID::Sqrt: { break;
{
} }
else if (node_op == "Subtract") case OP_TYPEID::StopGradient: { break;
{
} }
else if (node_op == "Sum") case OP_TYPEID::Subtract: { break;
}
case OP_TYPEID::Sum:
{ {
auto tmp = dynamic_cast<const op::Sum*>(&n); auto tmp = dynamic_cast<const op::Sum*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = tmp->get_reduction_axes();
break;
} }
else if (node_op == "Softmax") case OP_TYPEID::Softmax:
{ {
auto tmp = dynamic_cast<const op::Softmax*>(&n); auto tmp = dynamic_cast<const op::Softmax*>(&n);
node["softmax_axes"] = tmp->get_axes(); node["softmax_axes"] = tmp->get_axes();
break;
} }
else if (node_op == "Tan") case OP_TYPEID::Tan: { break;
{
} }
else if (node_op == "Tanh") case OP_TYPEID::Tanh: { break;
{
} }
else if (node_op == "TopK") case OP_TYPEID::TopK:
{ {
auto tmp = dynamic_cast<const op::TopK*>(&n); auto tmp = dynamic_cast<const op::TopK*>(&n);
node["top_k_axis"] = tmp->get_top_k_axis(); node["top_k_axis"] = tmp->get_top_k_axis();
node["index_element_type"] = write_element_type(tmp->get_index_element_type()); node["index_element_type"] = write_element_type(tmp->get_index_element_type());
node["k"] = tmp->get_k(); node["k"] = tmp->get_k();
node["compute_max"] = tmp->get_compute_max(); node["compute_max"] = tmp->get_compute_max();
break;
}
} }
#pragma GCC diagnostic pop
return node; return node;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment