Commit ef129a77 authored by Tomasz Socha's avatar Tomasz Socha Committed by Scott Cyphers

[SPEC] Add auto_broadcast parameter to SquaredDifference (#3856)

* [SPEC] Add auto_broadcast parameter to SquaredDifference

* Rename set_autobroadcast->set_autob
parent 08ca928a
......@@ -26,8 +26,11 @@ using namespace ngraph;
constexpr NodeTypeInfo op::SquaredDifference::type_info;
op::SquaredDifference::SquaredDifference(const Output<Node>& x1, const Output<Node>& x2)
op::SquaredDifference::SquaredDifference(const Output<Node>& x1,
const Output<Node>& x2,
const AutoBroadcastSpec& auto_broadcast)
: FusedOp({x1, x2})
, m_autobroadcast(auto_broadcast)
{
constructor_validate_and_infer_types();
}
......@@ -37,19 +40,14 @@ NodeVector op::SquaredDifference::decompose_op() const
const auto x1 = input_value(0);
const auto x2 = input_value(1);
const auto broadcasted = numpy_style_broadcast_values({x1, x2});
const auto difference = broadcasted.at(0) - broadcasted.at(1);
const auto difference = make_shared<op::Subtract>(x1, x2, m_autobroadcast);
return {difference * difference};
}
shared_ptr<Node> op::SquaredDifference::copy_with_new_args(const NodeVector& new_args) const
{
NODE_VALIDATION_CHECK(this,
new_args.size() == 2,
"Expected 2 elements in new_args for the SquaredDifference op but got ",
new_args.size());
check_new_args_count(this, new_args);
return make_shared<SquaredDifference>(new_args.at(0), new_args.at(1));
return make_shared<SquaredDifference>(new_args.at(0), new_args.at(1), get_autob());
}
......@@ -38,12 +38,24 @@ namespace ngraph
///
/// \param x1 First input tensor
/// \param x2 Second input tensor
SquaredDifference(const Output<Node>& x1, const Output<Node>& x2);
/// \param auto_broadcast Auto broadcast specification
SquaredDifference(const Output<Node>& x1,
const Output<Node>& x2,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const AutoBroadcastSpec& get_autob() const override { return m_autobroadcast; }
void set_autob(const AutoBroadcastSpec& auto_broadcast)
{
m_autobroadcast = auto_broadcast;
}
private:
AutoBroadcastSpec m_autobroadcast;
};
}
}
} // namespace op
} // namespace ngraph
......@@ -2612,7 +2612,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::SquaredDifference:
{
node = make_shared<op::SquaredDifference>(args[0], args[1]);
node = make_shared<op::SquaredDifference>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::Squeeze:
......@@ -4080,7 +4081,14 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Sqrt: { break;
}
case OP_TYPEID::SquaredDifference: { break;
case OP_TYPEID::SquaredDifference:
{
auto tmp = static_cast<const op::SquaredDifference*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::Squeeze: { break;
}
......
......@@ -34,7 +34,7 @@ TEST(type_prop, squared_difference)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("axes are incompatible"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
}
const auto clamp = make_shared<op::SquaredDifference>(x1, x3);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment