Commit 228570eb authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Fix broken serialize and deserialize for Sum and Product (#4050)

parent 22008460
......@@ -2386,11 +2386,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Product:
{
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
set<size_t> reduction_axes =
get_or_default<set<size_t>>(node_js, "reduction_axes", set<size_t>());
if (reduction_axes.empty())
{
node = make_shared<op::v0::Product>(args[0], args[1]);
}
else
{
node = make_shared<op::v0::Product>(args[0], reduction_axes);
}
break;
}
case OP_TYPEID::ReduceProd_v1:
......@@ -2807,11 +2812,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Sum:
{
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
set<size_t> reduction_axes =
get_or_default<set<size_t>>(node_js, "reduction_axes", set<size_t>());
if (reduction_axes.empty())
{
node = make_shared<op::v0::Sum>(args[0], args[1]);
}
else
{
node = make_shared<op::v0::Sum>(args[0], reduction_axes);
}
break;
}
case OP_TYPEID::Tan:
......@@ -4202,7 +4212,11 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::PRelu: { break;
}
case OP_TYPEID::Product: { break;
case OP_TYPEID::Product:
{
auto tmp = static_cast<const op::Product*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
break;
}
case OP_TYPEID::ReduceProd_v1:
{
......@@ -4480,7 +4494,11 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::Sum: { break;
case OP_TYPEID::Sum:
{
auto tmp = static_cast<const op::Sum*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
break;
}
case OP_TYPEID::ReduceSum_v1:
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment