Unverified Commit a517ba09 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Change serializer to use op_tbl (#1612)

* wip

* wip

* serializer update to use op_tbl

* fix deserialize

* remove Remainder
parent 51bcb92d
......@@ -16,8 +16,6 @@
include(ExternalProject)
message(STATUS "Fetching LLVM from llvm.org")
find_package(ZLIB REQUIRED)
# Override default LLVM binaries
......
......@@ -70,7 +70,6 @@
#include "ngraph/op/reduce.hpp"
#include "ngraph/op/reduce_window.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/remainder.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
......@@ -100,6 +99,31 @@ using namespace std;
using json = nlohmann::json;
using const_data_callback_t = shared_ptr<Node>(const string&, const element::Type&, const Shape&);
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
#define NGRAPH_OP(a) a,
enum class OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
static OP_TYPEID get_typeid(const string& s)
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {"Abs", OP_TYPEID::Abs},
// {"Acos", OP_TYPEID::Acos},
// ...
#define NGRAPH_OP(a) {#a, OP_TYPEID::a},
static const unordered_map<string, OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
return typeid_map.at(s);
}
template <typename T>
T get_or_default(nlohmann::json& j, const std::string& key, const T& default_value)
{
......@@ -343,48 +367,62 @@ static shared_ptr<ngraph::Function>
{
args.push_back(node_map.at(name));
}
if (node_op == "Abs")
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wimplicit-fallthrough"
switch (get_typeid(node_op))
{
case OP_TYPEID::Abs:
{
node = make_shared<op::Abs>(args[0]);
break;
}
else if (node_op == "Acos")
case OP_TYPEID::Acos:
{
node = make_shared<op::Acos>(args[0]);
break;
}
else if (node_op == "Add")
case OP_TYPEID::Add:
{
node = make_shared<op::Add>(args[0], args[1]);
break;
}
else if (node_op == "AllReduce")
case OP_TYPEID::AllReduce:
{
node = make_shared<op::AllReduce>(args[0]);
break;
}
else if (node_op == "And")
case OP_TYPEID::And:
{
node = make_shared<op::And>(args[0], args[1]);
break;
}
else if (node_op == "ArgMin")
case OP_TYPEID::ArgMin:
{
auto axis = node_js.at("axis").get<size_t>();
auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::ArgMin>(args[0], axis, target_type);
break;
}
else if (node_op == "ArgMax")
case OP_TYPEID::ArgMax:
{
auto axis = node_js.at("axis").get<size_t>();
auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::ArgMax>(args[0], axis, target_type);
break;
}
else if (node_op == "Asin")
case OP_TYPEID::Asin:
{
node = make_shared<op::Asin>(args[0]);
break;
}
else if (node_op == "Atan")
case OP_TYPEID::Atan:
{
node = make_shared<op::Atan>(args[0]);
break;
}
else if (node_op == "AvgPool")
case OP_TYPEID::AvgPool:
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
......@@ -399,8 +437,9 @@ static shared_ptr<ngraph::Function>
padding_below,
padding_above,
include_padding_in_avg_computation);
break;
}
else if (node_op == "AvgPoolBackprop")
case OP_TYPEID::AvgPoolBackprop:
{
auto forward_arg_shape = node_js.at("forward_arg_shape").get<vector<size_t>>();
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
......@@ -417,8 +456,9 @@ static shared_ptr<ngraph::Function>
padding_below,
padding_above,
include_padding_in_avg_computation);
break;
}
else if (node_op == "BatchNorm")
case OP_TYPEID::BatchNorm:
{
auto epsilon = node_js.at("eps").get<double>();
bool training = get_or_default<bool>(node_js, "training", true);
......@@ -436,29 +476,34 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::BatchNorm>(
epsilon, args[0], args[1], args[2], args[3], args[4]);
}
break;
}
else if (node_op == "BatchNormBackprop")
case OP_TYPEID::BatchNormBackprop:
{
auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNormBackprop>(
epsilon, args[0], args[1], args[2], args[3], args[4], args[5]);
break;
}
else if (node_op == "Broadcast")
case OP_TYPEID::Broadcast:
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto axes = node_js.at("axes").get<set<size_t>>();
node = make_shared<op::Broadcast>(args[0], shape, axes);
break;
}
else if (node_op == "Ceiling")
case OP_TYPEID::Ceiling:
{
node = make_shared<op::Ceiling>(args[0]);
break;
}
else if (node_op == "Concat")
case OP_TYPEID::Concat:
{
auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::Concat>(args, axis);
break;
}
else if (node_op == "Constant")
case OP_TYPEID::Constant:
{
auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
......@@ -473,13 +518,15 @@ static shared_ptr<ngraph::Function>
{
node = const_data_callback(node_name, element_type, shape);
}
break;
}
else if (node_op == "Convert")
case OP_TYPEID::Convert:
{
auto target_type = read_element_type(node_js.at("target_type"));
node = make_shared<op::Convert>(args[0], target_type);
break;
}
else if (node_op == "Convolution")
case OP_TYPEID::Convolution:
{
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
......@@ -516,8 +563,9 @@ static shared_ptr<ngraph::Function>
padding_above,
data_dilation_strides_maybe.get<std::vector<size_t>>());
}
break;
}
else if (node_op == "ConvolutionBackpropData")
case OP_TYPEID::ConvolutionBackpropData:
{
auto data_batch_shape = node_js.at("data_batch_shape").get<vector<size_t>>();
auto window_movement_strides_forward =
......@@ -538,8 +586,9 @@ static shared_ptr<ngraph::Function>
padding_below_forward,
padding_above_forward,
data_dilation_strides_forward);
break;
}
else if (node_op == "ConvolutionBackpropFilters")
case OP_TYPEID::ConvolutionBackpropFilters:
{
auto filters_shape = node_js.at("filters_shape").get<vector<size_t>>();
auto window_movement_strides_forward =
......@@ -560,20 +609,24 @@ static shared_ptr<ngraph::Function>
padding_below_forward,
padding_above_forward,
data_dilation_strides_forward);
break;
}
else if (node_op == "Cos")
case OP_TYPEID::Cos:
{
node = make_shared<op::Cos>(args[0]);
break;
}
else if (node_op == "Cosh")
case OP_TYPEID::Cosh:
{
node = make_shared<op::Cosh>(args[0]);
break;
}
else if (node_op == "Divide")
case OP_TYPEID::Divide:
{
node = make_shared<op::Divide>(args[0], args[1]);
break;
}
else if (node_op == "Dot")
case OP_TYPEID::Dot:
{
// For backwards compatibility, reduction_axes_count is optional.
auto obj = node_js["reduction_axes_count"];
......@@ -586,63 +639,76 @@ static shared_ptr<ngraph::Function>
size_t reduction_axes_count = obj.get<size_t>();
node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count);
}
break;
}
else if (node_op == "Equal")
case OP_TYPEID::Equal:
{
node = make_shared<op::Equal>(args[0], args[1]);
break;
}
else if (node_op == "Exp")
case OP_TYPEID::Exp:
{
node = make_shared<op::Exp>(args[0]);
break;
}
else if (node_op == "Floor")
case OP_TYPEID::Floor:
{
node = make_shared<op::Floor>(args[0]);
break;
}
else if (node_op == "FunctionCall")
case OP_TYPEID::FunctionCall:
{
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::FunctionCall>(f_ptr, args);
break;
}
else if (node_op == "GetOutputElement")
case OP_TYPEID::GetOutputElement:
{
node = make_shared<op::GetOutputElement>(args[0], node_js.at("n").get<size_t>());
break;
}
else if (node_op == "Greater")
case OP_TYPEID::Greater:
{
node = make_shared<op::Greater>(args[0], args[1]);
break;
}
else if (node_op == "GreaterEq")
case OP_TYPEID::GreaterEq:
{
node = make_shared<op::GreaterEq>(args[0], args[1]);
break;
}
else if (node_op == "Less")
case OP_TYPEID::Less:
{
node = make_shared<op::Less>(args[0], args[1]);
break;
}
else if (node_op == "LessEq")
case OP_TYPEID::LessEq:
{
node = make_shared<op::LessEq>(args[0], args[1]);
break;
}
else if (node_op == "Log")
case OP_TYPEID::Log:
{
node = make_shared<op::Log>(args[0]);
break;
}
else if (node_op == "LRN")
case OP_TYPEID::LRN:
{
auto alpha = node_js.at("alpha").get<double>();
auto beta = node_js.at("beta").get<double>();
auto bias = node_js.at("bias").get<double>();
auto nsize = node_js.at("nsize").get<size_t>();
node = make_shared<op::LRN>(args[0], alpha, beta, bias, nsize);
break;
}
else if (node_op == "Max")
case OP_TYPEID::Max:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Max>(args[0], reduction_axes);
break;
}
else if (node_op == "MaxPool")
case OP_TYPEID::MaxPool:
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
......@@ -675,8 +741,9 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::MaxPool>(args[0], window_shape, window_movement_strides);
}
break;
}
else if (node_op == "MaxPoolBackprop")
case OP_TYPEID::MaxPoolBackprop:
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
......@@ -689,55 +756,66 @@ static shared_ptr<ngraph::Function>
window_movement_strides,
padding_below,
padding_above);
break;
}
else if (node_op == "Maximum")
case OP_TYPEID::Maximum:
{
node = make_shared<op::Maximum>(args[0], args[1]);
break;
}
else if (node_op == "Min")
case OP_TYPEID::Min:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Min>(args[0], reduction_axes);
break;
}
else if (node_op == "Minimum")
case OP_TYPEID::Minimum:
{
node = make_shared<op::Minimum>(args[0], args[1]);
break;
}
else if (node_op == "Multiply")
case OP_TYPEID::Multiply:
{
node = make_shared<op::Multiply>(args[0], args[1]);
break;
}
else if (node_op == "Negative")
case OP_TYPEID::Negative:
{
node = make_shared<op::Negative>(args[0]);
break;
}
else if (node_op == "NotEqual")
case OP_TYPEID::NotEqual:
{
node = make_shared<op::NotEqual>(args[0], args[1]);
break;
}
else if (node_op == "Not")
case OP_TYPEID::Not:
{
node = make_shared<op::Not>(args[0]);
break;
}
else if (node_op == "OneHot")
case OP_TYPEID::OneHot:
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], shape, one_hot_axis);
break;
}
else if (node_op == "Or")
case OP_TYPEID::Or:
{
node = make_shared<op::Or>(args[0], args[1]);
break;
}
else if (node_op == "Pad")
case OP_TYPEID::Pad:
{
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto padding_interior = node_js.at("padding_interior").get<vector<size_t>>();
node = make_shared<op::Pad>(
args[0], args[1], padding_below, padding_above, padding_interior);
break;
}
else if (node_op == "Parameter")
case OP_TYPEID::Parameter:
{
auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
......@@ -745,24 +823,28 @@ static shared_ptr<ngraph::Function>
auto shape = type_node_js.at("shape");
auto cacheable = get_or_default<bool>(node_js, "cacheable", false);
node = make_shared<op::Parameter>(element_type, shape, cacheable);
break;
}
else if (node_op == "Power")
case OP_TYPEID::Power:
{
node = make_shared<op::Power>(args[0], args[1]);
break;
}
else if (node_op == "Product")
case OP_TYPEID::Product:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Product>(args[0], reduction_axes);
break;
}
else if (node_op == "Reduce")
case OP_TYPEID::Reduce:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::Reduce>(args[0], args[1], f_ptr, reduction_axes);
break;
}
else if (node_op == "ReduceWindow")
case OP_TYPEID::ReduceWindow:
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
......@@ -771,54 +853,59 @@ static shared_ptr<ngraph::Function>
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::ReduceWindow>(
args[0], args[1], f_ptr, window_shape, window_movement_strides);
break;
}
else if (node_op == "Remainder")
{
node = make_shared<op::Remainder>(args[0], args[1]);
}
else if (node_op == "Relu")
case OP_TYPEID::Relu:
{
node = make_shared<op::Relu>(args[0]);
break;
}
else if (node_op == "ReluBackprop")
case OP_TYPEID::ReluBackprop:
{
node = make_shared<op::ReluBackprop>(args[0], args[1]);
break;
}
else if (node_op == "ReplaceSlice")
case OP_TYPEID::ReplaceSlice:
{
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::ReplaceSlice>(
args[0], args[1], lower_bounds, upper_bounds, strides);
break;
}
else if (node_op == "Reshape")
case OP_TYPEID::Reshape:
{
auto input_order = node_js.at("input_order").get<vector<size_t>>();
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::Reshape>(args[0], input_order, output_shape);
break;
}
else if (node_op == "Result")
case OP_TYPEID::Result:
{
node = make_shared<op::Result>(args[0]);
break;
}
else if (node_op == "Reverse")
case OP_TYPEID::Reverse:
{
auto reversed_axes = node_js.at("reversed_axes").get<set<size_t>>();
node = make_shared<op::Reverse>(args[0], reversed_axes);
break;
}
else if (node_op == "ReverseSequence")
case OP_TYPEID::ReverseSequence:
{
auto batch_axis = node_js.at("batch_axis").get<size_t>();
auto sequence_axis = node_js.at("sequence_axis").get<size_t>();
node =
make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
break;
}
else if (node_op == "Select")
case OP_TYPEID::Select:
{
node = make_shared<op::Select>(args[0], args[1], args[2]);
break;
}
else if (node_op == "SelectAndScatter")
case OP_TYPEID::SelectAndScatter:
{
string selection_function_name = node_js.at("selection_function").get<string>();
shared_ptr<Function> selection_f_ptr = function_map.at(selection_function_name);
......@@ -836,78 +923,95 @@ static shared_ptr<ngraph::Function>
scatter_f_ptr,
window_shape,
window_movement_strides);
break;
}
else if (node_op == "Sigmoid")
case OP_TYPEID::Sigmoid:
{
node = make_shared<op::Sigmoid>(args[0]);
break;
}
else if (node_op == "SigmoidBackprop")
case OP_TYPEID::SigmoidBackprop:
{
node = make_shared<op::SigmoidBackprop>(args[0], args[1]);
break;
}
else if (node_op == "Sign")
case OP_TYPEID::Sign:
{
node = make_shared<op::Sign>(args[0]);
break;
}
else if (node_op == "Sin")
case OP_TYPEID::Sin:
{
node = make_shared<op::Sin>(args[0]);
break;
}
else if (node_op == "Sinh")
case OP_TYPEID::Sinh:
{
node = make_shared<op::Sinh>(args[0]);
break;
}
else if (node_op == "Slice")
case OP_TYPEID::Slice:
{
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::Slice>(args[0], lower_bounds, upper_bounds, strides);
break;
}
else if (node_op == "Softmax")
case OP_TYPEID::Softmax:
{
auto softmax_axes = node_js.at("softmax_axes").get<set<size_t>>();
node = make_shared<op::Softmax>(args[0], softmax_axes);
break;
}
else if (node_op == "Sqrt")
case OP_TYPEID::Sqrt:
{
node = make_shared<op::Sqrt>(args[0]);
break;
}
else if (node_op == "Subtract")
case OP_TYPEID::Subtract:
{
node = make_shared<op::Subtract>(args[0], args[1]);
break;
}
else if (node_op == "Sum")
case OP_TYPEID::Sum:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Sum>(args[0], reduction_axes);
break;
}
else if (node_op == "Tan")
case OP_TYPEID::Tan:
{
node = make_shared<op::Tan>(args[0]);
break;
}
else if (node_op == "Tanh")
case OP_TYPEID::Tanh:
{
node = make_shared<op::Tanh>(args[0]);
break;
}
else if (node_op == "TopK")
case OP_TYPEID::TopK:
{
auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
auto k = node_js.at("k").get<size_t>();
auto compute_max = node_js.at("compute_max").get<bool>();
auto target_type = read_element_type(node_js.at("index_element_type"));
node = make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max);
break;
}
else if (node_op == "StopGradient")
case OP_TYPEID::StopGradient:
{
node = make_shared<op::StopGradient>(args[0]);
break;
}
else
default:
{
stringstream ss;
ss << "unsupported op " << node_op;
throw runtime_error(ss.str());
}
}
#pragma GCC diagnostic pop
for (const string& name : control_deps_inputs)
{
......@@ -1011,37 +1115,41 @@ static json write(const Node& n, bool binary_constant_data)
}
string node_op = n.description();
if (node_op == "Abs")
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wimplicit-fallthrough"
switch (get_typeid(node_op))
{
case OP_TYPEID::Abs: { break;
}
else if (node_op == "Acos")
{
case OP_TYPEID::Acos: { break;
}
else if (node_op == "Add")
{
case OP_TYPEID::Add: { break;
}
else if (node_op == "ArgMin")
case OP_TYPEID::ArgMin:
{
auto tmp = dynamic_cast<const op::ArgMin*>(&n);
node["axis"] = tmp->get_reduction_axis();
node["index_element_type"] = write_element_type(tmp->get_element_type());
break;
}
else if (node_op == "ArgMax")
case OP_TYPEID::ArgMax:
{
auto tmp = dynamic_cast<const op::ArgMax*>(&n);
node["axis"] = tmp->get_reduction_axis();
node["index_element_type"] = write_element_type(tmp->get_element_type());
break;
}
else if (node_op == "AllReduce")
{
case OP_TYPEID::AllReduce: { break;
}
else if (node_op == "Asin")
{
case OP_TYPEID::And: { break;
}
else if (node_op == "Atan")
{
case OP_TYPEID::Asin: { break;
}
else if (node_op == "AvgPool")
case OP_TYPEID::Atan: { break;
}
case OP_TYPEID::AvgPool:
{
auto tmp = dynamic_cast<const op::AvgPool*>(&n);
node["window_shape"] = tmp->get_window_shape();
......@@ -1049,8 +1157,9 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
break;
}
else if (node_op == "AvgPoolBackprop")
case OP_TYPEID::AvgPoolBackprop:
{
auto tmp = dynamic_cast<const op::AvgPoolBackprop*>(&n);
node["forward_arg_shape"] = tmp->get_forward_arg_shape();
......@@ -1059,33 +1168,37 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
break;
}
else if (node_op == "BatchNorm")
case OP_TYPEID::BatchNorm:
{
auto tmp = dynamic_cast<const op::BatchNorm*>(&n);
node["eps"] = tmp->get_eps_value();
node["training"] = tmp->get_training_flag();
break;
}
else if (node_op == "BatchNormBackprop")
case OP_TYPEID::BatchNormBackprop:
{
auto tmp = dynamic_cast<const op::BatchNormBackprop*>(&n);
node["eps"] = tmp->get_eps_value();
break;
}
else if (node_op == "Broadcast")
case OP_TYPEID::Broadcast:
{
auto tmp = dynamic_cast<const op::Broadcast*>(&n);
node["axes"] = tmp->get_broadcast_axes();
node["shape"] = tmp->get_broadcast_shape();
break;
}
else if (node_op == "Ceiling")
{
case OP_TYPEID::Ceiling: { break;
}
else if (node_op == "Concat")
case OP_TYPEID::Concat:
{
auto tmp = dynamic_cast<const op::Concat*>(&n);
node["axis"] = tmp->get_concatenation_axis();
break;
}
else if (node_op == "Constant")
case OP_TYPEID::Constant:
{
auto tmp = dynamic_cast<const op::Constant*>(&n);
if (!binary_constant_data)
......@@ -1094,13 +1207,15 @@ static json write(const Node& n, bool binary_constant_data)
}
node["shape"] = tmp->get_shape();
node["element_type"] = write_element_type(tmp->get_element_type());
break;
}
else if (node_op == "Convert")
case OP_TYPEID::Convert:
{
auto tmp = dynamic_cast<const op::Convert*>(&n);
node["target_type"] = write_element_type(tmp->get_convert_element_type());
break;
}
else if (node_op == "Convolution")
case OP_TYPEID::Convolution:
{
auto tmp = dynamic_cast<const op::Convolution*>(&n);
node["window_movement_strides"] = tmp->get_window_movement_strides();
......@@ -1108,8 +1223,9 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides();
break;
}
else if (node_op == "ConvolutionBackpropData")
case OP_TYPEID::ConvolutionBackpropData:
{
auto tmp = dynamic_cast<const op::ConvolutionBackpropData*>(&n);
node["data_batch_shape"] = tmp->get_data_batch_shape();
......@@ -1118,8 +1234,9 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below_forward"] = tmp->get_padding_below_forward();
node["padding_above_forward"] = tmp->get_padding_above_forward();
node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward();
break;
}
else if (node_op == "ConvolutionBackpropFilters")
case OP_TYPEID::ConvolutionBackpropFilters:
{
auto tmp = dynamic_cast<const op::ConvolutionBackpropFilters*>(&n);
node["filters_shape"] = tmp->get_filters_shape();
......@@ -1128,246 +1245,242 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below_forward"] = tmp->get_padding_below_forward();
node["padding_above_forward"] = tmp->get_padding_above_forward();
node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward();
break;
}
else if (node_op == "Cos")
{
case OP_TYPEID::Cos: { break;
}
else if (node_op == "Cosh")
{
case OP_TYPEID::Cosh: { break;
}
else if (node_op == "Divide")
{
case OP_TYPEID::Divide: { break;
}
else if (node_op == "Dot")
case OP_TYPEID::Dot:
{
auto tmp = dynamic_cast<const op::Dot*>(&n);
node["reduction_axes_count"] = tmp->get_reduction_axes_count();
break;
}
else if (node_op == "Equal")
{
case OP_TYPEID::Equal: { break;
}
else if (node_op == "Exp")
{
case OP_TYPEID::Exp: { break;
}
else if (node_op == "Floor")
{
case OP_TYPEID::Floor: { break;
}
else if (node_op == "FunctionCall")
case OP_TYPEID::FunctionCall:
{
node["function"] = n.get_functions()[0]->get_name();
break;
}
else if (node_op == "GetOutputElement")
case OP_TYPEID::GetOutputElement:
{
auto tmp = dynamic_cast<const op::GetOutputElement*>(&n);
node["n"] = tmp->get_n();
break;
}
else if (node_op == "Greater")
{
case OP_TYPEID::Greater: { break;
}
else if (node_op == "GreaterEq")
{
case OP_TYPEID::GreaterEq: { break;
}
else if (node_op == "Less")
{
case OP_TYPEID::Less: { break;
}
else if (node_op == "LessEq")
{
case OP_TYPEID::LessEq: { break;
}
else if (node_op == "Log")
{
case OP_TYPEID::Log: { break;
}
else if (node_op == "LRN")
case OP_TYPEID::LRN:
{
auto tmp = dynamic_cast<const op::LRN*>(&n);
node["alpha"] = tmp->get_alpha();
node["beta"] = tmp->get_beta();
node["bias"] = tmp->get_bias();
node["nsize"] = tmp->get_nsize();
break;
}
else if (node_op == "Max")
case OP_TYPEID::Max:
{
auto tmp = dynamic_cast<const op::Max*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
break;
}
else if (node_op == "MaxPool")
case OP_TYPEID::MaxPool:
{
auto tmp = dynamic_cast<const op::MaxPool*>(&n);
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
break;
}
else if (node_op == "MaxPoolBackprop")
case OP_TYPEID::MaxPoolBackprop:
{
auto tmp = dynamic_cast<const op::MaxPoolBackprop*>(&n);
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
break;
}
else if (node_op == "Maximum")
{
case OP_TYPEID::Maximum: { break;
}
else if (node_op == "Min")
case OP_TYPEID::Min:
{
auto tmp = dynamic_cast<const op::Min*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
break;
}
else if (node_op == "Minimum")
{
case OP_TYPEID::Minimum: { break;
}
else if (node_op == "Multiply")
{
case OP_TYPEID::Multiply: { break;
}
else if (node_op == "Negative")
{
case OP_TYPEID::Negative: { break;
}
else if (node_op == "NotEqual")
{
case OP_TYPEID::NotEqual: { break;
}
else if (node_op == "Not")
{
case OP_TYPEID::Not: { break;
}
else if (node_op == "OneHot")
case OP_TYPEID::OneHot:
{
auto tmp = dynamic_cast<const op::OneHot*>(&n);
node["shape"] = tmp->get_shape();
node["one_hot_axis"] = tmp->get_one_hot_axis();
break;
}
case OP_TYPEID::Or: { break;
}
else if (node_op == "Pad")
case OP_TYPEID::Pad:
{
auto tmp = dynamic_cast<const op::Pad*>(&n);
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["padding_interior"] = tmp->get_padding_interior();
break;
}
else if (node_op == "Parameter")
case OP_TYPEID::Parameter:
{
auto tmp = dynamic_cast<const op::Parameter*>(&n);
node["shape"] = tmp->get_shape();
node["cacheable"] = tmp->get_cacheable();
node["element_type"] = write_element_type(tmp->get_element_type());
break;
}
else if (node_op == "Product")
case OP_TYPEID::Product:
{
auto tmp = dynamic_cast<const op::Product*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
break;
}
else if (node_op == "Power")
{
case OP_TYPEID::Power: { break;
}
else if (node_op == "Reduce")
case OP_TYPEID::Reduce:
{
auto tmp = dynamic_cast<const op::Reduce*>(&n);
node["function"] = tmp->get_functions()[0]->get_name();
node["reduction_axes"] = tmp->get_reduction_axes();
break;
}
else if (node_op == "ReduceWindow")
case OP_TYPEID::ReduceWindow:
{
auto tmp = dynamic_cast<const op::ReduceWindow*>(&n);
node["function"] = tmp->get_functions()[0]->get_name();
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
break;
}
else if (node_op == "Relu")
{
}
else if (node_op == "ReluBackprop")
{
case OP_TYPEID::Relu: { break;
}
else if (node_op == "Remainder")
{
case OP_TYPEID::ReluBackprop: { break;
}
else if (node_op == "ReplaceSlice")
case OP_TYPEID::ReplaceSlice:
{
auto tmp = dynamic_cast<const op::ReplaceSlice*>(&n);
node["lower_bounds"] = tmp->get_lower_bounds();
node["upper_bounds"] = tmp->get_upper_bounds();
node["strides"] = tmp->get_strides();
break;
}
else if (node_op == "Reshape")
case OP_TYPEID::Reshape:
{
auto tmp = dynamic_cast<const op::Reshape*>(&n);
node["input_order"] = tmp->get_input_order();
node["output_shape"] = tmp->get_output_shape();
break;
}
else if (node_op == "Result")
{
case OP_TYPEID::Result: { break;
}
else if (node_op == "Reverse")
case OP_TYPEID::Reverse:
{
auto tmp = dynamic_cast<const op::Reverse*>(&n);
node["reversed_axes"] = tmp->get_reversed_axes();
break;
}
else if (node_op == "ReverseSequence")
case OP_TYPEID::ReverseSequence:
{
auto tmp = dynamic_cast<const op::ReverseSequence*>(&n);
node["batch_axis"] = tmp->get_batch_axis();
node["sequence_axis"] = tmp->get_sequence_axis();
break;
}
else if (node_op == "Select")
{
case OP_TYPEID::Select: { break;
}
else if (node_op == "SelectAndScatter")
case OP_TYPEID::SelectAndScatter:
{
auto tmp = dynamic_cast<const op::SelectAndScatter*>(&n);
node["selection_function"] = tmp->get_functions()[0]->get_name();
node["scatter_function"] = tmp->get_functions()[1]->get_name();
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
break;
}
else if (node_op == "Sigmoid")
{
case OP_TYPEID::Sigmoid: { break;
}
else if (node_op == "SigmoidBackprop")
{
case OP_TYPEID::SigmoidBackprop: { break;
}
else if (node_op == "Sign")
{
case OP_TYPEID::Sign: { break;
}
else if (node_op == "Sin")
{
case OP_TYPEID::Sin: { break;
}
else if (node_op == "Sinh")
{
case OP_TYPEID::Sinh: { break;
}
else if (node_op == "Slice")
case OP_TYPEID::Slice:
{
auto tmp = dynamic_cast<const op::Slice*>(&n);
node["lower_bounds"] = tmp->get_lower_bounds();
node["upper_bounds"] = tmp->get_upper_bounds();
node["strides"] = tmp->get_strides();
break;
}
else if (node_op == "Sqrt")
{
case OP_TYPEID::Sqrt: { break;
}
else if (node_op == "Subtract")
{
case OP_TYPEID::StopGradient: { break;
}
else if (node_op == "Sum")
case OP_TYPEID::Subtract: { break;
}
case OP_TYPEID::Sum:
{
auto tmp = dynamic_cast<const op::Sum*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
break;
}
else if (node_op == "Softmax")
case OP_TYPEID::Softmax:
{
auto tmp = dynamic_cast<const op::Softmax*>(&n);
node["softmax_axes"] = tmp->get_axes();
break;
}
else if (node_op == "Tan")
{
case OP_TYPEID::Tan: { break;
}
else if (node_op == "Tanh")
{
case OP_TYPEID::Tanh: { break;
}
else if (node_op == "TopK")
case OP_TYPEID::TopK:
{
auto tmp = dynamic_cast<const op::TopK*>(&n);
node["top_k_axis"] = tmp->get_top_k_axis();
node["index_element_type"] = write_element_type(tmp->get_index_element_type());
node["k"] = tmp->get_k();
node["compute_max"] = tmp->get_compute_max();
break;
}
}
#pragma GCC diagnostic pop
return node;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment