Commit 228570eb authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Fix broken serialize and deserialize for Sum and Product (#4050)

parent 22008460
...@@ -2386,11 +2386,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2386,11 +2386,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
case OP_TYPEID::Product: case OP_TYPEID::Product:
{ {
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes")); set<size_t> reduction_axes =
get_or_default<set<size_t>>(node_js, "reduction_axes", set<size_t>());
if (reduction_axes.empty()) if (reduction_axes.empty())
{
node = make_shared<op::v0::Product>(args[0], args[1]); node = make_shared<op::v0::Product>(args[0], args[1]);
}
else else
{
node = make_shared<op::v0::Product>(args[0], reduction_axes); node = make_shared<op::v0::Product>(args[0], reduction_axes);
}
break; break;
} }
case OP_TYPEID::ReduceProd_v1: case OP_TYPEID::ReduceProd_v1:
...@@ -2807,11 +2812,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2807,11 +2812,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
{ {
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes")); set<size_t> reduction_axes =
get_or_default<set<size_t>>(node_js, "reduction_axes", set<size_t>());
if (reduction_axes.empty()) if (reduction_axes.empty())
{
node = make_shared<op::v0::Sum>(args[0], args[1]); node = make_shared<op::v0::Sum>(args[0], args[1]);
}
else else
{
node = make_shared<op::v0::Sum>(args[0], reduction_axes); node = make_shared<op::v0::Sum>(args[0], reduction_axes);
}
break; break;
} }
case OP_TYPEID::Tan: case OP_TYPEID::Tan:
...@@ -4202,7 +4212,11 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4202,7 +4212,11 @@ json JSONSerializer::serialize_node(const Node& n)
} }
case OP_TYPEID::PRelu: { break; case OP_TYPEID::PRelu: { break;
} }
case OP_TYPEID::Product: { break; case OP_TYPEID::Product:
{
auto tmp = static_cast<const op::Product*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
break;
} }
case OP_TYPEID::ReduceProd_v1: case OP_TYPEID::ReduceProd_v1:
{ {
...@@ -4480,7 +4494,11 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4480,7 +4494,11 @@ json JSONSerializer::serialize_node(const Node& n)
} }
break; break;
} }
case OP_TYPEID::Sum: { break; case OP_TYPEID::Sum:
{
auto tmp = static_cast<const op::Sum*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
break;
} }
case OP_TYPEID::ReduceSum_v1: case OP_TYPEID::ReduceSum_v1:
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment